Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request: add pandas / numpy array support #887

Open
guidopetri opened this issue Aug 29, 2024 · 2 comments
Open

Request: add pandas / numpy array support #887

guidopetri opened this issue Aug 29, 2024 · 2 comments
Labels
Milestone

Comments

@guidopetri
Copy link

Is your feature request related to a problem? Please describe.

I'm trying to use snapshots with pandas dataframes but it seems difficult to build my own extension separately from syrupy and use it across multiple projects. It also seems less than ideal to serialize the dataframe as text since this loses typing and precision information.

Describe the solution you'd like

I'd like to be able to do this:

assert pd.DataFrame(['test']) == snapshot

Describe alternatives you've considered

Pandas allows for conversion to other formats, e.g. csv/json. Serializing to these formats is not very readable and can lose information (e.g. CSV does not contain types), and does not support custom python objects.

Additional context

I've seen #786 and the code that that commenter wrote, but to use that extension you have to override private methods, and that doesn't seem ideal - if the private methods change, then all the projects using that extension are now broken. Additionally, there's no easy way to import that code, it'd have to be maintained across multiple projects.

@noahnu noahnu added feature request New feature or request tool compatibility labels Sep 7, 2024
@noahnu noahnu added this to the syrupy/5.0.0 milestone Feb 14, 2025
@noahnu noahnu self-assigned this Feb 16, 2025
@noahnu noahnu removed this from the syrupy/5.0.0 milestone Feb 17, 2025
@noahnu noahnu removed their assignment Feb 17, 2025
@noahnu
Copy link
Collaborator

noahnu commented Feb 17, 2025

I've seen #786 and the code that that commenter wrote, but to use that extension you have to override private methods, and that doesn't seem ideal - if the private methods change, then all the projects using that extension are now broken.

The "private" methods are actually relatively stable. It's a quirk of how we chose to name the functions when syrupy was first created. I've been thinking of revisiting the naming convention there to make it more pythonic, just haven't got around to that (and it'd be a breaking change). -- I might roll this into Syrupy v5 pending my schedule/other obligations, i.e. rename some methods to make them public.

@noahnu noahnu added this to the syrupy/5.0.0 milestone Feb 17, 2025
@kdebrab
Copy link

kdebrab commented Feb 26, 2025

FWIW, following code adds support for pandas dataframes, creating snapshots in either "parquet", "json" or "hdf" files and using the pandas.testing.assert_frame_equal method, including support of its arguments like atol (extended from https://github.com/Roestlab/massdash/blob/dev/massdash/testing/PandasSnapshotExtension.py):

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import pandas as pd
from syrupy.extensions.single_file import SingleFileSnapshotExtension

if TYPE_CHECKING:
    from syrupy.data import SnapshotCollection
    from syrupy.types import SerializableData


def create_pandas_dataframe_snapshot_extension(
    *, file_extension: Literal["parquet", "json", "hdf"] = "parquet", **assert_frame_equal_options: Any
) -> type[SingleFileSnapshotExtension]:
    class PandasDataFrameSnapshotExtension(SingleFileSnapshotExtension):
        """Handles Pandas Snapshots.

        Snapshots are stored as parquet, json or hdf files and the dataframes are compared using pandas testing methods.
        """

        _file_extension = file_extension
        if file_extension == "parquet":
            # work-around for https://github.com/pandas-dev/pandas/issues/41543
            _options = {"check_freq": False}
        elif file_extension == "json":
            _options = {"check_dtype": False, "check_freq": False, "check_names": False}
        else:
            _options = {}
        _options |= assert_frame_equal_options

        def matches(self, *, serialized_data: pd.DataFrame, snapshot_data: pd.DataFrame) -> bool:
            try:
                return pd.testing.assert_frame_equal(serialized_data, snapshot_data, **self._options) is None
            except AssertionError:
                return False

        def _read_snapshot_data_from_location(
            self,
            *,
            snapshot_location: str,
            **kwargs: object,
        ) -> pd.DataFrame | None:
            try:
                if self._file_extension == "parquet":
                    return pd.read_parquet(snapshot_location)
                if self._file_extension == "json":
                    return pd.read_json(snapshot_location, orient="table")
                if self._file_extension == "hdf":
                    return pd.read_hdf(snapshot_location)
                return NotImplementedError(f"file extension {self._file_extension} not supported")
            except FileNotFoundError:
                return None

        @classmethod
        def _write_snapshot_collection(cls, *, snapshot_collection: SnapshotCollection) -> None:
            filepath = snapshot_collection.location
            data: pd.DataFrame = next(iter(snapshot_collection)).data
            if cls._file_extension == "parquet":
                data.to_parquet(filepath)
            elif cls._file_extension == "json":
                data.to_json(filepath, orient="table")
            elif cls._file_extension == "hdf":
                data.to_hdf(filepath, key="/blah")
            else:
                raise NotImplementedError(f"file extension {cls._file_extension} not supported")

        def serialize(self, data: SerializableData, **kwargs: object) -> str:
            return data

        def diff_lines(self, serialized_data: pd.DataFrame, snapshot_data: pd.DataFrame) -> list[str]:
            try:
                pd.testing.assert_frame_equal(serialized_data, snapshot_data, **self._options)
            except AssertionError as e:
                return str(e).split("\n")

    return PandasDataFrameSnapshotExtension

Example usage:

import pandas as pd
import pytest

from .testing import create_pandas_dataframe_snapshot_extension


@pytest.fixture
def assert_frame_equal_snapshot(snapshot):
    def assert_frame_equal_snapshot_factory(**kwargs):
        return snapshot.use_extension(create_pandas_dataframe_snapshot_extension(**kwargs))

    return assert_frame_equal_snapshot_factory


def test_pd(assert_frame_equal_snapshot):
    index = pd.date_range("2022-1-1", freq="h", periods=6).set_names("time")
    # also MultiIndex index is supported, e.g.
    # index = pd.MultiIndex.from_product([["A", "B"], ["One", "Two", "Three"]], names=["first", "second"])
    columns = pd.Index(["a", "b"], name="letters")
    df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6], "b": [1.2, 2.3, 3.4, 4.5, 5.6, 6.7]}, index=index, columns=columns)
    df = df.astype({"b": "float32"})
    assert assert_frame_equal_snapshot(file_extension="parquet", atol=1e-3) == df

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants