Skip to content

Commit

Permalink
Merge pull request #285 from bobleesj/do-ops
Browse files Browse the repository at this point in the history
Refactor `__add__` operation in `DiffractionObject` and add tests
  • Loading branch information
sbillinge authored Dec 29, 2024
2 parents 3f2d0a4 + da70bd6 commit 1ea8e9a
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 34 deletions.
23 changes: 23 additions & 0 deletions news/add-operations-tests.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* unit tests for __add__ operation for DiffractionObject

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
98 changes: 64 additions & 34 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]

x_grid_emsg = (
"objects are not on the same x-grid. You may add them using the self.add method "
"and specifying how to handle the mismatch."
y_grid_length_mismatch_emsg = (
"The two objects have different y-array lengths. "
"Please ensure the length of the y-value during initialization is identical."
)

invalid_add_type_emsg = (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
)


Expand Down Expand Up @@ -169,32 +175,56 @@ def __eq__(self, other):
return True

def __add__(self, other):
summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two DiffractionObject objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
"""Add a scalar value or another DiffractionObject to the yarray of the
DiffractionObject.
def __radd__(self, other):
summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
Parameters
----------
other : DiffractionObject or int or float
The object to add to the current DiffractionObject. If `other` is a scalar value,
it will be added to all yarray. The length of the yarray must match if `other` is
an instance of DiffractionObject.
Returns
-------
DiffractionObject
The new and deep-copied DiffractionObject instance after adding values to the yarray.
Raises
------
ValueError
Raised when the length of the yarray of the two DiffractionObject instances do not match.
TypeError
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.
Examples
--------
Add a scalar value to the yarray of the DiffractionObject instance:
>>> new_do = my_do + 10.1
>>> new_do = 10.1 + my_do
Add the yarray of two DiffractionObject instances:
>>> new_do = my_do_1 + my_do_2
"""

self._check_operation_compatibility(other)
summed_do = deepcopy(self)
if isinstance(other, (int, float)):
summed_do._all_arrays[:, 0] += other
if isinstance(other, DiffractionObject):
summed_do._all_arrays[:, 0] += other.all_arrays[:, 0]
return summed_do

__radd__ = __add__

def _check_operation_compatibility(self, other):
if not isinstance(other, (DiffractionObject, int, float)):
raise TypeError(invalid_add_type_emsg)
if isinstance(other, DiffractionObject):
self_yarray = self.all_arrays[:, 0]
other_yarray = other.all_arrays[:, 0]
if len(self_yarray) != len(other_yarray):
raise ValueError(y_grid_length_mismatch_emsg)

def __sub__(self, other):
subtracted = deepcopy(self)
Expand All @@ -204,7 +234,7 @@ def __sub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
Expand All @@ -218,7 +248,7 @@ def __rsub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
Expand All @@ -232,7 +262,7 @@ def __mul__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
Expand All @@ -244,7 +274,7 @@ def __rmul__(self, other):
multiplied.on_tth[1] = other * self.on_tth[1]
multiplied.on_q[1] = other * self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
Expand All @@ -258,7 +288,7 @@ def __truediv__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
divided.on_q[1] = self.on_q[1] / other.on_q[1]
Expand All @@ -270,7 +300,7 @@ def __rtruediv__(self, other):
divided.on_tth[1] = other / self.on_tth[1]
divided.on_q[1] = other / self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
divided.on_q[1] = other.on_q[1] / self.on_q[1]
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def do_minimal_tth():
return DiffractionObject(wavelength=2 * np.pi, xarray=np.array([30, 60]), yarray=np.array([1, 2]), xtype="tth")


@pytest.fixture
def do_minimal_d():
# Create an instance of DiffractionObject with non-empty xarray, yarray, and wavelength values
return DiffractionObject(wavelength=1.54, xarray=np.array([1, 2]), yarray=np.array([1, 2]), xtype="d")


@pytest.fixture
def wavelength_warning_msg():
return (
Expand All @@ -63,3 +69,20 @@ def invalid_q_or_d_or_wavelength_error_msg():
"The supplied input array and wavelength will result in an impossible two-theta. "
"Please check these values and re-instantiate the DiffractionObject with correct values."
)


@pytest.fixture
def invalid_add_type_error_msg():
return (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
)


@pytest.fixture
def y_grid_size_mismatch_error_msg():
return (
"The two objects have different y-array lengths. "
"Please ensure the length of the y-value during initialization is identical."
)
75 changes: 75 additions & 0 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,78 @@ def test_copy_object(do_minimal):
do_copy = do.copy()
assert do == do_copy
assert id(do) != id(do_copy)


@pytest.mark.parametrize(
"starting_all_arrays, scalar_to_add, expected_all_arrays",
[
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
( # C1: Add integer of 5, expect yarray to increase by by 5
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5,
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
),
( # C2: Add float of 5.1, expect yarray to be added by 5.1
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5.1,
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
),
],
)
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
do = do_minimal_tth
assert np.allclose(do.all_arrays, starting_all_arrays)
do_scalar_right_sum = do + scalar_to_add
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
do_scalar_left_sum = scalar_to_add + do
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)


@pytest.mark.parametrize(
"do_1_all_arrays, "
"do_2_all_arrays, "
"expected_do_1_all_arrays_with_y_summed, "
"expected_do_2_all_arrays_with_y_summed",
[
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
( # C1: Add two DO objects, expect sum of yarray values
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
),
],
)
def test_addition_operator_by_another_do(
do_1_all_arrays,
do_2_all_arrays,
expected_do_1_all_arrays_with_y_summed,
expected_do_2_all_arrays_with_y_summed,
do_minimal_tth,
do_minimal_d,
):
do_1 = do_minimal_tth
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
do_2 = do_minimal_d
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)


def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
do = do_minimal_tth
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
do + "string_value"
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
"string_value" + do


def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
do_1 = do_minimal
do_2 = do_minimal_tth
assert len(do_1.all_arrays[:, 0]) == 0
assert len(do_2.all_arrays[:, 0]) == 2
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
do_1 + do_2

0 comments on commit 1ea8e9a

Please sign in to comment.