Skip to content

Commit

Permalink
Fixing: Corrupted result when joining tables contain list data types #…
Browse files Browse the repository at this point in the history
…615  (#616)

* inspect issue

* possible fix

* remove cmake temp

* adding test case

* minor improvements
  • Loading branch information
nirandaperera authored Aug 31, 2022
1 parent 68fa598 commit 121b386
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 169 deletions.
1 change: 0 additions & 1 deletion cpp/src/cylon/join/hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ Status multi_index_hash_join(const std::shared_ptr<arrow::Table> &ltab,
hash_map.clear();

// copy arrays from the table indices
// todo use arrow::compute::Take for this
return util::build_final_table(row_indices[0], row_indices[1],
ltab, rtab, config.GetLeftTableSuffix(),
config.GetRightTableSuffix(), joined_table, memory_pool);
Expand Down
29 changes: 15 additions & 14 deletions cpp/src/cylon/join/join_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,35 @@ Status build_final_table(const std::vector<int64_t> &left_indices,
const std::string &left_table_prefix,
const std::string &right_table_prefix,
std::shared_ptr<arrow::Table> *final_table,
arrow::MemoryPool *memory_pool) {
arrow::MemoryPool *pool) {
const auto &schema =
build_final_table_schema(left_tab, right_tab, left_table_prefix, right_table_prefix);

std::vector<std::shared_ptr<arrow::Array>> data_arrays;
std::vector<std::shared_ptr<arrow::ChunkedArray>> data_arrays;
data_arrays.reserve(left_tab->num_columns() + right_tab->num_columns());

// build arrays for left tab
for (auto &column: left_tab->columns()) {
std::shared_ptr<arrow::Array> destination_col_array;
for (const auto &column: left_tab->columns()) {
std::shared_ptr<arrow::Array> col_array;
RETURN_CYLON_STATUS_IF_ARROW_FAILED(
cylon::util::copy_array_by_indices(left_indices,
cylon::util::GetChunkOrEmptyArray(column, 0),
&destination_col_array,
memory_pool));
data_arrays.push_back(destination_col_array);
cylon::util::GetChunkOrEmptyArray(column, 0, pool),
&col_array,
pool));
data_arrays.push_back(std::make_shared<arrow::ChunkedArray>(std::move(col_array)));
}

// build arrays for right tab
for (auto &column: right_tab->columns()) {
std::shared_ptr<arrow::Array> destination_col_array;
for (const auto &column: right_tab->columns()) {
std::shared_ptr<arrow::Array> col_array;
RETURN_CYLON_STATUS_IF_ARROW_FAILED(
cylon::util::copy_array_by_indices(right_indices,
cylon::util::GetChunkOrEmptyArray(column, 0),
&destination_col_array,
memory_pool));
data_arrays.push_back(destination_col_array);
&col_array,
pool));
data_arrays.push_back(std::make_shared<arrow::ChunkedArray>(std::move(col_array)));
}
*final_table = arrow::Table::Make(schema, data_arrays);
*final_table = arrow::Table::Make(schema, std::move(data_arrays));
return Status::OK();
}

Expand Down
14 changes: 7 additions & 7 deletions cpp/src/cylon/join/join_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ std::shared_ptr<arrow::Schema> build_final_table_schema(const std::shared_ptr<ar
const std::string &right_table_prefix);

Status build_final_table(const std::vector<int64_t> &left_indices,
const std::vector<int64_t> &right_indices,
const std::shared_ptr<arrow::Table> &left_tab,
const std::shared_ptr<arrow::Table> &right_tab,
const std::string &left_table_prefix,
const std::string &right_table_prefix,
std::shared_ptr<arrow::Table> *final_table,
arrow::MemoryPool *memory_pool);
const std::vector<int64_t> &right_indices,
const std::shared_ptr<arrow::Table> &left_tab,
const std::shared_ptr<arrow::Table> &right_tab,
const std::string &left_table_prefix,
const std::string &right_table_prefix,
std::shared_ptr<arrow::Table> *final_table,
arrow::MemoryPool *pool);

