-
Notifications
You must be signed in to change notification settings - Fork 1
feat: tensor aggregators (mean, std) #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JakubPekar
wants to merge
21
commits into
main
Choose a base branch
from
feature/aggregations
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
9900fee
feat: tensor aggregators (mean, std)
JakubPekar f2176f4
feat: fixes
JakubPekar eb4b53c
feat: format
JakubPekar 131e9f7
feat: std ddof + docs
JakubPekar 73eaa0d
feat: docs update lint
JakubPekar f06fdf8
feat: docs
JakubPekar 0dab47a
feat: param fix
JakubPekar 746fc43
fix: ddof test
JakubPekar 0343a79
feat: std stability
JakubPekar 8bb4c47
chore: version update
JakubPekar e403c81
chore: lock
JakubPekar 51891b6
feat: docs
JakubPekar 3a67fc0
feat: zero factory
JakubPekar 409fd43
Merge branch 'main' into feature/aggregations
JakubPekar ec462b7
feat: docs + fix
JakubPekar ded51b1
feat: CR fixes
JakubPekar 5215154
fix: test
JakubPekar b65f7d9
feat: handle nulls
JakubPekar 75051d7
feat: fix std & test
JakubPekar 876e7ae
fix: lint
JakubPekar 62d51f6
fix: tests
JakubPekar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # ratiopath.ray.aggregate.TensorMean | ||
|
|
||
| ::: ratiopath.ray.aggregate.TensorMean |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # ratiopath.ray.aggregate.TensorStd | ||
|
|
||
| ::: ratiopath.ray.aggregate.TensorStd |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from ratiopath.ray.aggregate.tensor_mean import TensorMean | ||
| from ratiopath.ray.aggregate.tensor_std import TensorStd | ||
|
|
||
|
|
||
| __all__ = ["TensorMean", "TensorStd"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| from typing import cast | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ray.data.aggregate import AggregateFnV2 | ||
| from ray.data.block import Block, BlockAccessor | ||
|
|
||
|
|
||
| class TensorMean(AggregateFnV2[dict, np.ndarray | float]): | ||
| """Calculates the mean (average) of a column containing Tensors. | ||
|
|
||
| This aggregator treats the data column as a high-dimensional array where | ||
| **axis 0 represents the batch dimension**. To satisfy the requirements | ||
| of a reduction and prevent memory growth proportional to the number of rows, | ||
| axis 0 must be included in the aggregation. | ||
|
|
||
|
|
||
| Args: | ||
| on: The name of the column containing tensors or numbers. | ||
| axis: The axis or axes along which the reduction is computed. | ||
| - `None`: Global reduction. Collapses all dimensions (including batch) | ||
| to a single scalar. | ||
| - `int`: Aggregates over both the batch (axis 0) AND the specified | ||
| tensor dimension. For example, `axis=1` collapses the batch and | ||
| the first dimension of the tensors. | ||
| - `tuple`: A sequence of axes that **must** explicitly include `0`. | ||
| ignore_nulls: Whether to ignore null values. Defaults to True. | ||
| alias_name: Optional name for the resulting column. Defaults to "mean(<on>)". | ||
|
|
||
| Raises: | ||
| ValueError: If `axis` is provided as a tuple but does not include `0`. | ||
|
|
||
| Note: | ||
| This aggregator is designed for "reduction" operations. If you wish to | ||
| calculate statistics per-row without collapsing the batch dimension, | ||
| use `.map()` instead. | ||
|
|
||
| Example: | ||
| >>> import ray | ||
| >>> import numpy as np | ||
| >>> from ratiopath.ray.aggregate import TensorMean | ||
| >>> # Dataset with 2x2 matrices: total shape (Batch=2, Dim1=2, Dim2=2) | ||
| >>> ds = ray.data.from_items( | ||
| ... [ | ||
| ... {"m": np.array([[1, 1], [1, 1]])}, | ||
| ... {"m": np.array([[3, 3], [3, 3]])}, | ||
| ... ] | ||
| ... ) | ||
| >>> # 1. Global Mean (axis=None) -> Result: 2.0 | ||
| >>> ds.aggregate(TensorMean(on="m", axis=None)) | ||
| >>> | ||
| >>> # 2. Batch Mean (axis=0) -> Result: np.array([[2, 2], [2, 2]]) | ||
| >>> ds.aggregate(TensorMean(on="m", axis=0)) | ||
| >>> | ||
| >>> # 3. Mean across Batch and Rows (axis=(0, 1)) -> Result: np.array([2, 2]) | ||
| >>> ds.aggregate(TensorMean(on="m", axis=(0, 1))) | ||
| """ | ||
|
|
||
| _aggregate_axis: tuple[int, ...] | None = None | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| on: str, | ||
| axis: int | tuple[int, ...] | None = None, | ||
| ignore_nulls: bool = True, | ||
| alias_name: str | None = None, | ||
| ): | ||
| super().__init__( | ||
| name=alias_name if alias_name else f"mean({on})", | ||
| on=on, | ||
| ignore_nulls=ignore_nulls, | ||
| # Initialize with identity values for summation | ||
| zero_factory=self.zero_factory, | ||
| ) | ||
|
|
||
| if axis is not None: | ||
| axes = {0, axis} if isinstance(axis, int) else set(axis) | ||
|
|
||
| if 0 not in axes: | ||
| raise ValueError( | ||
| f"Invalid axis configuration: {axis}. Axis 0 (the batch dimension) " | ||
| "must be included to perform a reduction. To process rows " | ||
| "independently without collapsing the batch, use .map() instead." | ||
| ) | ||
|
|
||
| self._aggregate_axis = tuple(axes) | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @staticmethod | ||
| def zero_factory() -> dict: | ||
| return {"sum": 0, "shape": None, "count": 0} | ||
|
|
||
| def aggregate_block(self, block: Block) -> dict: | ||
| block_acc = BlockAccessor.for_block(block) | ||
|
|
||
| # Get exact counts before any NumPy conversion obscures the nulls | ||
| valid_count = cast( | ||
| "int", | ||
| block_acc.count(self._target_col_name, ignore_nulls=True), # type: ignore [arg-type] | ||
| ) | ||
| total_count = cast( | ||
| "int", | ||
| block_acc.count(self._target_col_name, ignore_nulls=False), # type: ignore [arg-type] | ||
| ) | ||
|
|
||
| # Catch nulls immediately if strict mode is on | ||
| if valid_count < total_count and not self._ignore_nulls: | ||
| raise ValueError( | ||
| f"Column '{self._target_col_name}' contains null values, but " | ||
| "ignore_nulls is False." | ||
| ) | ||
|
|
||
| if valid_count == 0: | ||
| return self.zero_factory() | ||
|
|
||
| col_np = cast("np.ndarray", block_acc.to_numpy(self._target_col_name)) | ||
|
|
||
| # Filter out nulls if necessary | ||
| if valid_count < total_count: | ||
| valid_tensors = [x for x in col_np if x is not None] | ||
| col_np = np.stack(valid_tensors) | ||
|
|
||
| # Perform the partial sum and calculate how many elements contributed | ||
| block_sum = np.sum(col_np, axis=self._aggregate_axis) | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| block_count = col_np.size // block_sum.size | ||
|
|
||
| return { | ||
| "sum": block_sum.flatten(), | ||
| "shape": block_sum.shape, | ||
| "count": block_count, | ||
| } | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def combine(self, current_accumulator: dict, new: dict) -> dict: | ||
| return { | ||
| "sum": np.asarray(current_accumulator["sum"]) + np.asarray(new["sum"]), | ||
| "shape": current_accumulator["shape"] or new["shape"], | ||
| "count": current_accumulator["count"] + new["count"], | ||
| } | ||
|
|
||
| def finalize(self, accumulator: dict) -> np.ndarray | float: # type: ignore [override] | ||
| count = accumulator["count"] | ||
|
|
||
| if count == 0: | ||
| return np.nan | ||
|
|
||
| # Reshape the flattened sum back to original aggregated dimensions | ||
| return np.asarray(accumulator["sum"]).reshape(accumulator["shape"]) / count | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| from typing import cast | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ray.data.aggregate import AggregateFnV2 | ||
| from ray.data.block import Block, BlockAccessor | ||
|
|
||
|
|
||
| class TensorStd(AggregateFnV2[dict, np.ndarray | float]): | ||
| """Calculates the standard deviation of a column containing Tensors. | ||
|
|
||
| This aggregator treats the data column as a high-dimensional array where | ||
| **axis 0 represents the batch dimension**. To satisfy the requirements | ||
| of a reduction and prevent memory growth proportional to the number of rows, | ||
| axis 0 must be included in the aggregation. | ||
|
|
||
| It uses a parallel variance accumulation algorithm (Chan's method) to maintain | ||
| numerical stability while processing data across multiple Ray blocks. | ||
|
|
||
| Args: | ||
| on: The name of the column containing tensors or numbers. | ||
| axis: The axis or axes along which the reduction is computed. | ||
| - `None`: Global reduction. Collapses all dimensions (including batch) | ||
| to a single scalar. | ||
| - `int`: Aggregates over both the batch (axis 0) AND the specified | ||
| tensor dimension. For example, `axis=1` collapses the batch and | ||
| the first dimension of the tensors. | ||
| - `tuple`: A sequence of axes that **must** explicitly include `0`. | ||
| ddof: Delta Degrees of Freedom. The divisor used in calculations | ||
| is $N - ddof$, where $N$ represents the number of elements. | ||
| Defaults to 1.0 (sample standard deviation). | ||
| ignore_nulls: Whether to ignore null values. Defaults to True. | ||
| alias_name: Optional name for the resulting column. Defaults to "std(<on>)". | ||
|
|
||
| Raises: | ||
| ValueError: If `axis` is provided as a tuple but does not include `0`. | ||
|
|
||
| Note: | ||
| This aggregator is designed for "reduction" operations. If you wish to | ||
| calculate statistics per-row without collapsing the batch dimension, | ||
| use `.map()` instead. | ||
|
|
||
| Example: | ||
| >>> import ray | ||
| >>> import numpy as np | ||
| >>> from ratiopath.ray.aggregate import TensorStd | ||
| >>> ds = ray.data.from_items( | ||
| ... [ | ||
| ... {"m": np.array([[1, 2], [1, 2]])}, | ||
| ... {"m": np.array([[5, 6], [5, 6]])}, | ||
| ... ] | ||
| ... ) | ||
| >>> # 1. Global Std (axis=None) -> All elements reduced to one scalar | ||
| >>> ds.aggregate(TensorStd(on="m", axis=None)) | ||
| >>> | ||
| >>> # 2. Batch Std (axis=0) -> Result is a 2x2 matrix of std values | ||
| >>> # calculated across the dataset rows. | ||
| >>> ds.aggregate(TensorStd(on="m", axis=0)) | ||
| >>> | ||
| >>> # 3. Int shorthand (axis=1) -> Internally uses axis=(0, 1) | ||
| >>> # Collapses batch and the first dimension of the tensor. | ||
| >>> ds.aggregate(TensorStd(on="m", axis=1)) | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| _aggregate_axis: tuple[int, ...] | None = None | ||
|
|
||
| def __init__( | ||
| self, | ||
| on: str, | ||
| axis: int | tuple[int, ...] | None = None, | ||
| ddof: float = 1.0, | ||
| ignore_nulls: bool = True, | ||
| alias_name: str | None = None, | ||
| ): | ||
| super().__init__( | ||
| name=alias_name if alias_name else f"std({on})", | ||
| on=on, | ||
| ignore_nulls=ignore_nulls, | ||
| zero_factory=self.zero_factory, | ||
| ) | ||
|
|
||
| self._ddof = ddof | ||
|
|
||
| if axis is not None: | ||
| axes = {0, axis} if isinstance(axis, int) else set(axis) | ||
|
|
||
| if 0 not in axes: | ||
| raise ValueError( | ||
| f"Invalid axis configuration: {axis}. Axis 0 (the batch dimension) " | ||
| "must be included to perform a reduction. To process rows " | ||
| "independently without collapsing the batch, use .map() instead." | ||
| ) | ||
|
|
||
| self._aggregate_axis = tuple(axes) | ||
|
|
||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @staticmethod | ||
| def zero_factory() -> dict: | ||
| return {"mean": 0, "ssd": 0, "shape": None, "count": 0} | ||
|
|
||
| def aggregate_block(self, block: Block) -> dict: | ||
| block_acc = BlockAccessor.for_block(block) | ||
|
|
||
| # Get exact counts before any NumPy conversion obscures the nulls | ||
| valid_count = cast( | ||
| "int", | ||
| block_acc.count(self._target_col_name, ignore_nulls=True), # type: ignore [arg-type] | ||
| ) | ||
| total_count = cast( | ||
| "int", | ||
| block_acc.count(self._target_col_name, ignore_nulls=False), # type: ignore [arg-type] | ||
| ) | ||
|
|
||
| # Catch nulls immediately if strict mode is on | ||
| if valid_count < total_count and not self._ignore_nulls: | ||
| raise ValueError( | ||
| f"Column '{self._target_col_name}' contains null values, but " | ||
| "ignore_nulls is False." | ||
| ) | ||
|
|
||
| if valid_count == 0: | ||
| return self.zero_factory() | ||
|
|
||
| col_np = cast("np.ndarray", block_acc.to_numpy(self._target_col_name)) | ||
|
|
||
| # Filter out nulls if necessary | ||
| if valid_count < total_count: | ||
| valid_tensors = [x for x in col_np if x is not None] | ||
| col_np = np.stack(valid_tensors) | ||
|
|
||
| # Partial sum and element count | ||
| block_sum = np.sum(col_np, axis=self._aggregate_axis, keepdims=True) | ||
| block_count = col_np.size // block_sum.size | ||
|
|
||
| mean = block_sum / block_count | ||
| block_ssd = np.sum((col_np - mean) ** 2, axis=self._aggregate_axis) | ||
|
|
||
| return { | ||
| "mean": mean.ravel(), | ||
| "ssd": block_ssd.ravel(), | ||
| "shape": block_ssd.shape, | ||
| "count": block_count, | ||
| } | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def combine(self, current_accumulator: dict, new: dict) -> dict: | ||
| if new["count"] == 0: | ||
| return current_accumulator | ||
|
|
||
| if current_accumulator["count"] == 0: | ||
| return new | ||
|
|
||
| n_current = current_accumulator["count"] | ||
| n_new = new["count"] | ||
| combined_count = n_current + n_new | ||
|
|
||
| mean_current = np.asarray(current_accumulator["mean"]) | ||
| mean_new = np.asarray(new["mean"]) | ||
|
|
||
| delta = mean_new - mean_current | ||
|
|
||
| # Chan's formula for the combined true mean | ||
| combined_mean = (mean_current * n_current + mean_new * n_new) / combined_count | ||
|
|
||
| combined_ssd = ( | ||
| np.asarray(current_accumulator["ssd"]) | ||
| + np.asarray(new["ssd"]) | ||
| + (delta**2 * n_current * n_new / combined_count) | ||
| ) | ||
|
|
||
| return { | ||
| "mean": combined_mean, | ||
| "ssd": combined_ssd, | ||
| "shape": new["shape"], | ||
| "count": combined_count, | ||
| } | ||
JakubPekar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def finalize(self, accumulator: dict) -> np.ndarray | float: # type: ignore [override] | ||
| count = accumulator["count"] | ||
|
|
||
| if count - self._ddof <= 0: | ||
| return np.nan | ||
|
|
||
| # np.maximum added as a safety net. Floating point jitter can occasionally | ||
| # result in trivially negative numbers (e.g., -1e-16), which crashes np.sqrt | ||
| variance = np.maximum( | ||
| 0.0, | ||
| np.asarray(accumulator["ssd"]).reshape(accumulator["shape"]) | ||
| / (count - self._ddof), | ||
| ) | ||
|
|
||
| return np.sqrt(variance) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.