Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow table columns and add_column to specify column class #436

Merged
merged 3 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## HDMF 2.3.0 (Upcoming)

### New features
- Add ability to specify a custom class for new columns to a `DynamicTable` that are not `VectorData`,
`DynamicTableRegion`, or `VocabData` using `DynamicTable.__columns__` or `DynamicTable.add_column(...)`. @rly (#436)

### Bug fixes
- Fix handling of empty lists against a spec with text/bytes dtype. @rly (#434)

Expand Down
40 changes: 28 additions & 12 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def __set_table_attr(self, col):
else:
setattr(self, col.name, col)

__reserved_colspec_keys = ['name', 'description', 'index', 'table', 'required', 'class']

def _init_class_columns(self):
"""
Process all predefined columns specified in class variable __columns__.
Expand All @@ -423,9 +425,10 @@ def _init_class_columns(self):
description=col['description'],
index=col.get('index', False),
table=col.get('table', False),
col_cls=col.get('class', VectorData),
# Pass through extra kwargs for add_column that subclasses may have added
**{k: col[k] for k in col.keys()
if k not in ['name', 'description', 'index', 'table', 'required']})
if k not in DynamicTable.__reserved_colspec_keys})
else:
# track the not yet initialized optional predefined columns
self.__uninit_cols[col['name']] = col
Expand All @@ -445,6 +448,7 @@ def __build_columns(columns, df=None):
for d in columns:
name = d['name']
desc = d.get('description', 'no description')
col_cls = d.get('class', VectorData)
data = None
if df is not None:
data = list(df[name].values)
Expand All @@ -460,17 +464,16 @@ def __build_columns(columns, df=None):
for d in data:
tmp_data.extend(d)
data = tmp_data
vdata = VectorData(name, desc, data=data)
vdata = col_cls(name, desc, data=data)
vindex = VectorIndex("%s_index" % name, index_data, target=vdata)
tmp.append(vindex)
tmp.append(vdata)
else:
if data is None:
data = list()
cls = VectorData
if d.get('table', False):
cls = DynamicTableRegion
tmp.append(cls(name, desc, data=data))
col_cls = DynamicTableRegion
tmp.append(col_cls(name, desc, data=data))
return tmp

def __len__(self):
Expand Down Expand Up @@ -500,10 +503,11 @@ def add_row(self, **kwargs):
self.add_column(col['name'], col['description'],
index=col.get('index', False),
table=col.get('table', False),
col_cls=col.get('class', VectorData),
# Pass through extra keyword arguments for add_column that
# subclasses may have added
**{k: col[k] for k in col.keys()
if k not in ['name', 'description', 'index', 'table', 'required']})
if k not in DynamicTable.__reserved_colspec_keys})
extra_columns.remove(col['name'])

if extra_columns or missing_columns:
Expand Down Expand Up @@ -557,7 +561,11 @@ def __eq__(self, other):
'doc': 'whether or not this column should be indexed', 'default': False},
{'name': 'vocab', 'type': (bool, 'array_data'), 'default': False,
'doc': ('whether or not this column contains data from a '
'controlled vocabulary or the controlled vocabulary')})
'controlled vocabulary or the controlled vocabulary')},
{'name': 'col_cls', 'type': type, 'default': VectorData,
'doc': ('class to use to represent the column data. If table=True, this field is ignored and a '
'DynamicTableRegion object is used. If vocab=True, this field is ignored and a VocabData '
'object is used.')},)
def add_column(self, **kwargs): # noqa: C901
"""
Add a column to this table.
Expand All @@ -567,7 +575,7 @@ def add_column(self, **kwargs): # noqa: C901
:raises ValueError: if the column has already been added to the table
"""
name, data = getargs('name', 'data', kwargs)
index, table, vocab = popargs('index', 'table', 'vocab', kwargs)
index, table, vocab, col_cls = popargs('index', 'table', 'vocab', 'col_cls', kwargs)

if isinstance(index, VectorIndex):
warn("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be "
Expand Down Expand Up @@ -600,22 +608,30 @@ def add_column(self, **kwargs): # noqa: C901
% (name, self.__class__.__name__, spec_index))
warn(msg)

spec_col_cls = self.__uninit_cols[name].get('class', VectorData)
if col_cls != spec_col_cls:
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
"col_cls argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF."
% (name, self.__class__.__name__, spec_col_cls))
warn(msg)

ckwargs = dict(kwargs)
cls = VectorData

# Add table if it's been specified
if table and vocab:
raise ValueError("column '%s' cannot be both a table region and come from a controlled vocabulary" % name)
if table is not False:
cls = DynamicTableRegion
col_cls = DynamicTableRegion
if isinstance(table, DynamicTable):
ckwargs['table'] = table
if vocab is not False:
cls = VocabData
col_cls = VocabData
if isinstance(vocab, (list, tuple, np.ndarray)):
ckwargs['vocabulary'] = vocab

col = cls(**ckwargs)
col = col_cls(**ckwargs)
col.parent = self
columns = [col]
self.__set_table_attr(col)
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ class SubTable(DynamicTable):
{'name': 'col6', 'description': 'optional region', 'table': True},
{'name': 'col7', 'description': 'required, indexed region', 'required': True, 'index': True, 'table': True},
{'name': 'col8', 'description': 'optional, indexed region', 'index': True, 'table': True},
{'name': 'col10', 'description': 'required, indexed vocab column', 'index': True, 'class': VocabData},
)


Expand Down Expand Up @@ -716,6 +717,8 @@ def test_init(self):
self.assertIsNone(table.col6)
self.assertIsNone(table.col8)
self.assertIsNone(table.col8_index)
self.assertIsNone(table.col10)
self.assertIsNone(table.col10_index)

# uninitialized optional predefined columns cannot be accessed in this manner
with self.assertRaisesWith(KeyError, "'col2'"):
Expand Down Expand Up @@ -764,6 +767,9 @@ def test_add_opt_column(self):
table.add_column(name='col8', description='column #8', index=True, table=True)
self.assertEqual(table.col8.description, 'column #8')

table.add_column(name='col10', description='column #10', index=True, col_cls=VocabData)
self.assertIsInstance(table.col10, VocabData)

def test_add_opt_column_mismatched_table_true(self):
"""Test that adding an optional column from __columns__ with non-matched table raises a warning."""
table = SubTable(name='subtable', description='subtable description')
Expand Down Expand Up @@ -814,6 +820,20 @@ def test_add_opt_column_mismatched_index_data(self):
self.assertEqual(table.col2.description, 'column #2')
self.assertEqual(type(table.get('col2')), VectorIndex) # not VectorData

def test_add_opt_column_mismatched_col_cls(self):
"""Test that adding an optional column from __columns__ with non-matched table raises a warning."""
table = SubTable(name='subtable', description='subtable description')
msg = ("Column 'col10' is predefined in SubTable with class=<class 'hdmf.common.table.VocabData'> "
"which does not match the entered col_cls "
"argument. The predefined class spec will be ignored. "
"Please ensure the new column complies with the spec. "
"This will raise an error in a future version of HDMF.")
with self.assertWarnsWith(UserWarning, msg):
table.add_column(name='col10', description='column #10', index=True)
self.assertEqual(table.col10.description, 'column #10')
self.assertEqual(type(table.col10), VectorData)
self.assertEqual(type(table.get('col10')), VectorIndex)

def test_add_opt_column_twice(self):
"""Test that adding an optional column from __columns__ twice fails the second time."""
table = SubTable(name='subtable', description='subtable description')
Expand Down