diff --git a/cgmanifest.json b/cgmanifest.json index e2b2c73a60a57..62ad0a3031625 100644 --- a/cgmanifest.json +++ b/cgmanifest.json @@ -450,7 +450,7 @@ { "component": { "git": { - "commitHash": "a11f5002af58a03d5902b13ef65c84cedb499024", + "commitHash": "573070aeeb77e267da2579ac1d75d92c688bbe97", "repositoryUrl": "https://github.com/microsoft/FeaturizersLibrary.git" }, "type": "git" diff --git a/cmake/external/featurizers.cmake b/cmake/external/featurizers.cmake index acc9b25c564d1..4700e85f032db 100644 --- a/cmake/external/featurizers.cmake +++ b/cmake/external/featurizers.cmake @@ -3,7 +3,7 @@ # This source code should not depend on the onnxruntime and may be built independently set(featurizers_URL "https://github.com/microsoft/FeaturizersLibrary.git") -set(featurizers_TAG "a11f5002af58a03d5902b13ef65c84cedb499024") +set(featurizers_TAG "573070aeeb77e267da2579ac1d75d92c688bbe97") set(featurizers_pref FeaturizersLibrary) set(featurizers_ROOT ${PROJECT_SOURCE_DIR}/external/${featurizers_pref}) @@ -24,6 +24,7 @@ if (WIN32) BINARY_DIR ${featurizers_BINARY_DIR} CMAKE_ARGS -Dfeaturizers_MSVC_STATIC_RUNTIME=${onnxruntime_MSVC_STATIC_RUNTIME} INSTALL_COMMAND "" + ) else() ExternalProject_Add(featurizers_lib diff --git a/onnxruntime/core/graph/featurizers_ops/featurizers_defs.cc b/onnxruntime/core/graph/featurizers_ops/featurizers_defs.cc index 5c12ef18bce8f..c46822774918d 100644 --- a/onnxruntime/core/graph/featurizers_ops/featurizers_defs.cc +++ b/onnxruntime/core/graph/featurizers_ops/featurizers_defs.cc @@ -41,6 +41,7 @@ static void RegisterMinMaxScalarFeaturizerVer1(); static void RegisterMissingDummiesFeaturizerVer1(); static void RegisterRobustScalarFeaturizerVer1(); static void RegisterStringFeaturizerVer1(); +static void RegisterTimeSeriesImputerFeaturizerVer1(); // ---------------------------------------------------------------------- // ---------------------------------------------------------------------- @@ -55,6 +56,7 @@ void RegisterMSFeaturizersSchemas() { RegisterMissingDummiesFeaturizerVer1(); RegisterRobustScalarFeaturizerVer1(); RegisterStringFeaturizerVer1(); + RegisterTimeSeriesImputerFeaturizerVer1(); } // ---------------------------------------------------------------------- @@ -212,7 +214,7 @@ void RegisterDateTimeFeaturizerVer1() { case 0: propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_INT32, output); break; - case 1: // fall through + case 1: // fall through case 2: case 3: case 4: @@ -223,11 +225,11 @@ void RegisterDateTimeFeaturizerVer1() { case 9: propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_UINT8, output); break; - case 10: // fall through + case 10: // fall through case 11: propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_UINT16, output); break; - case 12: // fall through + case 12: // fall through case 13: case 14: propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_UINT8, output); @@ -595,7 +597,6 @@ void RegisterRobustScalarFeaturizerVer1() { input_elem_type == ONNX_NAMESPACE::TensorProto_DataType_UINT32 || input_elem_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64 || input_elem_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) { - ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, 0); } else { fail_type_inference("input 1 is expected to have a accepted type"); @@ -648,7 +649,178 @@ void RegisterStringFeaturizerVer1() { .TypeAndShapeInferenceFunction( [](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_STRING, 0); - propagateShapeFromInputToOutput(ctx, 1, 0); + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + }); +} + +void RegisterTimeSeriesImputerFeaturizerVer1() { + static const char* doc = R"DOC( + Imputes rows and column values such that the generated output does not contain any + time gaps per grain (based on the time gaps encountered during training) and that + all missing column values are populated according to a strategy (forward fill, + backward fill, mode, etc.). + + This Featurizer is unique in that it will produce 0:N rows per invocation, depending upon the + input data. + + C++-style pseudo signature: + template + std::vector< + std::tuple< + bool, // true if the row was added + std::chrono::system_clock::time_point, + std::tuple, + std::tuple + > + > execute( + std::chrono::system_clock::time_point const &value, + std::tuple const &grain, + std::tuple const &colData + ); + + Examples: + During training, the time period was found to be 1 day... + + Input: + +------+-------+------------------+-------------------+ + | time | grain | forward fill col | backward fill col | + +======+=======+==================+===================+ + | 1 | A | 10 | None | + +------+-------+------------------+-------------------+ + | 2 | A | None | 200 | + +------+-------+------------------+-------------------+ + | 1 | B | -10 | -100 | + +------+-------+------------------+-------------------+ + | 4 | A | 40 | 400 | + +------+-------+------------------+-------------------+ + | 6 | A | 60 | 600 | + +------+-------+------------------+-------------------+ + | 3 | B | -30 | -300 | + +------+-------+------------------+-------------------+ + + Output: + +-------+------+-------+------------------+-------------------+ + | Added | time | grain | forward fill col | backward fill col | + +=======+======+=======+==================+===================+ + | false | 1 | A | 10 | 200 (from 2) | + +-------+------+-------+------------------+-------------------+ + | false | 2 | A | 10 (from 1) | 200 | + +-------+------+-------+------------------+-------------------+ + | true | 3 | A | 10 (from 2) | 400 (from 4) | + +-------+------+-------+------------------+-------------------+ + | false | 4 | A | 40 | 400 | + +-------+------+-------+------------------+-------------------+ + | true | 5 | A | 40 (from 4) | 600 (from 6) | + +-------+------+-------+------------------+-------------------+ + | false | 6 | A | 60 | 600 | + +-------+------+-------+------------------+-------------------+ + | false | 1 | B | -10 | -100 | + +-------+------+-------+------------------+-------------------+ + | true | 2 | B | -10 (from 1) | -300 (from 3) | + +-------+------+-------+------------------+-------------------+ + | false | 3 | B | -30 | -300 | + +-------+------+-------+------------------+-------------------+ + )DOC"; + + MS_FEATURIZERS_OPERATOR_SCHEMA(TimeSeriesImputerTransformer) + .SinceVersion(1) + .SetDomain(kMSFeaturizersDomain) + .SetDoc(doc) + .Input( + 0, + "State", + "State generated during training that is used for prediction", + "T0") + .Input( + 1, + "Times", + "Tensor of timestamps in seconds since epoch [R] where R is a number of rows.", + "T1") + .Input( + 2, + "Keys", + "Composite keys tensor of shape [R][K]. R is the same as Input(1)", + "T2") + .Input( + 3, + "Data", + "It is a data tensor of shape [R][C] where R - rows and C - columns. R must be the same with Input(1)", + "T2") + .Output( + 0, + "Added", + "Tensor of boolean with a shape of [IR]. Contains a boolean for each row in the result where true represents added row.", + "T3") + .Output( + 1, + "ImputedTimes", + "This is a tensor of timestamps in seconds since epoch of shape [IR], where IR is the number of output rows.", + "T1") + .Output( + 2, + "ImputedKeys", + "Contains keys along with the imputed keys. Tensor of shape [IR][K].", + "T2") + .Output( + 3, + "ImputedData", + "Tensor of shape [IR][C] where IR is the number of rows in the output." + "C is the number of columns.", + "T2") + .TypeConstraint( + "T0", + {"tensor(uint8)"}, + "No information is available") + .TypeConstraint( + "T1", + {"tensor(int64)"}, + "Represents number of seconds since epoch") + .TypeConstraint( + "T2", + {"tensor(string)"}, + "Output data") + .TypeConstraint( + "T3", + {"tensor(bool)"}, + "Boolean Tensor") + .TypeAndShapeInferenceFunction( + [](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_BOOL, 0); + propagateElemTypeFromDtypeToOutput(ctx, ONNX_NAMESPACE::TensorProto_DataType_INT64, 1); + // Number of output rows is not known + ONNX_NAMESPACE::TensorShapeProto shape_0_1; + shape_0_1.add_dim(); + ONNX_NAMESPACE::updateOutputShape(ctx, 0, shape_0_1); + ONNX_NAMESPACE::updateOutputShape(ctx, 1, shape_0_1); + + // Keys + propagateElemTypeFromInputToOutput(ctx, 2, 2); + // Keys shape + if (hasInputShape(ctx, 2)) { + const auto& input2_shape = getInputShape(ctx, 2); + if (input2_shape.dim_size() != 2) { + fail_shape_inference("Expecting keys to have 2 dimensions"); + } + ONNX_NAMESPACE::TensorShapeProto shape; + shape.add_dim(); + *shape.add_dim() = input2_shape.dim(1); + ONNX_NAMESPACE::updateOutputShape(ctx, 2, shape); + } + + // Data shape + propagateElemTypeFromInputToOutput(ctx, 3, 3); + if (hasInputShape(ctx, 3)) { + const auto& input3_shape = getInputShape(ctx, 3); + if (input3_shape.dim_size() != 2) { + fail_shape_inference("Expecting data to have 2 dimensions"); + } + ONNX_NAMESPACE::TensorShapeProto shape; + shape.add_dim(); + *shape.add_dim() = input3_shape.dim(1); + ONNX_NAMESPACE::updateOutputShape(ctx, 3, shape); + } }); } diff --git a/onnxruntime/featurizers_ops/cpu/time_series_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/time_series_imputer_transformer.cc new file mode 100644 index 0000000000000..ea8f4755eab37 --- /dev/null +++ b/onnxruntime/featurizers_ops/cpu/time_series_imputer_transformer.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +#include +#include + +#include "Featurizers/TimeSeriesImputerFeaturizer.h" +#include "Archive.h" + +namespace ft = Microsoft::Featurizer::Featurizers; + +namespace onnxruntime { +namespace featurizers { + +namespace timeseries_imputer_details { + +inline std::chrono::system_clock::time_point ToTimePoint(int64_t secs) { + return std::chrono::system_clock::from_time_t(secs); +} + +inline int64_t ToSecs(const std::chrono::system_clock::time_point& tp) { + using namespace std::chrono; + return duration_cast(tp.time_since_epoch()).count(); +} + +template +struct ToString { + std::string operator()(T val) const { + return std::to_string(val); + } +}; + +template <> +struct ToString { + const std::string& operator()(const std::string& val) const { + return val; + } +}; + +template +struct ToStringOptional { + nonstd::optional operator()(T val) const { + nonstd::optional result; + if (std::isnan(val)) { + return result; + } + result = std::to_string(val); + return result; + } +}; + +template <> +struct ToStringOptional { + nonstd::optional operator()(std::string val) const { + return (val.empty()) ? nonstd::optional() : nonstd::optional(std::move(val)); + } +}; + +template +struct FromString; + +template <> +struct FromString { + const std::string& operator()(const std::string& val) const { + return val; + } +}; + +template <> +struct FromString { + float operator()(const std::string& val) const { + char* str_end = nullptr; + const char* str = val.c_str(); + float result = std::strtof(str, &str_end); + if (str == str_end) { + ORT_THROW("Resulting key string is not convertible to float: ", val); + } + return result; + } +}; + +template <> +struct FromString { + double operator()(const std::string& val) const { + const char* str = val.c_str(); + char* str_end = nullptr; + double result = std::strtod(str, &str_end); + if (str == str_end) { + ORT_THROW("Resulting key string is not convertible to double: ", val); + } + return result; + } +}; +template +struct FromStringOptional { + T operator()(const nonstd::optional& val) const { + if (val.has_value()) { + return FromString()(*val); + } + return std::numeric_limits::quiet_NaN(); + } +}; + +template <> +struct FromStringOptional { + std::string operator()(const nonstd::optional& val) const { + if (val.has_value()) { + return *val; + } + return std::string(); + } +}; +} // namespace timeseries_imputer_details + +template +struct TimeSeriesImputerTransformerImpl { + void operator()(OpKernelContext* ctx, int64_t rows) { + const auto& state = *ctx->Input(0); + const uint8_t* const state_data = state.template Data(); + + const auto& times = *ctx->Input(1); + const auto& keys = *ctx->Input(2); + const auto& data = *ctx->Input(3); + + const int64_t keys_per_row = keys.Shape()[1]; + const int64_t columns = data.Shape()[1]; + + using namespace timeseries_imputer_details; + + using OutputType = std::tuple, std::vector>>; + std::vector output_rows; + std::function callback_fn; + callback_fn = [&output_rows](OutputType value) -> void { + output_rows.emplace_back(std::move(value)); + }; + + Microsoft::Featurizer::Archive archive(state_data, state.Shape().Size()); + ft::Components::TimeSeriesImputerEstimator::Transformer transformer(archive); + + const int64_t* times_data = times.template Data(); + const T* const keys_data = keys.template Data(); + const T* const data_data = data.template Data(); + + // for each row get timestamp, get all keys, get all data and feed it + for (int64_t row = 0; row < rows; ++row) { + const T* const key_row_data = keys_data + (row * keys_per_row); + const T* const keys_row_end = key_row_data + keys_per_row; + std::vector str_keys; + std::transform(key_row_data, keys_row_end, std::back_inserter(str_keys), + ToString()); + + std::vector> str_data; + const T* const data_row = data_data + (row * columns); + const T* const data_row_end = data_row + columns; + std::transform(data_row, data_row_end, std::back_inserter(str_data), + ToStringOptional()); + + auto tuple_row = std::make_tuple(ToTimePoint(*times_data), std::move(str_keys), std::move(str_data)); + + transformer.execute(tuple_row, callback_fn); + ++times_data; + } + + transformer.flush(callback_fn); + + // Compute output shapes now + // Number of outputs is the number of rows, + int64_t output_rows_num = static_cast(output_rows.size()); + TensorShape rows_shape({output_rows_num}); + TensorShape keys_shape({output_rows_num, keys_per_row}); + TensorShape data_shape({output_rows_num, columns}); + + auto* added_output = ctx->Output(0, rows_shape)->template MutableData(); + auto* time_output = ctx->Output(1, rows_shape)->template MutableData(); + auto* keys_output = ctx->Output(2, keys_shape)->template MutableData(); + auto* data_output = ctx->Output(3, data_shape)->template MutableData(); + + for (const auto& out : output_rows) { + *added_output++ = std::get<0>(out); + *time_output++ = ToSecs(std::get<1>(out)); + const auto& imputed_keys = std::get<2>(out); + ORT_ENFORCE(static_cast(imputed_keys.size()) == keys_per_row, + "resulting number of keys: ", imputed_keys.size(), " expected: ", keys_per_row); + const auto& imputed_data = std::get<3>(out); + ORT_ENFORCE(static_cast(imputed_data.size()) == columns, + "resulting number of columns: ", imputed_data.size(), " expected: ", columns); + keys_output = std::transform(imputed_keys.cbegin(), imputed_keys.cend(), keys_output, + FromString()); + data_output = std::transform(imputed_data.cbegin(), imputed_data.cend(), data_output, + FromStringOptional()); + } + } +}; + +class TimeSeriesImputerTransformer final : public OpKernel { + public: + explicit TimeSeriesImputerTransformer(const OpKernelInfo& info) : OpKernel(info) { + } + + static Status CheckBatches(int64_t rows, const TensorShape& shape) { + if (shape.NumDimensions() == 2) { + ORT_RETURN_IF_NOT(rows == shape[0], "Number of rows does not match"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect shape of [R][C]"); + } + return Status::OK(); + } + + Status Compute(OpKernelContext* ctx) const override { + const auto& times = *ctx->Input(1); + const auto& times_shape = times.Shape(); + ORT_RETURN_IF_NOT(times_shape.NumDimensions() == 1, "Times must have shape [B][R] or [R]"); + int64_t rows = times_shape[0]; + + const auto& keys = *ctx->Input(2); + ORT_RETURN_IF_ERROR(CheckBatches(rows, keys.Shape())); + const auto& data = *ctx->Input(3); + ORT_RETURN_IF_ERROR(CheckBatches(rows, data.Shape())); + + auto data_type = data.GetElementType(); + ORT_RETURN_IF_NOT(keys.GetElementType() == data_type, "Keys and data must have the same datatype"); + + TimeSeriesImputerTransformerImpl()(ctx, rows); + return Status::OK(); + } +}; + +ONNX_OPERATOR_KERNEL_EX( + TimeSeriesImputerTransformer, + kMSFeaturizersDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T0", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + TimeSeriesImputerTransformer); +} // namespace featurizers +} // namespace onnxruntime diff --git a/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc b/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc index 5f246c894b2b9..2fa7168880743 100644 --- a/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc +++ b/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc @@ -19,6 +19,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSFeaturizersDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSFeaturizersDomain, 1, MissingDummiesTransformer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSFeaturizersDomain, 1, RobustScalarTransformer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSFeaturizersDomain, 1, StringTransformer); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSFeaturizersDomain, 1, TimeSeriesImputerTransformer); Status RegisterCpuMSFeaturizersKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -31,6 +32,7 @@ Status RegisterCpuMSFeaturizersKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/featurizers_ops/time_series_imputer_transformer_test.cc b/onnxruntime/test/featurizers_ops/time_series_imputer_transformer_test.cc new file mode 100644 index 0000000000000..d8a351d248088 --- /dev/null +++ b/onnxruntime/test/featurizers_ops/time_series_imputer_transformer_test.cc @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include "Featurizers/TimeSeriesImputerFeaturizer.h" +#include "Featurizers/TestHelpers.h" +#include "Archive.h" + +namespace NS = Microsoft::Featurizer; + +namespace onnxruntime { +namespace test { + +inline std::chrono::system_clock::time_point GetTimePoint(std::chrono::system_clock::time_point tp, int unitsToAdd, std::string = "days") { + return tp + std::chrono::minutes(unitsToAdd * (60 * 24)); +} + +inline int64_t GetTimeSecs(std::chrono::system_clock::time_point tp) { + using namespace std::chrono; + return time_point_cast(tp).time_since_epoch().count(); +} + +using InputType = std::tuple< + std::chrono::system_clock::time_point, + std::vector, + std::vector>>; + +using TransformedType = std::vector< + std::tuple< + bool, + std::chrono::system_clock::time_point, + std::vector, + std::vector>>>; + +std::vector GetStream(const std::vector>& trainingBatches, + const std::vector& colsToImputeDataTypes, + bool supressError, NS::Featurizers::Components::TimeSeriesImputeStrategy tsImputeStrategy) { + using TSImputerEstimator = NS::Featurizers::TimeSeriesImputerEstimator; + + NS::AnnotationMapsPtr const pAllColumnAnnotations(NS::CreateTestAnnotationMapsPtr(1)); + TSImputerEstimator estimator(pAllColumnAnnotations, colsToImputeDataTypes, supressError, tsImputeStrategy); + + NS::TestHelpers::Train(estimator, trainingBatches); + TSImputerEstimator::TransformerUniquePtr pTransformer(estimator.create_transformer()); + + NS::Archive ar; + pTransformer->save(ar); + return ar.commit(); +} + +static void AddInputs(OpTester& test, const std::vector>& trainingBatches, + const std::vector& inferenceBatches, const std::vector& colsToImputeDataTypes, + bool supressError, NS::Featurizers::Components::TimeSeriesImputeStrategy tsImputeStrategy) { + auto stream = GetStream( + trainingBatches, + colsToImputeDataTypes, + supressError, + tsImputeStrategy); + + auto dim = static_cast(stream.size()); + test.AddInput("State", {dim}, stream); + + std::vector times; + std::vector keys; + std::vector data; + + using namespace std::chrono; + for (const auto& infb : inferenceBatches) { + times.push_back(time_point_cast(std::get<0>(infb)).time_since_epoch().count()); + keys.insert(keys.end(), std::get<1>(infb).cbegin(), std::get<1>(infb).cend()); + std::transform(std::get<2>(infb).cbegin(), std::get<2>(infb).cend(), std::back_inserter(data), + [](const nonstd::optional& opt) -> std::string { + if (opt.has_value()) return *opt; + return std::string(); + }); + } + + // Should have equal amount of keys per row + ASSERT_TRUE(keys.size() % times.size() == 0); + ASSERT_TRUE(data.size() % times.size() == 0); + test.AddInput("Times", {static_cast(times.size())}, times); + test.AddInput("Keys", {static_cast(times.size()), static_cast(keys.size() / times.size())}, keys); + test.AddInput("Data", {static_cast(times.size()), static_cast(data.size() / times.size())}, data); +} + +void AddOutputs(OpTester& test, const std::initializer_list& added, const std::initializer_list& times, + const std::vector& keys, const std::vector& data) { + ASSERT_TRUE(keys.size() % times.size() == 0); + ASSERT_TRUE(data.size() % times.size() == 0); + + std::vector times_int64; + std::transform(times.begin(), times.end(), std::back_inserter(times_int64), GetTimeSecs); + + test.AddOutput("Added", {static_cast(added.size())}, added); + test.AddOutput("ImputedTimes", {static_cast(times.size())}, times_int64); + test.AddOutput("ImputedKeys", {static_cast(times.size()), static_cast(keys.size() / times.size())}, keys); + test.AddOutput("ImputedData", {static_cast(times.size()), static_cast(data.size() / times.size())}, data); +} + +TEST(FeaturizersTests, RowImputation_1_grain_no_gaps) { + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + auto tp_0 = GetTimePoint(now, 0); + auto tp_1 = GetTimePoint(now, 1); + auto tp_2 = GetTimePoint(now, 2); + auto tuple_1 = std::make_tuple(tp_0, std::vector{"a"}, std::vector>{"14.5", "18"}); + auto tuple_2 = std::make_tuple(tp_1, std::vector{"a"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_3 = std::make_tuple(tp_2, std::vector{"a"}, std::vector>{"15.0", nonstd::optional{}}); + + std::vector inferenceBatches = {tuple_1, + tuple_2, + tuple_3}; + + OpTester test("TimeSeriesImputerTransformer", 1, onnxruntime::kMSFeaturizersDomain); + + AddInputs(test, {inferenceBatches}, inferenceBatches, + {NS::TypeId::Float64, NS::TypeId::Float64}, false, NS::Featurizers::Components::TimeSeriesImputeStrategy::Forward); + AddOutputs(test, {false, false, false}, {tp_0, tp_1, tp_2}, + {"a", "a", "a"}, {"14.5", "18", "14.5", "12", "15.0", "12"}); + + test.Run(); +} + +TEST(FeaturizersTests, RowImputation_1_grain_2_gaps) { + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + auto tp_0 = GetTimePoint(now, 0); + auto tp_1 = GetTimePoint(now, 1); + auto tp_2 = GetTimePoint(now, 2); + auto tp_3 = GetTimePoint(now, 3); + + auto tuple_0 = std::make_tuple(tp_0, std::vector{"a"}, std::vector>{"14.5", "18"}); + auto tuple_1 = std::make_tuple(tp_1, std::vector{"a"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_3 = std::make_tuple(tp_3, std::vector{"a"}, std::vector>{nonstd::optional{}, "15.0"}); + + OpTester test("TimeSeriesImputerTransformer", 1, onnxruntime::kMSFeaturizersDomain); + AddInputs(test, {{tuple_0, tuple_1}}, {tuple_0, tuple_3}, + {NS::TypeId::Float64, NS::TypeId::Float64}, false, NS::Featurizers::Components::TimeSeriesImputeStrategy::Forward); + + AddOutputs(test, {false, true, true, false}, {tp_0, tp_1, tp_2, tp_3}, + {"a", "a", "a", "a"}, {"14.5", "18", "14.5", "18", "14.5", "18", "14.5", "15.0"}); + test.Run(); +} + +TEST(FeaturizersTests, RowImputation_2_grains_no_gaps_input_interleaved) { + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + auto tp_0 = GetTimePoint(now, 0); + auto tp_1 = GetTimePoint(now, 1); + auto tp_5 = GetTimePoint(now, 5); + auto tp_6 = GetTimePoint(now, 6); + + auto tuple_0 = std::make_tuple(tp_0, std::vector{"a"}, std::vector>{"14.5", "18"}); + auto tuple_5 = std::make_tuple(tp_5, std::vector{"b"}, std::vector>{"14.5", "18"}); + auto tuple_5_inf = std::make_tuple(GetTimePoint(now, 5), std::vector{"b"}, std::vector>{"114.5", "118"}); + auto tuple_1 = std::make_tuple(tp_1, std::vector{"a"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_6 = std::make_tuple(tp_6, std::vector{"b"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_6_inf = std::make_tuple(GetTimePoint(now, 6), std::vector{"b"}, std::vector>{nonstd::optional{}, "112"}); + + OpTester test("TimeSeriesImputerTransformer", 1, onnxruntime::kMSFeaturizersDomain); + AddInputs(test, {{tuple_0, tuple_5, tuple_1, tuple_6}}, {tuple_0, tuple_5_inf, tuple_1, tuple_6_inf}, + {NS::TypeId::Float64, NS::TypeId::Float64}, false, NS::Featurizers::Components::TimeSeriesImputeStrategy::Forward); + + AddOutputs(test, {false, false, false, false}, {tp_0, tp_5, tp_1, tp_6}, + {"a", "b", "a", "b"}, {"14.5", "18", "114.5", "118", "14.5", "12", "114.5", "112"}); + test.Run(); +} + +TEST(FeaturizersTests, RowImputation_2_grains_1_gap_input_interleaved) { + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + auto tp_0 = GetTimePoint(now, 0); + auto tp_1 = GetTimePoint(now, 1); + auto tp_2 = GetTimePoint(now, 2); + auto tp_5 = GetTimePoint(now, 5); + auto tp_6 = GetTimePoint(now, 6); + auto tp_7 = GetTimePoint(now, 7); + + auto tuple_0 = std::make_tuple(tp_0, std::vector{"a"}, std::vector>{"14.5", "18"}); + auto tuple_2 = std::make_tuple(GetTimePoint(now, 2), std::vector{"a"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_5 = std::make_tuple(tp_5, std::vector{"b"}, std::vector>{"14.5", "18"}); + auto tuple_5_inf = std::make_tuple(tp_5, std::vector{"b"}, std::vector>{"114.5", "118"}); + auto tuple_1 = std::make_tuple(tp_1, std::vector{"a"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_6 = std::make_tuple(tp_6, std::vector{"b"}, std::vector>{nonstd::optional{}, "12"}); + auto tuple_7 = std::make_tuple(GetTimePoint(now, 7), std::vector{"b"}, std::vector>{nonstd::optional{}, "112"}); + + OpTester test("TimeSeriesImputerTransformer", 1, onnxruntime::kMSFeaturizersDomain); + AddInputs(test, {{tuple_0, tuple_5, tuple_1, tuple_6}}, {tuple_0, tuple_5_inf, tuple_2, tuple_7}, + {NS::TypeId::Float64, NS::TypeId::Float64}, false, NS::Featurizers::Components::TimeSeriesImputeStrategy::Forward); + + AddOutputs(test, {false, false, true, false, true, false}, {tp_0, tp_5, tp_1, tp_2, tp_6, tp_7}, + {"a", "b", "a", "a", "b", "b"}, {"14.5", "18", "114.5", "118", "14.5", "18", "14.5", "12", "114.5", "118", "114.5", "112"}); + + test.Run(); +} + +} // namespace test +} // namespace onnxruntime