Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State pickle compression #4430

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions benchmarks/benchmark_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Benchmarks for pickling and unpickling states."""

import logging
import pickle
import time
import uuid
from typing import Tuple

import pytest
from pytest_benchmark.fixture import BenchmarkFixture
from redis import Redis

from reflex.state import State
from reflex.utils.prerequisites import get_redis_sync

log = logging.getLogger(__name__)


SLOW_REDIS_MAP: dict[bytes, bytes] = {}


class SlowRedis:
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""

def __init__(self):
"""Initialize the slow Redis client."""
pass

def set(self, key: bytes, value: bytes) -> None:
"""Set a key-value pair in the slow Redis client.

Args:
key: The key.
value: The value.
"""
SLOW_REDIS_MAP[key] = value
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)

def get(self, key: bytes) -> bytes:
"""Get a value from the slow Redis client.

Args:
key: The key.

Returns:
The value.
"""
value = SLOW_REDIS_MAP[key]
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
return value


@pytest.mark.parametrize(
"protocol",
argvalues=[
pickle.DEFAULT_PROTOCOL,
pickle.HIGHEST_PROTOCOL,
],
ids=[
"pickle_default",
"pickle_highest",
],
)
@pytest.mark.parametrize(
"redis",
[
Redis,
SlowRedis,
None,
],
ids=[
"redis",
"slow_redis",
"no_redis",
],
)
@pytest.mark.parametrize(
"should_compress", [True, False], ids=["compress", "no_compress"]
)
@pytest.mark.benchmark(disable_gc=True)
def test_pickle(
request: pytest.FixtureRequest,
benchmark: BenchmarkFixture,
big_state: State,
big_state_size: Tuple[int, str],
protocol: int,
redis: Redis | SlowRedis | None,
should_compress: bool,
) -> None:
"""Benchmark pickling a big state.

Args:
request: The pytest fixture request object.
benchmark: The benchmark fixture.
big_state: The big state fixture.
big_state_size: The big state size fixture.
protocol: The pickle protocol.
redis: Whether to use Redis.
should_compress: Whether to compress the pickled state.
"""
if should_compress:
try:
from blosc2 import compress, decompress
except ImportError:
pytest.skip("Blosc is not available.")

def dump(obj: State) -> bytes:
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]

def load(data: bytes) -> State:
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]

else:

def dump(obj: State) -> bytes:
return pickle.dumps(obj, protocol=protocol)

def load(data: bytes) -> State:
return pickle.loads(data)

if redis:
if redis == Redis:
redis_client = get_redis_sync()
if redis_client is None:
pytest.skip("Redis is not available.")
else:
redis_client = SlowRedis()

key = str(uuid.uuid4()).encode()

def run(obj: State) -> None:
_ = redis_client.set(key, dump(obj))
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]

else:

def run(obj: State) -> None:
_ = load(dump(obj))

# calculate size before benchmark to not affect it
out = dump(big_state)
size = len(out)
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")

benchmark.extra_info["size"] = size
benchmark.extra_info["redis"] = redis
benchmark.extra_info["pickle_protocol"] = protocol
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
benchmark.group = f"{redis_group}_{big_state_size[1]}"

_ = benchmark(run, big_state)
44 changes: 44 additions & 0 deletions benchmarks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Shared conftest for all benchmark tests."""

from typing import Tuple

import pandas as pd
import pytest

from reflex.state import State
from reflex.testing import AppHarness, AppHarnessProd


Expand All @@ -18,3 +22,43 @@ def app_harness_env(request):
The AppHarness class to use for the test.
"""
return request.param


@pytest.fixture(params=[(10, "SmallState"), (2000, "BigState")], ids=["small", "big"])
def big_state_size(request: pytest.FixtureRequest) -> int:
"""The size of the DataFrame.

Args:
request: The pytest fixture request object.

Returns:
The size of the BigState
"""
return request.param


@pytest.fixture
def big_state(big_state_size: Tuple[int, str]) -> State:
"""A big state with a dictionary and a list of DataFrames.

Args:
big_state_size: The size of the big state.

Returns:
A big state instance.
"""
size, _ = big_state_size

class BigState(State):
"""A big state."""

d: dict[str, int]
d_repeated: dict[str, int]
df: list[pd.DataFrame]

d = {str(i): i for i in range(size)}
d_repeated = {str(i): i for i in range(size)}
df = [pd.DataFrame({"a": [i]}) for i in range(size)]

state = BigState(d=d, df=df, d_repeated=d_repeated)
return state
Loading
Loading