From 7310715613a0842067fc1e3fb20909fbe588c53b Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 10 Oct 2024 11:04:25 +0200 Subject: [PATCH] `Dataset.reduce` pass through non-numeric scalars --- xarray/core/dataset.py | 22 +++++++++++----------- xarray/tests/test_dataset.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..36d0614d11f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7009,17 +7009,17 @@ def reduce( if not reduce_dims: variables[name] = var else: - if ( - # Some reduction functions (e.g. std, var) need to run on variables - # that don't have the reduce dims: PR5393 - not is_extension_array_dtype(var.dtype) - and ( - not reduce_dims - or not numeric_only - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_) - ) - ): + + is_numeric = (not is_extension_array_dtype(var.dtype)) and ( + np.issubdtype(var.dtype, np.number) or var.dtype == np.bool_ + ) + + # pass through non-numeric scalar + if numeric_only and not is_numeric and var.ndim == 0: + variables[name] = var + + elif not reduce_dims or not numeric_only or is_numeric: + # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1178498de19..65d2575ef68 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5661,14 +5661,35 @@ def test_reduce_non_numeric(self) -> None: data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) # var4 is extension array categorical and should be dropped - assert ( - "var4" not in data1.mean() - and "var5" not in data1.mean() - and "var6" not in data1.mean() - ) + + assert "var4" not in data1.mean() + assert "var5" not in data1.mean() + assert "var6" not in data1.mean() + assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") + + assert "var5" not in data1.mean(dim="dim2") + assert "var6" in data1.mean(dim="dim2") + + @pytest.mark.parametrize("op", ("sum", "prod", "mean", "std")) + def test_reduce_non_numeric_scalar(self, op) -> None: + # enusure non-numeric scalar is passed through + + data_orig = create_test_data(seed=44, dim_sizes=(1, 2, 3)) + + # add a scalar + data = data_orig.assign(var4="string") + + result = getattr(data, op)() + expected = getattr(data_orig, op)().assign(var4="string") + + assert_equal(result, expected) + + result = getattr(data, op)("dim1") + expected = getattr(data_orig, op)("dim1").assign(var4="string") + + assert_equal(result, expected) @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning"