Skip to content

Commit

Permalink
Fix: handle scipy failures due to renaming of scipy.sparse modules in…
Browse files Browse the repository at this point in the history
… v1.8 (#6)

* Add private module aliases for scipy.sparse

Signed-off-by: Connor Tann <[email protected]>

* More verbose tests

Signed-off-by: Connor Tann <[email protected]>

* More comprehensive scipy fix

Signed-off-by: Connor Tann <[email protected]>

---------

Signed-off-by: Connor Tann <[email protected]>
  • Loading branch information
connortann authored Mar 9, 2024
1 parent 12964ec commit cf53c86
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
15 changes: 14 additions & 1 deletion slicer/slicer_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def default_alias(cls, o):


def _type_name(o: object) -> Tuple[str, str]:
return o.__class__.__module__, o.__class__.__name__
return _handle_module_aliases(o.__class__.__module__), o.__class__.__name__


def _safe_isinstance(
Expand All @@ -616,3 +616,16 @@ def _safe_isinstance(
return o_module == module_name and o_type == type_name
else:
return o_module == module_name and o_type in type_name


def _handle_module_aliases(module_name):
# scipy modules such as "scipy.sparse.csc" were renamed to "scipy.sparse._csc" in v1.8
# Standardise by removing underscores for compatibility with either name
# Else just pass module name unchanged
module_map = {
"scipy.sparse._csc": "scipy.sparse.csc",
"scipy.sparse._csr": "scipy.sparse.csr",
"scipy.sparse._dok": "scipy.sparse.dok",
"scipy.sparse._lil": "scipy.sparse.lil",
}
return module_map.get(module_name, module_name)
10 changes: 7 additions & 3 deletions slicer/test_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def test_slicer_sparse():

candidates = [csc_array, csr_array, dok_array, lil_array]
for candidate in candidates:
print("testing:", type(candidate))
slicer = S(candidate)
actual = slicer[0, 0]
assert ctr_eq(actual.o, 1)
Expand Down Expand Up @@ -341,7 +342,8 @@ def test_operations_1d():
array = np.array(elements)
torch_array = torch.tensor(elements)
containers = [li, tup, array, torch_array, di, series]
for _, ctr in enumerate(containers):
for ctr in containers:
print("testing:", type(ctr))
slicer = AtomicSlicer(ctr)

assert ctr_eq(slicer[0], elements[0])
Expand Down Expand Up @@ -371,7 +373,8 @@ def test_operations_2d():
sparse_lil = lil_matrix(elements)

containers = [li, df, sparse_csc, sparse_csr, sparse_dok, sparse_lil]
for _, ctr in enumerate(containers):
for ctr in containers:
print("testing:", type(ctr))
slicer = AtomicSlicer(ctr)

assert ctr_eq(slicer[0], elements[0])
Expand Down Expand Up @@ -432,7 +435,8 @@ def test_operations_3d():
list_of_multi_arrays,
di_of_multi_arrays,
]
for _, ctr in enumerate(containers):
for ctr in containers:
print("testing:", type(ctr))
slicer = AtomicSlicer(ctr)

assert ctr_eq(slicer[0], elements[0])
Expand Down

0 comments on commit cf53c86

Please sign in to comment.