Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Arrow types for efficient JSON data representation in pyarrow #312

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from db_dtypes import core
from db_dtypes.version import __version__
from . import _versions_helpers

from . import _versions_helpers

date_dtype_name = "dbdate"
time_dtype_name = "dbtime"
Expand All @@ -50,7 +50,7 @@
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
from db_dtypes.json import JSONArray, JSONDtype
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
else:
JSONArray = None
JSONDtype = None
Expand Down Expand Up @@ -374,6 +374,8 @@ def __sub__(self, other):
"DateDtype",
"JSONDtype",
"JSONArray",
"JSONArrowType",
"JSONArrowScalar",
"TimeArray",
"TimeDtype",
]
40 changes: 40 additions & 0 deletions db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def construct_array_type(cls):
"""Return the array type associated with this dtype."""
return JSONArray

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
"""Convert the pyarrow array to the extension array."""
return JSONArray(array)


class JSONArray(arrays.ArrowExtensionArray):
"""Extension array that handles BigQuery JSON data, leveraging a string-based
Expand Down Expand Up @@ -92,6 +96,10 @@ def __init__(self, values) -> None:
else:
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")

def __arrow_array__(self, type=None):
"""Convert to an arrow array. This is required for pyarrow extension."""
return pa.array(self.pa_data, type=JSONArrowType())

@classmethod
def _box_pa(
cls, value, pa_type: pa.DataType | None = None
Expand Down Expand Up @@ -208,6 +216,8 @@ def __getitem__(self, item):
value = self.pa_data[item]
if isinstance(value, pa.ChunkedArray):
return type(self)(value)
elif isinstance(value, pa.ExtensionScalar):
return value.as_py()
else:
scalar = JSONArray._deserialize_json(value.as_py())
if scalar is None:
Expand Down Expand Up @@ -244,3 +254,33 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
result[mask] = self._dtype.na_value
result[~mask] = data[~mask].pa_data.to_numpy()
return result


class JSONArrowScalar(pa.ExtensionScalar):
def as_py(self):
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)


class JSONArrowType(pa.ExtensionType):
"""Arrow extension type for the `dbjson` Pandas extension type."""

def __init__(self) -> None:
super().__init__(pa.string(), "dbjson")

def __arrow_ext_serialize__(self) -> bytes:
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
return JSONArrowType()

def to_pandas_dtype(self):
return JSONDtype()

def __arrow_ext_scalar_class__(self):
return JSONArrowScalar


# Register the type to be included in RecordBatches, sent over IPC and received in
# another Python process.
pa.register_extension_type(JSONArrowType())
4 changes: 0 additions & 4 deletions tests/compliance/json/test_json_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import pytest


class TestJSONArrayAccumulate(base.BaseAccumulateTests):
pass


class TestJSONArrayCasting(base.BaseCastingTests):
def test_astype_str(self, data):
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
Expand Down
123 changes: 122 additions & 1 deletion tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import json
import math

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import db_dtypes
Expand All @@ -36,7 +38,7 @@
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15.99,
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
},
Expand Down Expand Up @@ -114,3 +116,122 @@ def test_as_numpy_array():
]
)
pd._testing.assert_equal(result, expected)


def test_json_arrow_array():
data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values())
assert isinstance(data.__arrow_array__(), pa.ExtensionArray)


def test_json_arrow_storage_type():
arrow_json_type = db_dtypes.JSONArrowType()
assert arrow_json_type.extension_name == "dbjson"
assert pa.types.is_string(arrow_json_type.storage_type)


def test_json_arrow_constructors():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
storage_array = pa.array(data, type=pa.string())

arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
assert isinstance(arr_1, pa.ExtensionArray)

arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
assert isinstance(arr_2, pa.ExtensionArray)

assert arr_1 == arr_2


def test_json_arrow_to_pandas():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pandas()
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert pd.isna(s[6])


def test_json_arrow_to_pylist():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pylist()
assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None


def test_json_arrow_record_batch():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
sink = pa.BufferOutputStream()

with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
writer.write_batch(batch)

buf = sink.getvalue()

with pa.ipc.open_stream(buf) as reader:
result = reader.read_all()

json_col = result.column("json_col")
assert isinstance(json_col.type, db_dtypes.JSONArrowType)

s = json_col.to_pylist()

assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None
Loading