diff --git a/sparse/mlir_backend/_constructors.py b/sparse/mlir_backend/_constructors.py index b382c4f2..1e068188 100644 --- a/sparse/mlir_backend/_constructors.py +++ b/sparse/mlir_backend/_constructors.py @@ -364,6 +364,12 @@ def __del__(self): for field in self._obj.get__fields_(): free_memref(field) + def __getitem__(self, key) -> "Tensor": + # imported lazily to avoid cyclic dependency + from ._ops import getitem + + return getitem(self, key) + @_hold_self_ref_in_ret def to_scipy_sparse(self) -> sps.sparray | np.ndarray: return self._obj.to_sps(self.shape) diff --git a/sparse/mlir_backend/_ops.py b/sparse/mlir_backend/_ops.py index 963bbd1c..8feb7cc0 100644 --- a/sparse/mlir_backend/_ops.py +++ b/sparse/mlir_backend/_ops.py @@ -1,4 +1,5 @@ import ctypes +from types import EllipsisType import mlir.execution_engine import mlir.passmanager @@ -85,12 +86,39 @@ def get_reshape_module( def reshape(a, shape): return tensor.reshape(out_tensor_type, a, shape) - reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - if DEBUG: - (CWD / "reshape_module.mlir").write_text(str(module)) - pm.run(module.operation) - if DEBUG: - (CWD / "reshape_module_opt.mlir").write_text(str(module)) + reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + if DEBUG: + (CWD / "reshape_module.mlir").write_text(str(module)) + pm.run(module.operation) + if DEBUG: + (CWD / "reshape_module_opt.mlir").write_text(str(module)) + + return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) + + +@fn_cache +def get_slice_module( + in_tensor_type: ir.RankedTensorType, + out_tensor_type: ir.RankedTensorType, + offsets: tuple[int, ...], + sizes: tuple[int, ...], + strides: tuple[int, ...], +) -> ir.Module: + with ir.Location.unknown(ctx): + module = ir.Module.create() + + with ir.InsertionPoint(module.body): + + @func.FuncOp.from_py_func(in_tensor_type) + def getitem(a): + return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides) + + getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + if DEBUG: + (CWD / "getitem_module.mlir").write_text(str(module)) + pm.run(module.operation) + if DEBUG: + (CWD / "getitem_module_opt.mlir").write_text(str(module)) return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) @@ -195,3 +223,80 @@ def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) -> ) return Tensor(ret_obj, shape=shape) + + +def _add_missing_dims(key: tuple, ndim: int) -> tuple: + if len(key) < ndim and Ellipsis not in key: + return key + (...,) + return key + + +def _expand_ellipsis(key: tuple, ndim: int) -> tuple: + if Ellipsis in key: + if len([e for e in key if e is Ellipsis]) > 1: + raise Exception(f"Ellipsis should be used once: {key}") + to_expand = ndim - len(key) + 1 + if to_expand <= 0: + raise Exception(f"Invalid use of Ellipsis in {key}") + idx = key.index(Ellipsis) + return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :] + return key + + +def _decompose_slices( + key: tuple, + shape: tuple[int, ...], +) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + offsets = [] + sizes = [] + strides = [] + + for key_elem, size in zip(key, shape, strict=False): + if isinstance(key_elem, slice): + offset = key_elem.start if key_elem.start is not None else 0 + size = key_elem.stop - offset if key_elem.stop is not None else size - offset + stride = key_elem.step if key_elem.step is not None else 1 + elif isinstance(key_elem, int): + offset = key_elem + size = key_elem + 1 + stride = 1 + offsets.append(offset) + sizes.append(size) + strides.append(stride) + + return tuple(offsets), tuple(sizes), tuple(strides) + + +def _get_new_shape(sizes, strides) -> tuple[int, ...]: + return tuple(size // stride for size, stride in zip(sizes, strides, strict=False)) + + +def getitem( + x: Tensor, + key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...], +) -> Tensor: + if not isinstance(key, tuple): + key = (key,) + if None in key: + raise Exception(f"Lazy indexing isn't supported: {key}") + + ret_obj = x._format_class() + + key = _add_missing_dims(key, x.ndim) + key = _expand_ellipsis(key, x.ndim) + offsets, sizes, strides = _decompose_slices(key, x.shape) + + new_shape = _get_new_shape(sizes, strides) + out_tensor_type = x._obj.get_tensor_definition(new_shape) + + slice_module = get_slice_module( + x._obj.get_tensor_definition(x.shape), + out_tensor_type, + offsets, + sizes, + strides, + ) + + slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg()) + + return Tensor(ret_obj, shape=out_tensor_type.shape) diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py index 98ac90f2..e717f39b 100644 --- a/sparse/mlir_backend/tests/test_simple.py +++ b/sparse/mlir_backend/tests/test_simple.py @@ -341,3 +341,39 @@ def test_broadcast_to(dtype): assert result.format == "csr" np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0)) + + +@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404") +@parametrize_dtypes +@pytest.mark.parametrize( + "index", + [ + 0, + (2,), + (2, 3), + (..., slice(0, 4, 2)), + (1, slice(1, None, 1)), + # TODO: For below cases we need an update to ownership mechanism. + # `tensor[:, :]` returns the same memref that was passed. + # The mechanism sees the result as MLIR-allocated and frees + # it, while it still can be owned by SciPy/NumPy causing a + # segfault when it frees SciPy/NumPy managed memory. + # ..., + # slice(None), + # (slice(None), slice(None)), + ], +) +def test_indexing_2d(rng, dtype, index): + SHAPE = (20, 30) + DENSITY = 0.5 + + for format in ["csr", "csc", "coo"]: + arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng) + arr.sum_duplicates() + + tensor = sparse.asarray(arr) + + actual = tensor[index].to_scipy_sparse() + expected = arr.todense()[index] + + np.testing.assert_array_equal(actual.todense(), expected)