From 894e83e4fa67c973de22e69705fb1b4873f1176b Mon Sep 17 00:00:00 2001 From: Pramod Date: Mon, 21 Aug 2023 13:28:59 -0700 Subject: [PATCH] Fix int offset check for Spark --- velox/functions/lib/window/NthValue.cpp | 74 +++++++++++-------- .../lib/window/tests/WindowTestBase.cpp | 2 +- .../sparksql/window/tests/SparkWindowTest.cpp | 1 + 3 files changed, 46 insertions(+), 31 deletions(-) diff --git a/velox/functions/lib/window/NthValue.cpp b/velox/functions/lib/window/NthValue.cpp index d1707f806af7..8fd456c620cc 100644 --- a/velox/functions/lib/window/NthValue.cpp +++ b/velox/functions/lib/window/NthValue.cpp @@ -32,34 +32,35 @@ class NthValueFunction : public exec::WindowFunction { : WindowFunction(resultType, pool, nullptr), ignoreNulls_(ignoreNulls) { VELOX_CHECK_EQ(args.size(), 2); VELOX_CHECK_NULL(args[0].constantValue); + auto offsetType = args[1].type; + VELOX_USER_CHECK( + (offsetType->isInteger() || offsetType->isBigint()), + "Invalid offset type: {}", + offsetType->toString()); valueIndex_ = args[0].index.value(); - if (args[1].type->isInteger()) { - VELOX_USER_CHECK( - args[1].constantValue, "Offset must be literal for spark"); + if (args[1].constantValue) { if (args[1].constantValue->isNullAt(0)) { isConstantOffsetNull_ = true; return; } - constantOffset_ = - args[1] - .constantValue->template as>() - ->valueAt(0); - VELOX_USER_CHECK_GE( - constantOffset_.value(), 1, "Offset must be at least 1"); - } else { - if (args[1].constantValue) { - if (args[1].constantValue->isNullAt(0)) { - isConstantOffsetNull_ = true; - return; - } + if (offsetType->isInteger()) { + constantOffset_ = + args[1] + .constantValue->template as>() + ->valueAt(0); + } else { constantOffset_ = args[1] .constantValue->template as>() ->valueAt(0); - VELOX_USER_CHECK_GE( - constantOffset_.value(), 1, "Offset must be at least 1"); + } + VELOX_USER_CHECK_GE( + constantOffset_.value(), 1, "Offset must be at least 1"); + } else { + offsetIndex_ = args[1].index.value(); + if (offsetType->isInteger()) { + offsets_ = BaseVector::create>(INTEGER(), 0, pool); } else { - offsetIndex_ = args[1].index.value(); offsets_ = BaseVector::create>(BIGINT(), 0, pool); } } @@ -149,8 +150,7 @@ class NthValueFunction : public exec::WindowFunction { const vector_size_t* frameStarts, const vector_size_t* frameEnds, vector_size_t leastFrame) { - vector_size_t constantOffsetValue = - static_cast(constantOffset_.value()); + auto constantOffsetValue = constantOffset_.value(); if (ignoreNulls) { auto rawNulls = nulls_->as(); validRows.applyToSelected([&](auto i) { @@ -169,18 +169,19 @@ class NthValueFunction : public exec::WindowFunction { } } - template + template void setRowNumbersApplyLoop( const SelectivityVector& validRows, const vector_size_t* frameStarts, const vector_size_t* frameEnds, vector_size_t leastFrame = 0) { auto rawNulls = nulls_->as(); + auto offsetsVector = offsets_->as>(); validRows.applyToSelected([&](auto i) { - if (offsets_->isNullAt(i)) { + if (offsetsVector->isNullAt(i)) { rowNumbers_[i] = kNullRow; } else { - vector_size_t offset = offsets_->valueAt(i); + T offset = offsetsVector->valueAt(i); VELOX_USER_CHECK_GE(offset, 1, "Offset must be at least 1"); if constexpr (ignoreNulls) { setRowNumberIgnoreNulls( @@ -204,10 +205,21 @@ class NthValueFunction : public exec::WindowFunction { offsetIndex_, partitionOffset_, numRows, 0, offsets_); if (ignoreNulls) { - setRowNumbersApplyLoop( - validRows, frameStarts, frameEnds, leastFrame); + if (offsets_->type()->isInteger()) { + setRowNumbersApplyLoop( + validRows, frameStarts, frameEnds, leastFrame); + } else { + setRowNumbersApplyLoop( + validRows, frameStarts, frameEnds, leastFrame); + } } else { - setRowNumbersApplyLoop(validRows, frameStarts, frameEnds); + if (offsets_->type()->isInteger()) { + setRowNumbersApplyLoop( + validRows, frameStarts, frameEnds); + } else { + setRowNumbersApplyLoop( + validRows, frameStarts, frameEnds); + } } } @@ -222,27 +234,29 @@ class NthValueFunction : public exec::WindowFunction { invalidRows_.applyToSelected([&](auto i) { rowNumbers_[i] = kNullRow; }); } + template inline void setRowNumber( vector_size_t i, const vector_size_t* frameStarts, const vector_size_t* frameEnds, - vector_size_t offset) { + T offset) { auto frameStart = frameStarts[i]; auto frameEnd = frameEnds[i]; auto rowNumber = frameStart + offset - 1; rowNumbers_[i] = rowNumber <= frameEnd ? rowNumber : kNullRow; } + template inline void setRowNumberIgnoreNulls( vector_size_t i, const uint64_t* rawNulls, vector_size_t leastFrame, const vector_size_t* frameStarts, const vector_size_t* frameEnds, - vector_size_t offset) { + T offset) { auto frameStart = frameStarts[i]; auto frameEnd = frameEnds[i]; - vector_size_t nonNullCount = 0; + T nonNullCount = 0; for (auto j = frameStart; j <= frameEnd; j++) { if (!bits::isBitSet(rawNulls, j - leastFrame)) { ++nonNullCount; @@ -271,7 +285,7 @@ class NthValueFunction : public exec::WindowFunction { // This vector is used to extract values of the offset argument column // (if not a constant offset value). - FlatVectorPtr offsets_; + VectorPtr offsets_ = nullptr; // This offset tracks how far along the partition rows have been output. // This can be used to optimize reading offset column values corresponding diff --git a/velox/functions/lib/window/tests/WindowTestBase.cpp b/velox/functions/lib/window/tests/WindowTestBase.cpp index 11e0375f61eb..1d4ee775fc59 100644 --- a/velox/functions/lib/window/tests/WindowTestBase.cpp +++ b/velox/functions/lib/window/tests/WindowTestBase.cpp @@ -100,7 +100,7 @@ RowVectorPtr WindowTestBase::makeRandomInputVector(vector_size_t size) { {makeRandomInputVector(BIGINT(), size, 0.2), makeRandomInputVector(VARCHAR(), size, 0.3), makeFlatVector(size, genRandomFrameValue), - makeFlatVector(size, genRandomFrameValue)}); + makeFlatVector(size, genRandomFrameValue)}); } void WindowTestBase::testWindowFunction( diff --git a/velox/functions/sparksql/window/tests/SparkWindowTest.cpp b/velox/functions/sparksql/window/tests/SparkWindowTest.cpp index 5102c9a1a3c4..df87c432f9ba 100644 --- a/velox/functions/sparksql/window/tests/SparkWindowTest.cpp +++ b/velox/functions/sparksql/window/tests/SparkWindowTest.cpp @@ -25,6 +25,7 @@ namespace { static const std::vector kSparkWindowFunctions = { std::string("nth_value(c0, 1)"), + std::string("nth_value(c0, c3)"), std::string("row_number()"), std::string("rank()"), std::string("dense_rank()")};