From 81b114219cdc5365ed72933cab638ebca178ef57 Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Thu, 18 Apr 2024 22:14:04 +0200 Subject: [PATCH] [DAPHNE-#687,#695] Onehot bounds check and test This commit fixes the bounds checks of the onehot kernel, adds a zero mode, replaces the assertions, adds test cases and documentation. Closes #695 --- doc/DaphneDSL/Builtins.md | 5 +- src/runtime/local/kernels/OneHot.h | 26 ++++-- test/CMakeLists.txt | 1 + test/runtime/local/kernels/OneHotTest.cpp | 98 +++++++++++++++++++++++ 4 files changed, 120 insertions(+), 10 deletions(-) create mode 100644 test/runtime/local/kernels/OneHotTest.cpp diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index 05fc6cb0f..ce9bfa483 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -550,9 +550,10 @@ These must be provided in a separate [`.meta`-file](/doc/FileMetaDataFormat.md). The *(1 x m)* row-matrix `info` specifies the details (in the following, *d[j]* is short for `info[0, j]`): - If *d[j]* == -1, then the *j*-th column of `arg` will remain as it is. - - If *d[j]* >= 0, then the *j*-th column of `arg` will be encoded. + - If *d[j]* == 0, then the *j*-th column of `arg` will be omitted in the output. + - If *d[j]* > 0, then the *j*-th column of `arg` will be encoded to a vector of length *d[j]*. - More precisely, the *j*-th column of `arg` must contain only integral values in the range *[0, d[j] - 1]*, and will be replaced by *d[j]* columns containing only zeros and ones. + More precisely, if *d[j]* > 0 the *j*-th column of `arg` must contain only integral values in the range *[0, d[j] - 1]*, and will be replaced by *d[j]* columns containing only zeros and ones. For each row *i* in `arg`, the value in the `as.scalar(arg[i, j])`-th of those columns is set to 1, while all others are set to 0. - **`recode`**`(arg:matrix, orderPreserving:bool)` diff --git a/src/runtime/local/kernels/OneHot.h b/src/runtime/local/kernels/OneHot.h index e633e4f54..f1b4bc26e 100644 --- a/src/runtime/local/kernels/OneHot.h +++ b/src/runtime/local/kernels/OneHot.h @@ -24,6 +24,8 @@ #include #include #include +#include +#include // **************************************************************************** // Struct for partial template specialization @@ -54,10 +56,12 @@ void oneHot(DTRes *& res, const DTArg * arg, const DenseMatrix * info, template struct OneHot, DenseMatrix> { static void apply(DenseMatrix *& res, const DenseMatrix * arg, const DenseMatrix * info, DCTX(ctx)) { - assert((info->getNumRows() == 1) && "parameter info must be a row matrix"); + if (info->getNumRows() != 1) + throw std::runtime_error("OneHot: parameter 'info' must be a row matrix"); const size_t numColsArg = arg->getNumCols(); - assert((numColsArg == info->getNumCols()) && "parameter info must provide information for each column of parameter arg"); + if (info->getNumCols() != numColsArg) + throw std::runtime_error("OneHot: parameter 'info' must provide information for each column of parameter arg"); size_t numColsRes = 0; const int64_t * valuesInfo = info->getValues(); @@ -67,9 +71,12 @@ struct OneHot, DenseMatrix> { numColsRes++; else if(numDistinct > 0) numColsRes += numDistinct; - else - assert(false && "invalid info"); + else if (numDistinct != 0) + throw std::runtime_error("OneHot: parameter 'info' must be an integer greater or equal than -1"); } + + if (numColsRes == 0) + throw std::runtime_error("OneHot: parameter 'info' must contain at least one non-zero entry"); const size_t numRows = arg->getNumRows(); @@ -89,11 +96,14 @@ struct OneHot, DenseMatrix> { if(numDistinct == -1) // retain value from argument matrix valuesRes[cRes++] = valuesArg[cArg]; - else { + else if (numDistinct != 0) { // one-hot encode value from argument matrix - for(int64_t d = 0; d < numDistinct; d++) - valuesRes[cRes + d] = 0; - valuesRes[cRes + static_cast(valuesArg[cArg])] = 1; + memset(valuesRes + cRes, VT(0), numDistinct * sizeof(VT)); + const size_t argVal = static_cast(valuesArg[cArg]); + if (argVal >= 0 && argVal < static_cast(numDistinct)) + valuesRes[cRes + argVal] = 1; + else + throw std::out_of_range("OneHot: arg values that are encoded (info value != -1) must be positive and smaller than the corresponding info value"); cRes += numDistinct; } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 58aabba5f..ebea775c3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -115,6 +115,7 @@ set(TEST_SOURCES runtime/local/kernels/NumDistinctApproxTest.cpp runtime/local/kernels/MapTest.cpp runtime/local/kernels/MatMulTest.cpp + runtime/local/kernels/OneHotTest.cpp runtime/local/kernels/OrderTest.cpp runtime/local/kernels/OuterBinaryTest.cpp runtime/local/kernels/QuantizeTest.cpp diff --git a/test/runtime/local/kernels/OneHotTest.cpp b/test/runtime/local/kernels/OneHotTest.cpp new file mode 100644 index 000000000..4a25f9927 --- /dev/null +++ b/test/runtime/local/kernels/OneHotTest.cpp @@ -0,0 +1,98 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#define DATA_TYPES DenseMatrix +#define VALUE_TYPES int64_t, double + +TEMPLATE_PRODUCT_TEST_CASE("OneHot", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES)) { + using DTArg = TestType; + using VT = typename DTArg::VT; + using DTRes = DTArg; + + auto * arg = genGivenVals(3, { + -1, 0, 1, + -10, 1, VT(1.5), + 100, 2, 1, + }); + + DenseMatrix * info = nullptr; + DTRes * res = nullptr; + + SECTION("normal encoding") { + info = genGivenVals>(1, {-1, 3, 2}); + auto * exp = genGivenVals(3, { + -1, 1, 0, 0, 0, 1, + -10, 0, 1, 0, 0, 1, + 100, 0, 0, 1, 0, 1 + }); + + oneHot(res, arg, info, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(exp, res); + } + SECTION("normal encoding - skip columns") { + info = genGivenVals>(1, {0, 0, 3}); + auto * exp = genGivenVals(3, { + 0, 1, 0, + 0, 1, 0, + 0, 1, 0 + }); + + oneHot(res, arg, info, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(exp, res); + } + SECTION("negative example - invalid info shape (not row matrix)") { + info = genGivenVals>(3, {-1, 3, 2}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::runtime_error); + } + SECTION("negative example - invalid info shape (too small)") { + info = genGivenVals>(1, {-1, 3}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::runtime_error); + } + SECTION("negative example - invalid info value (int < -1)") { + info = genGivenVals>(1, {-2, 3, 2}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::runtime_error); + } + SECTION("negative example - empty selection") { + info = genGivenVals>(1, {0, 0, 0}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::runtime_error); + } + SECTION("negative example - not enough space reserved (0 <= info value < arg value)") { + info = genGivenVals>(1, {-1, 2, 2}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::out_of_range); + } + SECTION("negative example - out of bounds (arg value negative)") { + info = genGivenVals>(1, {3, 3, 3}); + REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::out_of_range); + } + + DataObjectFactory::destroy(arg, info); +} \ No newline at end of file