From 0dff7cf1fd1e68e421f03fde5db32302bed12f7b Mon Sep 17 00:00:00 2001 From: Aaron Zuspan <50475791+aazuspan@users.noreply.github.com> Date: Mon, 16 May 2022 10:28:09 -0700 Subject: [PATCH] In `mosaic`, ignore `axis` if `dim` is given (#149) --- stackstac/ops.py | 6 ++++-- stackstac/tests/test_mosaic.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/stackstac/ops.py b/stackstac/ops.py index 502650e..f5daf94 100644 --- a/stackstac/ops.py +++ b/stackstac/ops.py @@ -174,8 +174,8 @@ def mosaic( dim: The dimension name to mosaic. Default: None. axis: - The axis number to mosaic. Default: 0. Only one of - ``dim`` and ``axis`` can be given. + The axis number to mosaic. Default: 0. If ``dim`` is given, ``axis`` + is ignored. reverse: If False (default), the last item along the dimension is on top. If True, the first item in the dimension is on top. @@ -207,6 +207,8 @@ def mosaic( f"since {nodata} cannot exist in that dtype. " ) + axis = None if dim is not None else axis + func = ( partial(_mosaic_dask, split_every=split_every) if isinstance(arr.data, da.Array) diff --git a/stackstac/tests/test_mosaic.py b/stackstac/tests/test_mosaic.py index 9378f99..a420055 100644 --- a/stackstac/tests/test_mosaic.py +++ b/stackstac/tests/test_mosaic.py @@ -46,9 +46,10 @@ def test_mosaic_dtype_error(dtype: np.dtype): st_stc.raster_dtypes, st_np.array_shapes(max_dims=4, max_side=5), st.booleans(), + st.booleans(), ) def test_fuzz_mosaic( - data: st.DataObject, dtype: np.dtype, shape: Tuple[int, ...], reverse: bool + data: st.DataObject, dtype: np.dtype, shape: Tuple[int, ...], reverse: bool, use_dim: bool, ): """ See if we can break mosaic. @@ -73,8 +74,13 @@ def test_fuzz_mosaic( split_every = data.draw(st.integers(1, darr.numblocks[axis]), label="split_every") xarr = xr.DataArray(darr) + if use_dim: + kwargs = dict(dim=xarr.dims[axis]) + else: + kwargs = dict(axis=axis) + result = mosaic( - xarr, axis=axis, reverse=reverse, nodata=fill_value, split_every=split_every + xarr, reverse=reverse, nodata=fill_value, split_every=split_every, **kwargs ) assert result.dtype == arr.dtype result_np = mosaic(xr.DataArray(arr), axis=axis, reverse=reverse, nodata=fill_value)