Skip to content

Commit a3ffc52

Browse files
authored
add python binding for approx_distinct aggregate function (#1134)
1 parent f38443d commit a3ffc52

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

python/src/functions.rs

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ define_unary_function!(avg);
224224
define_unary_function!(min);
225225
define_unary_function!(max);
226226
define_unary_function!(count);
227+
define_unary_function!(approx_distinct);
227228

228229
#[pyclass(name = "Volatility", module = "datafusion.functions")]
229230
#[derive(Clone)]
@@ -323,6 +324,7 @@ pub fn init(module: &PyModule) -> PyResult<()> {
323324
module.add_class::<PyVolatility>()?;
324325
module.add_function(wrap_pyfunction!(abs, module)?)?;
325326
module.add_function(wrap_pyfunction!(acos, module)?)?;
327+
module.add_function(wrap_pyfunction!(approx_distinct, module)?)?;
326328
module.add_function(wrap_pyfunction!(array, module)?)?;
327329
module.add_function(wrap_pyfunction!(ascii, module)?)?;
328330
module.add_function(wrap_pyfunction!(asin, module)?)?;

python/tests/test_aggregation.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
import pytest
20+
from datafusion import ExecutionContext
21+
from datafusion import functions as f
22+
23+
24+
@pytest.fixture
25+
def df():
26+
ctx = ExecutionContext()
27+
28+
# create a RecordBatch and a new DataFrame from it
29+
batch = pa.RecordBatch.from_arrays(
30+
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
31+
names=["a", "b"],
32+
)
33+
return ctx.create_dataframe([[batch]])
34+
35+
36+
def test_built_in_aggregation(df):
37+
col_a = f.col("a")
38+
col_b = f.col("b")
39+
df = df.aggregate(
40+
[],
41+
[f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)],
42+
)
43+
result = df.collect()[0]
44+
assert result.column(0) == pa.array([3])
45+
assert result.column(1) == pa.array([1])
46+
assert result.column(2) == pa.array([3], type=pa.uint64())
47+
assert result.column(3) == pa.array([2], type=pa.uint64())

0 commit comments

Comments
 (0)