From 919e0031b229812026544560294bb1bf2fa4756e Mon Sep 17 00:00:00 2001 From: hollymandel Date: Fri, 4 Oct 2024 08:33:20 -0700 Subject: [PATCH 01/13] initial typing --- xarray/core/missing.py | 83 +++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 4523e4f8232..549d9754953 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Hashable, Sequence from functools import partial from numbers import Number -from typing import TYPE_CHECKING, Any, get_args +from typing import TYPE_CHECKING, Any, Optional, get_args import numpy as np import pandas as pd @@ -29,11 +29,12 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.variable import IndexVariable def _get_nan_block_lengths( - obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable -): + obj: Dataset | DataArray, dim: Hashable, index: Variable +) -> Any: """ Return an object where each NaN element in 'obj' is replaced by the length of the gap the element is in. @@ -66,12 +67,12 @@ class BaseInterpolator: cons_kwargs: dict[str, Any] call_kwargs: dict[str, Any] f: Callable - method: str + method: str | int - def __call__(self, x): + def __call__(self, x: np.ndarray) -> np.ndarray: return self.f(x, **self.call_kwargs) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}: method={self.method}" @@ -83,7 +84,14 @@ class NumpyInterpolator(BaseInterpolator): numpy.interp """ - def __init__(self, xi, yi, method="linear", fill_value=None, period=None): + def __init__( + self, + xi: Variable, + yi: np.ndarray, + method: Optional[str] = "linear", + fill_value=None, + period=None, + ): if method != "linear": raise ValueError("only method `linear` is valid for the NumpyInterpolator") @@ -104,8 +112,8 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None): self._left = fill_value[0] self._right = fill_value[1] elif is_scalar(fill_value): - self._left = fill_value - self._right = fill_value + self._left = fill_value # type: ignore[assignment] + self._right = fill_value # type: ignore[assignment] else: raise ValueError(f"{fill_value} is not a valid fill_value") @@ -130,14 +138,14 @@ class ScipyInterpolator(BaseInterpolator): def __init__( self, - xi, - yi, - method=None, - fill_value=None, - assume_sorted=True, - copy=False, - bounds_error=False, - order=None, + xi: Variable, + yi: np.ndarray, + method: Optional[str | int] = None, + fill_value: Optional[float | complex] = None, + assume_sorted: bool = True, + copy: bool = False, + bounds_error: bool = False, + order: Optional[int] = None, axis=-1, **kwargs, ): @@ -154,18 +162,13 @@ def __init__( raise ValueError("order is required when method=polynomial") method = order - self.method = method + self.method: str | int = method self.cons_kwargs = kwargs self.call_kwargs = {} nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j - if fill_value is None and method == "linear": - fill_value = nan, nan - elif fill_value is None: - fill_value = nan - self.f = interp1d( xi, yi, @@ -601,7 +604,12 @@ def _floatize_x(x, new_x): return x, new_x -def interp(var, indexes_coords, method: InterpOptions, **kwargs): +def interp( + var: Variable, + indexes_coords: dict[str, IndexVariable], + method: InterpOptions, + **kwargs, +) -> Variable: """Make an interpolation of Variable Parameters @@ -662,7 +670,13 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): return result -def interp_func(var, x, new_x, method: InterpOptions, kwargs): +def interp_func( + var: np.ndarray, + x: list[IndexVariable], + new_x: list[IndexVariable], + method: InterpOptions, + kwargs: dict, +) -> np.ndarray: """ multi-dimensional interpolation for array-like. Interpolated axes should be located in the last position. @@ -766,9 +780,14 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): return _interpnd(var, x, new_x, func, kwargs) -def _interp1d(var, x, new_x, func, kwargs): +def _interp1d( + var: np.ndarray, + x: IndexVariable, + new_x: IndexVariable, + func: Callable, + kwargs: dict, +) -> np.ndarray: # x, new_x are tuples of size 1. - x, new_x = x[0], new_x[0] rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: return reshape(rslt, (var.shape[:-1] + new_x.shape)) @@ -777,11 +796,17 @@ def _interp1d(var, x, new_x, func, kwargs): return rslt -def _interpnd(var, x, new_x, func, kwargs): +def _interpnd( + var: np.ndarray, + x: list[IndexVariable], + new_x: list[IndexVariable], + func: Callable, + kwargs: dict, +) -> np.ndarray: x, new_x = _floatize_x(x, new_x) if len(x) == 1: - return _interp1d(var, x, new_x, func, kwargs) + return _interp1d(var, x[0], new_x[0], func, kwargs) # move the interpolation axes to the start position var = var.transpose(range(-len(x), var.ndim - len(x))) From d00cd9391fa8e85f8479a949e224e4e949884887 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Fri, 4 Oct 2024 16:14:39 -0700 Subject: [PATCH 02/13] typing --- pyproject.toml | 3 +- xarray/coding/cftimeindex.py | 2 +- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 10 ++--- xarray/core/extension_array.py | 10 +++-- xarray/core/indexes.py | 12 +++--- xarray/core/missing.py | 66 +++++++++++++++----------------- xarray/core/utils.py | 4 +- xarray/core/variable.py | 2 +- xarray/groupers.py | 2 +- xarray/namedarray/daskmanager.py | 8 ++-- 11 files changed, 61 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c23d12ffba1..a84946d4123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ module = [ "netCDF4.*", "netcdftime.*", "opt_einsum.*", + "pandas.*", "pint.*", "pooch.*", "pyarrow.*", @@ -178,7 +179,7 @@ module = [ "xarray.tests.test_units", "xarray.tests.test_utils", "xarray.tests.test_variable", - "xarray.tests.test_weighted", + "xarray.tests.test_weighted" ] # Use strict = true whenever namedarray has become standalone. In the meantime diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index e85fa2736b2..5b11292ce30 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -517,7 +517,7 @@ def contains(self, key: Any) -> bool: """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift( # type: ignore[override] # freq is typed Any, we are more precise + def shift( # freq is typed Any, we are more precise self, periods: int | float, freq: str | timedelta | BaseCFTimeOffset | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d460e492c6..e37e74e48a3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3032,7 +3032,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") - level_number = idx._get_level_number(level) # type: ignore[attr-defined] + level_number = idx._get_level_number(level) variables = idx.levels[level_number] variable_dim = idx.names[level_number] diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..28f0ce16c61 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6629,7 +6629,7 @@ def interpolate_na( | None ) = None, **kwargs: Any, - ) -> Self: + ) -> Dataset: """Fill in NaNs by interpolating according to different methods. Parameters @@ -6760,7 +6760,7 @@ def interpolate_na( ) return new - def ffill(self, dim: Hashable, limit: int | None = None) -> Self: + def ffill(self, dim: Hashable, limit: int | None = None) -> Dataset: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -6824,7 +6824,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self: new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self, dim: Hashable, limit: int | None = None) -> Self: + def bfill(self, dim: Hashable, limit: int | None = None) -> Dataset: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -7523,7 +7523,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: if isinstance(idx, pd.MultiIndex): dims = tuple( - name if name is not None else "level_%i" % n # type: ignore[redundant-expr] + name if name is not None else "level_%i" % n for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels, strict=True): @@ -9829,7 +9829,7 @@ def eval( c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ - return pd.eval( # type: ignore[return-value] + return pd.eval( statement, resolvers=[self], target=self, diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index b2efeae7bb0..7a6b30417b0 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] + return type(arrays[0])._concat_same_type(arrays) @implements(np.where) @@ -57,8 +57,8 @@ def __extension_duck_array__where( and isinstance(y, pd.Categorical) and x.dtype != y.dtype ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] - y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) @@ -116,7 +116,9 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed + return type(self)( + type(self.array)([item]) + ) # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..5e9af04d6b2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -740,7 +740,7 @@ def isel( # scalar indexer: drop index return None - return self._replace(self.index[indxr]) # type: ignore[index] + return self._replace(self.index[indxr]) def sel( self, labels: dict[Any, Any], method=None, tolerance=None @@ -926,7 +926,7 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: return cast(T_PDIndex, new_index) if isinstance(index, pd.CategoricalIndex): - return index.remove_unused_categories() # type: ignore[attr-defined] + return index.remove_unused_categories() return index @@ -1164,7 +1164,7 @@ def create_variables( dtype = None else: level = name - dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? + dtype = self.level_coords_dtype[name] # TODO: are Hashables ok? var = variables.get(name, None) if var is not None: @@ -1174,7 +1174,9 @@ def create_variables( attrs = {} encoding = {} - data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok? + data = PandasMultiIndexingAdapter( + self.index, dtype=dtype, level=level + ) # TODO: are Hashables ok? index_vars[name] = IndexVariable( self.dim, data, @@ -1671,7 +1673,7 @@ def copy_indexes( convert_new_idx = False xr_idx = idx - new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment] + new_idx = xr_idx._copy(deep=deep, memo=memo) idx_vars = xr_idx.create_variables(coords) if convert_new_idx: diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 549d9754953..404767af3c6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.variable import IndexVariable def _get_nan_block_lengths( @@ -146,7 +145,7 @@ def __init__( copy: bool = False, bounds_error: bool = False, order: Optional[int] = None, - axis=-1, + axis: int = -1, **kwargs, ): from scipy.interpolate import interp1d @@ -167,8 +166,6 @@ def __init__( self.cons_kwargs = kwargs self.call_kwargs = {} - nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j - self.f = interp1d( xi, yi, @@ -192,13 +189,13 @@ class SplineInterpolator(BaseInterpolator): def __init__( self, - xi, - yi, - method="spline", - fill_value=None, - order=3, - nu=0, - ext=None, + xi: Variable, + yi: np.ndarray, + method: Optional[str | int] = "spline", + fill_value: Optional[float | complex] = None, + order: int = 3, + nu: Optional[float] = 0, + ext: Optional[int | str] = None, **kwargs, ): from scipy.interpolate import UnivariateSpline @@ -216,7 +213,9 @@ def __init__( self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs) -def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): +def _apply_over_vars_with_dim( + func: Callable, self: Dataset, dim: Optional[Hashable] = None, **kwargs +) -> Dataset: """Wrapper for datasets""" ds = type(self)(coords=self.coords, attrs=self.attrs) @@ -606,7 +605,7 @@ def _floatize_x(x, new_x): def interp( var: Variable, - indexes_coords: dict[str, IndexVariable], + indexes_coords: dict[Hashable, tuple[Any, Any]], method: InterpOptions, **kwargs, ) -> Variable: @@ -671,9 +670,9 @@ def interp( def interp_func( - var: np.ndarray, - x: list[IndexVariable], - new_x: list[IndexVariable], + var: DataArray, + x: tuple[Variable, ...], + new_x: tuple[Variable, ...], method: InterpOptions, kwargs: dict, ) -> np.ndarray: @@ -683,13 +682,10 @@ def interp_func( Parameters ---------- - var : np.ndarray or dask.array.Array - Array to be interpolated. The final dimension is interpolated. - x : a list of 1d array. - Original coordinates. Should not contain NaN. - new_x : a list of 1d array - New coordinates. Should not contain NaN. - method : string + var : Array to be interpolated. The final dimension is interpolated. + x : Original coordinates. Should not contain NaN. + new_x : New coordinates. Should not contain NaN. + method : {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima', 'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation @@ -710,7 +706,7 @@ def interp_func( scipy.interpolate.interp1d """ if not x: - return var.copy() + return var.data.copy() if len(x) == 1: func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs) @@ -727,11 +723,11 @@ def interp_func( # blockwise args format x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] - x_arginds = [item for pair in x_arginds for item in pair] + x_arginds = [item for pair in x_arginds for item in pair] # type: ignore[misc] new_x_arginds = [ [_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x ] - new_x_arginds = [item for pair in new_x_arginds for item in pair] + new_x_arginds = [item for pair in new_x_arginds for item in pair] # type: ignore[misc] args = (var, range(ndim), *x_arginds, *new_x_arginds) @@ -741,13 +737,13 @@ def interp_func( elem for pair in zip(rechunked, args[1::2], strict=True) for elem in pair ) - new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] # type: ignore[assignment] new_x0_chunks = new_x[0].chunks new_x0_shape = new_x[0].shape new_x0_chunks_is_not_none = new_x0_chunks is not None new_axes = { - ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] + ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] # type: ignore[index] for i in range(new_x[0].ndim) } @@ -757,7 +753,7 @@ def interp_func( # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: if not issubclass(var.dtype.type, np.inexact): - dtype = float + dtype = np.dtype(float) else: dtype = var.dtype @@ -772,18 +768,18 @@ def interp_func( localize=localize, concatenate=True, dtype=dtype, - new_axes=new_axes, + new_axes=new_axes, # type: ignore[arg-type] meta=meta, align_arrays=False, ) - return _interpnd(var, x, new_x, func, kwargs) + return _interpnd(var.data, x, new_x, func, kwargs) def _interp1d( var: np.ndarray, - x: IndexVariable, - new_x: IndexVariable, + x: Variable, + new_x: Variable, func: Callable, kwargs: dict, ) -> np.ndarray: @@ -798,8 +794,8 @@ def _interp1d( def _interpnd( var: np.ndarray, - x: list[IndexVariable], - new_x: list[IndexVariable], + x: tuple[Variable, ...], + new_x: tuple[Variable, ...], func: Callable, kwargs: dict, ) -> np.ndarray: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..7c09571d937 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -132,7 +132,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: if not is_valid_numpy_dtype(array.dtype): return np.dtype("O") - return array.dtype # type: ignore[return-value] + return array.dtype def maybe_coerce_to_str(index, original_coords): @@ -180,7 +180,7 @@ def equivalent(first: T, second: T) -> bool: return duck_array_ops.array_equiv(first, second) if isinstance(first, list) or isinstance(second, list): return list_equiv(first, second) # type: ignore[arg-type] - return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload] + return (first == second) or (pd.isnull(first) and pd.isnull(second)) def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d8cf0fe7550..492b5a8c68e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -151,7 +151,7 @@ def as_variable( ) from error elif utils.is_scalar(obj): obj = Variable([], obj) - elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None: + elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None: # type: ignore[redundant-expr] obj = Variable(obj.name, obj) elif isinstance(obj, set | dict): raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..ce614997a4d 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -296,7 +296,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: data = np.asarray(group.data) # Cast _DummyGroup data to array - binned, self.bins = pd.cut( # type: ignore [call-overload] + binned, self.bins = pd.cut( data.ravel(), bins=self.bins, right=self.right, diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index a056f4e00bd..32ec5ce6c88 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -21,13 +21,13 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray[Any, Any] + DaskArray = np.ndarray[Any, Any] # type: ignore[misc,assignment] dask_available = module_available("dask") -class DaskManager(ChunkManagerEntrypoint["DaskArray"]): +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] array_cls: type[DaskArray] available: bool = dask_available @@ -91,7 +91,7 @@ def array_api(self) -> Any: return da - def reduction( + def reduction( # type: ignore[override] self, arr: T_ChunkedArray, func: Callable[..., Any], @@ -113,7 +113,7 @@ def reduction( keepdims=keepdims, ) # type: ignore[no-untyped-call] - def scan( + def scan( # type: ignore[override] self, func: Callable[..., Any], binop: Callable[..., Any], From c5c9db652ae589bf53a1ee1d927532356d919727 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 08:24:32 -0700 Subject: [PATCH 03/13] reverting ignore flags due to pandas stubs --- xarray/coding/cftimeindex.py | 2 +- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 10 +++++----- xarray/core/extension_array.py | 10 ++++------ xarray/core/indexes.py | 12 +++++------- xarray/core/utils.py | 4 ++-- xarray/core/variable.py | 2 +- xarray/groupers.py | 2 +- xarray/namedarray/daskmanager.py | 8 ++++---- 9 files changed, 24 insertions(+), 28 deletions(-) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 5b11292ce30..e85fa2736b2 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -517,7 +517,7 @@ def contains(self, key: Any) -> bool: """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift( # freq is typed Any, we are more precise + def shift( # type: ignore[override] # freq is typed Any, we are more precise self, periods: int | float, freq: str | timedelta | BaseCFTimeOffset | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e37e74e48a3..8d460e492c6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3032,7 +3032,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") - level_number = idx._get_level_number(level) + level_number = idx._get_level_number(level) # type: ignore[attr-defined] variables = idx.levels[level_number] variable_dim = idx.names[level_number] diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 28f0ce16c61..a7dedd2ed07 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6629,7 +6629,7 @@ def interpolate_na( | None ) = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -6760,7 +6760,7 @@ def interpolate_na( ) return new - def ffill(self, dim: Hashable, limit: int | None = None) -> Dataset: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -6824,7 +6824,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Dataset: new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self, dim: Hashable, limit: int | None = None) -> Dataset: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -7523,7 +7523,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: if isinstance(idx, pd.MultiIndex): dims = tuple( - name if name is not None else "level_%i" % n + name if name is not None else "level_%i" % n # type: ignore[redundant-expr] for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels, strict=True): @@ -9829,7 +9829,7 @@ def eval( c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ - return pd.eval( + return pd.eval( # type: ignore[return-value] statement, resolvers=[self], target=self, diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 7a6b30417b0..b2efeae7bb0 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) + return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] @implements(np.where) @@ -57,8 +57,8 @@ def __extension_duck_array__where( and isinstance(y, pd.Categorical) and x.dtype != y.dtype ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) + x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] + y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) @@ -116,9 +116,7 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)( - type(self.array)([item]) - ) # only subclasses with proper __init__ allowed + return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5e9af04d6b2..5abc2129e3e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -740,7 +740,7 @@ def isel( # scalar indexer: drop index return None - return self._replace(self.index[indxr]) + return self._replace(self.index[indxr]) # type: ignore[index] def sel( self, labels: dict[Any, Any], method=None, tolerance=None @@ -926,7 +926,7 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: return cast(T_PDIndex, new_index) if isinstance(index, pd.CategoricalIndex): - return index.remove_unused_categories() + return index.remove_unused_categories() # type: ignore[attr-defined] return index @@ -1164,7 +1164,7 @@ def create_variables( dtype = None else: level = name - dtype = self.level_coords_dtype[name] # TODO: are Hashables ok? + dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? var = variables.get(name, None) if var is not None: @@ -1174,9 +1174,7 @@ def create_variables( attrs = {} encoding = {} - data = PandasMultiIndexingAdapter( - self.index, dtype=dtype, level=level - ) # TODO: are Hashables ok? + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok? index_vars[name] = IndexVariable( self.dim, data, @@ -1673,7 +1671,7 @@ def copy_indexes( convert_new_idx = False xr_idx = idx - new_idx = xr_idx._copy(deep=deep, memo=memo) + new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment] idx_vars = xr_idx.create_variables(coords) if convert_new_idx: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 7c09571d937..e5168342e1e 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -132,7 +132,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: if not is_valid_numpy_dtype(array.dtype): return np.dtype("O") - return array.dtype + return array.dtype # type: ignore[return-value] def maybe_coerce_to_str(index, original_coords): @@ -180,7 +180,7 @@ def equivalent(first: T, second: T) -> bool: return duck_array_ops.array_equiv(first, second) if isinstance(first, list) or isinstance(second, list): return list_equiv(first, second) # type: ignore[arg-type] - return (first == second) or (pd.isnull(first) and pd.isnull(second)) + return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload] def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 492b5a8c68e..d8cf0fe7550 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -151,7 +151,7 @@ def as_variable( ) from error elif utils.is_scalar(obj): obj = Variable([], obj) - elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None: # type: ignore[redundant-expr] + elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None: obj = Variable(obj.name, obj) elif isinstance(obj, set | dict): raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") diff --git a/xarray/groupers.py b/xarray/groupers.py index ce614997a4d..e4cb884e6de 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -296,7 +296,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: data = np.asarray(group.data) # Cast _DummyGroup data to array - binned, self.bins = pd.cut( + binned, self.bins = pd.cut( # type: ignore [call-overload] data.ravel(), bins=self.bins, right=self.right, diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 32ec5ce6c88..a056f4e00bd 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -21,13 +21,13 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray[Any, Any] # type: ignore[misc,assignment] + DaskArray = np.ndarray[Any, Any] dask_available = module_available("dask") -class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): array_cls: type[DaskArray] available: bool = dask_available @@ -91,7 +91,7 @@ def array_api(self) -> Any: return da - def reduction( # type: ignore[override] + def reduction( self, arr: T_ChunkedArray, func: Callable[..., Any], @@ -113,7 +113,7 @@ def reduction( # type: ignore[override] keepdims=keepdims, ) # type: ignore[no-untyped-call] - def scan( # type: ignore[override] + def scan( self, func: Callable[..., Any], binop: Callable[..., Any], From 407b8ca23890e71fec6d4f2e715c0159c62cbbe7 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 08:26:15 -0700 Subject: [PATCH 04/13] putting back pandas typing --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a84946d4123..345f4d500a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,6 @@ module = [ "netCDF4.*", "netcdftime.*", "opt_einsum.*", - "pandas.*", "pint.*", "pooch.*", "pyarrow.*", From 54945e934c68e467e6f73d4038b3dac8335fca76 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 08:26:51 -0700 Subject: [PATCH 05/13] comma --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 345f4d500a5..c23d12ffba1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,7 +178,7 @@ module = [ "xarray.tests.test_units", "xarray.tests.test_utils", "xarray.tests.test_variable", - "xarray.tests.test_weighted" + "xarray.tests.test_weighted", ] # Use strict = true whenever namedarray has become standalone. In the meantime From 33c3898ebcb8b1af36975dfa838b75cde5c0fcc3 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 08:33:30 -0700 Subject: [PATCH 06/13] suppress errors to allow self typing --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..017afa64c2a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6758,7 +6758,7 @@ def interpolate_na( max_gap=max_gap, **kwargs, ) - return new + return new # type: ignore[return-value] def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward @@ -6822,7 +6822,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self: from xarray.core.missing import _apply_over_vars_with_dim, ffill new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) - return new + return new # type: ignore[return-value] def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward @@ -6887,7 +6887,7 @@ def bfill(self, dim: Hashable, limit: int | None = None) -> Self: from xarray.core.missing import _apply_over_vars_with_dim, bfill new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) - return new + return new # type: ignore[return-value] def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. From 21c7add553389060c5e8ea33a372dfb83c7e9c88 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 08:39:01 -0700 Subject: [PATCH 07/13] optional to bar --- xarray/core/missing.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 404767af3c6..8bd074278e8 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Hashable, Sequence from functools import partial from numbers import Number -from typing import TYPE_CHECKING, Any, Optional, get_args +from typing import TYPE_CHECKING, Any, get_args import numpy as np import pandas as pd @@ -87,7 +87,7 @@ def __init__( self, xi: Variable, yi: np.ndarray, - method: Optional[str] = "linear", + method: str | None = "linear", fill_value=None, period=None, ): @@ -139,12 +139,12 @@ def __init__( self, xi: Variable, yi: np.ndarray, - method: Optional[str | int] = None, - fill_value: Optional[float | complex] = None, + method: str | int | None = None, + fill_value: float | complex | None = None, assume_sorted: bool = True, copy: bool = False, bounds_error: bool = False, - order: Optional[int] = None, + order: int | None = None, axis: int = -1, **kwargs, ): @@ -191,11 +191,11 @@ def __init__( self, xi: Variable, yi: np.ndarray, - method: Optional[str | int] = "spline", - fill_value: Optional[float | complex] = None, + method: str | int | None = "spline", + fill_value: float | complex | None = None, order: int = 3, - nu: Optional[float] = 0, - ext: Optional[int | str] = None, + nu: float | None = 0, + ext: int | str | None = None, **kwargs, ): from scipy.interpolate import UnivariateSpline @@ -214,7 +214,7 @@ def __init__( def _apply_over_vars_with_dim( - func: Callable, self: Dataset, dim: Optional[Hashable] = None, **kwargs + func: Callable, self: Dataset, dim: Hashable | None = None, **kwargs ) -> Dataset: """Wrapper for datasets""" ds = type(self)(coords=self.coords, attrs=self.attrs) From da2c75afdb8275bd16e9bb53c6109e3a63d226f2 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 10 Oct 2024 11:27:41 -0700 Subject: [PATCH 08/13] mapping types --- xarray/core/missing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 8bd074278e8..1ac3ddb6fe2 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,7 +2,7 @@ import datetime as dt import warnings -from collections.abc import Callable, Hashable, Sequence +from collections.abc import Callable, Hashable, Mapping, Sequence from functools import partial from numbers import Number from typing import TYPE_CHECKING, Any, get_args @@ -605,7 +605,7 @@ def _floatize_x(x, new_x): def interp( var: Variable, - indexes_coords: dict[Hashable, tuple[Any, Any]], + indexes_coords: Mapping[Any, tuple[Any, Any]], method: InterpOptions, **kwargs, ) -> Variable: @@ -674,7 +674,7 @@ def interp_func( x: tuple[Variable, ...], new_x: tuple[Variable, ...], method: InterpOptions, - kwargs: dict, + kwargs: dict[str, Any], ) -> np.ndarray: """ multi-dimensional interpolation for array-like. Interpolated axes should be @@ -781,7 +781,7 @@ def _interp1d( x: Variable, new_x: Variable, func: Callable, - kwargs: dict, + kwargs: dict[str, Any], ) -> np.ndarray: # x, new_x are tuples of size 1. rslt = func(x, var, **kwargs)(np.ravel(new_x)) @@ -797,7 +797,7 @@ def _interpnd( x: tuple[Variable, ...], new_x: tuple[Variable, ...], func: Callable, - kwargs: dict, + kwargs: dict[str, Any], ) -> np.ndarray: x, new_x = _floatize_x(x, new_x) From 5c603f030349eb51706c1203522ede25a0bcef58 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Mon, 14 Oct 2024 15:27:00 -0700 Subject: [PATCH 09/13] further typing --- xarray/core/dataarray.py | 6 ++-- xarray/core/missing.py | 69 +++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d460e492c6..95d33d2ab65 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3592,7 +3592,7 @@ def interpolate_na( """ from xarray.core.missing import interp_na - return interp_na( + return interp_na( # type: ignore[return-value] self, dim=dim, method=method, @@ -3685,7 +3685,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """ from xarray.core.missing import ffill - return ffill(self, dim, limit=limit) + return ffill(self, dim, limit=limit) # type: ignore[return-value] def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward @@ -3769,7 +3769,7 @@ def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """ from xarray.core.missing import bfill - return bfill(self, dim, limit=limit) + return bfill(self, dim, limit=limit) # type: ignore[return-value] def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 1ac3ddb6fe2..b46b084e843 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from numpy.typing import ArrayLike from xarray.core import utils from xarray.core.common import _contains_datetime_like_objects, ones_like @@ -20,7 +21,13 @@ timedelta_to_numeric, ) from xarray.core.options import _get_keep_attrs -from xarray.core.types import Interp1dOptions, InterpOptions +from xarray.core.types import ( + Interp1dOptions, + InterpOptions, + ScalarOrArray, + T_DuckArray, + T_Xarray, +) from xarray.core.utils import OrderedSet, is_scalar from xarray.core.variable import Variable, broadcast_variables from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -32,7 +39,7 @@ def _get_nan_block_lengths( - obj: Dataset | DataArray, dim: Hashable, index: Variable + obj: Dataset | DataArray, dim: Hashable, index: T_DuckArray | ArrayLike ) -> Any: """ Return an object where each NaN element in 'obj' is replaced by the @@ -68,7 +75,7 @@ class BaseInterpolator: f: Callable method: str | int - def __call__(self, x: np.ndarray) -> np.ndarray: + def __call__(self, x: ScalarOrArray) -> np.ndarray: # dask array will return np return self.f(x, **self.call_kwargs) def __repr__(self) -> str: @@ -85,11 +92,11 @@ class NumpyInterpolator(BaseInterpolator): def __init__( self, - xi: Variable, - yi: np.ndarray, + xi: T_DuckArray, # passed to np.asarray + yi: T_DuckArray, # passed to np.asarray + requires dtype attribute method: str | None = "linear", - fill_value=None, - period=None, + fill_value: float | complex | None = None, + period: float | None = None, ): if method != "linear": raise ValueError("only method `linear` is valid for the NumpyInterpolator") @@ -137,8 +144,8 @@ class ScipyInterpolator(BaseInterpolator): def __init__( self, - xi: Variable, - yi: np.ndarray, + xi: T_DuckArray, + yi: T_DuckArray, method: str | int | None = None, fill_value: float | complex | None = None, assume_sorted: bool = True, @@ -189,8 +196,8 @@ class SplineInterpolator(BaseInterpolator): def __init__( self, - xi: Variable, - yi: np.ndarray, + xi: T_DuckArray, + yi: T_DuckArray, method: str | int | None = "spline", fill_value: float | complex | None = None, order: int = 3, @@ -229,13 +236,16 @@ def _apply_over_vars_with_dim( def get_clean_interp_index( - arr, dim: Hashable, use_coordinate: Hashable | bool = True, strict: bool = True -): + arr: T_Xarray, + dim: Hashable, + use_coordinate: Hashable | bool = True, + strict: bool = True, +) -> np.ndarray[Any, np.dtype[np.float64]]: """Return index to use for x values in interpolation or curve fitting. Parameters ---------- - arr : DataArray + arr : DataArray or Dataset Array to interpolate or fit to a curve. dim : str Name of dimension along which to fit. @@ -268,13 +278,13 @@ def get_clean_interp_index( index = arr.get_index(dim) else: # string - index = arr.coords[use_coordinate] + index = arr.coords[use_coordinate] # type: ignore[assignment] if index.ndim != 1: raise ValueError( f"Coordinates used for interpolation must be 1D, " f"{use_coordinate} is {index.ndim}D." ) - index = index.to_index() + index = index.to_index() # type: ignore[attr-defined] # TODO: index.name is None for multiindexes # set name for nice error messages below @@ -293,15 +303,16 @@ def get_clean_interp_index( if isinstance(index, CFTimeIndex | pd.DatetimeIndex): offset = type(index[0])(1970, 1, 1) if isinstance(index, CFTimeIndex): - index = index.values - index = Variable( + index = index.values # type: ignore[assignment] + index = Variable( # type: ignore[assignment] data=datetime_to_numeric(index, offset=offset, datetime_unit="ns"), dims=(dim,), ) # raise if index cannot be cast to a float (e.g. MultiIndex) try: - index = index.values.astype(np.float64) + # this step ensures output is ndarray + index = index.values.astype(np.float64) # type: ignore[assignment] except (TypeError, ValueError) as err: # pandas raises a TypeError # xarray/numpy raise a ValueError @@ -310,11 +321,11 @@ def get_clean_interp_index( f"interpolation or curve fitting, got {type(index).__name__}." ) from err - return index + return index # type: ignore[return-value] def interp_na( - self, + self: T_Xarray, dim: Hashable | None = None, use_coordinate: bool | str = True, method: InterpOptions = "linear", @@ -324,7 +335,7 @@ def interp_na( ) = None, keep_attrs: bool | None = None, **kwargs, -): +) -> T_Xarray: """Interpolate values according to different methods.""" from xarray.coding.cftimeindex import CFTimeIndex @@ -392,14 +403,16 @@ def interp_na( return arr -def func_interpolate_na(interpolator, y, x, **kwargs): +def func_interpolate_na( + interpolator: Callable, y: T_Xarray, x: np.ndarray, **kwargs +) -> T_Xarray: """helper function to apply interpolation along 1 dimension""" # reversed arguments are so that attrs are preserved from da, not index # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below out = y.copy() - nans = pd.isnull(y) + nans = pd.isnull(y) # type: ignore[call-overload] nonans = ~nans # fast track for no-nans, all nan but one, and all-nans cases @@ -423,7 +436,9 @@ def _bfill(arr, n=None, axis=-1): return np.flip(arr, axis=axis) -def ffill(arr, dim=None, limit=None): +def ffill( + arr: T_Xarray, dim: Hashable | None = None, limit: int | None = None +) -> T_Xarray: """forward fill missing values""" axis = arr.get_axis_num(dim) @@ -441,7 +456,9 @@ def ffill(arr, dim=None, limit=None): ).transpose(*arr.dims) -def bfill(arr, dim=None, limit=None): +def bfill( + arr: T_Xarray, dim: Hashable | None = None, limit: int | None = None +) -> T_Xarray: """backfill missing values""" axis = arr.get_axis_num(dim) From dcd50593e288a655535f05cd004652f30732eff9 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Mon, 14 Oct 2024 19:52:53 -0700 Subject: [PATCH 10/13] more typing --- xarray/core/dataset.py | 2 +- xarray/core/missing.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 017afa64c2a..dd129c501ce 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4077,7 +4077,7 @@ def _validate_interp_indexer(x, new_x): # optimization: subset to coordinate range of the target index if method in ["linear", "nearest"]: for k, v in validated_indexers.items(): - obj, newidx = missing._localize(obj, {k: v}) + obj, newidx = missing._localize(obj, {k: v}) # type: ignore[assignment, arg-type] validated_indexers[k] = newidx[k] # optimization: create dask coordinate arrays once per Dataset diff --git a/xarray/core/missing.py b/xarray/core/missing.py index b46b084e843..3319624c6dd 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -368,6 +368,7 @@ def interp_na( # method index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate) + interp_class: type[BaseInterpolator] interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) @@ -476,7 +477,7 @@ def bfill( ).transpose(*arr.dims) -def _import_interpolant(interpolant, method): +def _import_interpolant(interpolant: str, method: InterpOptions): """Import interpolant from scipy.interpolate.""" try: from scipy import interpolate @@ -488,18 +489,16 @@ def _import_interpolant(interpolant, method): def _get_interpolator( method: InterpOptions, vectorizeable_only: bool = False, **kwargs -): +) -> tuple[type[BaseInterpolator], dict[str, Any]]: """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class """ - interp_class: ( - type[NumpyInterpolator] | type[ScipyInterpolator] | type[SplineInterpolator] - ) - interp1d_methods = get_args(Interp1dOptions) valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) + interp_class: type[BaseInterpolator] + # prefer numpy.interp for 1d linear interpolation. This function cannot # take higher dimensional data but scipy.interp1d can. if ( @@ -550,7 +549,9 @@ def _get_interpolator( return interp_class, kwargs -def _get_interpolator_nd(method, **kwargs): +def _get_interpolator_nd( + method: InterpOptions, **kwargs +) -> tuple[type[BaseInterpolator], dict[str, Any]]: """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class @@ -570,7 +571,7 @@ def _get_interpolator_nd(method, **kwargs): return interp_class, kwargs -def _get_valid_fill_mask(arr, dim, limit): +def _get_valid_fill_mask(arr: T_Xarray, dim: Hashable, limit: int) -> T_Xarray: """helper function to determine values that can be filled when limit is not None""" kw = {dim: limit + 1} @@ -578,13 +579,21 @@ def _get_valid_fill_mask(arr, dim, limit): new_dim = utils.get_temp_dimname(arr.dims, "_window") return ( arr.isnull() - .rolling(min_periods=1, **kw) + .rolling( + kw, + min_periods=1, + ) .construct(new_dim, fill_value=False) - .sum(new_dim, skipna=False) + .sum(dim=new_dim, skipna=False) # type: ignore[arg-type] ) <= limit -def _localize(var, indexes_coords): +def _localize( + var: Variable, + indexes_coords: dict[ + Any, tuple[Variable, Variable] + ], # indexes_coords altered so can't use mapping type +) -> tuple[Variable, dict[Any, tuple[Variable, Variable]]]: """Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation """ @@ -596,8 +605,11 @@ def _localize(var, indexes_coords): index = x.to_index() imin, imax = index.get_indexer([minval, maxval], method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) - indexes_coords[dim] = (x[indexes[dim]], new_x) - return var.isel(**indexes), indexes_coords + indexes_coords[dim] = ( + x[indexes[dim]], + new_x, + ) # probably the in-place modification is unintentional here? + return var.isel(indexers=indexes), indexes_coords def _floatize_x(x, new_x): From ceabc298a06af19ef2ecadf746a3326997383f5b Mon Sep 17 00:00:00 2001 From: hollymandel Date: Tue, 15 Oct 2024 08:55:48 -0700 Subject: [PATCH 11/13] through 850 --- xarray/core/missing.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 3319624c6dd..b8b7fe846ac 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -75,7 +75,7 @@ class BaseInterpolator: f: Callable method: str | int - def __call__(self, x: ScalarOrArray) -> np.ndarray: # dask array will return np + def __call__(self, x: ScalarOrArray) -> ArrayLike: # dask array will return np return self.f(x, **self.call_kwargs) def __repr__(self) -> str: @@ -612,13 +612,13 @@ def _localize( return var.isel(indexers=indexes), indexes_coords -def _floatize_x(x, new_x): +def _floatize_x(x: tuple[Variable, ...], new_x: tuple[Variable, ...]): """Make x and new_x float. This is particularly useful for datetime dtype. x, new_x: tuple of np.ndarray """ - x = list(x) - new_x = list(new_x) + x = list(x) # type: ignore[assignment] + new_x = list(new_x) # type: ignore[assignment] for i in range(len(x)): if _contains_datetime_like_objects(x[i]): # Scipy casts coordinates to np.float64, which is not accurate @@ -627,8 +627,8 @@ def _floatize_x(x, new_x): # offset (min(x)) and the variation (x - min(x)) can be # represented by float. xmin = x[i].values.min() - x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64) - new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64) + x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64) # type: ignore[index] + new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64) # type: ignore[index] return x, new_x @@ -699,12 +699,12 @@ def interp( def interp_func( - var: DataArray, + var: type[T_DuckArray], x: tuple[Variable, ...], new_x: tuple[Variable, ...], method: InterpOptions, kwargs: dict[str, Any], -) -> np.ndarray: +) -> ArrayLike: """ multi-dimensional interpolation for array-like. Interpolated axes should be located in the last position. @@ -742,8 +742,9 @@ def interp_func( else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if is_chunked_array(var): - chunkmanager = get_chunked_array_type(var) + # is_chunked_array typed using named_array, unsure of relationship to duck arrays + if is_chunked_array(var): # type: ignore[arg-type] + chunkmanager = get_chunked_array_type(var) # duck compatible ndim = var.ndim nconst = ndim - len(x) @@ -786,6 +787,8 @@ def interp_func( else: dtype = var.dtype + # mypy may flag this in the future since _meta is not a property + # of duck arrays--only inside this conditional if a dask array meta = var._meta return chunkmanager.blockwise( @@ -806,12 +809,12 @@ def interp_func( def _interp1d( - var: np.ndarray, - x: Variable, + var: type[T_DuckArray], + x: type[T_DuckArray], new_x: Variable, func: Callable, kwargs: dict[str, Any], -) -> np.ndarray: +) -> ArrayLike: # x, new_x are tuples of size 1. rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: @@ -822,12 +825,12 @@ def _interp1d( def _interpnd( - var: np.ndarray, - x: tuple[Variable, ...], + var: type[T_DuckArray], + x: tuple[T_DuckArray, ...], new_x: tuple[Variable, ...], func: Callable, kwargs: dict[str, Any], -) -> np.ndarray: +) -> ArrayLike: x, new_x = _floatize_x(x, new_x) if len(x) == 1: From bd18893dfd9bed0d61b1e2cc3699381245155a58 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Tue, 15 Oct 2024 11:28:06 -0700 Subject: [PATCH 12/13] moving duck array --- xarray/core/missing.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index b8b7fe846ac..fb2f2e24477 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,7 +2,7 @@ import datetime as dt import warnings -from collections.abc import Callable, Hashable, Mapping, Sequence +from collections.abc import Callable, Generator, Hashable, Mapping, Sequence from functools import partial from numbers import Number from typing import TYPE_CHECKING, Any, get_args @@ -34,9 +34,14 @@ from xarray.namedarray.pycompat import is_chunked_array if TYPE_CHECKING: + from typing import TypeVar + from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + # T_DuckArray can't be a function parameter because covariant + XarrayLike = TypeVar("XarrayLike", Variable, DataArray, Dataset) + def _get_nan_block_lengths( obj: Dataset | DataArray, dim: Hashable, index: T_DuckArray | ArrayLike @@ -593,7 +598,7 @@ def _localize( indexes_coords: dict[ Any, tuple[Variable, Variable] ], # indexes_coords altered so can't use mapping type -) -> tuple[Variable, dict[Any, tuple[Variable, Variable]]]: +) -> tuple[Variable | T_Xarray, dict[Any, tuple[Variable, Variable]]]: """Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation """ @@ -809,8 +814,8 @@ def interp_func( def _interp1d( - var: type[T_DuckArray], - x: type[T_DuckArray], + var: XarrayLike, + x: Variable, new_x: Variable, func: Callable, kwargs: dict[str, Any], @@ -825,8 +830,8 @@ def _interp1d( def _interpnd( - var: type[T_DuckArray], - x: tuple[T_DuckArray, ...], + var: XarrayLike, + x: tuple[Variable, ...], new_x: tuple[Variable, ...], func: Callable, kwargs: dict[str, Any], @@ -846,7 +851,13 @@ def _interpnd( return reshape(rslt, rslt.shape[:-1] + new_x[0].shape) -def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): +def _chunked_aware_interpnd( + var: Variable, + *coords, + interp_func: Callable, + interp_kwargs: dict[str, Any], + localize: bool = True, +) -> ArrayLike: """Wrapper for `_interpnd` through `blockwise` for chunked arrays. The first half arrays in `coords` are original coordinates, @@ -856,11 +867,13 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T nconst = len(var.shape) - n_x # _interpnd expect coords to be Variables - x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] - new_x = [ + x = tuple( + Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x]) + ) + new_x = tuple( Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) for _x in coords[n_x:] - ] + ) if localize: # _localize expect var to be a Variable @@ -871,7 +884,7 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T } # simple speed up for the local interpolation - var, indexes_coords = _localize(var, indexes_coords) + var, indexes_coords = _localize(var, indexes_coords) # type: ignore[assignment] x, new_x = zip(*[indexes_coords[d] for d in indexes_coords], strict=True) # put var back as a ndarray @@ -880,7 +893,9 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T return _interpnd(var, x, new_x, interp_func, interp_kwargs) -def decompose_interp(indexes_coords): +def decompose_interp( + indexes_coords: Mapping[Any, tuple[Any, Any]] +) -> Generator[Mapping[Any, tuple[Any, Any]]]: """Decompose the interpolation into a succession of independent interpolation keeping the order""" dest_dims = [ @@ -888,7 +903,7 @@ def decompose_interp(indexes_coords): for dim, dest in indexes_coords.items() ] partial_dest_dims = [] - partial_indexes_coords = {} + partial_indexes_coords: dict[Any, tuple[Any, Any]] = {} for i, index_coords in enumerate(indexes_coords.items()): partial_indexes_coords.update([index_coords]) From fde2b3df061bd51e02c6fa2f47ab3bda5027f6fe Mon Sep 17 00:00:00 2001 From: hollymandel Date: Tue, 15 Oct 2024 11:30:43 -0700 Subject: [PATCH 13/13] undoing code affecting change --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index fb2f2e24477..30a95edf80e 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -810,7 +810,7 @@ def interp_func( align_arrays=False, ) - return _interpnd(var.data, x, new_x, func, kwargs) + return _interpnd(var, x, new_x, func, kwargs) def _interp1d(