Skip to content

Commit

Permalink
Handle ShortDecimal correctly inside importFromArrow (#8957)
Browse files Browse the repository at this point in the history
Summary:
Arrow uses `int128_t` to store ShortDecimal values, while inside velox we use `int64_t`.

`ExportToArrow` already handle it specifically, but `ImportFromArrow` misses this.

This pr tries to fix it.

Pull Request resolved: #8957

Reviewed By: Yuhta

Differential Revision: D55019687

Pulled By: pedroerp

fbshipit-source-id: 2fe32236a0e17a52ef713cff96836a48a37fec56
  • Loading branch information
boneanxs authored and facebook-github-bot committed Mar 21, 2024
1 parent ffce11b commit dc561a3
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 16 deletions.
80 changes: 66 additions & 14 deletions velox/vector/arrow/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,44 @@ void exportToArrowImpl(
out.release = releaseArrowArray;
}

// Parses the velox decimal format from the given arrow format.
// The input format string should be in the form "d:precision,scale<,bitWidth>".
// bitWidth is not required and must be 128 if provided.
TypePtr parseDecimalFormat(const char* format) {
std::string invalidFormatMsg =
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal";
try {
std::string::size_type sz;
std::string formatStr(format);

auto firstCommaIdx = formatStr.find(',', 2);
auto secondCommaIdx = formatStr.find(',', firstCommaIdx + 1);

if (firstCommaIdx == std::string::npos ||
formatStr.size() == firstCommaIdx + 1 ||
(secondCommaIdx != std::string::npos &&
formatStr.size() == secondCommaIdx + 1)) {
VELOX_USER_FAIL(invalidFormatMsg, format);
}

// Parse "d:".
int precision = std::stoi(&format[2], &sz);
int scale = std::stoi(&format[firstCommaIdx + 1], &sz);
// If bitwidth is provided, check if it is equal to 128.
if (secondCommaIdx != std::string::npos) {
int bitWidth = std::stoi(&format[secondCommaIdx + 1], &sz);
VELOX_USER_CHECK_EQ(
bitWidth,
128,
"Conversion failed for '{}'. Velox decimal does not support custom bitwidth.",
format);
}
return DECIMAL(precision, scale);
} catch (std::invalid_argument&) {
VELOX_USER_FAIL(invalidFormatMsg, format);
}
}

TypePtr importFromArrowImpl(
const char* format,
const ArrowSchema& arrowSchema) {
Expand Down Expand Up @@ -1056,20 +1094,9 @@ TypePtr importFromArrowImpl(
}
break;

case 'd': { // decimal types.
try {
std::string::size_type sz;
// Parse "d:".
int precision = std::stoi(&format[2], &sz);
// Parse ",".
int scale = std::stoi(&format[2 + sz + 1], &sz);
return DECIMAL(precision, scale);
} catch (std::invalid_argument&) {
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal",
format);
}
}
case 'd':
// decimal types.
return parseDecimalFormat(format);

// Complex types.
case '+': {
Expand Down Expand Up @@ -1612,6 +1639,23 @@ VectorPtr createTimestampVector(
optionalNullCount(nullCount));
}

VectorPtr createShortDecimalVector(
memory::MemoryPool* pool,
const TypePtr& type,
BufferPtr nulls,
const int128_t* input,
size_t length,
int64_t nullCount) {
auto values = AlignedBuffer::allocate<int64_t>(length, pool);
auto rawValues = values->asMutable<int64_t>();
for (size_t i = 0; i < length; ++i) {
memcpy(rawValues + i, input + i, sizeof(int64_t));
}

return createFlatVector<TypeKind::BIGINT>(
pool, type, nulls, length, values, nullCount);
}

bool isREE(const ArrowSchema& arrowSchema) {
return arrowSchema.format[0] == '+' && arrowSchema.format[1] == 'r';
}
Expand Down Expand Up @@ -1691,6 +1735,14 @@ VectorPtr importFromArrowImpl(
static_cast<const int64_t*>(arrowArray.buffers[1]),
arrowArray.length,
arrowArray.null_count);
} else if (type->isShortDecimal()) {
return createShortDecimalVector(
pool,
type,
nulls,
static_cast<const int128_t*>(arrowArray.buffers[1]),
arrowArray.length,
arrowArray.null_count);
} else if (type->isRow()) {
// Row/structs.
return createRowVector(
Expand Down
29 changes: 27 additions & 2 deletions velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,10 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
std::is_same_v<TInput, int64_t> && std::is_same_v<TOutput, Timestamp>) {
assertTimestampVectorContent(
inputValues, output, arrowArray.null_count, format);
} else if constexpr (
std::is_same_v<TInput, int128_t> && std::is_same_v<TOutput, int64_t>) {
assertShortDecimalVectorContent(
inputValues, output, arrowArray.null_count);
} else {
assertVectorContent(inputValues, output, arrowArray.null_count);
}
Expand Down Expand Up @@ -1220,10 +1224,13 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
testArrowImport<double>("g", {-99.9, 4.3, 31.1, 129.11, -12});
testArrowImport<float>("f", {-99.9, 4.3, 31.1, 129.11, -12});

for (const std::string tsString : {"tss:", "tsm:", "tsu:", "tsn:"}) {
for (const auto& tsString : {"tss:", "tsm:", "tsu:", "tsn:"}) {
testArrowImport<Timestamp, int64_t>(
tsString.data(), {0, std::nullopt, Timestamp::kMaxSeconds});
tsString, {0, std::nullopt, Timestamp::kMaxSeconds});
}

testArrowImport<int64_t, int128_t>(
"d:5,2", {1, -1, 0, 12345, -12345, std::nullopt});
}