Status build_final_table_inplace_index(
size_t left_inplace_column, size_t right_inplace_column,
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/cylon/util/arrow_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ arrow::Status MakeDummyArray(const std::shared_ptr<arrow::DataType> &type, int64
std::shared_ptr<arrow::Array> *out,
arrow::MemoryPool *pool = arrow::default_memory_pool());

template<typename T>
typename std::enable_if_t<std::is_arithmetic<T>::value,
std::shared_ptr<arrow::Array>> WrapNumericVector(const std::vector<T> &data) {
auto buf = arrow::Buffer::Wrap(data);
auto type = arrow::TypeTraits<typename arrow::CTypeTraits<T>::ArrowType>::type_singleton();
auto array_data = arrow::ArrayData::Make(std::move(type), data.size(), {nullptr, std::move(buf)});
return arrow::MakeArray(array_data);
}

} // namespace util
} // namespace cylon
#endif // CYLON_SRC_UTIL_ARROW_UTILS_HPP_
152 changes: 6 additions & 146 deletions cpp/src/cylon/util/copy_arrray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,152 +134,12 @@ arrow::Status copy_array_by_indices(const std::vector<int64_t> &indices,
const std::shared_ptr<arrow::Array> &data_array,
std::shared_ptr<arrow::Array> *copied_array,
arrow::MemoryPool *memory_pool) {
switch (data_array->type()->id()) {
case arrow::Type::BOOL:
return do_copy_numeric_array<arrow::BooleanType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT8:
return do_copy_numeric_array<arrow::UInt8Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT8:
return do_copy_numeric_array<arrow::Int8Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT16:
return do_copy_numeric_array<arrow::UInt16Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT16:
return do_copy_numeric_array<arrow::Int16Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT32:
return do_copy_numeric_array<arrow::UInt32Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT32:
return do_copy_numeric_array<arrow::Int32Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT64:
return do_copy_numeric_array<arrow::UInt64Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT64:
return do_copy_numeric_array<arrow::Int64Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::HALF_FLOAT:
return do_copy_numeric_array<arrow::HalfFloatType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::FLOAT:
return do_copy_numeric_array<arrow::FloatType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::DOUBLE:
return do_copy_numeric_array<arrow::DoubleType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::STRING:
return do_copy_binary_array<arrow::StringType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::BINARY:
return do_copy_binary_array<arrow::BinaryType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::FIXED_SIZE_BINARY:
return do_copy_fixed_binary_array(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::LIST: {
auto t_value = std::static_pointer_cast<arrow::ListType>(data_array->type());
switch (t_value->value_type()->id()) {
case arrow::Type::BOOL:
return do_copy_numeric_list<arrow::BooleanType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT8:
return do_copy_numeric_list<arrow::UInt8Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT8:
return do_copy_numeric_list<arrow::Int8Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT16:
return do_copy_numeric_list<arrow::UInt16Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT16:
return do_copy_numeric_list<arrow::Int16Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT32:
return do_copy_numeric_list<arrow::UInt32Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT32:
return do_copy_numeric_list<arrow::Int32Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::UINT64:
return do_copy_numeric_list<arrow::UInt64Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::INT64:
return do_copy_numeric_list<arrow::Int64Type>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::HALF_FLOAT:
return do_copy_numeric_list<arrow::HalfFloatType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::FLOAT:
return do_copy_numeric_list<arrow::FloatType>(indices,
data_array,
copied_array,
memory_pool);
case arrow::Type::DOUBLE:
return do_copy_numeric_list<arrow::DoubleType>(indices,
data_array,
copied_array,
memory_pool);
default:
return arrow::Status::Invalid("Un-supported type");
}
}
default:
return arrow::Status::Invalid("Un-supported type");
}
auto idx_array = util::WrapNumericVector(indices);
arrow::compute::ExecContext exec_ctx(memory_pool);
ARROW_ASSIGN_OR_RAISE(*copied_array, arrow::compute::Take(*data_array, *idx_array,
arrow::compute::TakeOptions::NoBoundsCheck(),
&exec_ctx));
return arrow::Status::OK();
}

} // namespace util
Expand Down
38 changes: 38 additions & 0 deletions cpp/test/join_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,43 @@ TEST_CASE("Join testing chunks", "[join]") {
}
}

TEST_CASE("Join testing list type", "[join]") {
// todo: list types don't work in a dist env
if (ctx->GetWorldSize() > 1) {
return;
}

auto schema = arrow::schema({{arrow::field("a", arrow::int64())},
{arrow::field("b", arrow::list(arrow::float32()))}});
auto t0 = TableFromJSON(schema, {R"([{"a": 3, "b":[0.025, 1.0]},
{"a": 26, "b":[0.394]},
{"a": 51, "b":[0.755, 1.0]},
{"a": 20, "b":[0.030, 1.0]},
{"a": 33, "b":[0.318]}])"});
auto t1 = TableFromJSON(schema, {R"([{"a": 3, "b":[0.025, 1.0]},
{"a": 26, "b":[0.394]},
{"a": 51, "b":[0.755, 1.0]},
{"a": 20, "b":[0.030, 1.0]},
{"a": 33, "b":[0.318]}])"});
auto exp_schema = arrow::schema({{arrow::field("l_a", arrow::int64())},
{arrow::field("l_b", arrow::list(arrow::float32()))},
{arrow::field("r_a", arrow::int64())},
{arrow::field("r_b", arrow::list(arrow::float32()))}});
auto exp_inner = TableFromJSON(exp_schema, {R"([{"l_a": 3, "l_b":[0.025, 1.0], "r_a": 3, "r_b":[0.025, 1.0]},
{"l_a": 26, "l_b":[0.394], "r_a": 26, "r_b":[0.394]},
{"l_a": 51, "l_b":[0.755, 1.0], "r_a": 51, "r_b":[0.755, 1.0]},
{"l_a": 20, "l_b":[0.030, 1.0], "r_a": 20, "r_b":[0.030, 1.0]},
{"l_a": 33, "l_b":[0.318], "r_a": 33, "r_b":[0.318]}])"});

auto config = cylon::join::config::JoinConfig(cylon::join::config::JoinType::INNER,
0, 0,
cylon::join::config::JoinAlgorithm::HASH,
"l_",
"r_");
std::shared_ptr<arrow::Table> res;
CHECK_CYLON_STATUS(cylon::join::JoinTables(t0, t1, config, &res));
CHECK_ARROW_EQUAL(exp_inner, res);
}

}
}

0 comments on commit 121b386

Please sign in to comment.