Skip to content

Commit

Permalink
Prevent a degenerative join in test_dpp_reuse_broadcast_exchange [dat…
Browse files Browse the repository at this point in the history
…abricks] (#10168)
  • Loading branch information
NVnavkumar authored Jan 10, 2024
1 parent 4e57f5f commit c92bbc0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
41 changes: 32 additions & 9 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -329,19 +329,39 @@ def start(self, rand):
self._start(rand, lambda: self.next_val())

class RepeatSeqGen(DataGen):
"""Generate Repeated seq of `length` random items"""
def __init__(self, child, length):
super().__init__(child.data_type, nullable=False)
self.nullable = child.nullable
self._child = child
"""Generate Repeated seq of `length` random items if child is a DataGen,
otherwise repeat the provided seq when child is a list.
When child is a list:
data_type must be specified
length must be <= length of child
When child is a DataGen:
length must be specified
data_type must be None or match child's
"""
def __init__(self, child, length=None, data_type=None):
if isinstance(child, list):
super().__init__(data_type, nullable=False)
self.nullable = None in child
assert (length is None or length < len(child))
self._length = length if length is not None else len(child)
self._child = child[:length] if length is not None else child
else:
super().__init__(child.data_type, nullable=False)
self.nullable = child.nullable
assert(data_type is None or data_type != child.data_type)
assert(length is not None)
self._length = length
self._child = child
self._vals = []
self._length = length
self._index = 0

def __repr__(self):
return super().__repr__() + '(' + str(self._child) + ')'

def _cache_repr(self):
if isinstance(self._child, list):
return super()._cache_repr() + '(' + str(self._child) + ',' + str(self._length) + ')'
return super()._cache_repr() + '(' + self._child._cache_repr() + ',' + str(self._length) + ')'

def _loop_values(self):
Expand All @@ -351,9 +371,12 @@ def _loop_values(self):

def start(self, rand):
self._index = 0
self._child.start(rand)
self._start(rand, self._loop_values)
self._vals = [self._child.gen() for _ in range(0, self._length)]
if isinstance(self._child, list):
self._vals = self._child
else:
self._child.start(rand)
self._vals = [self._child.gen() for _ in range(0, self._length)]

class SetValuesGen(DataGen):
"""A set of values that are randomly selected"""
Expand Down
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/dpp_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,20 +14,25 @@

import pytest

from pyspark.sql.types import IntegerType

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect
from conftest import spark_tmp_table_factory
from data_gen import *
from marks import ignore_order, allow_non_gpu
from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later

# non-positive values here can produce a degenerative join, so here we ensure that most values are
# positive to ensure the join will produce rows. See https://github.com/NVIDIA/spark-rapids/issues/10147
value_gen = RepeatSeqGen([None, INT_MIN, -1, 0, 1, INT_MAX], data_type=IntegerType())

def create_dim_table(table_name, table_format, length=500):
def fn(spark):
df = gen_df(spark, [
('key', IntegerGen(nullable=False, min_val=0, max_val=9, special_cases=[])),
('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])),
('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])),
('value', int_gen),
('value', value_gen),
# specify nullable=False for `filter` to avoid generating invalid SQL with
# expression `filter = None` (https://github.com/NVIDIA/spark-rapids/issues/9817)
('filter', RepeatSeqGen(
Expand All @@ -49,7 +54,7 @@ def fn(spark):
('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])),
# ex_key is not a partition column
('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])),
('value', int_gen)], length)
('value', value_gen)], length)
df.write.format(table_format) \
.mode("overwrite") \
.partitionBy('key', 'skey') \
Expand Down

0 comments on commit c92bbc0

Please sign in to comment.