Skip to content
Merged
15 changes: 13 additions & 2 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def __extension_duck_array__where(
return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)


@implements(np.reshape)
def __extension_duck_array__reshape(
arr: T_ExtensionArray, shape: tuple
) -> T_ExtensionArray:
if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
return arr
raise NotImplementedError(
f"Cannot reshape 1d-only pandas extension array to: {shape}"
)


@dataclass(frozen=True)
class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
"""NEP-18 compliant wrapper for pandas extension arrays.
Expand Down Expand Up @@ -101,10 +112,10 @@ def replace_duck_with_extension_array(args) -> list:

args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
raise KeyError("Function not registered for pandas extension arrays.")
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
return type(self)[type(res)](res)
return PandasExtensionArray(res)
return res

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down
10 changes: 9 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.treenode import group_subtrees
Expand Down Expand Up @@ -176,6 +177,11 @@ def format_timedelta(t, timedelta_format=None):

def format_item(x, timedelta_format=None, quote_strings=True):
"""Returns a succinct summary of an object as a string"""
if isinstance(x, PandasExtensionArray):
# We want to bypass PandasExtensionArray's repr here
# because its __repr__ is PandasExtensionArray(array=[...])
# and this function is only for single elements.
return str(x.array[0])
if isinstance(x, np.datetime64 | datetime):
return format_timestamp(x)
if isinstance(x, np.timedelta64 | timedelta):
Expand All @@ -194,7 +200,9 @@ def format_items(x):
"""Returns a succinct summaries of all items in a sequence as strings"""
x = to_duck_array(x)
timedelta_format = "datetime"
if np.issubdtype(x.dtype, np.timedelta64):
if not isinstance(x, PandasExtensionArray) and np.issubdtype(
x.dtype, np.timedelta64
):
x = astype(x, dtype="timedelta64[ns]")
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
time_needed = x[~pd.isnull(x)] != day_part
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ def create_test_data(
)
),
)
if has_pyarrow:
obj["var5"] = (
"dim1",
pd.array(
rs.integers(1, 10, size=dim_sizes[0]).tolist(),
dtype="int64[pyarrow]",
),
)
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
else:
Expand Down
16 changes: 9 additions & 7 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
assert_equal,
assert_identical,
requires_dask,
requires_pyarrow,
)
from xarray.tests.test_dataset import create_test_data

Expand Down Expand Up @@ -154,19 +155,20 @@ def test_concat_missing_var() -> None:
assert_identical(actual, expected)


def test_concat_categorical() -> None:
@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)])
def test_concat_extension_array(var) -> None:
data1 = create_test_data(use_extension_array=True)
data2 = create_test_data(use_extension_array=True)
concatenated = concat([data1, data2], dim="dim1")
assert (
concatenated["var4"]
== type(data2["var4"].variable.data)._concat_same_type(
assert pd.Series(
concatenated[var]
== type(data2[var].variable.data)._concat_same_type(
[
data1["var4"].variable.data,
data2["var4"].variable.data,
data1[var].variable.data,
data2[var].variable.data,
]
)
).all()
).all() # need to wrap in series because pyarrow bool does not support `all`


def test_concat_missing_multiple_consecutive_var() -> None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3637,7 +3637,7 @@ def test_series_categorical_index(self) -> None:

s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc")))
arr = DataArray(s)
assert "'a'" in repr(arr) # should not error
assert "a a b b" in repr(arr) # should not error

@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("data", ["list", "array", True])
Expand Down
26 changes: 15 additions & 11 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
create_test_data,
has_cftime,
has_dask,
has_pyarrow,
raise_if_dask_computes,
requires_bottleneck,
requires_cftime,
Expand Down Expand Up @@ -283,26 +284,28 @@ def test_repr(self) -> None:
data = create_test_data(seed=123, use_extension_array=True)
data.attrs["foo"] = "bar"
# need to insert str dtype at runtime to handle different endianness
var5 = (
"\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1"
if has_pyarrow
else ""
)
expected = dedent(
"""\
f"""\
<xarray.Dataset> Size: 2kB
Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8)
Coordinates:
* dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
* dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
* time (time) datetime64[{}] 160B 2000-01-01 2000-01-02 ... 2000-01-20
* dim3 (dim3) {data["dim3"].dtype} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
* time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20
numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3
Dimensions without coordinates: dim1
Data variables:
var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364
var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423
var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555
var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a'
var4 (dim1) category 32B b c b a c a c a{var5}
Attributes:
foo: bar""".format(
data["dim3"].dtype,
"ns",
)
foo: bar"""
)
actual = "\n".join(x.rstrip() for x in repr(data).split("\n"))

Expand Down Expand Up @@ -5884,20 +5887,21 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None:
def test_reduce_non_numeric(self) -> None:
data1 = create_test_data(seed=44, use_extension_array=True)
data2 = create_test_data(seed=44)
add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]}
add_vars = {"var6": ["dim1", "dim2"], "var7": ["dim1"]}
for v, dims in sorted(add_vars.items()):
size = tuple(data1.sizes[d] for d in dims)
data = np.random.randint(0, 100, size=size).astype(np.str_)
data1[v] = (dims, data, {"foo": "variable"})
# var4 is extension array categorical and should be dropped
# var4 and var5 are extension arrays and should be dropped
assert (
"var4" not in data1.mean()
and "var5" not in data1.mean()
and "var6" not in data1.mean()
and "var7" not in data1.mean()
)
assert_equal(data1.mean(), data2.mean())
assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1"))
assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2")
assert "var6" not in data1.mean(dim="dim2") and "var7" in data1.mean(dim="dim2")

@pytest.mark.filterwarnings(
"ignore:Once the behaviour of DataArray:DeprecationWarning"
Expand Down
Loading