diff --git a/cpp/include/cuspatial/detail/intersection/linestring_intersection.cuh b/cpp/include/cuspatial/detail/intersection/linestring_intersection.cuh index 796365bde..2db5492ed 100644 --- a/cpp/include/cuspatial/detail/intersection/linestring_intersection.cuh +++ b/cpp/include/cuspatial/detail/intersection/linestring_intersection.cuh @@ -258,7 +258,12 @@ linestring_intersection_result pairwise_linestring_intersection( stream); points.remove_if(range(point_flags.begin(), point_flags.end()), stream); + + rmm::device_uvector point_flags_int(point_flags.size(), stream); + thrust::copy( + rmm::exec_policy(stream), point_flags.begin(), point_flags.end(), point_flags_int.begin()); } + // Phase 4: Assemble results as union column auto num_union_column_rows = points.geoms->size() + segments.geoms->size(); auto geometry_collection_offsets = diff --git a/cpp/include/cuspatial/detail/intersection/linestring_intersection_with_duplicates.cuh b/cpp/include/cuspatial/detail/intersection/linestring_intersection_with_duplicates.cuh index 941efd426..3279df2f9 100644 --- a/cpp/include/cuspatial/detail/intersection/linestring_intersection_with_duplicates.cuh +++ b/cpp/include/cuspatial/detail/intersection/linestring_intersection_with_duplicates.cuh @@ -50,6 +50,16 @@ namespace detail { namespace intersection_functors { +/** + * @brief Cast `uint8_t` to `X` + * + * @tparam X The type to cast to + */ +template +struct uchar_to_x { + X __device__ operator()(uint8_t c) { return static_cast(c); } +}; + /** @brief Functor to compute the updated offset buffer after `remove_if` operation. * * Given the `i`th row in the `geometry_collection_offset`, find the number of all @@ -292,11 +302,13 @@ struct linestring_intersection_intermediates { rmm::device_uvector reduced_flags(num_pairs(), stream); auto keys_begin = make_geometry_id_iterator(offsets->begin(), offsets->end()); + auto iflags = + thrust::make_transform_iterator(flags.begin(), intersection_functors::uchar_to_x{}); auto [keys_end, flags_end] = thrust::reduce_by_key(rmm::exec_policy(stream), keys_begin, keys_begin + flags.size(), - flags.begin(), + iflags, reduced_keys.begin(), reduced_flags.begin(), thrust::equal_to(), diff --git a/cpp/tests/intersection/linestring_intersection_large_test.cu b/cpp/tests/intersection/linestring_intersection_large_test.cu index 9c5833aa2..51c571392 100644 --- a/cpp/tests/intersection/linestring_intersection_large_test.cu +++ b/cpp/tests/intersection/linestring_intersection_large_test.cu @@ -2027,3 +2027,32 @@ TYPED_TEST(LinestringIntersectionLargeTest, LongInput) CUSPATIAL_RUN_TEST( this->template verify_legal_result, multilinestrings1.range(), multilinestrings2.range()); } + +template +struct coordinate_functor { + cuspatial::vec_2d __device__ operator()(std::size_t i) + { + return cuspatial::vec_2d{static_cast(i), static_cast(i)}; + } +}; + +TYPED_TEST(LinestringIntersectionLargeTest, LongInput_2) +{ + using P = cuspatial::vec_2d; + auto geometry_offset = cuspatial::test::make_device_vector({0, 1}); + auto part_offset = cuspatial::test::make_device_vector({0, 130}); + auto coordinates = rmm::device_uvector

(260, this->stream()); + + thrust::tabulate(rmm::exec_policy(this->stream()), + coordinates.begin(), + thrust::next(coordinates.begin(), 128), + coordinate_functor{}); + + coordinates.set_element(128, P{127.0, 0.0}, this->stream()); + coordinates.set_element(129, P{0.0, 0.0}, this->stream()); + + auto rng = cuspatial::make_multilinestring_range( + 1, geometry_offset.begin(), 1, part_offset.begin(), 130, coordinates.begin()); + + CUSPATIAL_RUN_TEST(this->template verify_legal_result, rng, rng); +}