From c93d595c3452df49a7e635c194ab419fd085ae81 Mon Sep 17 00:00:00 2001 From: Oliver Ruebel Date: Mon, 18 Nov 2019 14:09:56 -0800 Subject: [PATCH] Removed redundend shape check functions. Fix #204 (#205) * Removed redundend shape check functions. Fix #204 --- src/hdmf/backends/hdf5/h5tools.py | 6 +++--- src/hdmf/container.py | 6 +++--- src/hdmf/data_utils.py | 29 +---------------------------- src/hdmf/utils.py | 6 ++++-- src/hdmf/validate/validator.py | 7 +++---- tests/unit/test_io_hdf5.py | 4 ++-- 6 files changed, 16 insertions(+), 42 deletions(-) diff --git a/src/hdmf/backends/hdf5/h5tools.py b/src/hdmf/backends/hdf5/h5tools.py index 6b48ff2d7..bac077ec0 100644 --- a/src/hdmf/backends/hdf5/h5tools.py +++ b/src/hdmf/backends/hdf5/h5tools.py @@ -7,8 +7,8 @@ import warnings from ...container import Container -from ...utils import docval, getargs, popargs, call_docval_func -from ...data_utils import AbstractDataChunkIterator, get_shape +from ...utils import docval, getargs, popargs, call_docval_func, get_data_shape +from ...data_utils import AbstractDataChunkIterator from ...build import Builder, GroupBuilder, DatasetBuilder, LinkBuilder, BuildManager,\ RegionBuilder, ReferenceBuilder, TypeMap, ObjectMapper from ...spec import RefSpec, DtypeSpec, NamespaceCatalog, GroupSpec @@ -1028,7 +1028,7 @@ def __list_fill__(cls, parent, name, data, options=None): elif isinstance(dtype, np.dtype): data_shape = (len(data),) else: - data_shape = get_shape(data) + data_shape = get_data_shape(data) # Create the dataset try: dset = parent.create_dataset(name, shape=data_shape, dtype=dtype, **io_settings) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index 2932f1e67..862dfc2fe 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -2,8 +2,8 @@ from abc import abstractmethod from uuid import uuid4 from six import with_metaclass -from .utils import docval, get_docval, call_docval_func, getargs, ExtenderMeta -from .data_utils import DataIO, get_shape +from .utils import docval, get_docval, call_docval_func, getargs, ExtenderMeta, get_data_shape +from .data_utils import DataIO from warnings import warn import h5py @@ -397,7 +397,7 @@ def shape(self): :return: Shape tuple :rtype: tuple of ints """ - return get_shape(self.__data) + return get_data_shape(self.__data) @docval({'name': 'dataio', 'type': DataIO, 'doc': 'the DataIO to apply to the data held by this Data'}) def set_dataio(self, **kwargs): diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index acc623904..1b3e22afc 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -3,39 +3,12 @@ import numpy as np from warnings import warn -from six import with_metaclass, text_type, binary_type +from six import with_metaclass import copy from .utils import docval, getargs, popargs, docval_macro, get_data_shape -def __get_shape_helper(data): - """Helper function used by get_shape""" - shape = list() - if hasattr(data, '__len__'): - shape.append(len(data)) - if len(data) and not isinstance(data[0], (text_type, binary_type)): - shape.extend(__get_shape_helper(data[0])) - return tuple(shape) - - -def get_shape(data): - """ - Determine the data shape for the given data - :param data: Array for which the data should be determined - :type data: list, ndarray, dict - :return: None in case shape is unknown and shape tuple otherwise - """ - if isinstance(data, dict): - return None - elif hasattr(data, 'shape'): - return data.shape - elif hasattr(data, '__len__') and not isinstance(data, (text_type, binary_type)): - return __get_shape_helper(data) - else: - return None - - @docval_macro('array_data') class AbstractDataChunkIterator(with_metaclass(ABCMeta, object)): """ diff --git a/src/hdmf/utils.py b/src/hdmf/utils.py index fd800c872..86790fe53 100644 --- a/src/hdmf/utils.py +++ b/src/hdmf/utils.py @@ -648,9 +648,11 @@ def __get_shape_helper(local_data): return tuple(shape) if hasattr(data, 'maxshape'): return data.maxshape - if hasattr(data, 'shape'): + elif hasattr(data, 'shape'): return data.shape - if hasattr(data, '__len__') and not isinstance(data, (text_type, binary_type)): + elif isinstance(data, dict): + return None + elif hasattr(data, '__len__') and not isinstance(data, (text_type, binary_type)): if not strict_no_data_load or (isinstance(data, list) or isinstance(data, tuple) or isinstance(data, set)): return __get_shape_helper(data) else: diff --git a/src/hdmf/validate/validator.py b/src/hdmf/validate/validator.py index 57d083607..2e8050cf7 100644 --- a/src/hdmf/validate/validator.py +++ b/src/hdmf/validate/validator.py @@ -4,8 +4,7 @@ import re from itertools import chain -from ..utils import docval, getargs, call_docval_func, pystr -from ..data_utils import get_shape +from ..utils import docval, getargs, call_docval_func, pystr, get_data_shape from ..spec import Spec, AttributeSpec, GroupSpec, DatasetSpec, RefSpec from ..spec.spec import BaseStorageSpec, DtypeHelper @@ -318,7 +317,7 @@ def validate(self, **kwargs): dtype = get_type(value) if not check_type(spec.dtype, dtype): ret.append(DtypeError(self.get_spec_loc(spec), spec.dtype, dtype)) - shape = get_shape(value) + shape = get_data_shape(value) if not check_shape(spec.shape, shape): ret.append(ShapeError(self.get_spec_loc(spec), spec.shape, shape)) return ret @@ -374,7 +373,7 @@ def validate(self, **kwargs): if not check_type(self.spec.dtype, dtype): ret.append(DtypeError(self.get_spec_loc(self.spec), self.spec.dtype, dtype, location=self.get_builder_loc(builder))) - shape = get_shape(data) + shape = get_data_shape(data) if not check_shape(self.spec.shape, shape): if shape is None: ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(data), diff --git a/tests/unit/test_io_hdf5.py b/tests/unit/test_io_hdf5.py index ea929802d..95b1580b8 100644 --- a/tests/unit/test_io_hdf5.py +++ b/tests/unit/test_io_hdf5.py @@ -5,7 +5,7 @@ from hdmf.backends.hdf5 import HDF5IO from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder -from hdmf.data_utils import get_shape +from hdmf.utils import get_data_shape from numbers import Number @@ -239,5 +239,5 @@ def test_dataset_shape(self): io.write_builder(self.builder) builder = io.read_builder() dset = builder['test_bucket']['foo_holder']['foo1']['my_data'].data - self.assertEqual(get_shape(dset), (10,)) + self.assertEqual(get_data_shape(dset), (10,)) io.close()