-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request: add pandas / numpy array support #887
Comments
The "private" methods are actually relatively stable. It's a quirk of how we chose to name the functions when syrupy was first created. I've been thinking of revisiting the naming convention there to make it more pythonic, just haven't got around to that (and it'd be a breaking change). -- I might roll this into Syrupy v5 pending my schedule/other obligations, i.e. rename some methods to make them public. |
FWIW, following code adds support for pandas dataframes, creating snapshots in either "parquet", "json" or "hdf" files and using the from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
import pandas as pd
from syrupy.extensions.single_file import SingleFileSnapshotExtension
if TYPE_CHECKING:
from syrupy.data import SnapshotCollection
from syrupy.types import SerializableData
def create_pandas_dataframe_snapshot_extension(
*, file_extension: Literal["parquet", "json", "hdf"] = "parquet", **assert_frame_equal_options: Any
) -> type[SingleFileSnapshotExtension]:
class PandasDataFrameSnapshotExtension(SingleFileSnapshotExtension):
"""Handles Pandas Snapshots.
Snapshots are stored as parquet, json or hdf files and the dataframes are compared using pandas testing methods.
"""
_file_extension = file_extension
if file_extension == "parquet":
# work-around for https://github.com/pandas-dev/pandas/issues/41543
_options = {"check_freq": False}
elif file_extension == "json":
_options = {"check_dtype": False, "check_freq": False, "check_names": False}
else:
_options = {}
_options |= assert_frame_equal_options
def matches(self, *, serialized_data: pd.DataFrame, snapshot_data: pd.DataFrame) -> bool:
try:
return pd.testing.assert_frame_equal(serialized_data, snapshot_data, **self._options) is None
except AssertionError:
return False
def _read_snapshot_data_from_location(
self,
*,
snapshot_location: str,
**kwargs: object,
) -> pd.DataFrame | None:
try:
if self._file_extension == "parquet":
return pd.read_parquet(snapshot_location)
if self._file_extension == "json":
return pd.read_json(snapshot_location, orient="table")
if self._file_extension == "hdf":
return pd.read_hdf(snapshot_location)
return NotImplementedError(f"file extension {self._file_extension} not supported")
except FileNotFoundError:
return None
@classmethod
def _write_snapshot_collection(cls, *, snapshot_collection: SnapshotCollection) -> None:
filepath = snapshot_collection.location
data: pd.DataFrame = next(iter(snapshot_collection)).data
if cls._file_extension == "parquet":
data.to_parquet(filepath)
elif cls._file_extension == "json":
data.to_json(filepath, orient="table")
elif cls._file_extension == "hdf":
data.to_hdf(filepath, key="/blah")
else:
raise NotImplementedError(f"file extension {cls._file_extension} not supported")
def serialize(self, data: SerializableData, **kwargs: object) -> str:
return data
def diff_lines(self, serialized_data: pd.DataFrame, snapshot_data: pd.DataFrame) -> list[str]:
try:
pd.testing.assert_frame_equal(serialized_data, snapshot_data, **self._options)
except AssertionError as e:
return str(e).split("\n")
return PandasDataFrameSnapshotExtension Example usage: import pandas as pd
import pytest
from .testing import create_pandas_dataframe_snapshot_extension
@pytest.fixture
def assert_frame_equal_snapshot(snapshot):
def assert_frame_equal_snapshot_factory(**kwargs):
return snapshot.use_extension(create_pandas_dataframe_snapshot_extension(**kwargs))
return assert_frame_equal_snapshot_factory
def test_pd(assert_frame_equal_snapshot):
index = pd.date_range("2022-1-1", freq="h", periods=6).set_names("time")
# also MultiIndex index is supported, e.g.
# index = pd.MultiIndex.from_product([["A", "B"], ["One", "Two", "Three"]], names=["first", "second"])
columns = pd.Index(["a", "b"], name="letters")
df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6], "b": [1.2, 2.3, 3.4, 4.5, 5.6, 6.7]}, index=index, columns=columns)
df = df.astype({"b": "float32"})
assert assert_frame_equal_snapshot(file_extension="parquet", atol=1e-3) == df |
Is your feature request related to a problem? Please describe.
I'm trying to use snapshots with pandas dataframes but it seems difficult to build my own extension separately from syrupy and use it across multiple projects. It also seems less than ideal to serialize the dataframe as text since this loses typing and precision information.
Describe the solution you'd like
I'd like to be able to do this:
Describe alternatives you've considered
Pandas allows for conversion to other formats, e.g. csv/json. Serializing to these formats is not very readable and can lose information (e.g. CSV does not contain types), and does not support custom python objects.
Additional context
I've seen #786 and the code that that commenter wrote, but to use that extension you have to override private methods, and that doesn't seem ideal - if the private methods change, then all the projects using that extension are now broken. Additionally, there's no easy way to import that code, it'd have to be maintained across multiple projects.
The text was updated successfully, but these errors were encountered: