Skip to content

Commit

Permalink
Merge pull request #293 from bobleesj/op-mul-sub
Browse files Browse the repository at this point in the history
feat: Support *, /, - operations between two DiffractionObjects or scalar
  • Loading branch information
sbillinge authored Dec 30, 2024
2 parents 1ea8e9a + 21711fb commit 90fd625
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 127 deletions.
23 changes: 23 additions & 0 deletions news/op-mul-sub-div.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* addition, multiplication, subtraction, and division operators between two DiffractionObject instances or a scalar value with another DiffractionObject for modifying yarray (intensity) values.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
115 changes: 33 additions & 82 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,94 +217,45 @@ def __add__(self, other):

__radd__ = __add__

def _check_operation_compatibility(self, other):
if not isinstance(other, (DiffractionObject, int, float)):
raise TypeError(invalid_add_type_emsg)
def __sub__(self, other):
self._check_operation_compatibility(other)
subtracted_do = deepcopy(self)
if isinstance(other, (int, float)):
subtracted_do._all_arrays[:, 0] -= other
if isinstance(other, DiffractionObject):
self_yarray = self.all_arrays[:, 0]
other_yarray = other.all_arrays[:, 0]
if len(self_yarray) != len(other_yarray):
raise ValueError(y_grid_length_mismatch_emsg)
subtracted_do._all_arrays[:, 0] -= other.all_arrays[:, 0]
return subtracted_do

def __sub__(self, other):
subtracted = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
subtracted.on_tth[1] = self.on_tth[1] - other
subtracted.on_q[1] = self.on_q[1] - other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
return subtracted

def __rsub__(self, other):
subtracted = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
subtracted.on_tth[1] = other - self.on_tth[1]
subtracted.on_q[1] = other - self.on_q[1]
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
return subtracted
__rsub__ = __sub__

def __mul__(self, other):
multiplied = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
multiplied.on_tth[1] = other * self.on_tth[1]
multiplied.on_q[1] = other * self.on_q[1]
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
return multiplied

def __rmul__(self, other):
multiplied = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
multiplied.on_tth[1] = other * self.on_tth[1]
multiplied.on_q[1] = other * self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
return multiplied
self._check_operation_compatibility(other)
multiplied_do = deepcopy(self)
if isinstance(other, (int, float)):
multiplied_do._all_arrays[:, 0] *= other
if isinstance(other, DiffractionObject):
multiplied_do._all_arrays[:, 0] *= other.all_arrays[:, 0]
return multiplied_do

__rmul__ = __mul__

def __truediv__(self, other):
divided = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
divided.on_tth[1] = other / self.on_tth[1]
divided.on_q[1] = other / self.on_q[1]
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
divided.on_q[1] = self.on_q[1] / other.on_q[1]
return divided

def __rtruediv__(self, other):
divided = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
divided.on_tth[1] = other / self.on_tth[1]
divided.on_q[1] = other / self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
divided.on_q[1] = other.on_q[1] / self.on_q[1]
return divided
self._check_operation_compatibility(other)
divided_do = deepcopy(self)
if isinstance(other, (int, float)):
divided_do._all_arrays[:, 0] /= other
if isinstance(other, DiffractionObject):
divided_do._all_arrays[:, 0] /= other.all_arrays[:, 0]
return divided_do

__rtruediv__ = __truediv__

def _check_operation_compatibility(self, other):
if not isinstance(other, (DiffractionObject, int, float)):
raise TypeError(invalid_add_type_emsg)
if isinstance(other, DiffractionObject):
if self.all_arrays.shape != other.all_arrays.shape:
raise ValueError(y_grid_length_mismatch_emsg)

@property
def all_arrays(self):
Expand Down
191 changes: 146 additions & 45 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,75 +713,176 @@ def test_copy_object(do_minimal):


