Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-39217: [Python] RecordBatchReader.from_stream constructor for objects implementing the Arrow PyCapsule protocol #39218

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,49 @@ cdef class RecordBatchReader(_Weakrefable):
self.reader = c_reader
return self

@staticmethod
def from_stream(data, schema=None):
"""
Create RecordBatchReader from a Arrow-compatible stream object.

This accepts objects implementing the Arrow PyCapsule Protocol for
streams, i.e. objects that have a ``__arrow_c_stream__`` method.

Parameters
----------
data : Arrow-compatible stream object
Any object that implements the Arrow PyCapsule Protocol for
streams.
schema : Schema, default None
The schema to which the stream should be casted, if supported
by the stream object.

Returns
-------
RecordBatchReader
"""

if not hasattr(data, "__arrow_c_stream__"):
raise TypeError(
"Expected an object implementing the Arrow PyCapsule Protocol for "
"streams (i.e. having a `__arrow_c_stream__` method), "
f"got {type(data)!r}."
)

if schema is not None:
if not hasattr(schema, "__arrow_c_schema__"):
raise TypeError(
"Expected an object implementing the Arrow PyCapsule Protocol for "
"schema (i.e. having a `__arrow_c_schema__` method), "
f"got {type(schema)!r}."
)
requested = schema.__arrow_c_schema__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want to first test the presence of this method using hasattr as above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea

else:
requested = None

capsule = data.__arrow_c_stream__(requested)
return RecordBatchReader._import_from_c_capsule(capsule)

@staticmethod
def from_batches(Schema schema not None, batches):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,8 +3351,8 @@ class ArrayWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_array__(self, requested_type=None):
return self.data.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)

# Can roundtrip through the C array protocol
arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64()))
Expand Down
44 changes: 44 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,3 +1194,47 @@ def make_batches():
with pytest.raises(TypeError):
reader = pa.RecordBatchReader.from_batches(None, batches)
pass


def test_record_batch_reader_from_arrow_stream():

class StreamWrapper:
def __init__(self, batches):
self.batches = batches

def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_batches(
self.batches[0].schema, self.batches)
return reader.__arrow_c_stream__(requested_schema)

data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a'])
]
wrapper = StreamWrapper(data)

# Can roundtrip a pyarrow stream-like object
expected = pa.Table.from_batches(data)
reader = pa.RecordBatchReader.from_stream(expected)
assert reader.read_all() == expected

# Can roundtrip through the wrapper.
reader = pa.RecordBatchReader.from_stream(wrapper)
assert reader.read_all() == expected

# Passing schema works if already that schema
reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema)
assert reader.read_all() == expected

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
pa.RecordBatchReader.from_stream(
wrapper, schema=pa.schema([pa.field('a', pa.int32())])
)

# Proper type errors for wrong input
with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(data[0]['a'])

with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(expected, schema=data[0])
12 changes: 6 additions & 6 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ class BatchWrapper:
def __init__(self, batch):
self.batch = batch

def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.batch.__arrow_c_array__(requested_schema)

data = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64())
Expand All @@ -586,8 +586,8 @@ class BatchWrapper:
def __init__(self, batch):
self.batch = batch

def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.batch.__arrow_c_array__(requested_schema)

data = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64())
Expand Down Expand Up @@ -615,10 +615,10 @@ class StreamWrapper:
def __init__(self, batches):
self.batches = batches

def __arrow_c_stream__(self, requested_type=None):
def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_batches(
self.batches[0].schema, self.batches)
return reader.__arrow_c_stream__(requested_type)
return reader.__arrow_c_stream__(requested_schema)

data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
Expand Down