Skip to content

Commit

Permalink
Add full check of attribute properties in __eq__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-dark authored and ihnorton committed Dec 6, 2023
1 parent 22ff29a commit b822509
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
39 changes: 37 additions & 2 deletions tiledb/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,44 @@ def __init__(
def __eq__(self, other):
if not isinstance(other, Attr):
return False
if self.name != other.name or self.dtype != other.dtype:
if self.isnullable != other.isnullable or self.dtype != other.dtype:
return False
return True
if not self.isnullable:
# Check the fill values are equal.
def equal_or_nan(x, y):
return x == y or (np.isnan(x) and np.isnan(y))

if self.ncells == 1:
if not equal_or_nan(self.fill, other.fill):
return False
elif np.issubdtype(self.dtype, np.bytes_) or np.issubdtype(
self.dtype, np.str_
):
if self.fill != other.fill:
return False
elif self.dtype in {np.dtype("complex64"), np.dtype("complex128")}:
if not (
equal_or_nan(np.real(self.fill), np.real(other.fill))
and equal_or_nan(np.imag(self.fill), np.imag(other.fill))
):
return False
else:
if not all(
equal_or_nan(x, y)
or (
isinstance(x, str)
and x.lower() == "nat"
and isinstance(y, str)
and y.lower() == "nat"
)
for x, y in zip(self.fill[0], other.fill[0])
):
return False
return (
self._internal_name == other._internal_name
and self.isvar == other.isvar
and self.filters == other.filters
)

def dump(self):
"""Dumps a string representation of the Attr object to standard output (stdout)"""
Expand Down
29 changes: 29 additions & 0 deletions tiledb/tests/test_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class AttributeTest(DiskTestCase):
def test_minimal_attribute(self):
attr = tiledb.Attr()
self.assertEqual(attr, attr)
self.assertTrue(attr.isanon)
self.assertEqual(attr.name, "")
self.assertEqual(attr.dtype, np.float_)
Expand All @@ -30,6 +31,7 @@ def test_attribute(self, capfd):
attr.dump()
assert_captured(capfd, "Name: foo")

assert attr == attr
assert attr.name == "foo"
assert attr.dtype == np.float64, "default attribute type is float64"

Expand All @@ -46,6 +48,7 @@ def test_attribute(self, capfd):
)
def test_attribute_fill(self, dtype, fill):
attr = tiledb.Attr("", dtype=dtype, fill=fill)
assert attr == attr
assert np.array(attr.fill, dtype=dtype) == np.array(fill, dtype=dtype)

path = self.path()
Expand All @@ -68,6 +71,7 @@ def test_full_attribute(self, capfd):
attr.dump()
assert_captured(capfd, "Name: foo")

self.assertEqual(attr, attr)
self.assertEqual(attr.name, "foo")
self.assertEqual(attr.dtype, np.int64)
self.assertIsInstance(attr.filters[0], tiledb.ZstdFilter)
Expand All @@ -77,6 +81,7 @@ def test_ncell_attribute(self):
dtype = np.dtype([("", np.int32), ("", np.int32), ("", np.int32)])
attr = tiledb.Attr("foo", dtype=dtype)

self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, dtype)
self.assertEqual(attr.ncells, 3)

Expand Down Expand Up @@ -125,9 +130,27 @@ def test_two_cell_double_attribute(self, fill):
assert attr.fill == attr.fill
assert attr.ncells == 2

def test_ncell_double_attribute(self):
dtype = np.dtype([("", np.double), ("", np.double), ("", np.double)])
fill = np.array((0, np.nan, np.inf), dtype=dtype)
attr = tiledb.Attr("foo", dtype=dtype, fill=fill)

self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, dtype)
self.assertEqual(attr.ncells, 3)

def test_ncell_not_equal_fill_attribute(self):
dtype = np.dtype([("", np.double), ("", np.double), ("", np.double)])
fill1 = np.array((0, np.nan, np.inf), dtype=dtype)
fill2 = np.array((np.nan, -1, np.inf), dtype=dtype)
attr1 = tiledb.Attr("foo", dtype=dtype, fill=fill1)
attr2 = tiledb.Attr("foo", dtype=dtype, fill=fill2)
assert attr1 != attr2

def test_ncell_bytes_attribute(self):
dtype = np.dtype((np.bytes_, 10))
attr = tiledb.Attr("foo", dtype=dtype)
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, dtype)
self.assertEqual(attr.ncells, 10)

Expand All @@ -143,28 +166,34 @@ def test_bytes_var_attribute(self):
self.assertTrue(attr.isvar)

attr = tiledb.Attr("foo", var=True, dtype="S")
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, np.dtype("S"))
self.assertTrue(attr.isvar)

attr = tiledb.Attr("foo", var=False, dtype="S1")
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, np.dtype("S1"))
self.assertFalse(attr.isvar)

attr = tiledb.Attr("foo", dtype="S1")
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, np.dtype("S1"))
self.assertFalse(attr.isvar)

attr = tiledb.Attr("foo", dtype="S")
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, np.dtype("S"))
self.assertTrue(attr.isvar)

def test_nullable_attribute(self):
attr = tiledb.Attr("nullable", nullable=True, dtype=np.int32)
self.assertEqual(attr, attr)
self.assertEqual(attr.dtype, np.dtype(np.int32))
self.assertTrue(attr.isnullable)

def test_datetime_attribute(self):
attr = tiledb.Attr("foo", dtype=np.datetime64("", "D"))
self.assertEqual(attr, attr)
assert attr.dtype == np.dtype(np.datetime64("", "D"))
assert attr.dtype != np.dtype(np.datetime64("", "Y"))
assert attr.dtype != np.dtype(np.datetime64)
Expand Down

0 comments on commit b822509

Please sign in to comment.