Skip to content

Commit

Permalink
Merge pull request #116 from iguinn/nested_field_mask
Browse files Browse the repository at this point in the history
Nested field mask
  • Loading branch information
gipert authored Nov 21, 2024
2 parents 9d1ad8f + 487fc1a commit 0c68c60
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 42 deletions.
40 changes: 12 additions & 28 deletions src/lgdo/lh5/_serializers/read/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import bisect
import logging
import sys
from collections import defaultdict

import h5py
import numpy as np
Expand Down Expand Up @@ -72,19 +71,8 @@ def _h5_read_lgdo(
obj_buf=obj_buf,
)

# check field_mask and make it a default dict
if field_mask is None:
field_mask = defaultdict(lambda: True)
elif isinstance(field_mask, dict):
default = True
if len(field_mask) > 0:
default = not field_mask[next(iter(field_mask.keys()))]
field_mask = defaultdict(lambda: default, field_mask)
elif isinstance(field_mask, (list, tuple, set)):
field_mask = defaultdict(bool, {field: True for field in field_mask})
elif not isinstance(field_mask, defaultdict):
msg = "bad field_mask type"
raise ValueError(msg, type(field_mask).__name__)
# Convert whatever we input into a defaultdict
field_mask = utils.build_field_mask(field_mask)

if lgdotype is Struct:
return _h5_read_struct(
Expand Down Expand Up @@ -246,18 +234,16 @@ def _h5_read_struct(

# determine fields to be read out
all_fields = dtypeutils.get_struct_fields(attrs["datatype"])
selected_fields = (
[field for field in all_fields if field_mask[field]]
if field_mask is not None
else all_fields
)
selected_fields = utils.eval_field_mask(field_mask, all_fields)

# modify datatype in attrs if a field_mask was used
attrs["datatype"] = "struct{" + ",".join(selected_fields) + "}"
attrs["datatype"] = (
"struct{" + ",".join(field for field, _ in selected_fields) + "}"
)

# loop over fields and read
obj_dict = {}
for field in selected_fields:
for field, submask in selected_fields:
# support for integer keys
field_key = int(field) if attrs.get("int_keys") else str(field)
h5o = h5py.h5o.open(h5g, field.encode("utf-8"))
Expand All @@ -269,6 +255,7 @@ def _h5_read_struct(
n_rows=n_rows,
idx=idx,
use_h5idx=use_h5idx,
field_mask=submask,
decompress=decompress,
)
h5o.close()
Expand Down Expand Up @@ -297,19 +284,15 @@ def _h5_read_table(

# determine fields to be read out
all_fields = dtypeutils.get_struct_fields(attrs["datatype"])
selected_fields = (
[field for field in all_fields if field_mask[field]]
if field_mask is not None
else all_fields
)
selected_fields = utils.eval_field_mask(field_mask, all_fields)

# modify datatype in attrs if a field_mask was used
attrs["datatype"] = "table{" + ",".join(selected_fields) + "}"
attrs["datatype"] = "table{" + ",".join(field for field, _ in selected_fields) + "}"

# read out each of the fields
col_dict = {}
rows_read = []
for field in selected_fields:
for field, submask in selected_fields:
fld_buf = None
if obj_buf is not None:
if not isinstance(obj_buf, Table) or field not in obj_buf:
Expand All @@ -329,6 +312,7 @@ def _h5_read_table(
use_h5idx=use_h5idx,
obj_buf=fld_buf,
obj_buf_start=obj_buf_start,
field_mask=submask,
decompress=decompress,
)
h5o.close()
Expand Down
78 changes: 64 additions & 14 deletions src/lgdo/lh5/_serializers/read/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Collection, Mapping

import h5py
import numpy as np
Expand All @@ -22,6 +24,56 @@ def check_obj_buf_attrs(attrs, new_attrs, fname, oname):
raise LH5DecodeError(msg, fname, oname)


def build_field_mask(field_mask: Mapping[str, bool] | Collection[str]) -> defaultdict:
# check field_mask and make it a default dict
if field_mask is None:
return defaultdict(lambda: True)
if isinstance(field_mask, dict):
default = True
if len(field_mask) > 0:
default = not field_mask[next(iter(field_mask.keys()))]
return defaultdict(lambda: default, field_mask)
if isinstance(field_mask, (list, tuple, set)):
return defaultdict(bool, {field: True for field in field_mask})
if isinstance(field_mask, defaultdict):
return field_mask
msg = "bad field_mask type"
raise ValueError(msg, type(field_mask).__name__)


def eval_field_mask(
field_mask: defaultdict, all_fields: list[str]
) -> list[tuple(str, defaultdict)]:
"""Get list of fields that need to be loaded along with a sub-field-mask
in case we have a nested Table"""

if field_mask is None:
return all_fields

this_field_mask = defaultdict(field_mask.default_factory)
sub_field_masks = {}

for key, val in field_mask.items():
field = key.strip("/")
pos = field.find("/")
if pos < 0:
this_field_mask[field] = val
else:
sub_field = field[pos + 1 :]
field = field[:pos]
this_field_mask[field] = True
sub_mask = sub_field_masks.setdefault(
field, defaultdict(field_mask.default_factory)
)
sub_mask[sub_field] = val

return [
(field, sub_field_masks.get(field))
for field in all_fields
if this_field_mask[field]
]


def read_attrs(h5o, fname, oname):
"""Read all attributes for an hdf5 dataset or group using low level API
and return them as a dict. Assume all are strings or scalar types."""
Expand Down Expand Up @@ -114,6 +166,7 @@ def read_size_in_bytes(h5o, fname, oname, field_mask=None):
h5a.read(type_attr)
type_attr = type_attr.item().decode()
lgdotype = datatype.datatype(type_attr)
field_mask = build_field_mask(field_mask)

# scalars are dim-0 datasets
if lgdotype in (
Expand All @@ -124,24 +177,21 @@ def read_size_in_bytes(h5o, fname, oname, field_mask=None):
):
return int(np.prod(h5o.shape) * h5o.dtype.itemsize)

# structs don't have rows
if lgdotype in (types.Struct, types.Histogram, types.Histogram.Axis):
size = 0
for key in h5o:
obj = h5py.h5o.open(h5o, key)
size += read_size_in_bytes(obj, fname, oname, field_mask)
obj.close()
return size

# tables should have elements with all the same length
if lgdotype in (types.Table, types.WaveformTable):
if lgdotype in (
types.Struct,
types.Histogram,
types.Histogram.Axis,
types.Table,
types.WaveformTable,
):
# read out each of the fields
size = 0
if not field_mask:
field_mask = datatype.get_struct_fields(type_attr)
for field in field_mask:
all_fields = datatype.get_struct_fields(type_attr)
selected_fields = eval_field_mask(field_mask, all_fields)
for field, submask in selected_fields:
obj = h5py.h5o.open(h5o, field.encode())
size += read_size_in_bytes(obj, fname, field)
size += read_size_in_bytes(obj, fname, field, submask)
obj.close()
return size

Expand Down

0 comments on commit 0c68c60

Please sign in to comment.