Skip to content

Commit

Permalink
Merge pull request #129 from ManuelHu/pickle
Browse files Browse the repository at this point in the history
make lh5 types and exceptions picklable
  • Loading branch information
gipert authored Jan 14, 2025
2 parents f1a5297 + ef99e39 commit 269111a
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/lgdo/lh5/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __str__(self) -> str:
+ super().__str__()
)

def __reduce__(self) -> tuple: # for pickling.
return self.__class__, (*self.args, self.file, self.obj)


class LH5EncodeError(Exception):
def __init__(
Expand All @@ -32,3 +35,6 @@ def __str__(self) -> str:
f"while writing object {self.group}/{self.name} to file {self.file}: "
+ super().__str__()
)

def __reduce__(self) -> tuple: # for pickling.
return self.__class__, (*self.args, self.file, self.group, self.name)
11 changes: 8 additions & 3 deletions src/lgdo/types/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,18 @@ def fill(self, data, w: NDArray = None, keys: Sequence[str] = None) -> None:

def __setitem__(self, name: str, obj: LGDO) -> None:
# do not allow for new attributes on this
msg = "histogram fields cannot be mutated"
raise TypeError(msg)
known_keys = ("binning", "weights", "isdensity")
if name in known_keys and not dict.__contains__(self, name):
# but allow initialization while unpickling (after __init__() this is unreachable)
dict.__setitem__(self, name, obj)
else:
msg = "histogram fields cannot be mutated "
raise TypeError(msg)

def __getattr__(self, name: str) -> None:
# do not allow for new attributes on this
msg = "histogram fields cannot be mutated"
raise TypeError(msg)
raise AttributeError(msg)

def add_field(self, name: str | int, obj: LGDO) -> None: # noqa: ARG002
"""
Expand Down
6 changes: 6 additions & 0 deletions src/lgdo/types/lgdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
class LGDO(ABC):
"""Abstract base class representing a LEGEND Data Object (LGDO)."""

def __new__(cls, *_args, **_kwargs):
# allow for (un-)pickling LGDO objects.
obj = super().__new__(cls)
obj.attrs = {}
return obj

@abstractmethod
def __init__(self, attrs: dict[str, Any] | None = None) -> None:
self.attrs = {} if attrs is None else dict(attrs)
Expand Down
6 changes: 6 additions & 0 deletions src/lgdo/types/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class Table(Struct):
:meth:`__len__` to access valid data, which returns the ``size`` attribute.
"""

def __new__(cls, *args, **kwargs):
# allow for (un-)pickling LGDO objects.
obj = super().__new__(cls, *args, **kwargs)
obj.size = None
return obj

def __init__(
self,
col_dict: Mapping[str, LGDO] | pd.DataFrame | ak.Array | None = None,
Expand Down
17 changes: 17 additions & 0 deletions tests/lh5/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import pickle

from lgdo.lh5.exceptions import LH5DecodeError, LH5EncodeError


def test_pickle():
# test (un-)pickling of LH5 exceptions; e.g. for multiprocessing use.

ex = LH5EncodeError("message", "file", "group", "name")
ex = pickle.loads(pickle.dumps(ex))
assert isinstance(ex, LH5EncodeError)

ex = LH5DecodeError("message", "file", "obj")
ex = pickle.loads(pickle.dumps(ex))
assert isinstance(ex, LH5DecodeError)
13 changes: 13 additions & 0 deletions tests/types/test_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import awkward as ak
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -61,3 +63,14 @@ def test_view():

with pytest.raises(ValueError):
a.view_as("ak", with_units=True)


def test_pickle():
obj = Array(nda=np.array([1, 2, 3, 4]))
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Array)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert np.all(ex.nda == np.array([1, 2, 3, 4]))
50 changes: 50 additions & 0 deletions tests/types/test_encoded.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import awkward as ak
import awkward_pandas as akpd
import numpy as np
Expand Down Expand Up @@ -285,3 +287,51 @@ def test_aoeesa_view_as():

with pytest.raises(TypeError):
df = voev.view_as("np")


def test_aoeesa_pickle():
obj = ArrayOfEncodedEqualSizedArrays(
encoded_data=VectorOfVectors(
flattened_data=Array(nda=np.array([1, 2, 3, 4, 5, 2, 4, 8, 9, 7, 5, 3, 1])),
cumulative_length=Array(nda=np.array([2, 5, 6, 10, 13])),
),
decoded_size=99,
)

