Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor __add__ operation in DiffractionObject and add tests #285

Merged
merged 11 commits into from
Dec 29, 2024
23 changes: 23 additions & 0 deletions news/add-operations-tests.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* unit tests for __add__ operation for DiffractionObject

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
98 changes: 64 additions & 34 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]

x_grid_emsg = (
"objects are not on the same x-grid. You may add them using the self.add method "
"and specifying how to handle the mismatch."
y_grid_length_mismatch_emsg = (
"The two objects have different y-array lengths. "
"Please ensure the length of the y-value during initialization is identical."
)

invalid_add_type_emsg = (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
)


sbillinge marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -169,32 +175,56 @@ def __eq__(self, other):
return True

def __add__(self, other):
summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two DiffractionObject objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
"""Add a scalar value or another DiffractionObject to the yarray of the
DiffractionObject.

def __radd__(self, other):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

radd i think we don't need?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may. Are you sure? Anyway, we can test and see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct - added __radd__ back using __radd__ = __add__

please see a new test below for do + scalar as well as scalar + do

summed = deepcopy(self)
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
summed.on_tth[1] = self.on_tth[1] + other
summed.on_q[1] = self.on_q[1] + other
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to sum two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
else:
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
summed.on_q[1] = self.on_q[1] + other.on_q[1]
return summed
Parameters
----------
other : DiffractionObject or int or float
The object to add to the current DiffractionObject. If `other` is a scalar value,
it will be added to all yarray. The length of the yarray must match if `other` is
an instance of DiffractionObject.

Returns
-------
DiffractionObject
The new and deep-copied DiffractionObject instance after adding values to the yarray.

Raises
------
ValueError
Raised when the length of the yarray of the two DiffractionObject instances do not match.
TypeError
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.

Examples
--------
Add a scalar value to the yarray of the DiffractionObject instance:
>>> new_do = my_do + 10.1
>>> new_do = 10.1 + my_do
sbillinge marked this conversation as resolved.
Show resolved Hide resolved

Add the yarray of two DiffractionObject instances:
>>> new_do = my_do_1 + my_do_2
"""

self._check_operation_compatibility(other)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a private func that checks the validity of other

def _check_operation_compatibility(self, other):
        if not isinstance(other, (DiffractionObject, int, float)):
            raise TypeError(invalid_add_type_emsg)
        if isinstance(other, DiffractionObject):
            self_yarray = self.all_arrays[:, 0]
            other_yarray = other.all_arrays[:, 0]
            if len(self_yarray) != len(other_yarray):
                raise ValueError(y_grid_length_mismatch_emsg)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! it may be more readable if we use

if shape(self.all_arrays) != shape(other.all_arrays)

to accomplish the same thing?

summed_do = deepcopy(self)
if isinstance(other, (int, float)):
summed_do._all_arrays[:, 0] += other
if isinstance(other, DiffractionObject):
summed_do._all_arrays[:, 0] += other.all_arrays[:, 0]
return summed_do

__radd__ = __add__
sbillinge marked this conversation as resolved.
Show resolved Hide resolved

def _check_operation_compatibility(self, other):
if not isinstance(other, (DiffractionObject, int, float)):
raise TypeError(invalid_add_type_emsg)
if isinstance(other, DiffractionObject):
self_yarray = self.all_arrays[:, 0]
other_yarray = other.all_arrays[:, 0]
if len(self_yarray) != len(other_yarray):
raise ValueError(y_grid_length_mismatch_emsg)

def __sub__(self, other):
subtracted = deepcopy(self)
Expand All @@ -204,7 +234,7 @@ def __sub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
Expand All @@ -218,7 +248,7 @@ def __rsub__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to subtract two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
Expand All @@ -232,7 +262,7 @@ def __mul__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
Expand All @@ -244,7 +274,7 @@ def __rmul__(self, other):
multiplied.on_tth[1] = other * self.on_tth[1]
multiplied.on_q[1] = other * self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
Expand All @@ -258,7 +288,7 @@ def __truediv__(self, other):
elif not isinstance(other, DiffractionObject):
raise TypeError("I only know how to multiply two Scattering_object objects")
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
divided.on_q[1] = self.on_q[1] / other.on_q[1]
Expand All @@ -270,7 +300,7 @@ def __rtruediv__(self, other):
divided.on_tth[1] = other / self.on_tth[1]
divided.on_q[1] = other / self.on_q[1]
elif self.on_tth[0].all() != other.on_tth[0].all():
raise RuntimeError(x_grid_emsg)
raise RuntimeError(y_grid_length_mismatch_emsg)
else:
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
divided.on_q[1] = other.on_q[1] / self.on_q[1]
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def do_minimal_tth():
return DiffractionObject(wavelength=2 * np.pi, xarray=np.array([30, 60]), yarray=np.array([1, 2]), xtype="tth")


@pytest.fixture
def do_minimal_d():
# Create an instance of DiffractionObject with non-empty xarray, yarray, and wavelength values
return DiffractionObject(wavelength=1.54, xarray=np.array([1, 2]), yarray=np.array([1, 2]), xtype="d")


@pytest.fixture
def wavelength_warning_msg():
return (
Expand All @@ -63,3 +69,20 @@ def invalid_q_or_d_or_wavelength_error_msg():
"The supplied input array and wavelength will result in an impossible two-theta. "
"Please check these values and re-instantiate the DiffractionObject with correct values."
)


@pytest.fixture
def invalid_add_type_error_msg():
return (
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
"Please rerun by adding another DiffractionObject instance or a scalar value. "
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
)


@pytest.fixture
def y_grid_size_mismatch_error_msg():
return (
"The two objects have different y-array lengths. "
"Please ensure the length of the y-value during initialization is identical."
)
75 changes: 75 additions & 0 deletions tests/test_diffraction_objects.py
sbillinge marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,78 @@ def test_copy_object(do_minimal):
do_copy = do.copy()
assert do == do_copy
assert id(do) != id(do_copy)


@pytest.mark.parametrize(
"starting_all_arrays, scalar_to_add, expected_all_arrays",
[
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
( # C1: Add integer of 5, expect yarray to increase by by 5
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5,
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
),
( # C2: Add float of 5.1, expect yarray to be added by 5.1
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
5.1,
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
),
],
)
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
do = do_minimal_tth
assert np.allclose(do.all_arrays, starting_all_arrays)
do_scalar_right_sum = do + scalar_to_add
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
do_scalar_left_sum = scalar_to_add + do
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)

sbillinge marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize(
"do_1_all_arrays, "
"do_2_all_arrays, "
"expected_do_1_all_arrays_with_y_summed, "
"expected_do_2_all_arrays_with_y_summed",
[
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
( # C1: Add two DO objects, expect sum of yarray values
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
),
],
)
def test_addition_operator_by_another_do(
do_1_all_arrays,
do_2_all_arrays,
expected_do_1_all_arrays_with_y_summed,
expected_do_2_all_arrays_with_y_summed,
do_minimal_tth,
do_minimal_d,
):
do_1 = do_minimal_tth
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
do_2 = do_minimal_d
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)


def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
do = do_minimal_tth
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
do + "string_value"
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
"string_value" + do


def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
do_1 = do_minimal
do_2 = do_minimal_tth
assert len(do_1.all_arrays[:, 0]) == 0
assert len(do_2.all_arrays[:, 0]) == 2
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
do_1 + do_2
Loading