Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

helpers for datadict #392

Merged
merged 12 commits into from
May 31, 2023
114 changes: 79 additions & 35 deletions plottr/data/datadict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import copy as cp
import re
import logging
import pandas as pd
import numpy as np
from functools import reduce
Expand All @@ -18,11 +19,12 @@
__license__ = 'MIT'


# TODO: functionality that returns axes values given a set of slices.
# TODO: an easier way to access data and meta values.
# maybe with getattr/setattr?
logger = logging.getLogger(__name__)


# TODO: direct slicing of full datasets. implement getitem/setitem?
# TODO: feature to compare if datadicts are equal not fully tested yet.
# TODO: it may be more sophisticated do define a dataclass for a data field?
# (or some other dedicated object)


def is_meta_key(key: str) -> bool:
Expand Down Expand Up @@ -82,6 +84,7 @@ class DataDictBase(dict):

def __init__(self, **kw: Any):
super().__init__(self, **kw)
self.d_ = DataDictBase._DataAccess(self)

def __eq__(self, other: object) -> bool:
"""Check for content equality of two datadicts."""
Expand Down Expand Up @@ -332,6 +335,7 @@ def extract(self: T, data: List[str], include_meta: bool = True,
else:
data = data.copy()

# include all the axes used by the data.
for d in data:
for a in self.axes(d):
if a not in data:
Expand Down Expand Up @@ -426,7 +430,7 @@ def structure(self: T, add_shape: bool = False,
for n, v in self.data_items():
if n not in remove_data:
v2 = v.copy()
v2.pop('values')
v2['values'] = []
s[n] = cp.deepcopy(v2)
if 'axes' in s[n]:
for r in remove_data:
Expand Down Expand Up @@ -578,6 +582,8 @@ def validate(self) -> bool:
:return: ``True`` if valid, ``False`` if invalid.
:raises: ``ValueError`` if invalid.
"""
self._update_data_access()

msg = '\n'
for n, v in self.data_items():

Expand Down Expand Up @@ -618,7 +624,7 @@ def remove_unused_axes(self: T) -> T:
"""
dependents = self.dependents()
unused = []
ret = self.copy()
ret = self #.copy()

for n, v in self.data_items():
used = False
Expand Down Expand Up @@ -673,61 +679,93 @@ def reorder_axes(self: T, data_names: Union[str, Sequence[str], None] = None,
:param pos: New axes position in the form ``axis_name = new_position``.
Non-specified axes positions are adjusted automatically.

:return: Dataset with re-ordered axes.
:return: Dataset with re-ordered axes (not a copy)
"""
if data_names is None:
data_names = self.dependents()
if isinstance(data_names, str):
data_names = [data_names]

ret = self.copy()
for n in data_names:
neworder, newaxes = self.reorder_axes_indices(n, **pos)
ret[n]['axes'] = newaxes
self[n]['axes'] = newaxes

ret.validate()
return ret
self.validate()
return self

def copy(self: T) -> T:
"""
Make a copy of the dataset.

:return: A copy of the dataset.
"""
return cp.deepcopy(self)
logger.debug(f'copying a dataset with size {self.nbytes()}')
ret = self.structure()
assert ret is not None

for k, v in self.data_items():
ret[k]['values'] = self.data_vals(k).copy()
return ret

def astype(self: T, dtype: np.dtype) -> T:
"""
Convert all data values to given dtype.

:param dtype: np dtype.
:return: Copy of the dataset, with values as given type.
:return: Dataset, with values as given type (not a copy)
"""
ret = self.copy()
for k, v in ret.data_items():
for k, v in self.data_items():
vals = v['values']
if type(v['values']) not in [np.ndarray, np.ma.core.MaskedArray]:
vals = np.array(v['values'])
ret[k]['values'] = vals.astype(dtype)
self[k]['values'] = vals.astype(dtype)

return ret
return self

def mask_invalid(self: T) -> T:
"""
Mask all invalid data in all values.
:return: Copy of the dataset with invalid entries (nan/None) masked.
"""
ret = self.copy()
for d, _ in self.data_items():
arr = self.data_vals(d)
vals = np.ma.masked_where(num.is_invalid(arr), arr, copy=True)
try:
vals.fill_value = np.nan
except TypeError:
vals.fill_value = -9999
ret[d]['values'] = vals
self[d]['values'] = vals

return ret
return self

class _DataAccess:
def __init__(self, parent: "DataDictBase") -> None:
self._parent = parent

def __getattribute__(self, __name: str) -> Any:
parent = super(DataDictBase._DataAccess, self).__getattribute__('_parent')

if __name in [k for k, _ in parent.data_items()]:
return parent.data_vals(__name)
else:
return super(DataDictBase._DataAccess, self).__getattribute__(__name)

def __setattr__(self, __name: str, __value: Any) -> None:
# this check: make sure that we can set the parent correctly in the
# constructor.
if hasattr(self, '_parent'):
if __name in [k for k, _ in self._parent.data_items()]:
self._parent[__name]['values'] = __value

# still allow setting random things, essentially.
else:
super(DataDictBase._DataAccess, self).__setattr__(__name, __value)
else:
super(DataDictBase._DataAccess, self).__setattr__(__name, __value)

def _update_data_access(self) -> None:
for d, i in self.data_items():
self.d_.__dict__[d] = None


class DataDict(DataDictBase):
Expand Down Expand Up @@ -892,7 +930,7 @@ def expand(self) -> 'DataDict':
ret = DataDict(**struct)

if self.is_expanded():
return self.copy()
return self

ishp = self._inner_shapes()
size = max([int(np.prod(s)) for s in ishp.values()])
Expand Down Expand Up @@ -964,8 +1002,6 @@ def remove_invalid_entries(self) -> 'DataDict':
ishp = self._inner_shapes()
idxs = []

ret = self.copy()

# collect rows that are completely invalid
for d in self.dependents():

Expand Down Expand Up @@ -1001,10 +1037,10 @@ def remove_invalid_entries(self) -> 'DataDict':
if len(idxs) > 0:
remove_idxs = reduce(np.intersect1d,
tuple(np.array(idxs).astype(int)))
for k, v in ret.data_items():
for k, v in self.data_items():
v['values'] = np.delete(v['values'], remove_idxs, axis=0)

return ret
return self


class MeshgridDataDict(DataDictBase):
Expand Down Expand Up @@ -1129,19 +1165,23 @@ def reorder_axes(self, data_names: Union[str, Sequence[str], None] = None,
data_names = [data_names]

transposed = []
ret: "MeshgridDataDict" = self.copy()
orders = {}
orig_axes = {}
for n in data_names:
orders[n] = self.reorder_axes_indices(n, **pos)
orig_axes[n] = self.axes(n).copy()

for n in data_names:
neworder, newaxes = self.reorder_axes_indices(n, **pos)
ret[n]['axes'] = newaxes
ret[n]['values'] = self[n]['values'].transpose(neworder)
for ax in self.axes(n):
neworder, newaxes = orders[n]
self[n]['axes'] = newaxes
self[n]['values'] = self[n]['values'].transpose(neworder)
for ax in orig_axes[n]:
if ax not in transposed:
ret[ax]['values'] = self[ax]['values'].transpose(neworder)
self[ax]['values'] = self[ax]['values'].transpose(neworder)
transposed.append(ax)

ret.validate()
return ret
self.validate()
return self

def mean(self, axis: str) -> 'MeshgridDataDict':
"""Take the mean over the given axis.
Expand Down Expand Up @@ -1234,7 +1274,8 @@ def guess_shape_from_datadict(data: DataDict) -> \
def datadict_to_meshgrid(data: DataDict,
target_shape: Union[Tuple[int, ...], None] = None,
inner_axis_order: Union[None, Sequence[str]] = None,
use_existing_shape: bool = False) \
use_existing_shape: bool = False,
copy: bool = True) \
-> MeshgridDataDict:
"""
Try to make a meshgrid from a dataset.
Expand All @@ -1256,6 +1297,9 @@ def datadict_to_meshgrid(data: DataDict,
:param use_existing_shape: if ``True``, simply use the shape that the data
already has. For numpy-array data, this might already be present.
If ``False``, flatten and reshape.
:param copy: if ``True``, then we make a copy of the data arrays.
if ``False``, data array is modified in-place.

:raises: GriddingError (subclass of ValueError) if the data cannot be gridded.
:returns: The generated ``MeshgridDataDict``.
"""
Expand Down Expand Up @@ -1290,7 +1334,7 @@ def datadict_to_meshgrid(data: DataDict,
axlist = data.axes(data.dependents()[0])

for k, v in data.data_items():
vals = num.array1d_to_meshgrid(v['values'], target_shape, copy=True)
vals = num.array1d_to_meshgrid(v['values'], target_shape, copy=copy)

# if an inner axis order is given, we transpose to transform from that
# to the specified order.
Expand Down
4 changes: 2 additions & 2 deletions test/apps/autoplot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def data(self) -> Iterable[DataDictBase]:


def main(dataSrc):
plottrlog.LEVEL = logging.INFO
plottrlog.LEVEL = logging.DEBUG

app = QtWidgets.QApplication([])
fc, win = autoplot(plotWidgetClass=plotWidgetClass)
Expand All @@ -163,5 +163,5 @@ def main(dataSrc):
# src = ImageDataMovie(10, 2, 101)
src = ImageDataLiveAcquisition(101, 101, 67)
# src = ComplexImage(21, 21)
src.delay = 0.1
src.delay = 0.5
main(src)