ex = pickle.loads(pickle.dumps(obj))

desired = [
[1, 2],
[3, 4, 5],
[2],
[4, 8, 9, 7],
[5, 3, 1],
]

for i, v in enumerate(ex):
assert np.array_equal(v, desired[i])


def test_voev_pickle():
obj = VectorOfEncodedVectors(
encoded_data=VectorOfVectors(
flattened_data=Array(nda=np.array([1, 2, 3, 4, 5, 2, 4, 8, 9, 7, 5, 3, 1])),
cumulative_length=Array(nda=np.array([2, 5, 6, 10, 13])),
),
decoded_size=Array(shape=5, fill_val=6),
attrs={"units": "s"},
)

ex = pickle.loads(pickle.dumps(obj))

desired = [
[1, 2],
[3, 4, 5],
[2],
[4, 8, 9, 7],
[5, 3, 1],
]

for i, (v, s) in enumerate(ex):
assert np.array_equal(v, desired[i])
assert s == 6
14 changes: 13 additions & 1 deletion tests/types/test_histogram.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import pickle

import hist
import numpy as np
Expand Down Expand Up @@ -266,7 +267,7 @@ def test_view_as_np():
def test_not_like_table():
h = Histogram(np.array([1, 1]), (np.array([0, 1, 2]),))
assert h.form_datatype() == "struct{binning,weights,isdensity}"
with pytest.raises(TypeError):
with pytest.raises(AttributeError):
x = h.x # noqa: F841
with pytest.raises(TypeError):
h["x"] = Scalar(1.0)
Expand Down Expand Up @@ -392,3 +393,14 @@ def test_histogram_fill():

with pytest.raises(ValueError, match="data must be"):
h.fill(np.ones(shape=(5, 5)))


def test_pickle():
obj = Histogram(np.array([1, 1]), (Histogram.Axis.from_range_edges([0, 1, 2]),))
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Histogram)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert np.all(ex.weights == obj.weights)
13 changes: 13 additions & 0 deletions tests/types/test_scalar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

import lgdo
Expand Down Expand Up @@ -33,3 +35,14 @@ def test_getattrs():

def test_equality():
assert lgdo.Scalar(value=42) == lgdo.Scalar(value=42)


def test_pickle():
obj = lgdo.Scalar(value=10)
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, lgdo.Scalar)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert ex.value == 10
14 changes: 14 additions & 0 deletions tests/types/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

import lgdo
Expand Down Expand Up @@ -78,3 +80,15 @@ def test_remove_field():

struct.remove_field("array1", delete=True)
assert list(struct.keys()) == []


def test_pickle():
obj_dict = {"scalar1": lgdo.Scalar(value=10)}
attrs = {"attr1": 1}
struct = lgdo.Struct(obj_dict=obj_dict, attrs=attrs)

ex = pickle.loads(pickle.dumps(struct))
assert isinstance(ex, lgdo.Struct)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == struct.attrs["datatype"]
assert ex["scalar1"].value == 10
18 changes: 18 additions & 0 deletions tests/types/test_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pickle
import warnings

import awkward as ak
Expand Down Expand Up @@ -221,3 +222,20 @@ def test_remove_column():

tbl.remove_column("c")
assert list(tbl.keys()) == ["b"]


def test_pickle():
col_dict = {
"a": lgdo.Array(nda=np.array([1, 2, 3, 4])),
"b": lgdo.Array(nda=np.array([5, 6, 7, 8])),
"c": lgdo.Array(nda=np.array([9, 10, 11, 12])),
}
obj = Table(col_dict=col_dict)
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Table)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
for key, val in col_dict.items():
assert ex[key] == val
17 changes: 17 additions & 0 deletions tests/types/test_vectorofvectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import pickle
from collections import namedtuple

import awkward as ak
Expand Down Expand Up @@ -441,3 +442,19 @@ def test_lh5_iterator_view_as(lgnd_test_data):

for obj, _, _ in it:
assert ak.is_valid(obj.view_as("ak"))


def test_pickle(testvov):
obj = testvov.v2d
ex = pickle.loads(pickle.dumps(obj))

desired = [
np.array([1, 2]),
np.array([3, 4, 5]),
np.array([2]),
np.array([4, 8, 9, 7]),
np.array([5, 3, 1]),
]

for i in range(len(desired)):
assert np.array_equal(desired[i], ex[i])

0 comments on commit 269111a

Please sign in to comment.