diff --git a/cpp/src/cylon/join/hash_join.cpp b/cpp/src/cylon/join/hash_join.cpp index bc60573cb..332c06e61 100644 --- a/cpp/src/cylon/join/hash_join.cpp +++ b/cpp/src/cylon/join/hash_join.cpp @@ -207,7 +207,6 @@ Status multi_index_hash_join(const std::shared_ptr <ab, hash_map.clear(); // copy arrays from the table indices - // todo use arrow::compute::Take for this return util::build_final_table(row_indices[0], row_indices[1], ltab, rtab, config.GetLeftTableSuffix(), config.GetRightTableSuffix(), joined_table, memory_pool); diff --git a/cpp/src/cylon/join/join_utils.cpp b/cpp/src/cylon/join/join_utils.cpp index 37db4d569..5c22b759b 100644 --- a/cpp/src/cylon/join/join_utils.cpp +++ b/cpp/src/cylon/join/join_utils.cpp @@ -142,34 +142,35 @@ Status build_final_table(const std::vector &left_indices, const std::string &left_table_prefix, const std::string &right_table_prefix, std::shared_ptr *final_table, - arrow::MemoryPool *memory_pool) { + arrow::MemoryPool *pool) { const auto &schema = build_final_table_schema(left_tab, right_tab, left_table_prefix, right_table_prefix); - std::vector> data_arrays; + std::vector> data_arrays; + data_arrays.reserve(left_tab->num_columns() + right_tab->num_columns()); // build arrays for left tab - for (auto &column: left_tab->columns()) { - std::shared_ptr destination_col_array; + for (const auto &column: left_tab->columns()) { + std::shared_ptr col_array; RETURN_CYLON_STATUS_IF_ARROW_FAILED( cylon::util::copy_array_by_indices(left_indices, - cylon::util::GetChunkOrEmptyArray(column, 0), - &destination_col_array, - memory_pool)); - data_arrays.push_back(destination_col_array); + cylon::util::GetChunkOrEmptyArray(column, 0, pool), + &col_array, + pool)); + data_arrays.push_back(std::make_shared(std::move(col_array))); } // build arrays for right tab - for (auto &column: right_tab->columns()) { - std::shared_ptr destination_col_array; + for (const auto &column: right_tab->columns()) { + std::shared_ptr col_array; RETURN_CYLON_STATUS_IF_ARROW_FAILED( cylon::util::copy_array_by_indices(right_indices, cylon::util::GetChunkOrEmptyArray(column, 0), - &destination_col_array, - memory_pool)); - data_arrays.push_back(destination_col_array); + &col_array, + pool)); + data_arrays.push_back(std::make_shared(std::move(col_array))); } - *final_table = arrow::Table::Make(schema, data_arrays); + *final_table = arrow::Table::Make(schema, std::move(data_arrays)); return Status::OK(); } diff --git a/cpp/src/cylon/join/join_utils.hpp b/cpp/src/cylon/join/join_utils.hpp index 627d92850..09a442df3 100644 --- a/cpp/src/cylon/join/join_utils.hpp +++ b/cpp/src/cylon/join/join_utils.hpp @@ -30,13 +30,13 @@ std::shared_ptr build_final_table_schema(const std::shared_ptr &left_indices, - const std::vector &right_indices, - const std::shared_ptr &left_tab, - const std::shared_ptr &right_tab, - const std::string &left_table_prefix, - const std::string &right_table_prefix, - std::shared_ptr *final_table, - arrow::MemoryPool *memory_pool); + const std::vector &right_indices, + const std::shared_ptr &left_tab, + const std::shared_ptr &right_tab, + const std::string &left_table_prefix, + const std::string &right_table_prefix, + std::shared_ptr *final_table, + arrow::MemoryPool *pool); Status build_final_table_inplace_index( size_t left_inplace_column, size_t right_inplace_column, diff --git a/cpp/src/cylon/util/arrow_utils.hpp b/cpp/src/cylon/util/arrow_utils.hpp index 398baff3d..37bc8399b 100644 --- a/cpp/src/cylon/util/arrow_utils.hpp +++ b/cpp/src/cylon/util/arrow_utils.hpp @@ -161,6 +161,15 @@ arrow::Status MakeDummyArray(const std::shared_ptr &type, int64 std::shared_ptr *out, arrow::MemoryPool *pool = arrow::default_memory_pool()); +template +typename std::enable_if_t::value, + std::shared_ptr> WrapNumericVector(const std::vector &data) { + auto buf = arrow::Buffer::Wrap(data); + auto type = arrow::TypeTraits::ArrowType>::type_singleton(); + auto array_data = arrow::ArrayData::Make(std::move(type), data.size(), {nullptr, std::move(buf)}); + return arrow::MakeArray(array_data); +} + } // namespace util } // namespace cylon #endif // CYLON_SRC_UTIL_ARROW_UTILS_HPP_ diff --git a/cpp/src/cylon/util/copy_arrray.cpp b/cpp/src/cylon/util/copy_arrray.cpp index dfe5f69b3..98ce28f39 100644 --- a/cpp/src/cylon/util/copy_arrray.cpp +++ b/cpp/src/cylon/util/copy_arrray.cpp @@ -134,152 +134,12 @@ arrow::Status copy_array_by_indices(const std::vector &indices, const std::shared_ptr &data_array, std::shared_ptr *copied_array, arrow::MemoryPool *memory_pool) { - switch (data_array->type()->id()) { - case arrow::Type::BOOL: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT8: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT8: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT16: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT16: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT32: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT32: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT64: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT64: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::HALF_FLOAT: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::FLOAT: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::DOUBLE: - return do_copy_numeric_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::STRING: - return do_copy_binary_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::BINARY: - return do_copy_binary_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::FIXED_SIZE_BINARY: - return do_copy_fixed_binary_array(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::LIST: { - auto t_value = std::static_pointer_cast(data_array->type()); - switch (t_value->value_type()->id()) { - case arrow::Type::BOOL: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT8: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT8: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT16: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT16: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT32: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT32: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::UINT64: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::INT64: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::HALF_FLOAT: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::FLOAT: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - case arrow::Type::DOUBLE: - return do_copy_numeric_list(indices, - data_array, - copied_array, - memory_pool); - default: - return arrow::Status::Invalid("Un-supported type"); - } - } - default: - return arrow::Status::Invalid("Un-supported type"); - } + auto idx_array = util::WrapNumericVector(indices); + arrow::compute::ExecContext exec_ctx(memory_pool); + ARROW_ASSIGN_OR_RAISE(*copied_array, arrow::compute::Take(*data_array, *idx_array, + arrow::compute::TakeOptions::NoBoundsCheck(), + &exec_ctx)); + return arrow::Status::OK(); } } // namespace util diff --git a/cpp/test/join_test.cpp b/cpp/test/join_test.cpp index ebaeeff4c..a85cf4bf3 100644 --- a/cpp/test/join_test.cpp +++ b/cpp/test/join_test.cpp @@ -121,5 +121,43 @@ TEST_CASE("Join testing chunks", "[join]") { } } +TEST_CASE("Join testing list type", "[join]") { + // todo: list types don't work in a dist env + if (ctx->GetWorldSize() > 1) { + return; + } + + auto schema = arrow::schema({{arrow::field("a", arrow::int64())}, + {arrow::field("b", arrow::list(arrow::float32()))}}); + auto t0 = TableFromJSON(schema, {R"([{"a": 3, "b":[0.025, 1.0]}, + {"a": 26, "b":[0.394]}, + {"a": 51, "b":[0.755, 1.0]}, + {"a": 20, "b":[0.030, 1.0]}, + {"a": 33, "b":[0.318]}])"}); + auto t1 = TableFromJSON(schema, {R"([{"a": 3, "b":[0.025, 1.0]}, + {"a": 26, "b":[0.394]}, + {"a": 51, "b":[0.755, 1.0]}, + {"a": 20, "b":[0.030, 1.0]}, + {"a": 33, "b":[0.318]}])"}); + auto exp_schema = arrow::schema({{arrow::field("l_a", arrow::int64())}, + {arrow::field("l_b", arrow::list(arrow::float32()))}, + {arrow::field("r_a", arrow::int64())}, + {arrow::field("r_b", arrow::list(arrow::float32()))}}); + auto exp_inner = TableFromJSON(exp_schema, {R"([{"l_a": 3, "l_b":[0.025, 1.0], "r_a": 3, "r_b":[0.025, 1.0]}, + {"l_a": 26, "l_b":[0.394], "r_a": 26, "r_b":[0.394]}, + {"l_a": 51, "l_b":[0.755, 1.0], "r_a": 51, "r_b":[0.755, 1.0]}, + {"l_a": 20, "l_b":[0.030, 1.0], "r_a": 20, "r_b":[0.030, 1.0]}, + {"l_a": 33, "l_b":[0.318], "r_a": 33, "r_b":[0.318]}])"}); + + auto config = cylon::join::config::JoinConfig(cylon::join::config::JoinType::INNER, + 0, 0, + cylon::join::config::JoinAlgorithm::HASH, + "l_", + "r_"); + std::shared_ptr res; + CHECK_CYLON_STATUS(cylon::join::JoinTables(t0, t1, config, &res)); + CHECK_ARROW_EQUAL(exp_inner, res); +} + } } \ No newline at end of file diff --git a/python/cylonflow b/python/cylonflow index 25c8eee35..24de10b49 160000 --- a/python/cylonflow +++ b/python/cylonflow @@ -1 +1 @@ -Subproject commit 25c8eee35e7207a3ab8bd771d2eeb7edaf3e9dc9 +Subproject commit 24de10b49c190a7c9403fcddd211b497383b52ab