From bd61bc296f54c4b6f77d856c8a1e4434686bc388 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 9 Jan 2025 03:22:57 +0000 Subject: [PATCH 1/4] feat: add ArrowJSONtype to extend pyarrow for JSONDtype --- db_dtypes/__init__.py | 7 +++--- db_dtypes/json.py | 51 ++++++++++++++++++++++++++++++++++++++++- tests/unit/test_json.py | 35 ++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 4 deletions(-) diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 952643b..6ce652d 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -30,8 +30,8 @@ from db_dtypes import core from db_dtypes.version import __version__ -from . import _versions_helpers +from . import _versions_helpers date_dtype_name = "dbdate" time_dtype_name = "dbtime" @@ -50,7 +50,7 @@ # To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal # of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0. if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"): - from db_dtypes.json import JSONArray, JSONDtype + from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype else: JSONArray = None JSONDtype = None @@ -359,7 +359,7 @@ def __sub__(self, other): ) -if not JSONArray or not JSONDtype: +if not JSONArray or not JSONDtype or not ArrowJSONType: __all__ = [ "__version__", "DateArray", @@ -370,6 +370,7 @@ def __sub__(self, other): else: __all__ = [ "__version__", + "ArrowJSONType", "DateArray", "DateDtype", "JSONDtype", diff --git a/db_dtypes/json.py b/db_dtypes/json.py index c43ebc2..872ebe4 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -64,6 +64,10 @@ def construct_array_type(cls): """Return the array type associated with this dtype.""" return JSONArray + def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray: + """Convert the pyarrow array to the extension array.""" + return JSONArray(array) + class JSONArray(arrays.ArrowExtensionArray): """Extension array that handles BigQuery JSON data, leveraging a string-based @@ -92,6 +96,10 @@ def __init__(self, values) -> None: else: raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}") + def __arrow_array__(self): + """Convert to an arrow array. This is required for pyarrow extension.""" + return self.pa_data + @classmethod def _box_pa( cls, value, pa_type: pa.DataType | None = None @@ -151,7 +159,12 @@ def _serialize_json(value): def _deserialize_json(value): """A static method that converts a JSON string back into its original value.""" if not pd.isna(value): - return json.loads(value) + # Attempt to interpret the value as a JSON object. + # If it's not valid JSON, treat it as a regular string. + try: + return json.loads(value) + except json.JSONDecodeError: + return value else: return value @@ -244,3 +257,39 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray: result[mask] = self._dtype.na_value result[~mask] = data[~mask].pa_data.to_numpy() return result + + +class ArrowJSONType(pa.ExtensionType): + """Arrow extension type for the `dbjson` Pandas extension type.""" + + def __init__(self) -> None: + super().__init__(pa.string(), "dbjson") + + def __arrow_ext_serialize__(self) -> bytes: + # No parameters are necessary + return b"" + + def __eq__(self, other): + if isinstance(other, pyarrow.BaseExtensionType): + return type(self) == type(other) + else: + return NotImplemented + + def __ne__(self, other) -> bool: + return not self == other + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType: + # return an instance of this subclass + return ArrowJSONType() + + def __hash__(self) -> int: + return hash(str(self)) + + def to_pandas_dtype(self): + return JSONDtype() + + +# Register the type to be included in RecordBatches, sent over IPC and received in +# another Python process. +pa.register_extension_type(ArrowJSONType()) diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 112b50c..750ddbc 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest import db_dtypes @@ -114,3 +115,37 @@ def test_as_numpy_array(): ] ) pd._testing.assert_equal(result, expected) + + +def test_arrow_json_storage_type(): + arrow_json_type = db_dtypes.ArrowJSONType() + assert arrow_json_type.extension_name == "dbjson" + assert pa.types.is_string(arrow_json_type.storage_type) + + +def test_arrow_json_constructors(): + storage_array = pa.array( + ["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() + ) + arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array) + assert isinstance(arr_1, pa.ExtensionArray) + + arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array) + assert isinstance(arr_2, pa.ExtensionArray) + + assert arr_1 == arr_2 + + +def test_arrow_json_to_pandas(): + storage_array = pa.array( + [None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() + ) + arr = db_dtypes.ArrowJSONType().wrap_array(storage_array) + + s = arr.to_pandas() + assert isinstance(s.dtypes, db_dtypes.JSONDtype) + assert pd.isna(s[0]) + assert s[1] == 0 + assert s[2] == "str" + assert s[3]["b"] == 2 + assert s[4]["a"] == [1, 2, 3] From 8e7f62cee3e70b29be992e12b026bc5b040ecf62 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 9 Jan 2025 21:42:12 +0000 Subject: [PATCH 2/4] nit --- db_dtypes/json.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/db_dtypes/json.py b/db_dtypes/json.py index 872ebe4..ddbf7c7 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -269,15 +269,6 @@ def __arrow_ext_serialize__(self) -> bytes: # No parameters are necessary return b"" - def __eq__(self, other): - if isinstance(other, pyarrow.BaseExtensionType): - return type(self) == type(other) - else: - return NotImplemented - - def __ne__(self, other) -> bool: - return not self == other - @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType: # return an instance of this subclass From fdad61eb2e7936b94157ad1f6668762e128b4074 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 16 Jan 2025 18:16:01 +0000 Subject: [PATCH 3/4] add JSONArrowScalar --- db_dtypes/__init__.py | 7 +- db_dtypes/json.py | 20 ++- tests/compliance/json/test_json_compliance.py | 4 - tests/unit/test_json.py | 119 +++++++++++++++--- 4 files changed, 118 insertions(+), 32 deletions(-) diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 6ce652d..d5b05dc 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -50,7 +50,7 @@ # To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal # of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0. if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"): - from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype + from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype else: JSONArray = None JSONDtype = None @@ -359,7 +359,7 @@ def __sub__(self, other): ) -if not JSONArray or not JSONDtype or not ArrowJSONType: +if not JSONArray or not JSONDtype: __all__ = [ "__version__", "DateArray", @@ -370,11 +370,12 @@ def __sub__(self, other): else: __all__ = [ "__version__", - "ArrowJSONType", "DateArray", "DateDtype", "JSONDtype", "JSONArray", + "JSONArrowType", + "JSONArrowScalar", "TimeArray", "TimeDtype", ] diff --git a/db_dtypes/json.py b/db_dtypes/json.py index ddbf7c7..d08d1cb 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -221,6 +221,8 @@ def __getitem__(self, item): value = self.pa_data[item] if isinstance(value, pa.ChunkedArray): return type(self)(value) + elif isinstance(value, pa.ExtensionScalar): + return value.as_py() else: scalar = JSONArray._deserialize_json(value.as_py()) if scalar is None: @@ -259,20 +261,23 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray: return result -class ArrowJSONType(pa.ExtensionType): +class JSONArrowScalar(pa.ExtensionScalar): + def as_py(self): + return JSONArray._deserialize_json(self.value.as_py() if self.value else None) + + +class JSONArrowType(pa.ExtensionType): """Arrow extension type for the `dbjson` Pandas extension type.""" def __init__(self) -> None: super().__init__(pa.string(), "dbjson") def __arrow_ext_serialize__(self) -> bytes: - # No parameters are necessary return b"" @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType: - # return an instance of this subclass - return ArrowJSONType() + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType: + return JSONArrowType() def __hash__(self) -> int: return hash(str(self)) @@ -280,7 +285,10 @@ def __hash__(self) -> int: def to_pandas_dtype(self): return JSONDtype() + def __arrow_ext_scalar_class__(self): + return JSONArrowScalar + # Register the type to be included in RecordBatches, sent over IPC and received in # another Python process. -pa.register_extension_type(ArrowJSONType()) +pa.register_extension_type(JSONArrowType()) diff --git a/tests/compliance/json/test_json_compliance.py b/tests/compliance/json/test_json_compliance.py index 2a8e69a..9a0d0ef 100644 --- a/tests/compliance/json/test_json_compliance.py +++ b/tests/compliance/json/test_json_compliance.py @@ -22,10 +22,6 @@ import pytest -class TestJSONArrayAccumulate(base.BaseAccumulateTests): - pass - - class TestJSONArrayCasting(base.BaseCastingTests): def test_astype_str(self, data): # Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method. diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 750ddbc..949f1bd 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import math import numpy as np import pandas as pd @@ -37,7 +38,7 @@ "null_field": None, "order": { "items": ["book", "pen", "computer"], - "total": 15.99, + "total": 15, "address": {"street": "123 Main St", "city": "Anytown"}, }, }, @@ -117,35 +118,115 @@ def test_as_numpy_array(): pd._testing.assert_equal(result, expected) -def test_arrow_json_storage_type(): - arrow_json_type = db_dtypes.ArrowJSONType() +def test_json_arrow_storage_type(): + arrow_json_type = db_dtypes.JSONArrowType() assert arrow_json_type.extension_name == "dbjson" assert pa.types.is_string(arrow_json_type.storage_type) -def test_arrow_json_constructors(): - storage_array = pa.array( - ["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() - ) - arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array) +def test_json_arrow_constructors(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + storage_array = pa.array(data, type=pa.string()) + + arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array) assert isinstance(arr_1, pa.ExtensionArray) - arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array) + arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array) assert isinstance(arr_2, pa.ExtensionArray) assert arr_1 == arr_2 -def test_arrow_json_to_pandas(): - storage_array = pa.array( - [None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() - ) - arr = db_dtypes.ArrowJSONType().wrap_array(storage_array) +def test_json_arrow_to_pandas(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) s = arr.to_pandas() assert isinstance(s.dtypes, db_dtypes.JSONDtype) - assert pd.isna(s[0]) - assert s[1] == 0 - assert s[2] == "str" - assert s[3]["b"] == 2 - assert s[4]["a"] == [1, 2, 3] + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert pd.isna(s[6]) + + +def test_json_arrow_to_pylist(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + + s = arr.to_pylist() + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None + + +def test_json_arrow_record_batch(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + batch = pa.RecordBatch.from_arrays([arr], ["json_col"]) + sink = pa.BufferOutputStream() + + with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: + writer.write_batch(batch) + + buf = sink.getvalue() + + with pa.ipc.open_stream(buf) as reader: + result = reader.read_all() + + json_col = result.column("json_col") + assert isinstance(json_col.type, db_dtypes.JSONArrowType) + + s = json_col.to_pylist() + + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None From 14a6dcdd0312ef06c5dcb206af962ebe63ded3ca Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 17 Jan 2025 23:25:44 +0000 Subject: [PATCH 4/4] fix cover --- db_dtypes/json.py | 14 +++----------- tests/unit/test_json.py | 5 +++++ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/db_dtypes/json.py b/db_dtypes/json.py index d08d1cb..145eec3 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -96,9 +96,9 @@ def __init__(self, values) -> None: else: raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}") - def __arrow_array__(self): + def __arrow_array__(self, type=None): """Convert to an arrow array. This is required for pyarrow extension.""" - return self.pa_data + return pa.array(self.pa_data, type=JSONArrowType()) @classmethod def _box_pa( @@ -159,12 +159,7 @@ def _serialize_json(value): def _deserialize_json(value): """A static method that converts a JSON string back into its original value.""" if not pd.isna(value): - # Attempt to interpret the value as a JSON object. - # If it's not valid JSON, treat it as a regular string. - try: - return json.loads(value) - except json.JSONDecodeError: - return value + return json.loads(value) else: return value @@ -279,9 +274,6 @@ def __arrow_ext_serialize__(self) -> bytes: def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType: return JSONArrowType() - def __hash__(self) -> int: - return hash(str(self)) - def to_pandas_dtype(self): return JSONDtype() diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 949f1bd..055eef0 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -118,6 +118,11 @@ def test_as_numpy_array(): pd._testing.assert_equal(result, expected) +def test_json_arrow_array(): + data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values()) + assert isinstance(data.__arrow_array__(), pa.ExtensionArray) + + def test_json_arrow_storage_type(): arrow_json_type = db_dtypes.JSONArrowType() assert arrow_json_type.extension_name == "dbjson"