-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor __add__
operation in DiffractionObject
and add tests
#285
Changes from 5 commits
7db3a4f
3d4841b
5d0ebcc
9741a8e
3f577d7
d561583
ac5a2f3
12848e8
11c4166
846c72a
da70bd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
**Added:** | ||
|
||
* unit tests for __add__ operation for DiffractionObject | ||
|
||
**Changed:** | ||
|
||
* <news item> | ||
|
||
**Deprecated:** | ||
|
||
* <news item> | ||
|
||
**Removed:** | ||
|
||
* <news item> | ||
|
||
**Fixed:** | ||
|
||
* <news item> | ||
|
||
**Security:** | ||
|
||
* <news item> |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,15 @@ | |
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES | ||
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"] | ||
|
||
x_grid_emsg = ( | ||
"objects are not on the same x-grid. You may add them using the self.add method " | ||
"and specifying how to handle the mismatch." | ||
x_grid_length_mismatch_emsg = ( | ||
"The two objects have different x-array lengths. " | ||
"Please ensure the length of the x-value during initialization is identical." | ||
) | ||
|
||
invalid_add_type_emsg = ( | ||
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. " | ||
"Please rerun by adding another DiffractionObject instance or a scalar value. " | ||
"e.g., my_do_1 + my_do_2 or my_do + 10" | ||
) | ||
|
||
|
||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -169,32 +175,53 @@ def __eq__(self, other): | |
return True | ||
|
||
def __add__(self, other): | ||
summed = deepcopy(self) | ||
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): | ||
summed.on_tth[1] = self.on_tth[1] + other | ||
summed.on_q[1] = self.on_q[1] + other | ||
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to sum two DiffractionObject objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
else: | ||
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] | ||
summed.on_q[1] = self.on_q[1] + other.on_q[1] | ||
return summed | ||
"""Add a scalar value or another DiffractionObject to the xarrays of | ||
the DiffractionObject. | ||
|
||
def __radd__(self, other): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. radd i think we don't need? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may. Are you sure? Anyway, we can test and see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are correct - added please see a new test below for |
||
summed = deepcopy(self) | ||
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): | ||
summed.on_tth[1] = self.on_tth[1] + other | ||
summed.on_q[1] = self.on_q[1] + other | ||
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to sum two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
Parameters | ||
---------- | ||
other : DiffractionObject or int or float | ||
The object to add to the current DiffractionObject. If `other` is a scalar value, | ||
it will be added to all xarrays. The length of the xarrays must match if `other` is | ||
an instance of DiffractionObject. | ||
|
||
Returns | ||
------- | ||
DiffractionObject | ||
The new and deep-copied DiffractionObject instance after adding values to the xarrays. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
Raised when the length of the xarrays of the two DiffractionObject instances do not match. | ||
TypeError | ||
Raised when the type of `other` is not an instance of DiffractionObject, int, or float. | ||
|
||
Examples | ||
-------- | ||
Add a scalar value to the xarrays of the DiffractionObject instance: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add the radd example here too. |
||
>>> new_do = my_do + 10.1 | ||
|
||
Add the xarrays of two DiffractionObject instances: | ||
>>> new_do = my_do_1 + my_do_2 | ||
""" | ||
|
||
summed_do = deepcopy(self) | ||
# Add scalar value to all xarrays by broadcasting | ||
if isinstance(other, (int, float)): | ||
summed_do._all_arrays[:, 1] += other | ||
summed_do._all_arrays[:, 2] += other | ||
summed_do._all_arrays[:, 3] += other | ||
# Add xarrays of two DiffractionObject instances | ||
elif isinstance(other, DiffractionObject): | ||
if len(self.on_tth()[0]) != len(other.on_tth()[0]): | ||
raise ValueError(x_grid_length_mismatch_emsg) | ||
summed_do._all_arrays[:, 1] += other.on_q()[0] | ||
summed_do._all_arrays[:, 2] += other.on_tth()[0] | ||
summed_do._all_arrays[:, 3] += other.on_d()[0] | ||
else: | ||
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] | ||
summed.on_q[1] = self.on_q[1] + other.on_q[1] | ||
return summed | ||
raise TypeError(invalid_add_type_emsg) | ||
return summed_do | ||
|
||
def __sub__(self, other): | ||
subtracted = deepcopy(self) | ||
|
@@ -204,7 +231,7 @@ def __sub__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to subtract two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(x_grid_length_mismatch_emsg) | ||
else: | ||
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1] | ||
subtracted.on_q[1] = self.on_q[1] - other.on_q[1] | ||
|
@@ -218,7 +245,7 @@ def __rsub__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to subtract two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(x_grid_length_mismatch_emsg) | ||
else: | ||
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1] | ||
subtracted.on_q[1] = other.on_q[1] - self.on_q[1] | ||
|
@@ -232,19 +259,7 @@ def __mul__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to multiply two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
else: | ||
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] | ||
multiplied.on_q[1] = self.on_q[1] * other.on_q[1] | ||
return multiplied | ||
|
||
def __rmul__(self, other): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed since identical as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should test and make sure it is ok to remove this. On balance, I probably want to leave the 'r' functionalities |
||
multiplied = deepcopy(self) | ||
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): | ||
multiplied.on_tth[1] = other * self.on_tth[1] | ||
multiplied.on_q[1] = other * self.on_q[1] | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(x_grid_length_mismatch_emsg) | ||
else: | ||
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] | ||
multiplied.on_q[1] = self.on_q[1] * other.on_q[1] | ||
|
@@ -258,7 +273,7 @@ def __truediv__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to multiply two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(x_grid_length_mismatch_emsg) | ||
else: | ||
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1] | ||
divided.on_q[1] = self.on_q[1] / other.on_q[1] | ||
|
@@ -270,7 +285,7 @@ def __rtruediv__(self, other): | |
divided.on_tth[1] = other / self.on_tth[1] | ||
divided.on_q[1] = other / self.on_q[1] | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(x_grid_length_mismatch_emsg) | ||
else: | ||
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1] | ||
divided.on_q[1] = other.on_q[1] / self.on_q[1] | ||
|
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -702,3 +702,62 @@ def test_copy_object(do_minimal): | |
do_copy = do.copy() | ||
assert do == do_copy | ||
assert id(do) != id(do_copy) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"starting_all_arrays, scalar_value, expected_all_arrays", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
[ | ||
# Test scalar addition to xarray values (q, tth, d) and expect no change to yarray values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the add operation is supposed to operate on the yarray leaving xarrays unaffected. |
||
( # C1: Add integer of 5, expect xarray to increase by by 5 | ||
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]), | ||
5, | ||
np.array([[1.0, 5.51763809, 35.0, 17.13818192], [2.0, 6.0, 65.0, 11.28318531]]), | ||
), | ||
( # C2: Add float of 5.1, expect xarray to be added by 5.1 | ||
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]), | ||
5.1, | ||
np.array([[1.0, 5.61763809, 35.1, 17.23818192], [2.0, 6.1, 65.1, 11.38318531]]), | ||
), | ||
], | ||
) | ||
def test_addition_operator_by_scalar(starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth): | ||
do = do_minimal_tth | ||
assert np.allclose(do.all_arrays, starting_all_arrays) | ||
do_sum = do + scalar_value | ||
assert np.allclose(do_sum.all_arrays, expected_all_arrays) | ||
|
||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@pytest.mark.parametrize( | ||
"LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
[ | ||
# Test addition of two DO objects, expect combined xarray values (q, tth, d) and no change to yarray | ||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
( # C1: Add two DO objects with identical xarray values, expect sum of xarray values | ||
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),), | ||
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),), | ||
np.array([[1.0, 1.03527618, 60.0, 24.27636384], [2.0, 2.0, 120.0, 12.56637061]]), | ||
), | ||
], | ||
) | ||
def test_addition_operator_by_another_do(LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum, do_minimal_tth): | ||
assert np.allclose(do_minimal_tth.all_arrays, LHS_all_arrays) | ||
do_LHS = do_minimal_tth | ||
do_RHS = do_minimal_tth | ||
do_sum = do_LHS + do_RHS | ||
assert np.allclose(do_LHS.all_arrays, LHS_all_arrays) | ||
assert np.allclose(do_RHS.all_arrays, RHS_all_arrays) | ||
assert np.allclose(do_sum.all_arrays, expected_all_arrays_sum) | ||
|
||
|
||
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg): | ||
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition | ||
do_LHS = do_minimal_tth | ||
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)): | ||
do_LHS + "string_value" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. string not allowed |
||
|
||
|
||
def test_addition_operator_invalid_xarray_length(do_minimal, do_minimal_tth, x_grid_size_mismatch_error_msg): | ||
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays | ||
do_LHS = do_minimal | ||
do_RHS = do_minimal_tth | ||
with pytest.raises(ValueError, match=re.escape(x_grid_size_mismatch_error_msg)): | ||
do_LHS + do_RHS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. different lengths of xarray, not allowed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for completeness, lets also add a radd as an example here. We could say "to add 10 to all intensities use..."