From 6eeb91035823962b55d2d639d3145673198e07e9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Sun, 24 May 2026 17:43:28 -0500 Subject: [PATCH 1/6] Add jax output, allow evaluate to use other formats --- Cargo.lock | 2 +- pyproject.toml | 1 + python/opencosmo/column/column.py | 29 ++++- python/opencosmo/column/evaluate.py | 102 +++++---------- python/opencosmo/dataset/dataset.py | 12 +- python/opencosmo/dataset/evaluate.py | 30 ++--- python/opencosmo/dataset/formats.py | 165 +++++++++++++++++++++++- python/opencosmo/dataset/instantiate.py | 11 +- test/test_formats.py | 14 ++ uv.lock | 105 ++++++++++++++- 10 files changed, 357 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b5b9a4c..c8611404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,7 +96,7 @@ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "opencosmo" -version = "1.2.4" +version = "1.2.6" dependencies = [ "numpy", "pyo3", diff --git a/pyproject.toml b/pyproject.toml index 5540c1ba..f12d3c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dev = [ "pip>=25.1.1", "pytest>=8.3.4,<9.0.0", "pytest-timeout>=2.4.0", + "jax>=0.10.1", ] docs = [ "sphinx>=9.0.0", diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 8ff8c2db..d8dd419a 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -865,11 +865,6 @@ def get_units(self, units: dict[str, np.ndarray]): def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): data = {name: data[name] for name in self.__requires} chunk_sizes = index[1] if isinstance(index, tuple) else None - if self.__format != "astropy": - data = { - name: val.value if isinstance(val, u.Quantity) else val - for name, val in data.items() - } if self.batch_size > 0: length = len(next(iter(data.values()))) @@ -886,13 +881,33 @@ def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): case EvaluateStrategy.VECTORIZE: return evaluate_vectorized(data, self.__func, self.__kwargs, index) case EvaluateStrategy.ROW_WISE: - return evaluate_rows(data, self.__func, self.__kwargs) + return evaluate_rows(data, self.__func, self.__kwargs, self.__format) case EvaluateStrategy.CHUNKED: if chunk_sizes is None: raise ValueError( "Cannot evaluate in CHUNKED strategy with a non-chunked index" ) - return evaluate_chunks(data, self.__func, self.__kwargs, chunk_sizes) + return evaluate_chunks( + data, self.__func, self.__kwargs, chunk_sizes, self.__format + ) + + def evaluate_for_storage( + self, data: dict[str, np.ndarray], index: DataIndex | None + ) -> dict[str, np.ndarray]: + """ + Evaluate and return numpy-formatted output suitable for the column + cache. Input arrives in the numpy/astropy form used internally, so + it is first converted to the user's requested format before the + function runs, and the output is converted back to numpy. + """ + from opencosmo.dataset.formats import to_format_dict, to_numpy_dict + + required = {name: data[name] for name in self.__requires} + converted = to_format_dict(required, self.__format) + output = self.evaluate(converted, index) + if not isinstance(output, dict): + output = {next(iter(self.__produces)): output} + return to_numpy_dict(output) def evaluate_one(self, dataset: Dataset): match self.__strategy: diff --git a/python/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py index 20decfc1..d4822c84 100644 --- a/python/opencosmo/column/evaluate.py +++ b/python/opencosmo/column/evaluate.py @@ -3,11 +3,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable -import astropy.units as u import numpy as np -from opencosmo.evaluate import insert_data - if TYPE_CHECKING: from opencosmo import Dataset @@ -18,77 +15,48 @@ class EvaluateStrategy(Enum): CHUNKED = "chunked" -def evaluate_rows(data: dict[str, np.ndarray], func: Callable, kwargs: dict[str, Any]): +def evaluate_rows( + data: dict[str, Any], + func: Callable, + kwargs: dict[str, Any], + format: str, +): + from opencosmo.dataset.formats import stack_rows + data_length = len(next(iter(data.values()))) - storage = {} + per_column: dict[str, list] = {} for i in range(data_length): iterable_inputs = {name: values[i] for name, values in data.items()} output = func(**iterable_inputs, **kwargs) if not isinstance(output, dict): output = {func.__name__: output} - if i == 0: - storage = __make_row_based_output_from_first_values(output, data_length) - continue - insert_data(storage, i, output) - return storage - - -def __make_row_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - try: - shape = (data_length,) + value.shape - except AttributeError: - shape = (data_length,) - try: - dtype = value.dtype - except AttributeError: - dtype = type(value) - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0] = value - storage[name] = column_storage - - return storage + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: stack_rows(values, format) for name, values in per_column.items()} def evaluate_chunks( - data: dict[str, np.ndarray], + data: dict[str, Any], func: Callable, kwargs: dict[str, Any], chunk_sizes: np.ndarray, + format: str, ): - data_length = len(next(iter(data.values()))) + from opencosmo.dataset.formats import concat_chunks chunk_splits = np.cumsum(chunk_sizes) - storage = {} - input_data = {name: np.split(arr, chunk_splits) for name, arr in data.items()} - for i in range(len(chunk_splits)): - chunk_input_data = {name: split[i] for name, split in input_data.items()} + starts = np.concatenate([[0], chunk_splits[:-1]]) + per_column: dict[str, list] = {} + for start, end in zip(starts, chunk_splits): + chunk_input_data = { + name: arr[int(start) : int(end)] for name, arr in data.items() + } output = func(**chunk_input_data, **kwargs) if not isinstance(output, dict): output = {func.__name__: output} - if i == 0: - storage = __make_chunked_based_output_from_first_values(output, data_length) - continue - for name, values in output.items(): - storage[name][chunk_splits[i - 1] : chunk_splits[i]] = values - return storage - - -def __make_chunked_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - shape = (data_length,) + value.shape[1:] - dtype = value.dtype - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0 : len(value)] = value - storage[name] = column_storage - - return storage + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: concat_chunks(chunks, format) for name, chunks in per_column.items()} def evaluate_vectorized(data, func, kwargs, index): @@ -105,29 +73,25 @@ def do_first_evaluation( kwargs: dict[str, Any], dataset: Dataset, ): + from opencosmo.dataset.formats import fetch_as_dict + eval_strategy = EvaluateStrategy(strategy) + columns = list(dataset.columns) match eval_strategy: case EvaluateStrategy.VECTORIZE: - values = dataset.take(1).get_data(format, unpack=False) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} - + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) return func(**values, **kwargs), eval_strategy case EvaluateStrategy.ROW_WISE: - values = dataset.take(1).get_data(format, unpack=True) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) + values = {name: container[0] for name, container in values.items()} return func(**values, **kwargs), eval_strategy case EvaluateStrategy.CHUNKED: index = dataset.index assert isinstance(index, tuple) first_chunk_size = index[1][0] - first_chunk = dataset.take(first_chunk_size, at="start").get_data(format) - first_chunk = dict(first_chunk) + first_chunk = fetch_as_dict( + dataset.take(first_chunk_size, at="start"), columns, format + ) return func(**first_chunk, **kwargs), eval_strategy diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index ac3b62c6..7066b2ef 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -270,7 +270,7 @@ def get_data( on the data. The method supports output into several different formats, including - "astropy", "numpy", "pandas", "polars", and "pyarrow". Although astropy + "astropy", "numpy", "pandas", "polars", "jax", and "arrow". Although astropy and numpy are core dependencies of OpenCosmo, the remaining formats require you to have the relevant libraries installed in your python environment. This method will check that it can import the necessary @@ -286,7 +286,7 @@ def get_data( ---------- output: str, default="astropy" The format to output the data in. - Currently supported are "astropy", "numpy", "pandas", "polars", "arrow" + Currently supported are "astropy", "numpy", "pandas", "polars", "arrow", "jax" Returns ------- @@ -478,8 +478,11 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): as the function. Otherwise the data will be returned directly. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. allow_overwrite: bool, default = False batch_size: int, default = -1 @@ -496,6 +499,7 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): result : Dataset | dict[str, np.ndarray | astropy.units.Quantity] The new dataset with the evaluated column(s) or the results as numpy arrays or astropy quantities """ + verify_format(format) evaluated_column = build_evaluated_column( self, func, vectorize, insert, format, batch_size, evaluate_kwargs ) diff --git a/python/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py index 38968dfd..e2f84210 100644 --- a/python/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -10,6 +10,7 @@ from opencosmo.column.column import EvaluatedColumn from opencosmo.column.evaluate import EvaluateStrategy, do_first_evaluation +from opencosmo.dataset.formats import concat_chunks, fetch_as_dict from opencosmo.evaluate import ( insert_data, make_output_from_first_values, @@ -27,10 +28,6 @@ def build_evaluated_column( dataset, func, vectorize, insert, format, batch_size, evaluate_kwargs ): - if format not in ["astropy", "numpy"]: - raise ValueError( - f"Evaluate only supports numpy and astropy format, got: {format}" - ) kwarg_columns = set(evaluate_kwargs.keys()).intersection(dataset.columns) if kwarg_columns: raise ValueError( @@ -69,12 +66,7 @@ def visit_dataset( ) -> dict[str, np.ndarray]: if column.batch_size > 0: return visit_dataset_batched(column, dataset) - requires_names = column.requires_names - data = dataset.select(requires_names).get_data(format=column.format) - try: - data = dict(data) - except (TypeError, ValueError): - data = {next(iter(requires_names)): data} + data = fetch_as_dict(dataset, column.requires_names, column.format) output = column.evaluate(data, dataset.index) if not isinstance(output, dict): assert len(column.produces) == 1 @@ -89,24 +81,22 @@ def visit_dataset_batched(column: EvaluatedColumn, dataset: Dataset): output = defaultdict(list) - requires_names = column.requires_names for start, end in np.lib.stride_tricks.sliding_window_view(ranges, 2): - batch_data = ( - dataset.select(requires_names) - .take_range(start, end) - .get_data(format=column.format, unpack=False) + batch_data = fetch_as_dict( + dataset.take_range(start, end), + column.requires_names, + column.format, + unpack=False, ) - try: - batch_data = dict(batch_data) - except TypeError: - batch_data = {next(iter(requires_names)): batch_data} batch_output = column.evaluate(batch_data, None) if batch_output is not None and not isinstance(batch_output, dict): batch_output = {column.produces.pop(): batch_output} for name, column_batch in batch_output.items(): output[name].append(column_batch) - full_output = {name: np.concat(out) for name, out in output.items()} + full_output = { + name: concat_chunks(out, column.format) for name, out in output.items() + } return full_output diff --git a/python/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py index 321916d2..2ac5f711 100644 --- a/python/opencosmo/dataset/formats.py +++ b/python/opencosmo/dataset/formats.py @@ -1,11 +1,15 @@ from __future__ import annotations from importlib import import_module +from typing import TYPE_CHECKING, Any, Iterable import astropy.units as u import numpy as np from astropy.table import Column, QTable +if TYPE_CHECKING: + from opencosmo import Dataset + def verify_format(output_format: str): match output_format: @@ -19,6 +23,8 @@ def verify_format(output_format: str): import_name = "pyarrow" case "polars": import_name = "polars" + case "jax": + import_name = "jax" case _: raise ValueError(f"Unknown data output format {output_format}") @@ -39,13 +45,153 @@ def convert_data(data: dict[str, np.ndarray], output_format: str): case "astropy": return __convert_to_astropy(data) case "numpy": - return __convert_to_numpy(data) + return convert_to_numpy(data) case "pandas": return __convert_to_pandas(data) case "polars": return __convert_to_polars(data) case "arrow": return __convert_to_arrow(data) + case "jax": + return __convert_to_jax(data) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def fetch_as_dict( + dataset: Dataset, + requires_names: Iterable[str], + output_format: str, + unpack: bool = True, +) -> dict[str, Any]: + """ + Fetch the requested columns and return them as a {name: container} dict in + the user's requested format. Routes through astropy so that Quantities (with + units) survive into the conversion step; other formats then receive plain + values via to_format_dict. + """ + requires_names = list(requires_names) + raw = dataset.select(requires_names).get_data(format="astropy", unpack=unpack) + if isinstance(raw, QTable): + raw = {name: raw[name] for name in raw.colnames} + elif not isinstance(raw, dict): + raw = {requires_names[0]: raw} + if output_format == "astropy": + return raw + return to_format_dict(raw, output_format) + + +def to_format_dict(data: dict[str, np.ndarray], output_format: str) -> dict: + """ + Convert each column of a numpy/astropy dict to the requested format, + preserving the dict shape. Unlike convert_data, this never wraps the + result in a higher-level container (DataFrame, QTable, ...). Used to + feed user-supplied evaluate functions when the upstream data is + numpy-shaped (e.g. when reading from the column cache). + """ + if output_format == "astropy": + return data + + def strip(value): + return value.value if isinstance(value, (u.Quantity, Column)) else value + + match output_format: + case "numpy": + return {k: strip(v) for k, v in data.items()} + case "jax": + import jax.numpy as jnp + + return {k: jnp.asarray(strip(v)) for k, v in data.items()} + case "pandas": + import pandas as pd + + return {k: pd.Series(strip(v)) for k, v in data.items()} + case "polars": + import polars as pl + + return {k: pl.Series(values=strip(v)) for k, v in data.items()} + case "arrow": + import pyarrow as pa # type: ignore + + return {k: pa.array(strip(v)) for k, v in data.items()} + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def to_numpy_dict(data: dict) -> dict[str, np.ndarray]: + """ + Convert each value in a dict-of-format-arrays back to a numpy array + suitable for the column cache. Astropy Quantities are preserved so + that downstream unit handling continues to work; other formats are + converted to plain numpy with no unit information. + """ + result: dict[str, np.ndarray] = {} + for name, value in data.items(): + if isinstance(value, (u.Quantity, np.ndarray)): + result[name] = value + else: + result[name] = np.asarray(value) + return result + + +def stack_rows(values: list, output_format: str): + """ + Stack a list of per-row values into a 1-D container in the target format. + Used by row-wise evaluation strategies to assemble output without + preallocation, which would break for formats with immutable arrays + (e.g. jax). + """ + match output_format: + case "astropy": + if values and isinstance(values[0], u.Quantity): + return u.Quantity(values) + return np.array(values) + case "numpy": + return np.array(values) + case "jax": + import jax.numpy as jnp + + return jnp.array(values) + case "pandas": + import pandas as pd + + return pd.Series(values) + case "polars": + import polars as pl + + return pl.Series(values=values) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.array(values) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def concat_chunks(chunks: list, output_format: str): + """ + Concatenate a list of per-chunk arrays into a single container in the + target format. + """ + match output_format: + case "astropy" | "numpy": + return np.concatenate(chunks) + case "jax": + import jax.numpy as jnp + + return jnp.concatenate(chunks) + case "pandas": + import pandas as pd + + return pd.concat(chunks, ignore_index=True) + case "polars": + import polars as pl + + return pl.concat(chunks) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.concat_arrays(chunks) case _: raise ValueError(f"Unknown data output format {output_format}") @@ -62,7 +208,7 @@ def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: return QTable(data, copy=False) -def __convert_to_numpy( +def convert_to_numpy( data: dict[str, np.ndarray], ) -> dict[str, np.ndarray] | np.ndarray: converted_data = dict( @@ -82,7 +228,7 @@ def __convert_to_numpy( def __convert_to_pandas(data: dict[str, np.ndarray]): import pandas as pd - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): # only one column return pd.Series(numpy_data, name=next(iter(data.keys()))) @@ -92,7 +238,7 @@ def __convert_to_pandas(data: dict[str, np.ndarray]): def __convert_to_arrow(data: dict[str, np.ndarray]): import pyarrow as pa # type: ignore - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): return pa.array(numpy_data) @@ -106,8 +252,17 @@ def __convert_to_arrow(data: dict[str, np.ndarray]): def __convert_to_polars(data: dict[str, np.ndarray]): import polars as pl - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): return pl.Series(name=next(iter(data.keys())), values=numpy_data) return pl.from_dict(data) # type: ignore + + +def __convert_to_jax(data: dict[str, np.ndarray]): + import jax.numpy as jnp + + output_data = convert_to_numpy(data) + if isinstance(output_data, np.ndarray): + return jnp.asarray(output_data) + return {key: jnp.asarray(value) for key, value in output_data.items()} diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 60507c80..79d4ce75 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -4,7 +4,7 @@ import rustworkx as rx -from opencosmo.column.column import RawColumn +from opencosmo.column.column import EvaluatedColumn, RawColumn from opencosmo.dataset.graph import build_dependency_graph if TYPE_CHECKING: @@ -101,9 +101,12 @@ def build_derived_columns( name: all_data[dep_uuid][name] for name, dep_uuid in producer.dep_map.items() } - output = producer.evaluate(input_data, index) - if not isinstance(output, dict): - output = {next(iter(producer.produces)): output} + if isinstance(producer, EvaluatedColumn): + output = producer.evaluate_for_storage(input_data, index) + else: + output = producer.evaluate(input_data, index) + if not isinstance(output, dict): + output = {next(iter(producer.produces)): output} new_derived[producer.uuid] = output if not producer.no_cache: to_cache[producer.uuid] = output diff --git a/test/test_formats.py b/test/test_formats.py index 533fd383..d5b46d63 100644 --- a/test/test_formats.py +++ b/test/test_formats.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np import pandas as pd import polars as pl @@ -28,6 +29,12 @@ def test_return_pyarrow(input_path): assert all(isinstance(v, pa.Array) for v in data.values()) +def test_return_jax(input_path): + data = oc.open(input_path).get_data("jax") + assert isinstance(data, dict) + assert all(isinstance(v, jnp.ndarray) for v in data.values()) + + def test_return_pandas_single(input_path): dataset = oc.open(input_path) column = np.random.choice(dataset.columns) @@ -48,3 +55,10 @@ def test_return_pyarrow_single(input_path): column = np.random.choice(dataset.columns) data = dataset.select(column).get_data("arrow") assert isinstance(data, pa.Array) + + +def test_return_jax_single(input_path): + dataset = oc.open(input_path) + column = np.random.choice(dataset.columns) + data = dataset.select(column).get_data("jax") + assert isinstance(data, jnp.ndarray) diff --git a/uv.lock b/uv.lock index 0a16bfce..53519f59 100644 --- a/uv.lock +++ b/uv.lock @@ -6,10 +6,14 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", - "python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version < '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "(python_full_version < '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version < '3.13' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", ] [[package]] @@ -557,6 +561,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "jax" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/49/b082387119c4a6bc7596296bbdc6bce034628cdd2845ebb27304cbca3624/jax-0.10.1.tar.gz", hash = "sha256:11672410faf8752429eb9a131de203dc488a2a3a012d509baa2b39878008810d", size = 2718178, upload-time = "2026-05-20T14:54:09.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl", hash = "sha256:47f3192c76e9e3358de1b106a8af5e943fccb10510903f25d96ea53652729134", size = 3150973, upload-time = "2026-05-20T14:51:30.066Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/b4/fdb6e989b142d8a8d2093f342cbc5323fe0d4a7217fd899c8ddf9e108a5a/jaxlib-0.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7a4399c7429c87ee6f7ab1e5712c5548ecfabde974ee9f4a12957fea4e35efb8", size = 60816137, upload-time = "2026-05-20T14:52:53.028Z" }, + { url = "https://files.pythonhosted.org/packages/a7/29/0b4eaaca005708751ff301903a2ca760dcc34175dadb7536498c57a7de85/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:12126603ba472300c62480f86d20972d563143041039d9dea349ad510aa6123c", size = 80294034, upload-time = "2026-05-20T14:52:56.941Z" }, + { url = "https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:f3cdf5b7f48470ab5455ab79aab746419694ccb6b52651cc2ce5fb27def03588", size = 85828774, upload-time = "2026-05-20T14:53:01.749Z" }, + { url = "https://files.pythonhosted.org/packages/ec/8f/993ea419eca6f34fe12613e22a03b93f40e5b1e8e0df18d4060e1313a1fc/jaxlib-0.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:0acf3f8e7dca9074c0327f0f61502845792ca9f82fab23b841b00daa78e85488", size = 64830187, upload-time = "2026-05-20T14:53:05.932Z" }, + { url = "https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4167213aa00f14bb0d8fbd90f9ded75e976f71ce8baf8c3c44e04c8fb80ea0c1", size = 60815855, upload-time = "2026-05-20T14:53:11.718Z" }, + { url = "https://files.pythonhosted.org/packages/5b/76/9a1971bc9edb8728a7ba86d2693127ee46add9811230b8452321415fd4e9/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:e53223d8f6861d33c02dd02343fa464700401ff5363784d37c33407218e328b8", size = 80293947, upload-time = "2026-05-20T14:53:15.408Z" }, + { url = "https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:bb073a1224e659e01e8d32d47c000edb52ec2aa8ba97ec22b2228b3a46e5c167", size = 85829861, upload-time = "2026-05-20T14:53:19.773Z" }, + { url = "https://files.pythonhosted.org/packages/a9/df/48659e2ee57705c63a51525f810fe3e0c87af4ca9f89d4738281a872d58e/jaxlib-0.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:6449f1d4a22324f5f02c843360475783f9fd1d353fe711806cbf4e927d1360ae", size = 64828863, upload-time = "2026-05-20T14:53:24.562Z" }, + { url = "https://files.pythonhosted.org/packages/c7/34/3f7c95ee1b2555d611f836988a49b522c04b8d186e0528f91d45118089bf/jaxlib-0.10.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:72a7db1242d6b7773340e430c244e73a8d13f82ce83933c0529a8b4b1520dcf3", size = 60934063, upload-time = "2026-05-20T14:53:28.549Z" }, + { url = "https://files.pythonhosted.org/packages/3a/84/855a2395d299e00e1d9ffd0aecd516b52b81780f3e4ef537527be6d8c1fc/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:649c26ca92e9bbffb3c35226b37893916f20e6b89fb3911ce79b39e5cfb27b46", size = 80402500, upload-time = "2026-05-20T14:53:32.444Z" }, + { url = "https://files.pythonhosted.org/packages/be/35/153e91a9c770a981d525d845b3f4cdb71a1e119681594a33908e9536bdff/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:cfe75d8a17e0d33a7bed27f32d7a5344a66e8d4af7073973f396e14ff4a9c503", size = 85941210, upload-time = "2026-05-20T14:53:36.982Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a1/c4d4c0530313c50dd1ba07fff480cfd0c5f18c5ec49742f4a52a6edfd95f/jaxlib-0.10.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9a657554dbe56f3691d377cb99d00b29df14e5638e579e57965f9f69c32c9315", size = 60826189, upload-time = "2026-05-20T14:53:40.73Z" }, + { url = "https://files.pythonhosted.org/packages/98/71/529b2439b88491e0806ee9fa6191ecf90f4447dfe092e80bd19577c85260/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:93cd9989404f86a21b50c7ff4e850a31f55509664312080978bde97a664d92a9", size = 80300654, upload-time = "2026-05-20T14:53:44.751Z" }, + { url = "https://files.pythonhosted.org/packages/be/08/0bc4132fb2fe224ee9d83fd60e3650fd4d893b4e5148707a98df8f0333a4/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:26b94e9640b01968cc14b8353dd6b6540d723f30579c78b1f46a477fc4aa196d", size = 85841241, upload-time = "2026-05-20T14:53:49.13Z" }, + { url = "https://files.pythonhosted.org/packages/5c/de/423d748ce3367bd5ea20d8cc34a7ceb6420da4d41c20b247f54194700d04/jaxlib-0.10.1-cp314-cp314-win_amd64.whl", hash = "sha256:375820799bbf7d515dd4e4d40f3334566b73d3fe64d340afbd6aa897d5d7c486", size = 67301520, upload-time = "2026-05-20T14:53:52.989Z" }, + { url = "https://files.pythonhosted.org/packages/2d/bf/3f7ce089d62f7ac85ea678925471f7ec88038899e67ab02079c1e7d8ad4e/jaxlib-0.10.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:bc52f16bdd61b299efeed7ea8e743e91a118059a463da2ab97196fc05b69ddb7", size = 60935393, upload-time = "2026-05-20T14:53:56.886Z" }, + { url = "https://files.pythonhosted.org/packages/65/6a/38cf1d4ff8c8f74ec8d567e0aa6e3d2082ab6fc580545f6bb51368da20a9/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:55b0a473fbd57d31dc3935c4cd5c0c38af5f7c1f41300df27923cf46676972ca", size = 80405741, upload-time = "2026-05-20T14:54:01.94Z" }, + { url = "https://files.pythonhosted.org/packages/16/9e/d3cff171aaf13a09aab26a44d1a27dcbf0d6311e4d855f5d99685965ace3/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:4b0cb8ef960a3723037db63f28ffa20083d90fff6d30085e99d7c63cfa08e4c0", size = 85942183, upload-time = "2026-05-20T14:54:05.91Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -832,6 +882,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/49/d651878698a0b67f23aa28e17f45a6d6dd3d3f933fa29087fa4ce5947b5a/matplotlib-3.10.8-cp314-cp314t-win_arm64.whl", hash = "sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f", size = 8192560, upload-time = "2025-12-10T22:56:38.008Z" }, ] +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" }, + { url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, + { url = "https://files.pythonhosted.org/packages/f5/f0/0cfadd537c5470378b1b32bd859cf2824972174b51b873c9d95cfd7475a5/ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7", size = 212222, upload-time = "2025-11-17T22:31:53.742Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/9acc86985bfad8f2c2d30291b27cd2bb4c74cea08695bd540906ed744249/ml_dtypes-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460", size = 160793, upload-time = "2025-11-17T22:31:55.358Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" }, + { url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" }, + { url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" }, + { url = "https://files.pythonhosted.org/packages/e1/8b/200088c6859d8221454825959df35b5244fa9bdf263fd0249ac5fb75e281/ml_dtypes-0.5.4-cp313-cp313-win_amd64.whl", hash = "sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328", size = 212224, upload-time = "2025-11-17T22:32:01.349Z" }, + { url = "https://files.pythonhosted.org/packages/8f/75/dfc3775cb36367816e678f69a7843f6f03bd4e2bcd79941e01ea960a068e/ml_dtypes-0.5.4-cp313-cp313-win_arm64.whl", hash = "sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175", size = 160798, upload-time = "2025-11-17T22:32:02.864Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" }, + { url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" }, + { url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" }, + { url = "https://files.pythonhosted.org/packages/8c/27/12607423d0a9c6bbbcc780ad19f1f6baa2b68b18ce4bddcdc122c4c68dc9/ml_dtypes-0.5.4-cp313-cp313t-win_amd64.whl", hash = "sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6", size = 225612, upload-time = "2025-11-17T22:32:08.615Z" }, + { url = "https://files.pythonhosted.org/packages/e5/80/5a5929e92c72936d5b19872c5fb8fc09327c1da67b3b68c6a13139e77e20/ml_dtypes-0.5.4-cp313-cp313t-win_arm64.whl", hash = "sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1", size = 164145, upload-time = "2025-11-17T22:32:09.782Z" }, + { url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" }, + { url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" }, + { url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" }, + { url = "https://files.pythonhosted.org/packages/e9/93/2bfed22d2498c468f6bcd0d9f56b033eaa19f33320389314c19ef6766413/ml_dtypes-0.5.4-cp314-cp314-win_amd64.whl", hash = "sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56", size = 221032, upload-time = "2025-11-17T22:32:15.763Z" }, + { url = "https://files.pythonhosted.org/packages/76/a3/9c912fe6ea747bb10fe2f8f54d027eb265db05dfb0c6335e3e063e74e6e8/ml_dtypes-0.5.4-cp314-cp314-win_arm64.whl", hash = "sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049", size = 163353, upload-time = "2025-11-17T22:32:16.932Z" }, + { url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" }, + { url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" }, + { url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" }, + { url = "https://files.pythonhosted.org/packages/84/44/f4d18446eacb20ea11e82f133ea8f86e2bf2891785b67d9da8d0ab0ef525/ml_dtypes-0.5.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1", size = 236612, upload-time = "2025-11-17T22:32:22.579Z" }, + { url = "https://files.pythonhosted.org/packages/ad/3f/3d42e9a78fe5edf792a83c074b13b9b770092a4fbf3462872f4303135f09/ml_dtypes-0.5.4-cp314-cp314t-win_arm64.whl", hash = "sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d", size = 168825, upload-time = "2025-11-17T22:32:23.766Z" }, +] + [[package]] name = "more-itertools" version = "11.0.1" @@ -1047,6 +1133,7 @@ io = [ [package.dev-dependencies] dev = [ + { name = "jax" }, { name = "mypy" }, { name = "pip" }, { name = "pre-commit" }, @@ -1094,6 +1181,7 @@ provides-extras = ["io"] [package.metadata.requires-dev] dev = [ + { name = "jax", specifier = ">=0.10.1" }, { name = "mypy", specifier = ">=1.15,<2.0.0" }, { name = "pip", specifier = ">=25.1.1" }, { name = "pre-commit", specifier = ">=4.2.0,<5.0.0" }, @@ -1119,6 +1207,15 @@ test = [ ] test-mpi = [{ name = "mpi-pytest", specifier = ">=2025.4.0,<2026.0.0" }] +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" }, +] + [[package]] name = "packaging" version = "26.0" From 5fceb99fcd286566c5d91d433e558f1a200047e1 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Sun, 24 May 2026 17:46:05 -0500 Subject: [PATCH 2/6] Add updated evaluate formats tests --- test/test_evaluate_formats.py | 185 ++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 test/test_evaluate_formats.py diff --git a/test/test_evaluate_formats.py b/test/test_evaluate_formats.py new file mode 100644 index 00000000..4ba425e0 --- /dev/null +++ b/test/test_evaluate_formats.py @@ -0,0 +1,185 @@ +import jax.numpy as jnp +import numpy as np +import pandas as pd +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import opencosmo as oc + + +@pytest.fixture +def input_path(snapshot_path): + return snapshot_path / "haloproperties.hdf5" + + +FORMATS = ["jax", "pandas", "polars", "arrow"] +SCALARS = { + "jax": (jnp.ndarray, np.floating, float), + "pandas": (pd.Series, np.floating, float), + "polars": (pl.Series, float, int), + "arrow": (pa.Scalar, float, int), +} + + +def _vectorized_func(format): + """Multiply two columns using each format's native multiplication path.""" + + if format == "arrow": + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return pc.multiply(fof_halo_mass, fof_halo_com_vx) + + else: + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + return fof_px + + +def _row_func(format): + """Multiply two scalars; works for any format because each row is a scalar.""" + + def fof_px(fof_halo_mass, fof_halo_com_vx): + if isinstance(fof_halo_mass, pa.Scalar): + return fof_halo_mass.as_py() * fof_halo_com_vx.as_py() + return float(fof_halo_mass) * float(fof_halo_com_vx) + + return fof_px + + +def _expected(input_path): + data = ( + oc.open(input_path) + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + return data["fof_halo_mass"] * data["fof_halo_com_vx"] + + +def _to_numpy(value): + if isinstance(value, jnp.ndarray): + return np.asarray(value) + if isinstance(value, (pd.Series, pl.Series)): + return value.to_numpy() + if isinstance(value, pa.Array): + return value.to_numpy(zero_copy_only=False) + return np.asarray(value) + + +# --------------------------------------------------------------------------- +# insert = False (return result directly) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_noinsert(input_path, format): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_noinsert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + result = ds.evaluate( + _row_func(format), vectorize=False, insert=False, format=format + ) + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_noinsert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + result = ds.evaluate( + _vectorized_func(format), + insert=False, + batch_size=batch_size, + format=format, + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +# --------------------------------------------------------------------------- +# insert = True (converted to numpy, stored in cache) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_insert(input_path, format): + ds = oc.open(input_path) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_insert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + ds = ds.evaluate(_row_func(format), vectorize=False, insert=True, format=format) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_insert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + ds = ds.evaluate( + _vectorized_func(format), + insert=True, + batch_size=batch_size, + format=format, + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Output-type assertions on the not-insert path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "format,expected_type", + [ + ("jax", jnp.ndarray), + ("pandas", pd.Series), + ("polars", pl.Series), + ("arrow", pa.Array), + ], +) +def test_evaluate_noinsert_returns_native_container(input_path, format, expected_type): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert isinstance(result["fof_px"], expected_type) From 91d1a7850a6d6a5526ad92be216020c3cbde81de Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:21:58 -0500 Subject: [PATCH 3/6] Add support for other formats to StructureCollection.evaluate --- .../collection/lightcone/lightcone.py | 34 ++-- .../collection/structure/evaluate.py | 127 ++++----------- .../collection/structure/structure.py | 22 ++- python/opencosmo/dataset/dataset.py | 17 +- python/opencosmo/dataset/formats.py | 50 +++--- test/test_evaluate_formats.py | 152 ++++++++++++++++++ 6 files changed, 270 insertions(+), 132 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index f82410e7..98766124 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -27,7 +27,7 @@ from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn from opencosmo.dataset.evaluate import build_evaluated_column -from opencosmo.dataset.formats import convert_data, verify_format +from opencosmo.dataset.formats import concat_chunks, convert_data, verify_format from opencosmo.dataset.take import ( get_end_take_index, get_random_take_index, @@ -325,7 +325,13 @@ def get_pixels(self, nside: int = 64): return lcutils.get_pixels(self, int(level)) - def get_data(self, format="astropy", unpack: bool = True, **kwargs): + def get_data( + self, + format="astropy", + unpack: bool = True, + wrap_single: bool = False, + **kwargs, + ): """ Get the data in this dataset as an astropy table/column or as numpy array(s). Note that a dataset does not load data from disk into @@ -340,7 +346,9 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): units will be attached. If the dataset only contains a single column, it will be returned as an - astropy.table.Column or a single numpy array. + astropy.table.Column or a single numpy array. Pass :code:`wrap_single=True` + to always return the format's multi-column container (QTable, DataFrame, + dict, ...) regardless of column count. Parameters ---------- @@ -348,6 +356,10 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): The format to output the data in. Currently supported are "astropy", "numpy", "pandas", "polars", and "arrow" + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. + Returns ------- data: Table | Column | dict[str, ndarray] | ndarray @@ -383,11 +395,11 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): key: value[0] if len(value) == 1 else value for key, value in table.items() } - return convert_data(output_data, format) + return convert_data(output_data, format, wrap_single=wrap_single) if format != "astropy": - return convert_data(dict(table), format) - elif len(table.columns) == 1: + return convert_data(dict(table), format, wrap_single=wrap_single) + elif len(table.columns) == 1 and not wrap_single: return next(iter(dict(table).values())) return table @@ -758,9 +770,11 @@ def evaluate( The function to evaluate on the rows in the dataset. format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. vectorize: bool, default = False Whether to provide the values as full columns (True) or one row at a time (False) @@ -834,7 +848,7 @@ def evaluate( keys = next(iter(result.values())).keys() output = {} for key in keys: - output[key] = np.concatenate([r[key] for r in result.values()]) + output[key] = concat_chunks([r[key] for r in result.values()], format) return output def filter(self, *masks: ColumnMask, **kwargs) -> Self: diff --git a/python/opencosmo/collection/structure/evaluate.py b/python/opencosmo/collection/structure/evaluate.py index 7bed468a..efc80cb8 100644 --- a/python/opencosmo/collection/structure/evaluate.py +++ b/python/opencosmo/collection/structure/evaluate.py @@ -1,16 +1,12 @@ from __future__ import annotations from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional -import numpy as np from astropy.units import Quantity # type: ignore from opencosmo import dataset as ds -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) +from opencosmo.dataset.formats import concat_chunks, stack_rows if TYPE_CHECKING: from opencosmo import StructureCollection @@ -82,17 +78,22 @@ def evaluate_into_properties( kwargs: dict[str, Any], insert: bool, ): - storage = __make_output(function, collection, format, kwargs, {}, insert) - for i, structure in enumerate(collection.objects()): - if i == 0: - continue + per_column: dict[str, list] = {} + for structure in collection.objects(): input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) - if storage is not None: - insert_data(storage, i, output) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) + if not isinstance(output, dict): + output = {function.__name__: output} + for name, value in output.items(): + per_column.setdefault(name, []).append(value) - return storage + if not per_column: + return None + return {name: stack_rows(values, format) for name, values in per_column.items()} def evaluate_into_dataset( @@ -103,25 +104,28 @@ def evaluate_into_dataset( dataset: str, insert: bool, ): - storage = __make_chunked_output(function, collection, dataset, format, kwargs, {}) - + per_column: dict[str, list] = {} for i, structure in enumerate(collection.objects()): - if i == 0: - continue input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) if not isinstance(output, dict): output = {function.__name__: output} - - if storage is not None: - for name, output_arr in output.items(): - storage[name].append(output_arr) - - if storage is None: - return - output_data = {name: np.concatenate(data) for name, data in storage.items()} - return output_data + if i == 0: + expected_length = len(input_structure[dataset]) + if any(len(v) != expected_length for v in output.values()): + raise ValueError( + "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" + ) + for name, output_arr in output.items(): + per_column.setdefault(name, []).append(output_arr) + + if not per_column: + return None + return {name: concat_chunks(data, format) for name, data in per_column.items()} def __make_input(structure: dict, format: str = "astropy"): @@ -130,78 +134,15 @@ def __make_input(structure: dict, format: str = "astropy"): if isinstance(element, dict): values[name] = __make_input(element, format) elif isinstance(element, ds.Dataset): - data = element.get_data(format) + data = element.get_data(format, wrap_single=True) values[name] = data - elif isinstance(element, Quantity) and format == "numpy": + elif isinstance(element, Quantity) and format != "astropy": values[name] = element.value else: values[name] = element return values -def __make_output( - function: Callable, - collection: StructureCollection, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = next(collection.take(1, at="start").objects()) - first_input = __make_input(first_structure, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - n_rows = len(collection) - return make_output_from_first_values(first_values, n_rows) - - -def __make_chunked_output( - function: Callable, - collection: StructureCollection, - dataset: str, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = collection.take(1, at="start") - expected_length = len(first_structure[dataset]) - first_structure_data = next(iter(first_structure.objects())) - - first_input = __make_input(first_structure_data, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - if any(len(fv) != expected_length for fv in first_values.values()): - raise ValueError( - "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" - ) - return {name: [fv] for name, fv in first_values.items()} - - def __prepare_collection( spec: dict[str, Optional[list[str]]], collection: StructureCollection ) -> StructureCollection: diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 56b69ca5..d5100968 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -21,6 +21,7 @@ from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import evaluate from opencosmo.collection.structure import io as sio +from opencosmo.dataset.formats import verify_format from opencosmo.index.unary import get_length from opencosmo.io.schema import FileEntry, make_schema @@ -597,8 +598,11 @@ def computation(halo_properties, dm_particles): collection contains galaxies. If False, simply return the data. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. **evaluate_kwargs: any, Any additional arguments that are required for your function to run. These will be passed directly @@ -623,8 +627,7 @@ def computation(halo_properties, dm_particles): **evaluate_kwargs, ) - if format not in ["astropy", "numpy"]: - raise ValueError(f"Invalid format requested for data: {format}") + verify_format(format) if dataset is not None and dataset.startswith("galaxies"): # Nested structure collection, special case @@ -701,10 +704,12 @@ def computation(halo_properties, dm_particles): ) if not insert or output is None: return output + from opencosmo.dataset.formats import to_numpy_dict + return self.with_new_columns( - **output, dataset=dataset if dataset is not None else self.__source.dtype, allow_overwrite=allow_overwrite, + **to_numpy_dict(output), # type: ignore ) def evaluate_on_dataset( @@ -751,8 +756,11 @@ def evaluate_on_dataset( Whether to provide the values as full columns (True) or one row at a time (False). Ignored if :code:`batch_size` is set. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. insert: bool, default = True If true, the data will be inserted as a column in this dataset. The new column will have the same name diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 7066b2ef..cc591d53 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -260,7 +260,12 @@ def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False) return st.get_metadata(self.__state, columns, ignore_sort) def get_data( - self, format="astropy", unpack=True, metadata_columns=[], **kwargs + self, + format="astropy", + unpack=True, + metadata_columns=[], + wrap_single=False, + **kwargs, ) -> OpenCosmoData: """ Get the data in this dataset as an astropy table/column or as @@ -280,7 +285,9 @@ def get_data( If the dataset only contains a single column, it will not be put in a table or dictionary. "astropy", "numpy" and "arrow" will return a single array - in this case, while "polars" and "pandas" will return a Series object. + in this case, while "polars" and "pandas" will return a Series object. Pass + :code:`wrap_single=True` to always return the format's multi-column container + (QTable, DataFrame, dict, ...) regardless of column count. Parameters ---------- @@ -288,6 +295,10 @@ def get_data( The format to output the data in. Currently supported are "astropy", "numpy", "pandas", "polars", "arrow", "jax" + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. + Returns ------- data: Any @@ -321,7 +332,7 @@ def get_data( for key, value in data.items() } - return convert_data(data, format) + return convert_data(data, format, wrap_single=wrap_single) def bound(self, region: Region, select_by: Optional[str] = None): """ diff --git a/python/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py index 2ac5f711..ff70f72b 100644 --- a/python/opencosmo/dataset/formats.py +++ b/python/opencosmo/dataset/formats.py @@ -40,20 +40,29 @@ def __verify_import(import_name: str, format_name: str): ) -def convert_data(data: dict[str, np.ndarray], output_format: str): +def convert_data( + data: dict[str, np.ndarray], output_format: str, wrap_single: bool = False +): + """ + If `wrap_single` is True, the result is always the format's natural + multi-column container (QTable, DataFrame, dict[name, array], etc.) even + when there is only one column, instead of collapsing to a bare array / + Series. Used by callers (e.g. evaluate) that want a uniform + `container[colname]` access pattern regardless of column count. + """ match output_format: case "astropy": - return __convert_to_astropy(data) + return __convert_to_astropy(data, wrap_single) case "numpy": - return convert_to_numpy(data) + return convert_to_numpy(data, wrap_single) case "pandas": - return __convert_to_pandas(data) + return __convert_to_pandas(data, wrap_single) case "polars": - return __convert_to_polars(data) + return __convert_to_polars(data, wrap_single) case "arrow": - return __convert_to_arrow(data) + return __convert_to_arrow(data, wrap_single) case "jax": - return __convert_to_jax(data) + return __convert_to_jax(data, wrap_single) case _: raise ValueError(f"Unknown data output format {output_format}") @@ -196,8 +205,10 @@ def concat_chunks(chunks: list, output_format: str): raise ValueError(f"Unknown data output format {output_format}") -def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: - if len(data) == 1: +def __convert_to_astropy( + data: dict[str, np.ndarray], wrap_single: bool = False +) -> QTable: + if len(data) == 1 and not wrap_single: return next(iter(data.values())) if any( (isinstance(d, u.Quantity) and d.isscalar) or not isinstance(d, np.ndarray) @@ -210,6 +221,7 @@ def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: def convert_to_numpy( data: dict[str, np.ndarray], + wrap_single: bool = False, ) -> dict[str, np.ndarray] | np.ndarray: converted_data = dict( map( @@ -220,25 +232,25 @@ def convert_to_numpy( data.items(), ) ) - if len(converted_data) == 1: + if len(converted_data) == 1 and not wrap_single: return next(iter(converted_data.values())) return converted_data -def __convert_to_pandas(data: dict[str, np.ndarray]): +def __convert_to_pandas(data: dict[str, np.ndarray], wrap_single: bool = False): import pandas as pd - numpy_data = convert_to_numpy(data) - if isinstance(numpy_data, np.ndarray): # only one column + numpy_data = convert_to_numpy(data, wrap_single) + if isinstance(numpy_data, np.ndarray): # only one column, wrap_single=False return pd.Series(numpy_data, name=next(iter(data.keys()))) return pd.DataFrame(numpy_data, copy=True) -def __convert_to_arrow(data: dict[str, np.ndarray]): +def __convert_to_arrow(data: dict[str, np.ndarray], wrap_single: bool = False): import pyarrow as pa # type: ignore - numpy_data = convert_to_numpy(data) + numpy_data = convert_to_numpy(data, wrap_single) if isinstance(numpy_data, np.ndarray): return pa.array(numpy_data) @@ -249,20 +261,20 @@ def __convert_to_arrow(data: dict[str, np.ndarray]): return dict(converted_data) -def __convert_to_polars(data: dict[str, np.ndarray]): +def __convert_to_polars(data: dict[str, np.ndarray], wrap_single: bool = False): import polars as pl - numpy_data = convert_to_numpy(data) + numpy_data = convert_to_numpy(data, wrap_single) if isinstance(numpy_data, np.ndarray): return pl.Series(name=next(iter(data.keys())), values=numpy_data) return pl.from_dict(data) # type: ignore -def __convert_to_jax(data: dict[str, np.ndarray]): +def __convert_to_jax(data: dict[str, np.ndarray], wrap_single: bool = False): import jax.numpy as jnp - output_data = convert_to_numpy(data) + output_data = convert_to_numpy(data, wrap_single) if isinstance(output_data, np.ndarray): return jnp.asarray(output_data) return {key: jnp.asarray(value) for key, value in output_data.items()} diff --git a/test/test_evaluate_formats.py b/test/test_evaluate_formats.py index 4ba425e0..6b1363c9 100644 --- a/test/test_evaluate_formats.py +++ b/test/test_evaluate_formats.py @@ -183,3 +183,155 @@ def test_evaluate_noinsert_returns_native_container(input_path, format, expected _vectorized_func(format), vectorize=True, insert=False, format=format ) assert isinstance(result["fof_px"], expected_type) + + +# --------------------------------------------------------------------------- +# StructureCollection paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def halo_paths(snapshot_path): + files = ["haloproperties.hdf5", "haloparticles.hdf5"] + return [snapshot_path / f for f in files] + + +def _mean_x(format): + """Per-structure function: mean of dm_particles 'x' coord. Returns a scalar + in the user's format. fof_halo_center_x is a scalar Quantity that has had + its unit stripped for non-astropy formats, so it's plain float.""" + + def offset(halo_properties, dm_particles): + x = dm_particles["x"] + if format == "arrow": + mean_x = pc.mean(x).as_py() + elif format == "polars": + mean_x = x.mean() + elif format == "pandas": + mean_x = float(x.mean()) + elif format == "jax": + mean_x = float(jnp.mean(x)) + else: + mean_x = float(np.mean(x)) + return mean_x - float(halo_properties["fof_halo_center_x"]) + + return offset + + +def _arange_like(format): + """Per-structure function with dataset=`dm_particles`: must return an array + in the user's format with the same length as the input dataset.""" + + def particle_id(x, y, z): + n = len(x) + if format == "jax": + return jnp.arange(n) + if format == "pandas": + return pd.Series(np.arange(n)) + if format == "polars": + return pl.Series(values=np.arange(n)) + if format == "arrow": + return pa.array(np.arange(n)) + return np.arange(n) + + return particle_id + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + collection = collection.evaluate( + _mean_x(format), **spec, format=format, insert=True + ) + data = collection["halo_properties"].select("offset").get_data("numpy") + assert len(data) == 50 + assert np.any(data != 0) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties_noinsert(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + result = collection.evaluate(_mean_x(format), **spec, format=format, insert=False) + assert "offset" in result + assert len(result["offset"]) == 50 + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_dataset(halo_paths, format): + collection = oc.open(*halo_paths).take(20) + collection = collection.evaluate( + _arange_like(format), + dataset="dm_particles", + format=format, + insert=True, + ) + for halo in collection.halos(["dm_particles"]): + pid = halo["dm_particles"].select("particle_id").get_data("numpy") + assert np.all(pid == np.arange(len(pid))) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_on_dataset(halo_paths, format): + """Routes through Dataset.evaluate via the collection wrapper.""" + collection = oc.open(*halo_paths).take(50) + selected = ( + collection["halo_properties"] + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + collection = collection.evaluate_on_dataset( + _vectorized_func(format), + dataset="halo_properties", + vectorize=True, + format=format, + insert=True, + ) + data = collection["halo_properties"].select("fof_px").get_data("numpy") + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Lightcone paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def lc_paths(lightcone_path): + return [ + lightcone_path / "step_600" / "haloproperties.hdf5", + lightcone_path / "step_601" / "haloproperties.hdf5", + ] + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_insert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + for name in ds.keys(): + data = ds[name].select("fof_px").get_data("numpy") + original = ( + ds[name].select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + ) + expected = original["fof_halo_mass"] * original["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_noinsert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert "fof_px" in result + assert len(result["fof_px"]) == len(ds) From df709ec9dfee70b63840c1b4596965e448f565e5 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:24:31 -0500 Subject: [PATCH 4/6] Update docstrings to better reflect behavior --- python/opencosmo/collection/lightcone/lightcone.py | 3 +-- python/opencosmo/collection/simulation/simulation.py | 8 +++++--- python/opencosmo/collection/structure/structure.py | 6 ++---- python/opencosmo/dataset/dataset.py | 3 +-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 98766124..94be05ba 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -773,8 +773,7 @@ def evaluate( The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. vectorize: bool, default = False Whether to provide the values as full columns (True) or one row at a time (False) diff --git a/python/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py index e3284c16..30d04ef5 100644 --- a/python/opencosmo/collection/simulation/simulation.py +++ b/python/opencosmo/collection/simulation/simulation.py @@ -384,9 +384,11 @@ def evaluate( datasets: str | list[str], optional The datasets to evaluate on. If not provided, will be evaluated on all datasets format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. + vectorize: bool, default = False Whether to vectorize the computation. See :py:meth:`StructureCollection.evaluate ` and/or :py:meth:`Dataset.evaluate ` for more details. diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index d5100968..26c4430c 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -601,8 +601,7 @@ def computation(halo_properties, dm_particles): The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. **evaluate_kwargs: any, Any additional arguments that are required for your function to run. These will be passed directly @@ -759,8 +758,7 @@ def evaluate_on_dataset( The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. insert: bool, default = True If true, the data will be inserted as a column in this dataset. The new column will have the same name diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index cc591d53..4dcba8e0 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -492,8 +492,7 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. allow_overwrite: bool, default = False batch_size: int, default = -1 From a8edab8f78b0633288594a0f6e414b2bcc85d7aa Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:27:18 -0500 Subject: [PATCH 5/6] Add changelog --- changes/+95437205.feature.rst | 1 + changes/+c414feda.feature.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changes/+95437205.feature.rst create mode 100644 changes/+c414feda.feature.rst diff --git a/changes/+95437205.feature.rst b/changes/+95437205.feature.rst new file mode 100644 index 00000000..f59c1e7f --- /dev/null +++ b/changes/+95437205.feature.rst @@ -0,0 +1 @@ +:py:meth:`get_data ` now supports :code:`jax` as an output format. diff --git a/changes/+c414feda.feature.rst b/changes/+c414feda.feature.rst new file mode 100644 index 00000000..d0e6e288 --- /dev/null +++ b/changes/+c414feda.feature.rst @@ -0,0 +1 @@ +All :code:`evaluate` methods (e.g. :py:meth:`Dataset.evaluate `) now support passing data to the function in any format supported by :py:meth:`get_data `. From 7697925b36bd89a034e2452b5009024b3c1708ee Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:39:31 -0500 Subject: [PATCH 6/6] Delete a bunch of dead code --- python/opencosmo/dataset/evaluate.py | 85 +--------------------------- python/opencosmo/evaluate.py | 48 ---------------- 2 files changed, 1 insertion(+), 132 deletions(-) delete mode 100644 python/opencosmo/evaluate.py diff --git a/python/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py index e2f84210..2f14e24c 100644 --- a/python/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -2,8 +2,7 @@ from collections import defaultdict from inspect import Parameter, signature -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np from astropy.units import Quantity @@ -11,10 +10,6 @@ from opencosmo.column.column import EvaluatedColumn from opencosmo.column.evaluate import EvaluateStrategy, do_first_evaluation from opencosmo.dataset.formats import concat_chunks, fetch_as_dict -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) if TYPE_CHECKING: from opencosmo import Dataset @@ -165,84 +160,6 @@ def verify_for_lazy_evaluation( return column -def __visit_rows_in_dataset( - function: Callable, - dataset: Dataset, - format: str, - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, -): - first_row_values = dict(dataset.take(1, at="start").get_data()) - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - storage = __make_output(function, first_row_values | first_row_kwargs, len(dataset)) - for i, row in enumerate(dataset.rows(include_units=format == "astropy")): - if i == 0: - continue - iter_kwargs = {name: arr[i] for name, arr in iterable_kwargs.items()} - output = function(**row, **kwargs, **iter_kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __visit_rows_in_data( - function: Callable, - data: dict[str, np.ndarray], - format="astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, np.ndarray] = {}, -): - data = {key: d for key, d in data.items() if key in signature(function).parameters} - first_row_data = {name: arr[0] for name, arr in data.items()} - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - n_rows = len(next(iter(data.values()))) - storage = __make_output(function, first_row_data | first_row_kwargs, n_rows) - if format == "numpy": - data = { - key: arr.value if isinstance(arr, Quantity) else arr - for key, arr in data.items() - } - - for i in range(1, n_rows): - row = { - name: arr[i] for name, arr in chain(data.items(), iterable_kwargs.items()) - } - output = function(**row, **kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __make_output( - function: Callable, - first_input_values: dict[str, Any], - n_rows: int, -) -> dict | None: - first_values = function(**first_input_values) - if first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - - return make_output_from_first_values(first_values, n_rows) - - -def __visit_vectorize( - function: Callable, - data: dict[str, Iterable] | Iterable, - evaluator_kwargs: dict[str, Any] = {}, -): - pars = signature(function).parameters - - if not isinstance(data, dict) or (len(data) > 1 and len(pars) == 1): - return function(data, **evaluator_kwargs) - - input_data = {pname: data[pname] for pname in pars if pname in data} - - return function(**input_data, **evaluator_kwargs) - - def __verify( function: Callable, data_columns: Iterable[str], kwarg_names: Iterable[str] ): diff --git a/python/opencosmo/evaluate.py b/python/opencosmo/evaluate.py deleted file mode 100644 index 74db256d..00000000 --- a/python/opencosmo/evaluate.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Any - -import astropy.units as u -import numpy as np - -""" -General helper routines for evaluating expressions on datasets and collections -""" - - -def insert_data( - storage: dict[str, np.ndarray], index: int, values_to_insert: dict[str, Any] -): - if isinstance(values_to_insert, dict): - for name, value in values_to_insert.items(): - storage[name][index] = value - return storage - - name = next(iter(storage.keys())) - storage[name][index] = values_to_insert - - -def make_output_from_first_values(first_values: dict, n_rows: int): - storage = {} - new_first_values = {} - for name, value in first_values.items(): - shape: tuple[int, ...] = (n_rows,) - dtype = type(value) - if not isinstance(value, np.ndarray): - new_first_values[name] = value - elif isinstance(value, u.Quantity) and value.isscalar: - dtype = value.value.dtype - new_first_values[name] = value - elif isinstance(value, np.ndarray) and len(value) == 1: - dtype = value.dtype - new_first_values[name] = value[0] - else: - dtype = value.dtype - shape = shape + value.shape - new_first_values[name] = value - - storage[name] = np.zeros(shape, dtype=dtype) - for name, value in new_first_values.items(): - if isinstance(value, u.Quantity): - storage[name] = storage[name] * value.unit - - storage[name][0] = value - return storage