Skip to content

Commit

Permalink
feat: Add Arrow types for efficient JSON data representation in pyarr…
Browse files Browse the repository at this point in the history
…ow (#312)

* feat: add ArrowJSONtype to extend pyarrow for JSONDtype

* nit

* add JSONArrowScalar

* fix cover
  • Loading branch information
chelsea-lin authored Jan 17, 2025
1 parent b6c1428 commit d9992fc
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 7 deletions.
6 changes: 4 additions & 2 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from db_dtypes import core
from db_dtypes.version import __version__
from . import _versions_helpers

from . import _versions_helpers

date_dtype_name = "dbdate"
time_dtype_name = "dbtime"
Expand All @@ -50,7 +50,7 @@
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
from db_dtypes.json import JSONArray, JSONDtype
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
else:
JSONArray = None
JSONDtype = None
Expand Down Expand Up @@ -374,6 +374,8 @@ def __sub__(self, other):
"DateDtype",
"JSONDtype",
"JSONArray",
"JSONArrowType",
"JSONArrowScalar",
"TimeArray",
"TimeDtype",
]
40 changes: 40 additions & 0 deletions db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def construct_array_type(cls):
"""Return the array type associated with this dtype."""
return JSONArray

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
"""Convert the pyarrow array to the extension array."""
return JSONArray(array)


class JSONArray(arrays.ArrowExtensionArray):
"""Extension array that handles BigQuery JSON data, leveraging a string-based
Expand Down Expand Up @@ -92,6 +96,10 @@ def __init__(self, values) -> None:
else:
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")

def __arrow_array__(self, type=None):
"""Convert to an arrow array. This is required for pyarrow extension."""
return pa.array(self.pa_data, type=JSONArrowType())

@classmethod
def _box_pa(
cls, value, pa_type: pa.DataType | None = None
Expand Down Expand Up @@ -208,6 +216,8 @@ def __getitem__(self, item):
value = self.pa_data[item]
if isinstance(value, pa.ChunkedArray):
return type(self)(value)
elif isinstance(value, pa.ExtensionScalar):
return value.as_py()
else:
scalar = JSONArray._deserialize_json(value.as_py())
if scalar is None:
Expand Down Expand Up @@ -244,3 +254,33 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
result[mask] = self._dtype.na_value
result[~mask] = data[~mask].pa_data.to_numpy()
return result


class JSONArrowScalar(pa.ExtensionScalar):
def as_py(self):
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)


class JSONArrowType(pa.ExtensionType):
"""Arrow extension type for the `dbjson` Pandas extension type."""

def __init__(self) -> None:
super().__init__(pa.string(), "dbjson")

def __arrow_ext_serialize__(self) -> bytes:
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
return JSONArrowType()

def to_pandas_dtype(self):
return JSONDtype()

def __arrow_ext_scalar_class__(self):
return JSONArrowScalar


# Register the type to be included in RecordBatches, sent over IPC and received in
# another Python process.
pa.register_extension_type(JSONArrowType())
4 changes: 0 additions & 4 deletions tests/compliance/json/test_json_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import pytest


class TestJSONArrayAccumulate(base.BaseAccumulateTests):
pass


class TestJSONArrayCasting(base.BaseCastingTests):
def test_astype_str(self, data):
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
Expand Down
123 changes: 122 additions & 1 deletion tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import json
import math

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import db_dtypes
Expand All @@ -36,7 +38,7 @@
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15.99,
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
},
Expand Down Expand Up @@ -114,3 +116,122 @@ def test_as_numpy_array():
]
)
pd._testing.assert_equal(result, expected)


def test_json_arrow_array():
data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values())
assert isinstance(data.__arrow_array__(), pa.ExtensionArray)


def test_json_arrow_storage_type():
arrow_json_type = db_dtypes.JSONArrowType()
assert arrow_json_type.extension_name == "dbjson"
assert pa.types.is_string(arrow_json_type.storage_type)


def test_json_arrow_constructors():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
storage_array = pa.array(data, type=pa.string())

arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
assert isinstance(arr_1, pa.ExtensionArray)

arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
assert isinstance(arr_2, pa.ExtensionArray)

assert arr_1 == arr_2


def test_json_arrow_to_pandas():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pandas()
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert pd.isna(s[6])


def test_json_arrow_to_pylist():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pylist()
assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None


def test_json_arrow_record_batch():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
sink = pa.BufferOutputStream()

with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
writer.write_batch(batch)

buf = sink.getvalue()

with pa.ipc.open_stream(buf) as reader:
result = reader.read_all()

json_col = result.column("json_col")
assert isinstance(json_col.type, db_dtypes.JSONArrowType)

s = json_col.to_pylist()

assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None

0 comments on commit d9992fc

Please sign in to comment.