Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changes/+95437205.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:py:meth:`get_data <opencosmo.Dataset.get_data>` now supports :code:`jax` as an output format.
1 change: 1 addition & 0 deletions changes/+c414feda.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
All :code:`evaluate` methods (e.g. :py:meth:`Dataset.evaluate <opencosmo.Dataset.evaluate>`) now support passing data to the function in any format supported by :py:meth:`get_data <opencosmo.Dataset.get_data>`.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dev = [
"pip>=25.1.1",
"pytest>=8.3.4,<9.0.0",
"pytest-timeout>=2.4.0",
"jax>=0.10.1",
]
docs = [
"sphinx>=9.0.0",
Expand Down
33 changes: 23 additions & 10 deletions python/opencosmo/collection/lightcone/lightcone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema
from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn
from opencosmo.dataset.evaluate import build_evaluated_column
from opencosmo.dataset.formats import convert_data, verify_format
from opencosmo.dataset.formats import concat_chunks, convert_data, verify_format
from opencosmo.dataset.take import (
get_end_take_index,
get_random_take_index,
Expand Down Expand Up @@ -325,7 +325,13 @@ def get_pixels(self, nside: int = 64):

return lcutils.get_pixels(self, int(level))

def get_data(self, format="astropy", unpack: bool = True, **kwargs):
def get_data(
self,
format="astropy",
unpack: bool = True,
wrap_single: bool = False,
**kwargs,
):
"""
Get the data in this dataset as an astropy table/column or as
numpy array(s). Note that a dataset does not load data from disk into
Expand All @@ -340,14 +346,20 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs):
units will be attached.

If the dataset only contains a single column, it will be returned as an
astropy.table.Column or a single numpy array.
astropy.table.Column or a single numpy array. Pass :code:`wrap_single=True`
to always return the format's multi-column container (QTable, DataFrame,
dict, ...) regardless of column count.

Parameters
----------
output: str, default="astropy"
The format to output the data in. Currently supported are "astropy", "numpy",
"pandas", "polars", and "arrow"

wrap_single: bool, default=False
If True, always return the format's natural multi-column container even
when only one column is present.

Returns
-------
data: Table | Column | dict[str, ndarray] | ndarray
Expand Down Expand Up @@ -383,11 +395,11 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs):
key: value[0] if len(value) == 1 else value
for key, value in table.items()
}
return convert_data(output_data, format)
return convert_data(output_data, format, wrap_single=wrap_single)

if format != "astropy":
return convert_data(dict(table), format)
elif len(table.columns) == 1:
return convert_data(dict(table), format, wrap_single=wrap_single)
elif len(table.columns) == 1 and not wrap_single:
return next(iter(dict(table).values()))

return table
Expand Down Expand Up @@ -758,9 +770,10 @@ def evaluate(
The function to evaluate on the rows in the dataset.

format: str, default = "astropy"
The format of the data that is provided to your function. If "astropy", will be a dictionary of
astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that
this method does not support all the formats available in :py:meth:`get_data <opencosmo.Lightcone.get_data>`
The format in which to provide column data to your function. Supports the same formats
as :py:meth:`get_data <opencosmo.Lightcone.get_data>` ("astropy", "numpy", "pandas",
"polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted
back to numpy before being stored.

vectorize: bool, default = False
Whether to provide the values as full columns (True) or one row at a time (False)
Expand Down Expand Up @@ -834,7 +847,7 @@ def evaluate(
keys = next(iter(result.values())).keys()
output = {}
for key in keys:
output[key] = np.concatenate([r[key] for r in result.values()])
output[key] = concat_chunks([r[key] for r in result.values()], format)
return output

def filter(self, *masks: ColumnMask, **kwargs) -> Self:
Expand Down
8 changes: 5 additions & 3 deletions python/opencosmo/collection/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,11 @@ def evaluate(
datasets: str | list[str], optional
The datasets to evaluate on. If not provided, will be evaluated on all datasets
format: str, default = "astropy"
The format of the data that is provided to your function. If "astropy", will be a dictionary of
astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that
this method does not support all the formats available in :py:meth:`get_data <opencosmo.Dataset.get_data>`
The format in which to provide column data to your function. Supports the same formats
as :py:meth:`get_data <opencosmo.Dataset.get_data>` ("astropy", "numpy", "pandas",
"polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted
back to numpy before being stored.

vectorize: bool, default = False
Whether to vectorize the computation. See :py:meth:`StructureCollection.evaluate <opencosmo.StructureCollection.Evaluate>`
and/or :py:meth:`Dataset.evaluate <opencosmo.Dataset.Evaluate>` for more details.
Expand Down
127 changes: 34 additions & 93 deletions python/opencosmo/collection/structure/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional

import numpy as np
from astropy.units import Quantity # type: ignore

from opencosmo import dataset as ds
from opencosmo.evaluate import (
insert_data,
make_output_from_first_values,
)
from opencosmo.dataset.formats import concat_chunks, stack_rows

if TYPE_CHECKING:
from opencosmo import StructureCollection
Expand Down Expand Up @@ -82,17 +78,22 @@ def evaluate_into_properties(
kwargs: dict[str, Any],
insert: bool,
):
storage = __make_output(function, collection, format, kwargs, {}, insert)
for i, structure in enumerate(collection.objects()):
if i == 0:
continue
per_column: dict[str, list] = {}
for structure in collection.objects():
input_structure = __make_input(structure, format)

output = function(**input_structure, **kwargs)
if storage is not None:
insert_data(storage, i, output)
if output is None and insert:
raise ValueError(
"You asked to insert these values, but your function returns None!"
)
if not isinstance(output, dict):
output = {function.__name__: output}
for name, value in output.items():
per_column.setdefault(name, []).append(value)

return storage
if not per_column:
return None
return {name: stack_rows(values, format) for name, values in per_column.items()}


def evaluate_into_dataset(
Expand All @@ -103,25 +104,28 @@ def evaluate_into_dataset(
dataset: str,
insert: bool,
):
storage = __make_chunked_output(function, collection, dataset, format, kwargs, {})

per_column: dict[str, list] = {}
for i, structure in enumerate(collection.objects()):
if i == 0:
continue
input_structure = __make_input(structure, format)

output = function(**input_structure, **kwargs)
if output is None and insert:
raise ValueError(
"You asked to insert these values, but your function returns None!"
)
if not isinstance(output, dict):
output = {function.__name__: output}

if storage is not None:
for name, output_arr in output.items():
storage[name].append(output_arr)

if storage is None:
return
output_data = {name: np.concatenate(data) for name, data in storage.items()}
return output_data
if i == 0:
expected_length = len(input_structure[dataset])
if any(len(v) != expected_length for v in output.values()):
raise ValueError(
"If you pass a `dataset` argument, your function should output an array with the same length as that dataset"
)
for name, output_arr in output.items():
per_column.setdefault(name, []).append(output_arr)

if not per_column:
return None
return {name: concat_chunks(data, format) for name, data in per_column.items()}


def __make_input(structure: dict, format: str = "astropy"):
Expand All @@ -130,78 +134,15 @@ def __make_input(structure: dict, format: str = "astropy"):
if isinstance(element, dict):
values[name] = __make_input(element, format)
elif isinstance(element, ds.Dataset):
data = element.get_data(format)
data = element.get_data(format, wrap_single=True)
values[name] = data
elif isinstance(element, Quantity) and format == "numpy":
elif isinstance(element, Quantity) and format != "astropy":
values[name] = element.value
else:
values[name] = element
return values


def __make_output(
function: Callable,
collection: StructureCollection,
format: str = "astropy",
kwargs: dict[str, Any] = {},
iterable_kwargs: dict[str, Sequence] = {},
insert: bool = True,
) -> dict | None:
first_structure = next(collection.take(1, at="start").objects())
first_input = __make_input(first_structure, format)
first_values = function(
**first_input,
**kwargs,
**{name: arr[0] for name, arr in iterable_kwargs.items()},
)
if first_values is None and insert:
raise ValueError(
"You asked to insert these values, but your function returns None!"
)
elif first_values is None:
return None
if not isinstance(first_values, dict):
name = function.__name__
first_values = {name: first_values}
n_rows = len(collection)
return make_output_from_first_values(first_values, n_rows)


def __make_chunked_output(
function: Callable,
collection: StructureCollection,
dataset: str,
format: str = "astropy",
kwargs: dict[str, Any] = {},
iterable_kwargs: dict[str, Sequence] = {},
insert: bool = True,
) -> dict | None:
first_structure = collection.take(1, at="start")
expected_length = len(first_structure[dataset])
first_structure_data = next(iter(first_structure.objects()))

first_input = __make_input(first_structure_data, format)
first_values = function(
**first_input,
**kwargs,
**{name: arr[0] for name, arr in iterable_kwargs.items()},
)
if first_values is None and insert:
raise ValueError(
"You asked to insert these values, but your function returns None!"
)
elif first_values is None:
return None
if not isinstance(first_values, dict):
name = function.__name__
first_values = {name: first_values}
if any(len(fv) != expected_length for fv in first_values.values()):
raise ValueError(
"If you pass a `dataset` argument, your function should output an array with the same length as that dataset"
)
return {name: [fv] for name, fv in first_values.items()}


def __prepare_collection(
spec: dict[str, Optional[list[str]]], collection: StructureCollection
) -> StructureCollection:
Expand Down
20 changes: 13 additions & 7 deletions python/opencosmo/collection/structure/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from opencosmo.collection.lightcone import lightcone as lc
from opencosmo.collection.structure import evaluate
from opencosmo.collection.structure import io as sio
from opencosmo.dataset.formats import verify_format
from opencosmo.index.unary import get_length
from opencosmo.io.schema import FileEntry, make_schema

Expand Down Expand Up @@ -597,8 +598,10 @@ def computation(halo_properties, dm_particles):
collection contains galaxies. If False, simply return the data.

format: str, default = astropy
Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that
this method does not support all the formats available in :py:meth:`get_data <opencosmo.Dataset.get_data>`
The format in which to provide column data to your function. Supports the same formats
as :py:meth:`get_data <opencosmo.Dataset.get_data>` ("astropy", "numpy", "pandas",
"polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted
back to numpy before being stored.

**evaluate_kwargs: any,
Any additional arguments that are required for your function to run. These will be passed directly
Expand All @@ -623,8 +626,7 @@ def computation(halo_properties, dm_particles):
**evaluate_kwargs,
)

if format not in ["astropy", "numpy"]:
raise ValueError(f"Invalid format requested for data: {format}")
verify_format(format)

if dataset is not None and dataset.startswith("galaxies"):
# Nested structure collection, special case
Expand Down Expand Up @@ -701,10 +703,12 @@ def computation(halo_properties, dm_particles):
)
if not insert or output is None:
return output
from opencosmo.dataset.formats import to_numpy_dict

return self.with_new_columns(
**output,
dataset=dataset if dataset is not None else self.__source.dtype,
allow_overwrite=allow_overwrite,
**to_numpy_dict(output), # type: ignore
)

def evaluate_on_dataset(
Expand Down Expand Up @@ -751,8 +755,10 @@ def evaluate_on_dataset(
Whether to provide the values as full columns (True) or one row at a time (False). Ignored if :code:`batch_size` is set.

format: str, default = astropy
Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that
this method does not support all the formats available in :py:meth:`get_data <opencosmo.Dataset.get_data>`
The format in which to provide column data to your function. Supports the same formats
as :py:meth:`get_data <opencosmo.Dataset.get_data>` ("astropy", "numpy", "pandas",
"polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted
back to numpy before being stored.

insert: bool, default = True
If true, the data will be inserted as a column in this dataset. The new column will have the same name
Expand Down
29 changes: 22 additions & 7 deletions python/opencosmo/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,11 +865,6 @@ def get_units(self, units: dict[str, np.ndarray]):
def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None):
data = {name: data[name] for name in self.__requires}
chunk_sizes = index[1] if isinstance(index, tuple) else None
if self.__format != "astropy":
data = {
name: val.value if isinstance(val, u.Quantity) else val
for name, val in data.items()
}

if self.batch_size > 0:
length = len(next(iter(data.values())))
Expand All @@ -886,13 +881,33 @@ def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None):
case EvaluateStrategy.VECTORIZE:
return evaluate_vectorized(data, self.__func, self.__kwargs, index)
case EvaluateStrategy.ROW_WISE:
return evaluate_rows(data, self.__func, self.__kwargs)
return evaluate_rows(data, self.__func, self.__kwargs, self.__format)
case EvaluateStrategy.CHUNKED:
if chunk_sizes is None:
raise ValueError(
"Cannot evaluate in CHUNKED strategy with a non-chunked index"
)
return evaluate_chunks(data, self.__func, self.__kwargs, chunk_sizes)
return evaluate_chunks(
data, self.__func, self.__kwargs, chunk_sizes, self.__format
)

def evaluate_for_storage(
self, data: dict[str, np.ndarray], index: DataIndex | None
) -> dict[str, np.ndarray]:
"""
Evaluate and return numpy-formatted output suitable for the column
cache. Input arrives in the numpy/astropy form used internally, so
it is first converted to the user's requested format before the
function runs, and the output is converted back to numpy.
"""
from opencosmo.dataset.formats import to_format_dict, to_numpy_dict

required = {name: data[name] for name in self.__requires}
converted = to_format_dict(required, self.__format)
output = self.evaluate(converted, index)
if not isinstance(output, dict):
output = {next(iter(self.__produces)): output}
return to_numpy_dict(output)

def evaluate_one(self, dataset: Dataset):
match self.__strategy:
Expand Down
Loading
Loading