@pytest.mark.parametrize(
"starting_all_arrays, scalar_to_add, expected_all_arrays",
"operation, starting_yarray, scalar_value, expected_yarray",
[
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
( # C1: Add integer of 5, expect yarray to increase by by 5
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
# Test scalar addition, subtraction, multiplication, and division to y-values by adding a scalar value
# C1: Test scalar addition to y-values (intensity), expect no change to x-values (q, tth, d)
( # 1. Add 5
"add",
np.array([1.0, 2.0]),
5,
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
np.array([6.0, 7.0]),
),
( # C2: Add float of 5.1, expect yarray to be added by 5.1
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
( # 2. Add 5.1
"add",
np.array([1.0, 2.0]),
5.1,
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
np.array([6.1, 7.1]),
),
# C2: Test scalar subtraction to y-values (intensity), expect no change to x-values (q, tth, d)
( # 1. Subtract 1
"sub",
np.array([1.0, 2.0]),
1,
np.array([0.0, 1.0]),
),
( # 2. Subtract 0.5
"sub",
np.array([1.0, 2.0]),
0.5,
np.array([0.5, 1.5]),
),
# C3: Test scalar multiplication to y-values (intensity), expect no change to x-values (q, tth, d)
( # 1. Multiply by 2
"mul",
np.array([1.0, 2.0]),
2,
np.array([2.0, 4.0]),
),
( # 2. Multiply by 2.5
"mul",
np.array([1.0, 2.0]),
2.5,
np.array([2.5, 5.0]),
),
# C4: Test scalar division to y-values (intensity), expect no change to x-values (q, tth, d)
( # 1. Divide by 2
"div",
np.array([1.0, 2.0]),
2,
np.array([0.5, 1.0]),
),
( # 2. Divide by 2.5
"div",
np.array([1.0, 2.0]),
2.5,
np.array([0.4, 0.8]),
),
],
)
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
def test_scalar_operations(operation, starting_yarray, scalar_value, expected_yarray, do_minimal_tth):
do = do_minimal_tth
assert np.allclose(do.all_arrays, starting_all_arrays)
do_scalar_right_sum = do + scalar_to_add
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
do_scalar_left_sum = scalar_to_add + do
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)
expected_xarray_constant = np.array([[0.51763809, 30.0, 12.13818192], [1.0, 60.0, 6.28318531]])
assert np.allclose(do.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
assert np.allclose(do.all_arrays[:, 0], starting_yarray)
if operation == "add":
do_right_op = do + scalar_value
do_left_op = scalar_value + do
elif operation == "sub":
do_right_op = do - scalar_value
do_left_op = scalar_value - do
elif operation == "mul":
do_right_op = do * scalar_value
do_left_op = scalar_value * do
elif operation == "div":
do_right_op = do / scalar_value
do_left_op = scalar_value / do
assert np.allclose(do_right_op.all_arrays[:, 0], expected_yarray)
assert np.allclose(do_left_op.all_arrays[:, 0], expected_yarray)
# Ensure x-values are unchanged
assert np.allclose(do_right_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
assert np.allclose(do_left_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)


@pytest.mark.parametrize(
"do_1_all_arrays, "
"do_2_all_arrays, "
"expected_do_1_all_arrays_with_y_summed, "
"expected_do_2_all_arrays_with_y_summed",
"operation, " "expected_do_1_all_arrays_with_y_modified, " "expected_do_2_all_arrays_with_y_modified",
[
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
( # C1: Add two DO objects, expect sum of yarray values
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
# Test addition, subtraction, multiplication, and division of two DO objects
( # Test addition of two DO objects, expect combined yarray values
"add",
np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
),
( # Test subtraction of two DO objects, expect differences in yarray values
"sub",
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [0.0, 1.0, 60.0, 6.28318531]]),
np.array([[0.0, 6.28318531, 100.70777771, 1], [0.0, 3.14159265, 45.28748053, 2.0]]),
),
( # Test multiplication of two DO objects, expect multiplication in yarray values
"mul",
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
np.array([[1.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
),
( # Test division of two DO objects, expect division in yarray values
"div",
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [1.0, 1.0, 60.0, 6.28318531]]),
np.array([[1.0, 6.28318531, 100.70777771, 1], [1.0, 3.14159265, 45.28748053, 2.0]]),
),
],
)
def test_addition_operator_by_another_do(
do_1_all_arrays,
do_2_all_arrays,
expected_do_1_all_arrays_with_y_summed,
expected_do_2_all_arrays_with_y_summed,
def test_binary_operator_on_do(
operation,
expected_do_1_all_arrays_with_y_modified,
expected_do_2_all_arrays_with_y_modified,
do_minimal_tth,
do_minimal_d,
):
do_1 = do_minimal_tth
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
do_2 = do_minimal_d
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)

assert np.allclose(
do_1.all_arrays, np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]])
)
assert np.allclose(
do_2.all_arrays, np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]])
)

def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
if operation == "add":
do_1_y_modified = do_1 + do_2
do_2_y_modified = do_2 + do_1
elif operation == "sub":
do_1_y_modified = do_1 - do_2
do_2_y_modified = do_2 - do_1
elif operation == "mul":
do_1_y_modified = do_1 * do_2
do_2_y_modified = do_2 * do_1
elif operation == "div":
do_1_y_modified = do_1 / do_2
do_2_y_modified = do_2 / do_1

assert np.allclose(do_1_y_modified.all_arrays, expected_do_1_all_arrays_with_y_modified)
assert np.allclose(do_2_y_modified.all_arrays, expected_do_2_all_arrays_with_y_modified)


def test_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
# Add a string to a DiffractionObject, expect TypeError
do = do_minimal_tth
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
do + "string_value"
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
"string_value" + do


def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
invalid_value = "string_value"
operations = [
(lambda x, y: x + y), # Test addition
(lambda x, y: x - y), # Test subtraction
(lambda x, y: x * y), # Test multiplication
(lambda x, y: x / y), # Test division
]
for operation in operations:
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
operation(do, invalid_value)
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
operation(invalid_value, do)


@pytest.mark.parametrize("operation", ["add", "sub", "mul", "div"])
def test_operator_invalid_yarray_length(operation, do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
# Add two DO objects with different yarray lengths, expect ValueError
do_1 = do_minimal
do_2 = do_minimal_tth
assert len(do_1.all_arrays[:, 0]) == 0
assert len(do_2.all_arrays[:, 0]) == 2
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
do_1 + do_2
if operation == "add":
do_1 + do_2
elif operation == "sub":
do_1 - do_2
elif operation == "mul":
do_1 * do_2
elif operation == "div":
do_1 / do_2

0 comments on commit 90fd625

Please sign in to comment.