Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ v2026.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
By `Alfonso Ladino <https://github.com/aladinor>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -129,6 +132,8 @@ Internal Changes
runtime behavior. This enables CI integration for type stub validation and helps
prevent type annotation regressions (:issue:`11086`).
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
- Remove ``setup.py`` file (:pull:`11261`).
By `Nick Hodgskin <https://github.com/VeckoTheGecko>`_.

.. _whats-new.2026.02.0:

Expand Down
4 changes: 0 additions & 4 deletions setup.py

This file was deleted.

57 changes: 46 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,28 @@ def _coord_variables(self) -> ChainMap[Hashable, Variable]:
*(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type]
)

@property
def _coord_variables_all(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables,
*(p._node_coord_variables for p in self.parents),
)

def _resolve_inherit(
self, inherit: bool | Literal["all_coords", "indexes"]
) -> tuple[Mapping[Hashable, Variable], dict[Hashable, Index]]:
"""Resolve the inherit parameter to (coord_vars, indexes)."""
if inherit is False:
return self._node_coord_variables, dict(self._node_indexes)
if inherit is True or inherit == "indexes":
return self._coord_variables, dict(self._indexes)
if inherit == "all_coords":
return self._coord_variables_all, dict(self._indexes)
raise ValueError(
f"Invalid value for inherit: {inherit!r}. "
"Expected True, False, 'indexes', or 'all'."
)

@property
def _dims(self) -> ChainMap[Hashable, int]:
return ChainMap(self._node_dims, *(p._node_dims for p in self.parents))
Expand All @@ -596,8 +618,12 @@ def _dims(self) -> ChainMap[Hashable, int]:
def _indexes(self) -> ChainMap[Hashable, Index]:
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))

def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
coord_vars = self._coord_variables if inherit else self._node_coord_variables
def _to_dataset_view(
self,
rebuild_dims: bool,
inherit: bool | Literal["all_coords", "indexes"] = True,
) -> DatasetView:
coord_vars, indexes = self._resolve_inherit(inherit)
variables = dict(self._data_variables)
variables |= coord_vars
if rebuild_dims:
Expand Down Expand Up @@ -636,10 +662,10 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
dims = dict(self._node_dims)
return DatasetView._constructor(
variables=variables,
coord_names=set(self._coord_variables),
coord_names=set(coord_vars),
dims=dims,
attrs=self._attrs,
indexes=dict(self._indexes if inherit else self._node_indexes),
indexes=indexes,
encoding=self._encoding,
close=None,
)
Expand Down Expand Up @@ -669,30 +695,39 @@ def dataset(self, data: Dataset | None = None) -> None:
# xarray-contrib/datatree
ds = dataset

def to_dataset(self, inherit: bool = True) -> Dataset:
def to_dataset(
self, inherit: bool | Literal["all_coords", "indexes"] = True
) -> Dataset:
"""
Return the data in this node as a new xarray.Dataset object.

Parameters
----------
inherit : bool, optional
If False, only include coordinates and indexes defined at the level
of this DataTree node, excluding any inherited coordinates and indexes.
inherit : bool or {"all_coords", "indexes"}, default True
Controls which coordinates are inherited from parent nodes.

- True or "indexes": inherit only indexed coordinates (default).
- "all_coords": inherit all coordinates, including non-index coordinates.
- False: only include coordinates defined at this node.

See Also
--------
DataTree.dataset
"""
coord_vars = self._coord_variables if inherit else self._node_coord_variables
coord_vars, indexes = self._resolve_inherit(inherit)
variables = dict(self._data_variables)
variables |= coord_vars
dims = calculate_dimensions(variables) if inherit else dict(self._node_dims)
dims = (
dict(self._node_dims)
if inherit is False
else calculate_dimensions(variables)
)
return Dataset._construct_direct(
variables,
set(coord_vars),
dims,
None if self._attrs is None else dict(self._attrs),
dict(self._indexes if inherit else self._node_indexes),
indexes,
None if self._encoding is None else dict(self._encoding),
None,
)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable, Hashable, Mapping, Sequence
from functools import partial
from types import EllipsisType
from typing import TYPE_CHECKING, Any, NoReturn, cast
from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -582,7 +582,7 @@ def to_index(self) -> pd.Index:
return self.to_index_variable().to_index()

def to_dict(
self, data: bool | str = "list", encoding: bool = False
self, data: bool | Literal["list", "array"] = "list", encoding: bool = False
) -> dict[str, Any]:
"""Dictionary representation of variable."""
item: dict[str, Any] = {
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ def test_to_dataset_inherited(self) -> None:
assert_identical(tree.to_dataset(inherit=True), base)
assert_identical(subtree.to_dataset(inherit=True), sub_and_base)

def test_to_dataset_inherit_all(self) -> None:
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
subtree = typing.cast(DataTree, tree["sub"])

expected = xr.Dataset(coords={"a": [1], "b": 2, "c": [3]})
assert_identical(subtree.to_dataset(inherit="all_coords"), expected)
assert_identical(tree.to_dataset(inherit="all_coords"), base)

mid = xr.Dataset(coords={"c": 3.0})
leaf = xr.Dataset(coords={"d": [4]})
deep = DataTree.from_dict({"/": base, "/mid": mid, "/mid/leaf": leaf})
leaf_node = typing.cast(DataTree, deep["/mid/leaf"])
result = leaf_node.to_dataset(inherit="all_coords")
assert set(result.coords) == {"a", "b", "c", "d"}

def test_to_dataset_inherit_invalid(self) -> None:
tree = DataTree()
with pytest.raises(ValueError, match="Invalid value for inherit"):
tree.to_dataset(inherit="invalid") # type: ignore[arg-type]


class TestVariablesChildrenNameCollisions:
def test_parent_already_has_variable_with_childs_name(self) -> None:
Expand Down
Loading