diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 868ce6005c0..c1278220cc4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,7 +29,7 @@ New Features - Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This includes ``datatree`` support, and removing slashes from dimension names. By `Miguel Jimenez-Urias `_. -- Improved support pandas Extension Arrays. (:issue:`9661`, :pull:`9671`) +- Improved support pandas categorical extension as indices (i.e., :py:class:`pandas.IntervalIndex`). (:issue:`9661`, :pull:`9671`) By `Ilan Gold `_. - Improved checks and errors raised when trying to align objects with conflicting indexes. It is now possible to align objects each with multiple indexes sharing common dimension(s). @@ -52,6 +52,7 @@ Breaking changes now return objects indexed by :py:meth:`pandas.IntervalArray` objects, instead of numpy object arrays containing tuples. This change enables interval-aware indexing of such Xarray objects. (:pull:`9671`). By `Ilan Gold `_. +- Remove ``PandasExtensionArrayIndex`` from :py:attr:`xarray.Variable.data` when the attribute is a :py:class:`pandas.api.extensions.ExtensionArray` (:pull:`10263`). By `Ilan Gold `_. - The html and text ``repr`` for ``DataTree`` are now truncated. Up to 6 children are displayed for each node -- the first 3 and the last 3 children -- with a ``...`` between them. The number of children to include in the display is configurable via options. For instance use diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 625078dbea9..88274203361 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7091,21 +7091,21 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): { **dict(zip(non_extension_array_columns, data, strict=True)), **{ - c: self.variables[c].data.array + c: self.variables[c].data for c in extension_array_columns_same_index }, }, index=index, ) for extension_array_column in extension_array_columns_different_index: - extension_array = self.variables[extension_array_column].data.array + extension_array = self.variables[extension_array_column].data index = self[ self.variables[extension_array_column].dims[0] ].coords.to_index() extension_array_df = pd.DataFrame( {extension_array_column: extension_array}, index=pd.Index(index.array) - if isinstance(index, PandasExtensionArray) + if isinstance(index, PandasExtensionArray) # type: ignore[redundant-expr] else index, ) extension_array_df.index.name = self.variables[extension_array_column].dims[ diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index e8006a4c8c3..269016ddfd1 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,13 +1,16 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import Generic, cast import numpy as np import pandas as pd +from packaging.version import Version from pandas.api.types import is_extension_array_dtype from xarray.core.types import DTypeLikeSave, T_ExtensionArray +from xarray.core.utils import NDArrayMixin HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} @@ -33,12 +36,12 @@ def __extension_duck_array__issubdtype( def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): if shape[0] == len(arr) and len(shape) == 1: return arr - raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + raise NotImplementedError("Cannot broadcast 1d-only pandas extension array.") @implements(np.stack) def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): - raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + raise NotImplementedError("Cannot stack 1d-only pandas extension array.") @implements(np.concatenate) @@ -62,21 +65,22 @@ def __extension_duck_array__where( return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) -class PandasExtensionArray(Generic[T_ExtensionArray]): - array: T_ExtensionArray +@dataclass(frozen=True) +class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): + """NEP-18 compliant wrapper for pandas extension arrays. + + Parameters + ---------- + array : T_ExtensionArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + ``` + """ - def __init__(self, array: T_ExtensionArray): - """NEP-18 compliant wrapper for pandas extension arrays. + array: T_ExtensionArray - Parameters - ---------- - array : T_ExtensionArray - The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. - ``` - """ - if not isinstance(array, pd.api.extensions.ExtensionArray): - raise TypeError(f"{array} is not an pandas ExtensionArray.") - self.array = array + def __post_init__(self): + if not isinstance(self.array, pd.api.extensions.ExtensionArray): + raise TypeError(f"{self.array} is not an pandas ExtensionArray.") def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: @@ -105,19 +109,13 @@ def replace_duck_with_extension_array(args) -> list: def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) - def __repr__(self): - return f"PandasExtensionArray(array={self.array!r})" - - def __getattr__(self, attr: str) -> object: - return getattr(self.array, attr) - def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] if is_extension_array_dtype(item): - return type(self)(item) - if np.isscalar(item): - return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed - return item + return PandasExtensionArray(item) + if np.isscalar(item) or isinstance(key, int): + return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore] + return PandasExtensionArray(item) def __setitem__(self, key, val): self.array[key] = val @@ -132,3 +130,15 @@ def __ne__(self, other): def __len__(self): return len(self.array) + + @property + def ndim(self) -> int: + return 1 + + def __array__( + self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + ) -> np.ndarray: + if Version(np.__version__) >= Version("2.0.0"): + return np.asarray(self.array, dtype=dtype, copy=copy) + else: + return np.asarray(self.array, dtype=dtype) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index f713f535647..7aa333ffb2e 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -626,6 +626,8 @@ def short_array_repr(array): if isinstance(array, AbstractArray): array = array.data + if isinstance(array, pd.api.extensions.ExtensionArray): + return repr(array) array = to_duck_array(array) # default to lower precision so a full (abbreviated) line can fit on diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index aa56006eff3..97e8c8df8af 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -19,7 +19,6 @@ from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform -from xarray.core.extension_array import PandasExtensionArray from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -37,6 +36,7 @@ from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: + from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexes import Index from xarray.core.types import Self from xarray.core.variable import Variable @@ -1797,6 +1797,8 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray: # We return an PandasExtensionArray wrapper type that satisfies # duck array protocols. This is what's needed for tests to pass. if pd.api.types.is_extension_array_dtype(self.array): + from xarray.core.extension_array import PandasExtensionArray + return PandasExtensionArray(self.array.array) return np.asarray(self) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cc6b18d3a31..b8b33997780 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -410,12 +410,20 @@ def data(self): Variable.as_numpy Variable.values """ - if is_duck_array(self._data): - return self._data + if isinstance(self._data, PandasExtensionArray): + duck_array = self._data.array elif isinstance(self._data, indexing.ExplicitlyIndexed): - return self._data.get_duck_array() + duck_array = self._data.get_duck_array() + elif is_duck_array(self._data): + duck_array = self._data else: - return self.values + duck_array = self.values + if isinstance(duck_array, PandasExtensionArray): + # even though PandasExtensionArray is a duck array, + # we should not return the PandasExtensionArray wrapper, + # and instead return the underlying data. + return duck_array.array + return duck_array @data.setter def data(self, data: T_DuckArray | ArrayLike) -> None: @@ -1366,7 +1374,7 @@ def set_dims(self, dim, shape=None): elif shape is not None: dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) - expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) + expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape) else: indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) expanded_data = self.data[indexer] diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 31024d72e60..0394151cdeb 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -60,7 +60,9 @@ def assert_writeable(ds): name for name, var in ds.variables.items() if not isinstance(var, IndexVariable) - and not isinstance(var.data, PandasExtensionArray) + and not isinstance( + var.data, PandasExtensionArray | pd.api.extensions.ExtensionArray + ) and not var.data.flags.writeable ] assert not readonly, readonly diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 5f484ec6d07..49c6490d819 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -160,10 +160,10 @@ def test_concat_categorical() -> None: concatenated = concat([data1, data2], dim="dim1") assert ( concatenated["var4"] - == type(data2["var4"].variable.data.array)._concat_same_type( + == type(data2["var4"].variable.data)._concat_same_type( [ - data1["var4"].variable.data.array, - data2["var4"].variable.data.array, + data1["var4"].variable.data, + data2["var4"].variable.data, ] ) ).all() diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f48e0bb6d00..52c60a77066 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1826,6 +1826,12 @@ def test_categorical_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() + def test_extension_array_reindex_same(self) -> None: + series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) + test = xr.Dataset({"test": series}) + res = test.reindex(dim_0=series.index) + align(res, test, join="exact") + def test_categorical_multiindex(self) -> None: i1 = pd.Series([0, 0]) cat = pd.CategoricalDtype(categories=["foo", "baz", "bar"]) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index dcf8349aba4..ff84041f8f1 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -196,8 +196,8 @@ def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2): concatenated = concatenate( (PandasExtensionArray(arrow1), PandasExtensionArray(arrow2)) ) - assert concatenated[2]["x"] == 3 - assert concatenated[3]["y"] + assert concatenated[2].array[0]["x"] == 3 + assert concatenated[3].array[0]["y"] def test___getitem__extension_duck_array(self, categorical1): extension_duck_array = PandasExtensionArray(categorical1) @@ -1094,8 +1094,3 @@ def test_extension_array_singleton_equality(categorical1): def test_extension_array_repr(int1): int_duck_array = PandasExtensionArray(int1) assert repr(int1) in repr(int_duck_array) - - -def test_extension_array_attr(int1): - int_duck_array = PandasExtensionArray(int1) - assert (~int_duck_array.fillna(10)).all() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 892816c5b6a..f920eda3f76 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -815,7 +815,7 @@ def test_groupby_getitem(dataset) -> None: assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y")[1]) with pytest.raises( - NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." + NotImplementedError, match="Cannot broadcast 1d-only pandas extension array." ): dataset.groupby("boo") dataset = dataset.drop_vars(["cat"])