template <typename TOutput, typename TInput>
Expand Down Expand Up @@ -1304,6 +1311,24 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
}

private:
// Creates short decimals from int128 and asserts the content of actual vector
// with the expected values.
void assertShortDecimalVectorContent(
const std::vector<std::optional<int128_t>>& expectedValues,
const VectorPtr& actual,
size_t nullCount) {
std::vector<std::optional<int64_t>> decValues;
decValues.reserve(expectedValues.size());
for (const auto& value : expectedValues) {
if (value) {
decValues.emplace_back(static_cast<int64_t>(*value));
} else {
decValues.emplace_back(std::nullopt);
}
}
assertVectorContent(decValues, actual, nullCount);
}

// Creates timestamp from bigint and asserts the content of actual vector with
// the expected timestamp values.
void assertTimestampVectorContent(
Expand Down
16 changes: 16 additions & 0 deletions velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,22 @@ TEST_F(ArrowBridgeSchemaImportTest, scalar) {
VELOX_ASSERT_THROW(
*testSchemaImport("d2,15"),
"Unable to convert 'd2,15' ArrowSchema decimal format to Velox decimal");
EXPECT_EQ(*DECIMAL(10, 4), *testSchemaImport("d:10,4,128"));
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,4,256"),
"Conversion failed for 'd:10,4,256'. Velox decimal does not support custom bitwidth.");
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,4,"),
"Unable to convert 'd:10,4,' ArrowSchema decimal format to Velox decimal");
VELOX_ASSERT_THROW(
*testSchemaImport("d:10"),
"Unable to convert 'd:10' ArrowSchema decimal format to Velox decimal");
VELOX_ASSERT_THROW(
*testSchemaImport("d:"),
"Unable to convert 'd:' ArrowSchema decimal format to Velox decimal");
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,"),
"Unable to convert 'd:10,' ArrowSchema decimal format to Velox decimal");
}

TEST_F(ArrowBridgeSchemaImportTest, complexTypes) {
Expand Down

0 comments on commit dc561a3

Please sign in to comment.