Skip to content

Commit e713bc3

Browse files
authored
update cargo.toml in python crate and fix unit test due to hash joins (#483)
* update cargo.toml * fix group by * remove unused imports
1 parent e82d053 commit e713bc3

File tree

5 files changed

+23
-29
lines changed

5 files changed

+23
-29
lines changed

python/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ libc = "0.2"
3131
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3232
rand = "0.7"
3333
pyo3 = { version = "0.13.2", features = ["extension-module"] }
34-
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "2423ff0d" }
34+
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "c3fc0c75af5ff2ebb99dba197d9d2ccd83eb5952" }
3535

3636
[lib]
3737
name = "datafusion"

python/tests/generic.py

-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import unittest
19-
import tempfile
2018
import datetime
21-
import os.path
22-
import shutil
23-
2419
import numpy
2520
import pyarrow
26-
import datafusion
2721

2822
# used to write parquet files
2923
import pyarrow.parquet

python/tests/test_df.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
import pyarrow as pa
2121
import datafusion
22+
2223
f = datafusion.functions
2324

2425

2526
class TestCase(unittest.TestCase):
26-
2727
def _prepare(self):
2828
ctx = datafusion.ExecutionContext()
2929

@@ -51,12 +51,10 @@ def test_select(self):
5151
def test_filter(self):
5252
df = self._prepare()
5353

54-
df = df \
55-
.select(
56-
f.col("a") + f.col("b"),
57-
f.col("a") - f.col("b"),
58-
) \
59-
.filter(f.col("a") > f.lit(2))
54+
df = df.select(
55+
f.col("a") + f.col("b"),
56+
f.col("a") - f.col("b"),
57+
).filter(f.col("a") > f.lit(2))
6058

6159
# execute and collect the first (and only) batch
6260
result = df.collect()[0]
@@ -66,12 +64,10 @@ def test_filter(self):
6664

6765
def test_sort(self):
6866
df = self._prepare()
69-
df = df.sort([
70-
f.col("b").sort(ascending=False)
71-
])
67+
df = df.sort([f.col("b").sort(ascending=False)])
7268

7369
table = pa.Table.from_batches(df.collect())
74-
expected = {'a': [3, 2, 1], 'b': [6, 5, 4]}
70+
expected = {"a": [3, 2, 1], "b": [6, 5, 4]}
7571
self.assertEqual(table.to_pydict(), expected)
7672

7773
def test_limit(self):
@@ -111,10 +107,8 @@ def test_join(self):
111107
df1 = ctx.create_dataframe([[batch]])
112108

113109
df = df.join(df1, on="a", how="inner")
114-
df = df.sort([
115-
f.col("a").sort(ascending=True)
116-
])
110+
df = df.sort([f.col("a").sort(ascending=True)])
117111
table = pa.Table.from_batches(df.collect())
118112

119-
expected = {'a': [1, 2], 'c': [8, 10], 'b': [4, 5]}
113+
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
120114
self.assertEqual(table.to_pydict(), expected)

python/tests/test_sql.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,18 @@ def test_execute(self):
8282
)
8383

8484
# group by
85-
result = ctx.sql(
85+
results = ctx.sql(
8686
"SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)"
8787
).collect()
8888

89-
result_keys = result[0].to_pydict()["CAST(a AS Int32)"]
90-
result_values = result[0].to_pydict()["COUNT(a)"]
89+
# group by returns batches
90+
result_keys = []
91+
result_values = []
92+
for result in results:
93+
pydict = result.to_pydict()
94+
result_keys.extend(pydict["CAST(a AS Int32)"])
95+
result_values.extend(pydict["COUNT(a)"])
96+
9197
result_keys, result_values = (
9298
list(t) for t in zip(*sorted(zip(result_keys, result_values)))
9399
)

python/tests/test_udaf.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717

1818
import unittest
19-
2019
import pyarrow
2120
import pyarrow.compute
2221
import datafusion
@@ -86,6 +85,7 @@ def test_group_by(self):
8685
df = df.aggregate([f.col("b")], [udaf(f.col("a"))])
8786

8887
# execute and collect the first (and only) batch
89-
result = df.collect()[0]
90-
91-
self.assertEqual(result.column(1), pyarrow.array([1.0 + 2.0, 3.0]))
88+
batches = df.collect()
89+
arrays = [batch.column(1) for batch in batches]
90+
joined = pyarrow.concat_arrays(arrays)
91+
self.assertEqual(joined, pyarrow.array([1.0 + 2.0, 3.0]))

0 commit comments

Comments
 (0)