Skip to content

Commit

Permalink
add case change tests to parametrized gentle_asarray test
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Aug 8, 2023
1 parent 0ae8911 commit c96b625
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
14 changes: 10 additions & 4 deletions src/stdatamodels/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def gentle_asarray(a, dtype, allow_extra_columns=False):
# check if names match but the order is incorrect
if set(out_lower_names) == set(in_lower_names):
# all the columns exist but they are in the wrong order
# reorder the columns
reordered_array = merge_arrays([a[n] for n in out_lower_names], flatten=True)
# reorder the columns, the names might differ in case
reordered_names = sorted(in_dtype.names, key=lambda n: out_lower_names.index(n.lower()))
reordered_array = merge_arrays([a[n] for n in reordered_names], flatten=True)
reordered_subdtypes = [reordered_array.dtype[n] for n in reordered_array.dtype.names]
out_subdtypes = [out_dtype[n] for n in out_dtype.names]
if reordered_subdtypes == out_subdtypes:
Expand Down Expand Up @@ -150,15 +151,20 @@ def gentle_asarray(a, dtype, allow_extra_columns=False):
return _safe_asanyarray(a, new_dtype)

# reorder columns so required columns are first
required_names = [n for n in in_dtype.names if n.lower() in out_lower_names]
required_names.sort(key=lambda n: out_lower_names.index(n.lower()))
extra_names = [n for n in in_dtype.names if n.lower() not in out_lower_names]
names_ordered = out_dtype.names + tuple(extra_names)
names_ordered = tuple(required_names + extra_names)
reordered_array = merge_arrays([a[n] for n in names_ordered], flatten=True)
reordered_array.dtype.names = names_ordered

extra_dtype_descr = [(n, in_dtype[n]) for n in extra_names]
new_dtype = np.dtype(out_dtype.descr + extra_dtype_descr)

# check that required columns have the correct dtype
if reordered_array.dtype.descr[:n_required] == out_dtype.descr:
reordered_subdtypes = [reordered_array.dtype[n] for n in reordered_array.dtype.names]
out_subdtypes = [out_dtype[n] for n in out_dtype.names]
if reordered_subdtypes[:n_required] == out_subdtypes:
return reordered_array.view(new_dtype)
return _safe_asanyarray(reordered_array, new_dtype)

Expand Down
6 changes: 5 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def test_gentle_asarray_invalid_conversion():
@pytest.mark.parametrize("change_dtype", [True, False], ids=['different_dtype', 'same_dtype'])
@pytest.mark.parametrize("extra_columns", [True, False], ids=['extra_columns', 'no_extra_columns'])
@pytest.mark.parametrize("allow_extra", [True, False], ids=['allow_extra', 'disallow_extra'])
def test_gentle_asarray_structured_dtype_configurations(reorder, change_dtype, extra_columns, allow_extra):
@pytest.mark.parametrize("change_case", [True, False], ids=['changed_case', 'same_case'])
def test_gentle_asarray_structured_dtype_configurations(reorder, change_dtype, extra_columns, allow_extra, change_case):
"""
Test gentle_asarray with a structured array with a few combinations of:
- misordered columns
Expand Down Expand Up @@ -216,6 +217,9 @@ def test_gentle_asarray_structured_dtype_configurations(reorder, change_dtype, e
input_array['s'] = b'a'
input_array['b'] = True
input_array['u'] = 3
if change_case:
input_array.dtype.names = tuple([n.upper() for n in input_array.dtype.names])
input_dtype = input_array.dtype
if not allow_extra and extra_columns:
# if we have extra columns, but don't allow them, gentle_asarray should fail
with pytest.raises(ValueError):
Expand Down

0 comments on commit c96b625

Please sign in to comment.