Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor __add__ operation in DiffractionObject and add tests #285

Merged
merged 11 commits into from
Dec 29, 2024
23 changes: 23 additions & 0 deletions news/add-operations-tests.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* unit tests for __add__ operation for DiffractionObject

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
103 changes: 59 additions & 44 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]

x_grid_emsg = (
"objects are not on the same x-grid. You may add them using the self.add method "
"and specifying how to handle the mismatch."
x_grid_length_mismatch_emsg = (
"The two objects have different x-array lengths. "
"Please ensure the length of the x-value during initialization is identical."
)

invalid_add_type_emsg = (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for completeness, lets also add a radd as an example here. We could say "to add 10 to all intensities use..."

)


sbillinge marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -169,32 +175,53 @@ def __eq__(self, other):
return True

def __add__(self, other):
summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two DiffractionObject objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
"""Add a scalar value or another DiffractionObject to the xarrays of
the DiffractionObject.

def __radd__(self, other):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

radd i think we don't need?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may. Are you sure? Anyway, we can test and see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct - added __radd__ back using __radd__ = __add__

please see a new test below for do + scalar as well as scalar + do

summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
Parameters
----------
other : DiffractionObject or int or float
The object to add to the current DiffractionObject. If `other` is a scalar value,
it will be added to all xarrays. The length of the xarrays must match if `other` is
an instance of DiffractionObject.

Returns
-------
DiffractionObject
The new and deep-copied DiffractionObject instance after adding values to the xarrays.

Raises
------
ValueError
Raised when the length of the xarrays of the two DiffractionObject instances do not match.
TypeError
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.

Examples
--------
Add a scalar value to the xarrays of the DiffractionObject instance:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add the radd example here too.

>>> new_do = my_do + 10.1

Add the xarrays of two DiffractionObject instances:
>>> new_do = my_do_1 + my_do_2
"""

summed_do = deepcopy(self)
# Add scalar value to all xarrays by broadcasting
if isinstance(other, (int, float)):
summed_do._all_arrays[:, 1] += other
summed_do._all_arrays[:, 2] += other
summed_do._all_arrays[:, 3] += other
# Add xarrays of two DiffractionObject instances
elif isinstance(other, DiffractionObject):
if len(self.on_tth()[0]) != len(other.on_tth()[0]):
raise ValueError(x_grid_length_mismatch_emsg)
summed_do._all_arrays[:, 1] += other.on_q()[0]
summed_do._all_arrays[:, 2] += other.on_tth()[0]
summed_do._all_arrays[:, 3] += other.on_d()[0]
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
raise TypeError(invalid_add_type_emsg)
return summed_do

def __sub__(self, other):
subtracted = deepcopy(self)
Expand All @@ -204,7 +231,7 @@ def __sub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(x_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
Expand All @@ -218,7 +245,7 @@ def __rsub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(x_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
Expand All @@ -232,19 +259,7 @@ def __mul__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
return multiplied

def __rmul__(self, other):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed since identical as __mul__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test and make sure it is ok to remove this. On balance, I probably want to leave the 'r' functionalities

multiplied = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
multiplied.on_tth[1] = other * self.on_tth[1]
multiplied.on_q[1] = other * self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(x_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
Expand All @@ -258,7 +273,7 @@ def __truediv__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(x_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
divided.on_q[1] = self.on_q[1] / other.on_q[1]
Expand All @@ -270,7 +285,7 @@ def __rtruediv__(self, other):
divided.on_tth[1] = other / self.on_tth[1]
divided.on_q[1] = other / self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(x_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
divided.on_q[1] = other.on_q[1] / self.on_q[1]
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,20 @@ def invalid_q_or_d_or_wavelength_error_msg():
"The supplied input array and wavelength will result in an impossible two-theta. "
"Please check these values and re-instantiate the DiffractionObject with correct values."
)


@pytest.fixture
def invalid_add_type_error_msg():
return (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10"
)


@pytest.fixture
def x_grid_size_mismatch_error_msg():
return (
"The two objects have different x-array lengths. "
"Please ensure the length of the x-value during initialization is identical."
)
59 changes: 59 additions & 0 deletions tests/test_diffraction_objects.py
sbillinge marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,62 @@ def test_copy_object(do_minimal):
do_copy = do.copy()
assert do == do_copy
assert id(do) != id(do_copy)


@pytest.mark.parametrize(
"starting_all_arrays, scalar_value, expected_all_arrays",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. test adding do + number

[
# Test scalar addition to xarray values (q, tth, d) and expect no change to yarray values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the add operation is supposed to operate on the yarray leaving xarrays unaffected.

( # C1: Add integer of 5, expect xarray to increase by by 5
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5,
np.array([[1.0, 5.51763809, 35.0, 17.13818192], [2.0, 6.0, 65.0, 11.28318531]]),
),
( # C2: Add float of 5.1, expect xarray to be added by 5.1
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5.1,
np.array([[1.0, 5.61763809, 35.1, 17.23818192], [2.0, 6.1, 65.1, 11.38318531]]),
),
],
)
def test_addition_operator_by_scalar(starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
do = do_minimal_tth
assert np.allclose(do.all_arrays, starting_all_arrays)
do_sum = do + scalar_value
assert np.allclose(do_sum.all_arrays, expected_all_arrays)

sbillinge marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize(
"LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. test adding do1 + do2

[
# Test addition of two DO objects, expect combined xarray values (q, tth, d) and no change to yarray
sbillinge marked this conversation as resolved.
Show resolved Hide resolved
( # C1: Add two DO objects with identical xarray values, expect sum of xarray values
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
np.array([[1.0, 1.03527618, 60.0, 24.27636384], [2.0, 2.0, 120.0, 12.56637061]]),
),
],
)
def test_addition_operator_by_another_do(LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum, do_minimal_tth):
assert np.allclose(do_minimal_tth.all_arrays, LHS_all_arrays)
do_LHS = do_minimal_tth
do_RHS = do_minimal_tth
do_sum = do_LHS + do_RHS
assert np.allclose(do_LHS.all_arrays, LHS_all_arrays)
assert np.allclose(do_RHS.all_arrays, RHS_all_arrays)
assert np.allclose(do_sum.all_arrays, expected_all_arrays_sum)


def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
do_LHS = do_minimal_tth
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
do_LHS + "string_value"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string not allowed



def test_addition_operator_invalid_xarray_length(do_minimal, do_minimal_tth, x_grid_size_mismatch_error_msg):
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
do_LHS = do_minimal
do_RHS = do_minimal_tth
with pytest.raises(ValueError, match=re.escape(x_grid_size_mismatch_error_msg)):
do_LHS + do_RHS
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

different lengths of xarray, not allowed

Loading