From 38c54444b7800c7f088b8f11b60f1122870a65ec Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 17:58:36 +0100 Subject: [PATCH 1/2] GH-39217: [Python] RecordBatchReader.from_stream constructor for objects implementing the Arrow PyCapsule protocol --- python/pyarrow/ipc.pxi | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index ae52f5cf34e8b..f8639da08ec78 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -883,6 +883,43 @@ cdef class RecordBatchReader(_Weakrefable): self.reader = c_reader return self + @staticmethod + def from_stream(data, schema=None): + """ + Create RecordBatchReader from a Arrow-compatible stream object. + + This accepts objects implementing the Arrow PyCapsule Protocol for + streams, i.e. objects that have a ``__arrow_c_stream__`` method. + + Parameters + ---------- + data : Arrow-compatible stream object + Any object that implements the Arrow PyCapsule Protocol for + streams. + schema : Schema, default None + The schema to which the stream should be casted, is supported + by the stream object. + + Returns + ------- + RecordBatchReader + """ + + if not hasattr(data, "__arrow_c_stream__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "streams (i.e. having a `__arrow_c_stream__` method), " + f"got {type(data)!r}." + ) + + if schema is not None: + requested = schema.__arrow_c_schema__() + else: + requested = None + + capsule = data.__arrow_c_stream__(requested) + return RecordBatchReader._import_from_c_capsule(capsule) + @staticmethod def from_batches(Schema schema not None, batches): """ From 8e71540dfe37d14cf6ef05d3ee3ddbe585986ec5 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 21 Dec 2023 16:08:52 +0100 Subject: [PATCH 2/2] add tests and schema keyword validation --- python/pyarrow/ipc.pxi | 8 +++++- python/pyarrow/tests/test_array.py | 4 +-- python/pyarrow/tests/test_ipc.py | 44 ++++++++++++++++++++++++++++++ python/pyarrow/tests/test_table.py | 12 ++++---- 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index f8639da08ec78..da9636dfc86e1 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -897,7 +897,7 @@ cdef class RecordBatchReader(_Weakrefable): Any object that implements the Arrow PyCapsule Protocol for streams. schema : Schema, default None - The schema to which the stream should be casted, is supported + The schema to which the stream should be casted, if supported by the stream object. Returns @@ -913,6 +913,12 @@ cdef class RecordBatchReader(_Weakrefable): ) if schema is not None: + if not hasattr(schema, "__arrow_c_schema__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "schema (i.e. having a `__arrow_c_schema__` method), " + f"got {type(schema)!r}." + ) requested = schema.__arrow_c_schema__() else: requested = None diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 599d15d023a55..f071d30509a86 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -3341,8 +3341,8 @@ class ArrayWrapper: def __init__(self, data): self.data = data - def __arrow_c_array__(self, requested_type=None): - return self.data.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.data.__arrow_c_array__(requested_schema) # Can roundtrip through the C array protocol arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64())) diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 450d26e3b771c..f75ec8158a9da 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1194,3 +1194,47 @@ def make_batches(): with pytest.raises(TypeError): reader = pa.RecordBatchReader.from_batches(None, batches) pass + + +def test_record_batch_reader_from_arrow_stream(): + + class StreamWrapper: + def __init__(self, batches): + self.batches = batches + + def __arrow_c_stream__(self, requested_schema=None): + reader = pa.RecordBatchReader.from_batches( + self.batches[0].schema, self.batches) + return reader.__arrow_c_stream__(requested_schema) + + data = [ + pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), + pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']) + ] + wrapper = StreamWrapper(data) + + # Can roundtrip a pyarrow stream-like object + expected = pa.Table.from_batches(data) + reader = pa.RecordBatchReader.from_stream(expected) + assert reader.read_all() == expected + + # Can roundtrip through the wrapper. + reader = pa.RecordBatchReader.from_stream(wrapper) + assert reader.read_all() == expected + + # Passing schema works if already that schema + reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema) + assert reader.read_all() == expected + + # If schema doesn't match, raises NotImplementedError + with pytest.raises(NotImplementedError): + pa.RecordBatchReader.from_stream( + wrapper, schema=pa.schema([pa.field('a', pa.int32())]) + ) + + # Proper type errors for wrong input + with pytest.raises(TypeError): + pa.RecordBatchReader.from_stream(data[0]['a']) + + with pytest.raises(TypeError): + pa.RecordBatchReader.from_stream(expected, schema=data[0]) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index a678f521e38d5..ee036f136c77b 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -558,8 +558,8 @@ class BatchWrapper: def __init__(self, batch): self.batch = batch - def __arrow_c_array__(self, requested_type=None): - return self.batch.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.batch.__arrow_c_array__(requested_schema) data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) @@ -586,8 +586,8 @@ class BatchWrapper: def __init__(self, batch): self.batch = batch - def __arrow_c_array__(self, requested_type=None): - return self.batch.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.batch.__arrow_c_array__(requested_schema) data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) @@ -615,10 +615,10 @@ class StreamWrapper: def __init__(self, batches): self.batches = batches - def __arrow_c_stream__(self, requested_type=None): + def __arrow_c_stream__(self, requested_schema=None): reader = pa.RecordBatchReader.from_batches( self.batches[0].schema, self.batches) - return reader.__arrow_c_stream__(requested_type) + return reader.__arrow_c_stream__(requested_schema) data = [ pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),