From 24f99645c8a1b91c524fd2b335f863c3f8c89840 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 11:38:51 -0500 Subject: [PATCH 001/139] Implement some index operations --- Cargo.lock | 229 ++++++++++++++++++ Cargo.toml | 14 ++ pyproject.toml | 13 +- {src => python}/opencosmo/__init__.py | 0 .../opencosmo/analysis/__init__.py | 0 {src => python}/opencosmo/analysis/cli.py | 0 {src => python}/opencosmo/analysis/diffsky.py | 0 .../opencosmo/analysis/install/__init__.py | 0 .../opencosmo/analysis/install/install.py | 0 .../opencosmo/analysis/install/source.py | 0 .../analysis/install/specs/__init__.py | 0 .../analysis/install/specs/diffsky.json | 0 .../analysis/install/specs/vizualize.json | 0 .../opencosmo/analysis/install/versions.py | 0 {src => python}/opencosmo/analysis/mpi.py | 0 .../opencosmo/analysis/yt_utils.py | 8 +- {src => python}/opencosmo/analysis/yt_viz.py | 10 +- .../opencosmo/collection/__init__.py | 0 .../collection/lightcone/__init__.py | 0 .../collection/lightcone/coordinates.py | 0 .../collection/lightcone/healpix_map.py | 0 .../collection/lightcone/lightcone.py | 0 .../opencosmo/collection/lightcone/stack.py | 0 .../opencosmo/collection/protocols.py | 0 .../collection/simulation/__init__.py | 0 .../collection/simulation/simulation.py | 0 .../collection/structure/__init__.py | 0 .../collection/structure/evaluate.py | 0 .../opencosmo/collection/structure/handler.py | 11 +- .../opencosmo/collection/structure/io.py | 0 .../collection/structure/structure.py | 0 {src => python}/opencosmo/column/__init__.py | 0 {src => python}/opencosmo/column/cache.py | 0 {src => python}/opencosmo/column/column.py | 0 {src => python}/opencosmo/column/evaluate.py | 0 {src => python}/opencosmo/column/select.py | 0 {src => python}/opencosmo/column/stock.py | 0 {src => python}/opencosmo/cosmology.py | 0 {src => python}/opencosmo/dataset/__init__.py | 0 {src => python}/opencosmo/dataset/build.py | 0 {src => python}/opencosmo/dataset/dataset.py | 4 +- {src => python}/opencosmo/dataset/derived.py | 0 {src => python}/opencosmo/dataset/evaluate.py | 0 {src => python}/opencosmo/dataset/formats.py | 0 {src => python}/opencosmo/dataset/im.py | 0 {src => python}/opencosmo/dataset/mpi.py | 0 {src => python}/opencosmo/dataset/state.py | 0 {src => python}/opencosmo/evaluate.py | 0 {src => python}/opencosmo/file.py | 0 {src => python}/opencosmo/handler/empty.py | 1 - {src => python}/opencosmo/handler/hdf5.py | 2 - .../opencosmo/handler/protocols.py | 1 - {src => python}/opencosmo/header.py | 0 {src => python}/opencosmo/index/__init__.py | 0 {src => python}/opencosmo/index/build.py | 0 {src => python}/opencosmo/index/get.py | 0 {src => python}/opencosmo/index/in_range.py | 37 +-- {src => python}/opencosmo/index/mask.py | 0 {src => python}/opencosmo/index/project.py | 0 {src => python}/opencosmo/index/take.py | 0 python/opencosmo/index/unary.py | 31 +++ {src => python}/opencosmo/io/__init__.py | 0 {src => python}/opencosmo/io/io.py | 0 {src => python}/opencosmo/io/iopen.py | 0 {src => python}/opencosmo/io/mpi.py | 0 {src => python}/opencosmo/io/parquet.py | 0 {src => python}/opencosmo/io/protocols.py | 0 {src => python}/opencosmo/io/schema.py | 0 {src => python}/opencosmo/io/serial.py | 0 {src => python}/opencosmo/io/updaters.py | 0 {src => python}/opencosmo/io/verify.py | 0 {src => python}/opencosmo/io/writer.py | 0 {src => python}/opencosmo/mpi.py | 0 .../opencosmo/parameters/__init__.py | 0 .../opencosmo/parameters/cosmology.py | 0 .../opencosmo/parameters/diffsky.py | 0 {src => python}/opencosmo/parameters/dtype.py | 0 {src => python}/opencosmo/parameters/file.py | 0 {src => python}/opencosmo/parameters/hacc.py | 0 .../opencosmo/parameters/lightcone.py | 0 .../opencosmo/parameters/origin.py | 0 .../opencosmo/parameters/parameters.py | 0 {src => python}/opencosmo/parameters/units.py | 0 {src => python}/opencosmo/parameters/utils.py | 0 {src => python}/opencosmo/spatial/__init__.py | 0 {src => python}/opencosmo/spatial/builders.py | 0 {src => python}/opencosmo/spatial/check.py | 0 {src => python}/opencosmo/spatial/healpix.py | 0 {src => python}/opencosmo/spatial/models.py | 0 {src => python}/opencosmo/spatial/octree.py | 0 .../opencosmo/spatial/protocols.py | 0 {src => python}/opencosmo/spatial/region.py | 0 .../opencosmo/spatial/relations.py | 0 {src => python}/opencosmo/spatial/tree.py | 0 {src => python}/opencosmo/spatial/utils.py | 4 +- {src => python}/opencosmo/units/__init__.py | 0 {src => python}/opencosmo/units/convention.py | 0 {src => python}/opencosmo/units/converters.py | 0 {src => python}/opencosmo/units/get.py | 0 {src => python}/opencosmo/units/handler.py | 0 src/index.rs | 129 ++++++++++ src/lib.rs | 21 ++ src/opencosmo/index/unary.py | 61 ----- 103 files changed, 457 insertions(+), 119 deletions(-) create mode 100644 Cargo.lock create mode 100644 Cargo.toml rename {src => python}/opencosmo/__init__.py (100%) rename {src => python}/opencosmo/analysis/__init__.py (100%) rename {src => python}/opencosmo/analysis/cli.py (100%) rename {src => python}/opencosmo/analysis/diffsky.py (100%) rename {src => python}/opencosmo/analysis/install/__init__.py (100%) rename {src => python}/opencosmo/analysis/install/install.py (100%) rename {src => python}/opencosmo/analysis/install/source.py (100%) rename {src => python}/opencosmo/analysis/install/specs/__init__.py (100%) rename {src => python}/opencosmo/analysis/install/specs/diffsky.json (100%) rename {src => python}/opencosmo/analysis/install/specs/vizualize.json (100%) rename {src => python}/opencosmo/analysis/install/versions.py (100%) rename {src => python}/opencosmo/analysis/mpi.py (100%) rename {src => python}/opencosmo/analysis/yt_utils.py (97%) rename {src => python}/opencosmo/analysis/yt_viz.py (99%) rename {src => python}/opencosmo/collection/__init__.py (100%) rename {src => python}/opencosmo/collection/lightcone/__init__.py (100%) rename {src => python}/opencosmo/collection/lightcone/coordinates.py (100%) rename {src => python}/opencosmo/collection/lightcone/healpix_map.py (100%) rename {src => python}/opencosmo/collection/lightcone/lightcone.py (100%) rename {src => python}/opencosmo/collection/lightcone/stack.py (100%) rename {src => python}/opencosmo/collection/protocols.py (100%) rename {src => python}/opencosmo/collection/simulation/__init__.py (100%) rename {src => python}/opencosmo/collection/simulation/simulation.py (100%) rename {src => python}/opencosmo/collection/structure/__init__.py (100%) rename {src => python}/opencosmo/collection/structure/evaluate.py (100%) rename {src => python}/opencosmo/collection/structure/handler.py (97%) rename {src => python}/opencosmo/collection/structure/io.py (100%) rename {src => python}/opencosmo/collection/structure/structure.py (100%) rename {src => python}/opencosmo/column/__init__.py (100%) rename {src => python}/opencosmo/column/cache.py (100%) rename {src => python}/opencosmo/column/column.py (100%) rename {src => python}/opencosmo/column/evaluate.py (100%) rename {src => python}/opencosmo/column/select.py (100%) rename {src => python}/opencosmo/column/stock.py (100%) rename {src => python}/opencosmo/cosmology.py (100%) rename {src => python}/opencosmo/dataset/__init__.py (100%) rename {src => python}/opencosmo/dataset/build.py (100%) rename {src => python}/opencosmo/dataset/dataset.py (99%) rename {src => python}/opencosmo/dataset/derived.py (100%) rename {src => python}/opencosmo/dataset/evaluate.py (100%) rename {src => python}/opencosmo/dataset/formats.py (100%) rename {src => python}/opencosmo/dataset/im.py (100%) rename {src => python}/opencosmo/dataset/mpi.py (100%) rename {src => python}/opencosmo/dataset/state.py (100%) rename {src => python}/opencosmo/evaluate.py (100%) rename {src => python}/opencosmo/file.py (100%) rename {src => python}/opencosmo/handler/empty.py (99%) rename {src => python}/opencosmo/handler/hdf5.py (99%) rename {src => python}/opencosmo/handler/protocols.py (99%) rename {src => python}/opencosmo/header.py (100%) rename {src => python}/opencosmo/index/__init__.py (100%) rename {src => python}/opencosmo/index/build.py (100%) rename {src => python}/opencosmo/index/get.py (100%) rename {src => python}/opencosmo/index/in_range.py (51%) rename {src => python}/opencosmo/index/mask.py (100%) rename {src => python}/opencosmo/index/project.py (100%) rename {src => python}/opencosmo/index/take.py (100%) create mode 100644 python/opencosmo/index/unary.py rename {src => python}/opencosmo/io/__init__.py (100%) rename {src => python}/opencosmo/io/io.py (100%) rename {src => python}/opencosmo/io/iopen.py (100%) rename {src => python}/opencosmo/io/mpi.py (100%) rename {src => python}/opencosmo/io/parquet.py (100%) rename {src => python}/opencosmo/io/protocols.py (100%) rename {src => python}/opencosmo/io/schema.py (100%) rename {src => python}/opencosmo/io/serial.py (100%) rename {src => python}/opencosmo/io/updaters.py (100%) rename {src => python}/opencosmo/io/verify.py (100%) rename {src => python}/opencosmo/io/writer.py (100%) rename {src => python}/opencosmo/mpi.py (100%) rename {src => python}/opencosmo/parameters/__init__.py (100%) rename {src => python}/opencosmo/parameters/cosmology.py (100%) rename {src => python}/opencosmo/parameters/diffsky.py (100%) rename {src => python}/opencosmo/parameters/dtype.py (100%) rename {src => python}/opencosmo/parameters/file.py (100%) rename {src => python}/opencosmo/parameters/hacc.py (100%) rename {src => python}/opencosmo/parameters/lightcone.py (100%) rename {src => python}/opencosmo/parameters/origin.py (100%) rename {src => python}/opencosmo/parameters/parameters.py (100%) rename {src => python}/opencosmo/parameters/units.py (100%) rename {src => python}/opencosmo/parameters/utils.py (100%) rename {src => python}/opencosmo/spatial/__init__.py (100%) rename {src => python}/opencosmo/spatial/builders.py (100%) rename {src => python}/opencosmo/spatial/check.py (100%) rename {src => python}/opencosmo/spatial/healpix.py (100%) rename {src => python}/opencosmo/spatial/models.py (100%) rename {src => python}/opencosmo/spatial/octree.py (100%) rename {src => python}/opencosmo/spatial/protocols.py (100%) rename {src => python}/opencosmo/spatial/region.py (100%) rename {src => python}/opencosmo/spatial/relations.py (100%) rename {src => python}/opencosmo/spatial/tree.py (100%) rename {src => python}/opencosmo/spatial/utils.py (88%) rename {src => python}/opencosmo/units/__init__.py (100%) rename {src => python}/opencosmo/units/convention.py (100%) rename {src => python}/opencosmo/units/converters.py (100%) rename {src => python}/opencosmo/units/get.py (100%) rename {src => python}/opencosmo/units/handler.py (100%) create mode 100644 src/index.rs create mode 100644 src/lib.rs delete mode 100644 src/opencosmo/index/unary.py diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..4b5b9a4c --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,229 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "libc" +version = "0.2.184" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "numpy" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "opencosmo" +version = "1.2.4" +dependencies = [ + "numpy", + "pyo3", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12" +dependencies = [ + "libc", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", +] + +[[package]] +name = "pyo3-build-config" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..9cb40fba --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "opencosmo" +version = "1.2.4" +edition = "2021" + +[lib] +name = "_lib" +# "cdylib" is necessary to produce a shared library for Python to import from. +crate-type = ["cdylib"] + +[dependencies] +numpy = "0.28" +pyo3 = "0.28.2" + diff --git a/pyproject.toml b/pyproject.toml index 8460bb7f..d17a0241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,13 +64,16 @@ io = [ "pyarrow>=21.0.0", ] -[build-system] -requires = ["uv_build>=0.8.2,<0.9.0", "pip"] -build-backend = "uv_build" -[tool.uv.build-backend] -source-include = ["LICENSE.md"] +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +module-name = "opencosmo._lib" +include = ["LICENSE.md"] +python-source = "python" [[tool.mypy.overrides]] module = ["h5py", "astropy.*", "healpy", "healsparse", "numba"] diff --git a/src/opencosmo/__init__.py b/python/opencosmo/__init__.py similarity index 100% rename from src/opencosmo/__init__.py rename to python/opencosmo/__init__.py diff --git a/src/opencosmo/analysis/__init__.py b/python/opencosmo/analysis/__init__.py similarity index 100% rename from src/opencosmo/analysis/__init__.py rename to python/opencosmo/analysis/__init__.py diff --git a/src/opencosmo/analysis/cli.py b/python/opencosmo/analysis/cli.py similarity index 100% rename from src/opencosmo/analysis/cli.py rename to python/opencosmo/analysis/cli.py diff --git a/src/opencosmo/analysis/diffsky.py b/python/opencosmo/analysis/diffsky.py similarity index 100% rename from src/opencosmo/analysis/diffsky.py rename to python/opencosmo/analysis/diffsky.py diff --git a/src/opencosmo/analysis/install/__init__.py b/python/opencosmo/analysis/install/__init__.py similarity index 100% rename from src/opencosmo/analysis/install/__init__.py rename to python/opencosmo/analysis/install/__init__.py diff --git a/src/opencosmo/analysis/install/install.py b/python/opencosmo/analysis/install/install.py similarity index 100% rename from src/opencosmo/analysis/install/install.py rename to python/opencosmo/analysis/install/install.py diff --git a/src/opencosmo/analysis/install/source.py b/python/opencosmo/analysis/install/source.py similarity index 100% rename from src/opencosmo/analysis/install/source.py rename to python/opencosmo/analysis/install/source.py diff --git a/src/opencosmo/analysis/install/specs/__init__.py b/python/opencosmo/analysis/install/specs/__init__.py similarity index 100% rename from src/opencosmo/analysis/install/specs/__init__.py rename to python/opencosmo/analysis/install/specs/__init__.py diff --git a/src/opencosmo/analysis/install/specs/diffsky.json b/python/opencosmo/analysis/install/specs/diffsky.json similarity index 100% rename from src/opencosmo/analysis/install/specs/diffsky.json rename to python/opencosmo/analysis/install/specs/diffsky.json diff --git a/src/opencosmo/analysis/install/specs/vizualize.json b/python/opencosmo/analysis/install/specs/vizualize.json similarity index 100% rename from src/opencosmo/analysis/install/specs/vizualize.json rename to python/opencosmo/analysis/install/specs/vizualize.json diff --git a/src/opencosmo/analysis/install/versions.py b/python/opencosmo/analysis/install/versions.py similarity index 100% rename from src/opencosmo/analysis/install/versions.py rename to python/opencosmo/analysis/install/versions.py diff --git a/src/opencosmo/analysis/mpi.py b/python/opencosmo/analysis/mpi.py similarity index 100% rename from src/opencosmo/analysis/mpi.py rename to python/opencosmo/analysis/mpi.py diff --git a/src/opencosmo/analysis/yt_utils.py b/python/opencosmo/analysis/yt_utils.py similarity index 97% rename from src/opencosmo/analysis/yt_utils.py rename to python/opencosmo/analysis/yt_utils.py index 9d724eed..41c201ed 100644 --- a/src/opencosmo/analysis/yt_utils.py +++ b/python/opencosmo/analysis/yt_utils.py @@ -92,9 +92,11 @@ def astropy_to_yt(array): return unyt_array(array.data, "dimensionless") if "littleh" in str(array.unit): - raise RuntimeError("cannot convert factors of littleh to yt convention, " - "try converting the opencosmo dataset to comoving units " - "(e.g. set `ds = ds.with_units(\"comoving\"))`") + raise RuntimeError( + "cannot convert factors of littleh to yt convention, " + "try converting the opencosmo dataset to comoving units " + '(e.g. set `ds = ds.with_units("comoving"))`' + ) return unyt_array.from_astropy(array) diff --git a/src/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py similarity index 99% rename from src/opencosmo/analysis/yt_viz.py rename to python/opencosmo/analysis/yt_viz.py index e96ae001..e86c989b 100644 --- a/src/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -13,8 +13,8 @@ from opencosmo.analysis import create_yt_dataset if TYPE_CHECKING: - from matplotlib.figure import Figure from matplotlib.colors import Normalize + from matplotlib.figure import Figure from yt.visualization.plot_window import NormalPlot # ruff: noqa: E501 @@ -279,7 +279,7 @@ def halo_projection_array( weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", cmap: Optional[str] = "gray", - cmap_norm: Optional[Normalize] = None, # type: ignore + cmap_norm: Optional[Normalize] = None, # type: ignore zlim: Optional[Tuple[float, float]] = None, params: Optional[Dict[str, Any]] = None, length_scale: Optional[str] = None, @@ -306,7 +306,7 @@ def halo_projection_array( halo_ids : int or 2D array of int Unique ID of the halo(s) to be visualized. The shape of `halo_ids` sets the layout of the figure (e.g., if `halo_ids` is a 2x3 array, the outputted figure will be a 2x3 - array of projections). To leave a panel in the outputted figure blank, set the corresponding + array of projections). To leave a panel in the outputted figure blank, set the corresponding entry into the `halo_ids` array to `None`. If `int`, a single panel is output while preserving formatting. data : opencosmo.StructureCollection OpenCosmo StructureCollection dataset containing both halo properties and particle data @@ -354,7 +354,7 @@ def halo_projection_array( - ``"zlims"``: 2D array of colorbar limits (log-scaled) - ``"labels"``: 2D array of panel labels (or None) - ``"cmaps"``: 2D array of Matplotlib colormaps for each panel - - ``"cmap_norms"``: 2D array of colormap normalization method (e.g. matplotlib.colors.LogNorm()) + - ``"cmap_norms"``: 2D array of colormap normalization method (e.g. matplotlib.colors.LogNorm()) - ``"widths"``: 2D array of widths in units of R200 text_color : str, optional Set the color of all text annotations. Default is "gray" @@ -448,7 +448,7 @@ def halo_projection_array( if len(data) > 1: data_id = data.filter(oc.col("unique_tag") == halo_id) else: - if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore raise RuntimeError(f"Halo ID {halo_id} not in dataset!") data_id = data halo_data = next(iter(data_id.objects())) diff --git a/src/opencosmo/collection/__init__.py b/python/opencosmo/collection/__init__.py similarity index 100% rename from src/opencosmo/collection/__init__.py rename to python/opencosmo/collection/__init__.py diff --git a/src/opencosmo/collection/lightcone/__init__.py b/python/opencosmo/collection/lightcone/__init__.py similarity index 100% rename from src/opencosmo/collection/lightcone/__init__.py rename to python/opencosmo/collection/lightcone/__init__.py diff --git a/src/opencosmo/collection/lightcone/coordinates.py b/python/opencosmo/collection/lightcone/coordinates.py similarity index 100% rename from src/opencosmo/collection/lightcone/coordinates.py rename to python/opencosmo/collection/lightcone/coordinates.py diff --git a/src/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py similarity index 100% rename from src/opencosmo/collection/lightcone/healpix_map.py rename to python/opencosmo/collection/lightcone/healpix_map.py diff --git a/src/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py similarity index 100% rename from src/opencosmo/collection/lightcone/lightcone.py rename to python/opencosmo/collection/lightcone/lightcone.py diff --git a/src/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py similarity index 100% rename from src/opencosmo/collection/lightcone/stack.py rename to python/opencosmo/collection/lightcone/stack.py diff --git a/src/opencosmo/collection/protocols.py b/python/opencosmo/collection/protocols.py similarity index 100% rename from src/opencosmo/collection/protocols.py rename to python/opencosmo/collection/protocols.py diff --git a/src/opencosmo/collection/simulation/__init__.py b/python/opencosmo/collection/simulation/__init__.py similarity index 100% rename from src/opencosmo/collection/simulation/__init__.py rename to python/opencosmo/collection/simulation/__init__.py diff --git a/src/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py similarity index 100% rename from src/opencosmo/collection/simulation/simulation.py rename to python/opencosmo/collection/simulation/simulation.py diff --git a/src/opencosmo/collection/structure/__init__.py b/python/opencosmo/collection/structure/__init__.py similarity index 100% rename from src/opencosmo/collection/structure/__init__.py rename to python/opencosmo/collection/structure/__init__.py diff --git a/src/opencosmo/collection/structure/evaluate.py b/python/opencosmo/collection/structure/evaluate.py similarity index 100% rename from src/opencosmo/collection/structure/evaluate.py rename to python/opencosmo/collection/structure/evaluate.py diff --git a/src/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py similarity index 97% rename from src/opencosmo/collection/structure/handler.py rename to python/opencosmo/collection/structure/handler.py index 4b50f85f..88a07cf7 100644 --- a/src/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -40,9 +40,13 @@ def create_start_size(data, start_name, size_name): start = data.pop(start_name, None) size = data.pop(size_name, None) - if start is None: + if start is None or size is None: return None valid = size > 0 + + start = start.astype(np.int64) + size = size.astype(np.int64) + if isinstance(start, np.ndarray): return (start[valid], size[valid]) if size == 0: @@ -55,6 +59,7 @@ def create_idx(data, idx_name): if idx is None: return None + idx = idx.astype(np.int64) valid = idx >= 0 if isinstance(idx, np.ndarray): @@ -265,7 +270,7 @@ def rebuild_row_index( index_into_original: np.ndarray, ): valid_rows = original_metadata_column >= 0 - index = np.full(len(original_metadata_column), -1, dtype=int) + index = np.full(len(original_metadata_column), -1, dtype=np.int64) index[valid_rows] = np.arange(0, sum(valid_rows)) index_to_take = index[index_into_original] index_to_take = index_to_take[index_to_take >= 0] @@ -279,7 +284,7 @@ def rebuild_chunk_index( original_size_column: np.ndarray, index_into_original: np.ndarray, ): - chunk_boundaries = np.zeros(len(original_size_column) + 1, dtype=int) + chunk_boundaries = np.zeros(len(original_size_column) + 1, dtype=np.int64) _ = np.cumsum(original_size_column, out=chunk_boundaries[1:]) valid_rows = original_size_column[index_into_original] > 0 diff --git a/src/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py similarity index 100% rename from src/opencosmo/collection/structure/io.py rename to python/opencosmo/collection/structure/io.py diff --git a/src/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py similarity index 100% rename from src/opencosmo/collection/structure/structure.py rename to python/opencosmo/collection/structure/structure.py diff --git a/src/opencosmo/column/__init__.py b/python/opencosmo/column/__init__.py similarity index 100% rename from src/opencosmo/column/__init__.py rename to python/opencosmo/column/__init__.py diff --git a/src/opencosmo/column/cache.py b/python/opencosmo/column/cache.py similarity index 100% rename from src/opencosmo/column/cache.py rename to python/opencosmo/column/cache.py diff --git a/src/opencosmo/column/column.py b/python/opencosmo/column/column.py similarity index 100% rename from src/opencosmo/column/column.py rename to python/opencosmo/column/column.py diff --git a/src/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py similarity index 100% rename from src/opencosmo/column/evaluate.py rename to python/opencosmo/column/evaluate.py diff --git a/src/opencosmo/column/select.py b/python/opencosmo/column/select.py similarity index 100% rename from src/opencosmo/column/select.py rename to python/opencosmo/column/select.py diff --git a/src/opencosmo/column/stock.py b/python/opencosmo/column/stock.py similarity index 100% rename from src/opencosmo/column/stock.py rename to python/opencosmo/column/stock.py diff --git a/src/opencosmo/cosmology.py b/python/opencosmo/cosmology.py similarity index 100% rename from src/opencosmo/cosmology.py rename to python/opencosmo/cosmology.py diff --git a/src/opencosmo/dataset/__init__.py b/python/opencosmo/dataset/__init__.py similarity index 100% rename from src/opencosmo/dataset/__init__.py rename to python/opencosmo/dataset/__init__.py diff --git a/src/opencosmo/dataset/build.py b/python/opencosmo/dataset/build.py similarity index 100% rename from src/opencosmo/dataset/build.py rename to python/opencosmo/dataset/build.py diff --git a/src/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py similarity index 99% rename from src/opencosmo/dataset/dataset.py rename to python/opencosmo/dataset/dataset.py index d0baf19f..a3531e2c 100644 --- a/src/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -20,7 +20,7 @@ from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.index import into_array, mask, project +from opencosmo.index import empty, into_array, mask, project from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -349,7 +349,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): check_region = region if not self.__state.region.intersects(check_region): - new_state = self.__state.take_rows(np.array([])) + new_state = self.__state.take_rows(empty()) return Dataset(self.__header, new_state, self.__tree) if not self.__state.region.contains(check_region): diff --git a/src/opencosmo/dataset/derived.py b/python/opencosmo/dataset/derived.py similarity index 100% rename from src/opencosmo/dataset/derived.py rename to python/opencosmo/dataset/derived.py diff --git a/src/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py similarity index 100% rename from src/opencosmo/dataset/evaluate.py rename to python/opencosmo/dataset/evaluate.py diff --git a/src/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py similarity index 100% rename from src/opencosmo/dataset/formats.py rename to python/opencosmo/dataset/formats.py diff --git a/src/opencosmo/dataset/im.py b/python/opencosmo/dataset/im.py similarity index 100% rename from src/opencosmo/dataset/im.py rename to python/opencosmo/dataset/im.py diff --git a/src/opencosmo/dataset/mpi.py b/python/opencosmo/dataset/mpi.py similarity index 100% rename from src/opencosmo/dataset/mpi.py rename to python/opencosmo/dataset/mpi.py diff --git a/src/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py similarity index 100% rename from src/opencosmo/dataset/state.py rename to python/opencosmo/dataset/state.py diff --git a/src/opencosmo/evaluate.py b/python/opencosmo/evaluate.py similarity index 100% rename from src/opencosmo/evaluate.py rename to python/opencosmo/evaluate.py diff --git a/src/opencosmo/file.py b/python/opencosmo/file.py similarity index 100% rename from src/opencosmo/file.py rename to python/opencosmo/file.py diff --git a/src/opencosmo/handler/empty.py b/python/opencosmo/handler/empty.py similarity index 99% rename from src/opencosmo/handler/empty.py rename to python/opencosmo/handler/empty.py index e46161da..f18a4f6f 100644 --- a/src/opencosmo/handler/empty.py +++ b/python/opencosmo/handler/empty.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: import numpy as np - from opencosmo.index import DataIndex diff --git a/src/opencosmo/handler/hdf5.py b/python/opencosmo/handler/hdf5.py similarity index 99% rename from src/opencosmo/handler/hdf5.py rename to python/opencosmo/handler/hdf5.py index b5a3e5d0..bd4f44ee 100644 --- a/src/opencosmo/handler/hdf5.py +++ b/python/opencosmo/handler/hdf5.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Iterable, Optional import numpy as np - from opencosmo.index import ( SimpleIndex, from_size, @@ -21,7 +20,6 @@ if TYPE_CHECKING: import h5py - from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.schema import Schema diff --git a/src/opencosmo/handler/protocols.py b/python/opencosmo/handler/protocols.py similarity index 99% rename from src/opencosmo/handler/protocols.py rename to python/opencosmo/handler/protocols.py index faf119dc..e46335f1 100644 --- a/src/opencosmo/handler/protocols.py +++ b/python/opencosmo/handler/protocols.py @@ -4,7 +4,6 @@ if TYPE_CHECKING: import numpy as np - from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.schema import Schema diff --git a/src/opencosmo/header.py b/python/opencosmo/header.py similarity index 100% rename from src/opencosmo/header.py rename to python/opencosmo/header.py diff --git a/src/opencosmo/index/__init__.py b/python/opencosmo/index/__init__.py similarity index 100% rename from src/opencosmo/index/__init__.py rename to python/opencosmo/index/__init__.py diff --git a/src/opencosmo/index/build.py b/python/opencosmo/index/build.py similarity index 100% rename from src/opencosmo/index/build.py rename to python/opencosmo/index/build.py diff --git a/src/opencosmo/index/get.py b/python/opencosmo/index/get.py similarity index 100% rename from src/opencosmo/index/get.py rename to python/opencosmo/index/get.py diff --git a/src/opencosmo/index/in_range.py b/python/opencosmo/index/in_range.py similarity index 51% rename from src/opencosmo/index/in_range.py rename to python/opencosmo/index/in_range.py index 7f067802..e9aa7698 100644 --- a/src/opencosmo/index/in_range.py +++ b/python/opencosmo/index/in_range.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING -import numba as nb import numpy as np +from opencosmo._lib import index as idx + if TYPE_CHECKING: from numpy.typing import NDArray @@ -20,7 +21,7 @@ def n_in_range( case np.ndarray(): return __n_in_range_simple(index, range_starts, range_sizes) case (np.ndarray(), np.ndarray()): - return __n_in_range_chunked(*index, range_starts, range_sizes) + return idx.n_in_range_chunked(*index, range_starts, range_sizes) case _: raise ValueError(f"Unknown index type {type(index)}") @@ -40,35 +41,3 @@ def __n_in_range_simple( start_idxs = np.searchsorted(index_to_search, start, "left") end_idxs = np.searchsorted(index_to_search, ends, "left") return end_idxs - start_idxs - - -@nb.njit -def __n_in_range_chunked( - starts: NDArray[np.int_], - sizes: NDArray[np.int_], - range_starts: NDArray[np.int_], - range_sizes: NDArray[np.int_], -) -> NDArray[np.int_]: - """ - Return the number of elements in this index that fall within - a specified data range. Used to mask spatial index. - - - As with numpy, this is the half-open range [start, end) - """ - if len(range_starts) != len(range_sizes): - raise ValueError("Start and size arrays must have the same length") - if np.any(range_sizes < 0): - raise ValueError("Sizes must greater than or equal to zero") - if len(starts) == 0: - return np.zeros_like(range_starts) - - index_ranges = np.vstack((starts, starts + sizes)) - output = np.zeros_like(range_starts) - for i in range(len(range_starts)): - index_chunk_ranges = np.clip( - index_ranges, a_min=range_starts[i], a_max=range_starts[i] + range_sizes[i] - ) - output[i] = np.sum(index_chunk_ranges[1] - index_chunk_ranges[0]) - - return output diff --git a/src/opencosmo/index/mask.py b/python/opencosmo/index/mask.py similarity index 100% rename from src/opencosmo/index/mask.py rename to python/opencosmo/index/mask.py diff --git a/src/opencosmo/index/project.py b/python/opencosmo/index/project.py similarity index 100% rename from src/opencosmo/index/project.py rename to python/opencosmo/index/project.py diff --git a/src/opencosmo/index/take.py b/python/opencosmo/index/take.py similarity index 100% rename from src/opencosmo/index/take.py rename to python/opencosmo/index/take.py diff --git a/python/opencosmo/index/unary.py b/python/opencosmo/index/unary.py new file mode 100644 index 00000000..4989d169 --- /dev/null +++ b/python/opencosmo/index/unary.py @@ -0,0 +1,31 @@ +import numpy as np +from numpy.typing import NDArray + +from opencosmo._lib import index as idx + +""" +Implementations for unary operations on indices +""" + +SimpleIndex = NDArray[np.int_] +ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] + + +def get_length(index: SimpleIndex | ChunkedIndex): + match index: + case np.ndarray(): + return len(index) + case (np.ndarray(), np.ndarray()): + return int(np.sum(index[1])) + case _: + raise TypeError(f"Invalid index type {type(index)}") + + +def get_range(index: SimpleIndex | ChunkedIndex): + match index: + case np.ndarray(): + return idx.get_simple_range(index) + case (np.ndarray(), np.ndarray()): + return idx.get_chunked_range(*index) + case _: + raise ValueError(f"Unknown index type {type(index)}") diff --git a/src/opencosmo/io/__init__.py b/python/opencosmo/io/__init__.py similarity index 100% rename from src/opencosmo/io/__init__.py rename to python/opencosmo/io/__init__.py diff --git a/src/opencosmo/io/io.py b/python/opencosmo/io/io.py similarity index 100% rename from src/opencosmo/io/io.py rename to python/opencosmo/io/io.py diff --git a/src/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py similarity index 100% rename from src/opencosmo/io/iopen.py rename to python/opencosmo/io/iopen.py diff --git a/src/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py similarity index 100% rename from src/opencosmo/io/mpi.py rename to python/opencosmo/io/mpi.py diff --git a/src/opencosmo/io/parquet.py b/python/opencosmo/io/parquet.py similarity index 100% rename from src/opencosmo/io/parquet.py rename to python/opencosmo/io/parquet.py diff --git a/src/opencosmo/io/protocols.py b/python/opencosmo/io/protocols.py similarity index 100% rename from src/opencosmo/io/protocols.py rename to python/opencosmo/io/protocols.py diff --git a/src/opencosmo/io/schema.py b/python/opencosmo/io/schema.py similarity index 100% rename from src/opencosmo/io/schema.py rename to python/opencosmo/io/schema.py diff --git a/src/opencosmo/io/serial.py b/python/opencosmo/io/serial.py similarity index 100% rename from src/opencosmo/io/serial.py rename to python/opencosmo/io/serial.py diff --git a/src/opencosmo/io/updaters.py b/python/opencosmo/io/updaters.py similarity index 100% rename from src/opencosmo/io/updaters.py rename to python/opencosmo/io/updaters.py diff --git a/src/opencosmo/io/verify.py b/python/opencosmo/io/verify.py similarity index 100% rename from src/opencosmo/io/verify.py rename to python/opencosmo/io/verify.py diff --git a/src/opencosmo/io/writer.py b/python/opencosmo/io/writer.py similarity index 100% rename from src/opencosmo/io/writer.py rename to python/opencosmo/io/writer.py diff --git a/src/opencosmo/mpi.py b/python/opencosmo/mpi.py similarity index 100% rename from src/opencosmo/mpi.py rename to python/opencosmo/mpi.py diff --git a/src/opencosmo/parameters/__init__.py b/python/opencosmo/parameters/__init__.py similarity index 100% rename from src/opencosmo/parameters/__init__.py rename to python/opencosmo/parameters/__init__.py diff --git a/src/opencosmo/parameters/cosmology.py b/python/opencosmo/parameters/cosmology.py similarity index 100% rename from src/opencosmo/parameters/cosmology.py rename to python/opencosmo/parameters/cosmology.py diff --git a/src/opencosmo/parameters/diffsky.py b/python/opencosmo/parameters/diffsky.py similarity index 100% rename from src/opencosmo/parameters/diffsky.py rename to python/opencosmo/parameters/diffsky.py diff --git a/src/opencosmo/parameters/dtype.py b/python/opencosmo/parameters/dtype.py similarity index 100% rename from src/opencosmo/parameters/dtype.py rename to python/opencosmo/parameters/dtype.py diff --git a/src/opencosmo/parameters/file.py b/python/opencosmo/parameters/file.py similarity index 100% rename from src/opencosmo/parameters/file.py rename to python/opencosmo/parameters/file.py diff --git a/src/opencosmo/parameters/hacc.py b/python/opencosmo/parameters/hacc.py similarity index 100% rename from src/opencosmo/parameters/hacc.py rename to python/opencosmo/parameters/hacc.py diff --git a/src/opencosmo/parameters/lightcone.py b/python/opencosmo/parameters/lightcone.py similarity index 100% rename from src/opencosmo/parameters/lightcone.py rename to python/opencosmo/parameters/lightcone.py diff --git a/src/opencosmo/parameters/origin.py b/python/opencosmo/parameters/origin.py similarity index 100% rename from src/opencosmo/parameters/origin.py rename to python/opencosmo/parameters/origin.py diff --git a/src/opencosmo/parameters/parameters.py b/python/opencosmo/parameters/parameters.py similarity index 100% rename from src/opencosmo/parameters/parameters.py rename to python/opencosmo/parameters/parameters.py diff --git a/src/opencosmo/parameters/units.py b/python/opencosmo/parameters/units.py similarity index 100% rename from src/opencosmo/parameters/units.py rename to python/opencosmo/parameters/units.py diff --git a/src/opencosmo/parameters/utils.py b/python/opencosmo/parameters/utils.py similarity index 100% rename from src/opencosmo/parameters/utils.py rename to python/opencosmo/parameters/utils.py diff --git a/src/opencosmo/spatial/__init__.py b/python/opencosmo/spatial/__init__.py similarity index 100% rename from src/opencosmo/spatial/__init__.py rename to python/opencosmo/spatial/__init__.py diff --git a/src/opencosmo/spatial/builders.py b/python/opencosmo/spatial/builders.py similarity index 100% rename from src/opencosmo/spatial/builders.py rename to python/opencosmo/spatial/builders.py diff --git a/src/opencosmo/spatial/check.py b/python/opencosmo/spatial/check.py similarity index 100% rename from src/opencosmo/spatial/check.py rename to python/opencosmo/spatial/check.py diff --git a/src/opencosmo/spatial/healpix.py b/python/opencosmo/spatial/healpix.py similarity index 100% rename from src/opencosmo/spatial/healpix.py rename to python/opencosmo/spatial/healpix.py diff --git a/src/opencosmo/spatial/models.py b/python/opencosmo/spatial/models.py similarity index 100% rename from src/opencosmo/spatial/models.py rename to python/opencosmo/spatial/models.py diff --git a/src/opencosmo/spatial/octree.py b/python/opencosmo/spatial/octree.py similarity index 100% rename from src/opencosmo/spatial/octree.py rename to python/opencosmo/spatial/octree.py diff --git a/src/opencosmo/spatial/protocols.py b/python/opencosmo/spatial/protocols.py similarity index 100% rename from src/opencosmo/spatial/protocols.py rename to python/opencosmo/spatial/protocols.py diff --git a/src/opencosmo/spatial/region.py b/python/opencosmo/spatial/region.py similarity index 100% rename from src/opencosmo/spatial/region.py rename to python/opencosmo/spatial/region.py diff --git a/src/opencosmo/spatial/relations.py b/python/opencosmo/spatial/relations.py similarity index 100% rename from src/opencosmo/spatial/relations.py rename to python/opencosmo/spatial/relations.py diff --git a/src/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py similarity index 100% rename from src/opencosmo/spatial/tree.py rename to python/opencosmo/spatial/tree.py diff --git a/src/opencosmo/spatial/utils.py b/python/opencosmo/spatial/utils.py similarity index 88% rename from src/opencosmo/spatial/utils.py rename to python/opencosmo/spatial/utils.py index 200c91b5..bb006236 100644 --- a/src/opencosmo/spatial/utils.py +++ b/python/opencosmo/spatial/utils.py @@ -21,8 +21,8 @@ def combine_upwards( raise ValueError("Recieved invalid number of counts!") group = target.require_group(f"level_{level}") - new_starts = np.insert(np.cumsum(counts, dtype=np.int32), 0, 0)[:-1] - counts = counts.astype(np.int32) # This should be fixed + new_starts = np.insert(np.cumsum(counts, dtype=np.int64), 0, 0)[:-1] + counts = counts.astype(np.int64) # This should be fixed group.create_dataset("start", data=new_starts) group.create_dataset("size", data=counts) diff --git a/src/opencosmo/units/__init__.py b/python/opencosmo/units/__init__.py similarity index 100% rename from src/opencosmo/units/__init__.py rename to python/opencosmo/units/__init__.py diff --git a/src/opencosmo/units/convention.py b/python/opencosmo/units/convention.py similarity index 100% rename from src/opencosmo/units/convention.py rename to python/opencosmo/units/convention.py diff --git a/src/opencosmo/units/converters.py b/python/opencosmo/units/converters.py similarity index 100% rename from src/opencosmo/units/converters.py rename to python/opencosmo/units/converters.py diff --git a/src/opencosmo/units/get.py b/python/opencosmo/units/get.py similarity index 100% rename from src/opencosmo/units/get.py rename to python/opencosmo/units/get.py diff --git a/src/opencosmo/units/handler.py b/python/opencosmo/units/handler.py similarity index 100% rename from src/opencosmo/units/handler.py rename to python/opencosmo/units/handler.py diff --git a/src/index.rs b/src/index.rs new file mode 100644 index 00000000..36467eda --- /dev/null +++ b/src/index.rs @@ -0,0 +1,129 @@ +use pyo3::prelude::*; +#[pymodule] +pub(crate) mod index { + use numpy::ndarray::{Array1, ArrayView1}; + use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1}; + use pyo3::exceptions::{PyTypeError, PyValueError}; + use pyo3::prelude::*; + use std::io::Result; + use std::iter::zip; + fn unpack_index_array<'py>(index: &Bound<'py, PyAny>) -> PyResult> { + let index_data = index + .cast::>() + .map_err(|_| PyTypeError::new_err("Indices should be a 1-d array of i64"))?; + + Ok(index_data.readonly()) + } + + fn unpack_chunked_index<'py>( + start: &Bound<'py, PyAny>, + size: &Bound<'py, PyAny>, + ) -> PyResult<(PyReadonlyArray1<'py, i64>, PyReadonlyArray1<'py, i64>)> { + let start_arr = unpack_index_array(start)?; + let size_arr = unpack_index_array(size)?; + if start_arr.len()? != size_arr.len()? { + return Err(PyValueError::new_err( + "start and size must be the same length in a chunked index!", + )); + } + Ok((start_arr, size_arr)) + } + + #[pyfunction(name = "get_simple_range")] + pub(crate) fn get_simple_range_py(index: &Bound<'_, PyAny>) -> PyResult<(i64, i64)> { + let index_arr = unpack_index_array(index)?; + get_simple_range(index_arr.as_array()) + } + #[pyfunction(name = "get_chunked_range")] + pub(crate) fn get_chunked_range_py( + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult<(i64, i64)> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + get_chunked_range(start_arr.as_array(), size_arr.as_array()) + } + #[pyfunction(name = "n_in_range_chunked")] + pub(crate) fn n_in_range_chunked_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + range_start: &Bound<'_, PyAny>, + range_size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let (range_start_arr, range_size_arr) = unpack_chunked_index(range_start, range_size)?; + let result = n_in_range_chunked( + start_arr.as_array(), + size_arr.as_array(), + range_start_arr.as_array(), + range_size_arr.as_array(), + ); + Ok(result.into_pyarray(py)) + } + + fn n_in_range_chunked( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + range_start: ArrayView1<'_, i64>, + range_size: ArrayView1<'_, i64>, + ) -> Array1 { + let mut output = Array1::::zeros(range_start.len()); + if start.len() == 0 { + return output; + } + let end = &start + &size; + for (i, (&rst, &rsi)) in zip(range_start, range_size).enumerate() { + let chunk_end = rst + rsi; + let total = zip(start, &end) + .filter(|(s, e)| !(**s > chunk_end || **e < rst)) + .map(|(&s, &e)| { + let mut cr = (s, e); + if chunk_end < e { + cr = (cr.0, chunk_end) + } + if rst > s { + cr = (rst, cr.1) + } + cr.1 - cr.0 + }) + .sum(); + output[i] = total; + } + output + } + fn get_simple_range(index: ArrayView1<'_, i64>) -> PyResult<(i64, i64)> { + if index.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (index[0], index[0]); + for &item in index { + if item < index_range.0 { + index_range = (item, index_range.1) + } + if item > index_range.1 { + index_range = (index_range.0, item) + } + } + Ok(index_range) + } + + fn get_chunked_range( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> PyResult<(i64, i64)> { + if start.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (start[0], start[0] + size[0]); + for (&st, &si) in zip(start, size) { + let end = st + si; + if st < index_range.0 { + index_range = (st, index_range.1); + } + if end > index_range.1 { + index_range = (index_range.0, end); + } + } + Ok(index_range) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..67e1177a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,21 @@ +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +/// +/// +use pyo3::prelude::*; +mod index; + +#[pyo3::pymodule] +mod _lib { + use pyo3::prelude::*; + + #[pymodule_export] + use crate::index::index; + + /// Formats the sum of two numbers as string. + #[pyfunction] + fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) + } +} diff --git a/src/opencosmo/index/unary.py b/src/opencosmo/index/unary.py deleted file mode 100644 index c260b158..00000000 --- a/src/opencosmo/index/unary.py +++ /dev/null @@ -1,61 +0,0 @@ -import numba as nb -import numpy as np -from numpy.typing import NDArray - -""" -Implementations for unary operations on indices -""" - -SimpleIndex = NDArray[np.int_] -ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] - - -def get_length(index: SimpleIndex | ChunkedIndex): - match index: - case np.ndarray(): - return len(index) - case (np.ndarray(), np.ndarray()): - return int(np.sum(index[1])) - case _: - raise TypeError(f"Invalid index type {type(index)}") - - -def get_range(index: SimpleIndex | ChunkedIndex): - match index: - case np.ndarray(): - return __get_simple_range(index) - case (np.ndarray(), np.ndarray()): - return __get_chunked_range(*index) - case _: - raise ValueError(f"Unknown index type {type(index)}") - - -@nb.njit -def __get_simple_range(index: SimpleIndex): - if len(index) == 0: - return (0, 0) - - min = index[0] - max = index[0] - for val in index[1:]: - if val < min: - min = val - if val > max: - max = val - return (min, max) - - -@nb.njit -def __get_chunked_range(starts: NDArray[np.int_], sizes: NDArray[np.int_]): - if len(starts) == 0: - return (0, 0) - min = starts[0] - max = min + sizes[0] - for i in range(1, len(starts)): - start = starts[i] - end = start + sizes[i] - if start < min: - min = start - if end > max: - max = end - return (min, max) From 42ff07a87953d76db6d06fea1e3ab5cc088794ea Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 12:45:42 -0500 Subject: [PATCH 002/139] Some take operations --- .../opencosmo/collection/structure/handler.py | 10 +- python/opencosmo/index/mask.py | 16 +-- python/opencosmo/index/take.py | 26 +--- src/index.rs | 119 ++++++++++++++---- 4 files changed, 105 insertions(+), 66 deletions(-) diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 88a07cf7..10a065ce 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -217,8 +217,12 @@ def rebuild_datasets( assert len(size_column) == 1 size_column_name = size_column[0] size_column_data = metadata[size_column_name] + new_datasets[name] = rebuild_chunk_index( - new_source, dataset, size_column_data, index_into_original + new_source, + dataset, + size_column_data.astype(np.int64), + index_into_original.astype(np.int64), ) return new_datasets @@ -251,8 +255,8 @@ def resort(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): else: size_column = [name for name in self.columns[name] if "size" in name] assert len(size_column) == 1 - size_column_data = meta[size_column[0]] - chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=int) + size_column_data = meta[size_column[0]].astype(np.int64) + chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=np.int64) _ = np.cumsum(size_column_data, out=chunk_boundaries[1:]) starts = chunk_boundaries[sort_index] sizes = size_column_data[sort_index] diff --git a/python/opencosmo/index/mask.py b/python/opencosmo/index/mask.py index 9dc95f26..7ccdfd92 100644 --- a/python/opencosmo/index/mask.py +++ b/python/opencosmo/index/mask.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING -import numba as nb import numpy as np if TYPE_CHECKING: from numpy.typing import NDArray +from opencosmo._lib import index as idxlib from opencosmo.index.unary import get_length @@ -46,16 +46,4 @@ def into_array(index: np.ndarray | tuple): if len(index[0]) == 1: return np.arange(index[0][0], index[0][0] + index[1][0]) - return __chunked_into_array(*index) - - -@nb.njit -def __chunked_into_array(starts: NDArray[np.int_], sizes: NDArray[np.int_]): - output = np.zeros(np.sum(sizes), dtype=np.int64) - rs = 0 - for i in range(len(starts)): - output[rs : rs + sizes[i]] = np.arange( - starts[i], starts[i] + sizes[i], dtype=np.int64 - ) - rs += sizes[i] - return output + return idxlib.chunked_into_array(*index) diff --git a/python/opencosmo/index/take.py b/python/opencosmo/index/take.py index a83f137a..a2e27a65 100644 --- a/python/opencosmo/index/take.py +++ b/python/opencosmo/index/take.py @@ -3,6 +3,8 @@ import numba as nb # type: ignore import numpy as np +from opencosmo._lib import index as idxlib + SimpleIndex = np.ndarray ChunkedIndex = tuple[np.ndarray, np.ndarray] @@ -12,7 +14,7 @@ def take(from_, by): case (np.ndarray(), np.ndarray()): return __take_simple_from_simple(from_, by) case (np.ndarray(), (np.ndarray(), np.ndarray())): - return __take_chunked_from_simple(from_, by) + return idxlib.take_chunked_from_simple(from_, by) case ((np.ndarray(), np.ndarray()), np.ndarray()): return __take_simple_from_chunked(from_, by) case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): @@ -31,28 +33,6 @@ def __take_simple_from_simple(from_: np.ndarray, by: np.ndarray): return from_[by] -def __take_chunked_from_simple(from_: SimpleIndex, by: ChunkedIndex): - output = np.zeros(by[1].sum(), dtype=int) - output = __cfs_helper(from_, *by, output) - return output - - -@nb.njit -def __cfs_helper(arr, starts, sizes, storage): - rs = 0 - for i in range(len(starts)): - cstart = starts[i] - csize = sizes[i] - storage[rs : rs + csize] = arr[cstart : cstart + csize] - rs += csize - return storage - - -@nb.njit -def __cfc_helper(from_starts, from_sizes, by_starts, by_sizes): - pass - - @nb.njit def prefix_sum(arr): out = np.empty(len(arr) + 1, dtype=arr.dtype) diff --git a/src/index.rs b/src/index.rs index 36467eda..f6742e54 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,12 +1,13 @@ use pyo3::prelude::*; #[pymodule] pub(crate) mod index { + use numpy::ndarray::s; use numpy::ndarray::{Array1, ArrayView1}; use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; - use std::io::Result; use std::iter::zip; + fn unpack_index_array<'py>(index: &Bound<'py, PyAny>) -> PyResult> { let index_data = index .cast::>() @@ -34,6 +35,22 @@ pub(crate) mod index { let index_arr = unpack_index_array(index)?; get_simple_range(index_arr.as_array()) } + fn get_simple_range(index: ArrayView1<'_, i64>) -> PyResult<(i64, i64)> { + if index.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (index[0], index[0]); + for &item in index { + if item < index_range.0 { + index_range = (item, index_range.1) + } + if item > index_range.1 { + index_range = (index_range.0, item) + } + } + Ok(index_range) + } + #[pyfunction(name = "get_chunked_range")] pub(crate) fn get_chunked_range_py( start: &Bound<'_, PyAny>, @@ -42,6 +59,27 @@ pub(crate) mod index { let (start_arr, size_arr) = unpack_chunked_index(start, size)?; get_chunked_range(start_arr.as_array(), size_arr.as_array()) } + + fn get_chunked_range( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> PyResult<(i64, i64)> { + if start.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (start[0], start[0] + size[0]); + for (&st, &si) in zip(start, size) { + let end = st + si; + if st < index_range.0 { + index_range = (st, index_range.1); + } + if end > index_range.1 { + index_range = (index_range.0, end); + } + } + Ok(index_range) + } + #[pyfunction(name = "n_in_range_chunked")] pub(crate) fn n_in_range_chunked_py<'py>( py: Python<'py>, @@ -91,39 +129,68 @@ pub(crate) mod index { } output } - fn get_simple_range(index: ArrayView1<'_, i64>) -> PyResult<(i64, i64)> { - if index.len() == 0 { - return Ok((0, 0)); - } - let mut index_range = (index[0], index[0]); - for &item in index { - if item < index_range.0 { - index_range = (item, index_range.1) - } - if item > index_range.1 { - index_range = (index_range.0, item) - } + #[pyfunction(name = "chunked_into_array")] + pub(crate) fn chunked_into_array_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let output = chunked_into_array(start_arr.as_array(), size_arr.as_array()); + Ok(output.into_pyarray(py)) + } + fn chunked_into_array(start: ArrayView1<'_, i64>, size: ArrayView1<'_, i64>) -> Array1 { + let total_length = size.sum(); + let mut output = Array1::::zeros(total_length as usize); + let mut rs: i64 = 0; + for (&st, &si) in zip(start, size) { + let range = Array1::from_iter(st..st + si); + output + .slice_mut(s![rs as usize..(rs + si) as usize]) + .assign(&range); + rs += si; } - Ok(index_range) + output } - fn get_chunked_range( + #[pyfunction(name = "take_chunked_from_simple")] + fn take_chunked_from_simple_py<'py>( + py: Python<'py>, + simple: &Bound<'_, PyAny>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let simple_arr = unpack_index_array(simple)?; + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let result = take_chunked_from_simple( + simple_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + Ok(result?.into_pyarray(py)) + } + + fn take_chunked_from_simple( + simple: ArrayView1<'_, i64>, start: ArrayView1<'_, i64>, size: ArrayView1<'_, i64>, - ) -> PyResult<(i64, i64)> { - if start.len() == 0 { - return Ok((0, 0)); - } - let mut index_range = (start[0], start[0] + size[0]); + ) -> Result, PyErr> { + let total_length = size.sum(); + let mut output = Array1::::zeros(total_length as usize); + let mut rs: i64 = 0; for (&st, &si) in zip(start, size) { let end = st + si; - if st < index_range.0 { - index_range = (st, index_range.1); - } - if end > index_range.1 { - index_range = (index_range.0, end); + if end as usize > simple.len() { + return Err(PyValueError::new_err( + "The chunked index is outside of the range of the simple index!", + )); } + let to_insert = simple.slice(s![st as usize..end as usize]); + output + .slice_mut(s![rs as usize..(rs + si) as usize]) + .assign(&to_insert); + rs += si } - Ok(index_range) + Ok(output) } } From 599bdf57523de31e4896979da908e606bc52976f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 13:43:30 -0500 Subject: [PATCH 003/139] some more implementation --- pyproject.toml | 22 ++++++++----------- .../collection/lightcone/lightcone.py | 1 + .../opencosmo/collection/lightcone/stack.py | 2 ++ python/opencosmo/index/take.py | 2 +- python/opencosmo/io/io.py | 2 ++ python/opencosmo/io/mpi.py | 1 + python/opencosmo/spatial/tree.py | 8 +++++-- src/lib.rs | 2 +- test/parallel/test_lc_mpi.py | 6 ++--- 9 files changed, 26 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d17a0241..ba37b01d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ test = [ test-mpi = [ "mpi-pytest>=2025.4.0,<2026.0.0" ] + [project.scripts] opencosmo = "opencosmo.analysis.cli:cli" @@ -64,8 +65,6 @@ io = [ "pyarrow>=21.0.0", ] - - [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" @@ -75,27 +74,19 @@ module-name = "opencosmo._lib" include = ["LICENSE.md"] python-source = "python" +[tool.uv] +cache-keys = [{file = "pyproject.toml"}, {file = "Cargo.toml"}, {file = "**/*.rs"}] + [[tool.mypy.overrides]] module = ["h5py", "astropy.*", "healpy", "healsparse", "numba"] ignore_missing_imports = true -[tool.towncrier] -directory = "changes" -package = "opencosmo" -name = "opencosmo" -filename = "docs/source/changelog.rst" [tool.pytest.ini_options] testpaths = ["test"] log_cli = true log_cli_level = "INFO" -[tool.pyrefly] -project-includes = [ - "**/*.py*", -] -replace-imports-with-any = ["astropy.units", "astropy.constants"] - [lint.pycodestyle] max-doc-length = 90 @@ -104,6 +95,11 @@ future-annotations = true preview = false extend-select = ["TC"] +[tool.towncrier] +directory = "changes" +package = "opencosmo" +name = "opencosmo" +filename = "docs/source/changelog.rst" [tool.towncrier.fragment.improvement] name = "Improvements" diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index e721ee65..86baf202 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -646,6 +646,7 @@ def make_schema(self, name: str = "", _min_size=100_000) -> Schema: } children.update(child_schemas) + print("DONE") return make_schema(name, FileEntry.LIGHTCONE, children=children) def bound(self, region: Region, select_by: Optional[str] = None): diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index 22ac5be1..b746a754 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -115,6 +115,8 @@ def stack_lightcone_datasets_in_schema( [schema.children["data"] for schema in schemas] ) + print("ASDF") + assert False order = get_stacked_lightcone_order(ds_list, max_level) updater = partial(update_order, order=order) diff --git a/python/opencosmo/index/take.py b/python/opencosmo/index/take.py index a2e27a65..6c41e2b0 100644 --- a/python/opencosmo/index/take.py +++ b/python/opencosmo/index/take.py @@ -14,7 +14,7 @@ def take(from_, by): case (np.ndarray(), np.ndarray()): return __take_simple_from_simple(from_, by) case (np.ndarray(), (np.ndarray(), np.ndarray())): - return idxlib.take_chunked_from_simple(from_, by) + return idxlib.take_chunked_from_simple(from_, *by) case ((np.ndarray(), np.ndarray()), np.ndarray()): return __take_simple_from_chunked(from_, by) case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): diff --git a/python/opencosmo/io/io.py b/python/opencosmo/io/io.py index 4305eab2..6f311856 100644 --- a/python/opencosmo/io/io.py +++ b/python/opencosmo/io/io.py @@ -143,7 +143,9 @@ def write(path: Path, dataset: Writeable, overwrite=False, **schema_kwargs) -> N path = resolve_path(path, existance_requirement) + print("Schema") schema = dataset.make_schema(**schema_kwargs) + print("Done") if mpiio is not None: return mpiio.write_parallel(path, schema) diff --git a/python/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py index cc990ad3..9166227a 100644 --- a/python/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -80,6 +80,7 @@ def write_parallel(file: Path, file_schema: Schema): has_data = [i for i, state in enumerate(results) if state == CombineState.VALID] if len(has_data) == 0: raise ValueError("No ranks have any data to write!") + print("ASDF") group = comm.Get_group() new_group = group.Incl(has_data) diff --git a/python/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py index 5eadad05..1ad5f40a 100644 --- a/python/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -239,8 +239,12 @@ def query(self, region: Region) -> tuple[ChunkedIndex, ChunkedIndex]: return (contains_start, contains_size), (intersects_start, intersects_size) def apply_index(self, index: DataIndex, min_counts: int = 100) -> Tree: - max_level_starts = self.__columns[f"level_{self.__max_level}/start"][:] - max_level_sizes = self.__columns[f"level_{self.__max_level}/size"][:] + max_level_starts = self.__columns[f"level_{self.__max_level}/start"][:].astype( + np.int64 + ) + max_level_sizes = self.__columns[f"level_{self.__max_level}/size"][:].astype( + np.int64 + ) n = n_in_range(index, max_level_starts, max_level_sizes) target = h5py.File(f"{uuid1()}.hdf5", "w", driver="core", backing_store=False) result = combine_upwards( diff --git a/src/lib.rs b/src/lib.rs index 67e1177a..6cebfacf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; mod index; -#[pyo3::pymodule] +#[pymodule] mod _lib { use pyo3::prelude::*; diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index c5d9023d..183885b1 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -3,14 +3,13 @@ import astropy.units as u import numpy as np +import opencosmo as oc import pytest from astropy.coordinates import SkyCoord from healpy import pix2ang from mpi4py import MPI -from pytest_mpi.parallel_assert import parallel_assert - -import opencosmo as oc from opencosmo.mpi import get_comm_world +from pytest_mpi.parallel_assert import parallel_assert IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -410,6 +409,7 @@ def test_write_some_missing(core_path_487, core_path_475, per_test_dir): original_data_length = comm.allgather(len(original_data)) ds = ds.with_new_columns(gal_id=np.arange(len(ds))) + print("hi") oc.write(per_test_dir / "lightcone.hdf5", ds) ds = oc.open(per_test_dir / "lightcone.hdf5", synth_cores=True) written_data = ds.select("early_index").get_data("numpy") From dd2a01610a30c6b435c5bb9dd094035480852243 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 15:52:41 -0500 Subject: [PATCH 004/139] Implement chunked_from_chunked --- .../collection/lightcone/lightcone.py | 1 - .../opencosmo/collection/lightcone/stack.py | 2 - python/opencosmo/index/take.py | 90 +------------------ python/opencosmo/io/io.py | 2 - python/opencosmo/io/iopen.py | 8 +- python/opencosmo/io/mpi.py | 1 - python/opencosmo/spatial/tree.py | 8 +- src/index.rs | 66 ++++++++++++++ test/parallel/test_lc_mpi.py | 4 +- 9 files changed, 80 insertions(+), 102 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 86baf202..e721ee65 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -646,7 +646,6 @@ def make_schema(self, name: str = "", _min_size=100_000) -> Schema: } children.update(child_schemas) - print("DONE") return make_schema(name, FileEntry.LIGHTCONE, children=children) def bound(self, region: Region, select_by: Optional[str] = None): diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index b746a754..22ac5be1 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -115,8 +115,6 @@ def stack_lightcone_datasets_in_schema( [schema.children["data"] for schema in schemas] ) - print("ASDF") - assert False order = get_stacked_lightcone_order(ds_list, max_level) updater = partial(update_order, order=order) diff --git a/python/opencosmo/index/take.py b/python/opencosmo/index/take.py index 6c41e2b0..2eddd0c0 100644 --- a/python/opencosmo/index/take.py +++ b/python/opencosmo/index/take.py @@ -1,6 +1,5 @@ from __future__ import annotations -import numba as nb # type: ignore import numpy as np from opencosmo._lib import index as idxlib @@ -18,7 +17,8 @@ def take(from_, by): case ((np.ndarray(), np.ndarray()), np.ndarray()): return __take_simple_from_chunked(from_, by) case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): - return __take_chunked_from_chunked(from_, by) + print("Chunked from chunked") + return idxlib.take_chunked_from_chunked(*from_, *by) def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex): @@ -31,89 +31,3 @@ def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex): def __take_simple_from_simple(from_: np.ndarray, by: np.ndarray): return from_[by] - - -@nb.njit -def prefix_sum(arr): - out = np.empty(len(arr) + 1, dtype=arr.dtype) - total = 0 - out[0] = 0 - for i in range(len(arr)): - total += arr[i] - out[i + 1] = total - return out - - -@nb.njit -def find_chunk(prefix, x): - """ - Returns index i such that prefix[i] <= x < prefix[i+1]. - """ - lo = 0 - hi = len(prefix) - 1 # prefix has length N+1 - while lo + 1 < hi: - mid = (lo + hi) // 2 - if prefix[mid] <= x: - lo = mid - else: - hi = mid - return lo - - -@nb.njit -def resolve_spanning_numba( - start1, size1, start2, size2, out_start, out_size, out_owner -): - """ - Resolves index2 slices into data-level chunks. - Returns the number of output segments written. - """ - prefix = prefix_sum(size1) - out_pos = 0 - - for j in range(len(start2)): - logical = start2[j] - remaining = size2[j] - - while remaining > 0: - # Find which chunk in index1 we are inside - i1 = find_chunk(prefix, logical) - - # Where inside that chunk? - offset = logical - prefix[i1] - - # How many logical units remain in this chunk? - chunk_left = size1[i1] - offset - - # How much we take - take = chunk_left if chunk_left < remaining else remaining - - # Emit result - out_start[out_pos] = start1[i1] + offset - out_size[out_pos] = take - out_owner[out_pos] = j - - out_pos += 1 - - # Advance - logical += take - remaining -= take - - return out_pos - - -def __take_chunked_from_chunked(from_: ChunkedIndex, by: ChunkedIndex): - if len(from_[0]) == 0 and from_[0][0] == 0: - return by - - max_out = len(by[1]) * len(from_[1]) - out_start = np.empty(max_out, dtype=np.int64) - out_size = np.empty(max_out, dtype=np.int64) - out_owner = np.empty(max_out, dtype=np.int64) - - n = resolve_spanning_numba( - from_[0], from_[1], by[0], by[1], out_start, out_size, out_owner - ) - out_start = np.resize(out_start, (n,)) - out_size = np.resize(out_size, (n,)) - return out_start, out_size diff --git a/python/opencosmo/io/io.py b/python/opencosmo/io/io.py index 6f311856..4305eab2 100644 --- a/python/opencosmo/io/io.py +++ b/python/opencosmo/io/io.py @@ -143,9 +143,7 @@ def write(path: Path, dataset: Writeable, overwrite=False, **schema_kwargs) -> N path = resolve_path(path, existance_requirement) - print("Schema") schema = dataset.make_schema(**schema_kwargs) - print("Done") if mpiio is not None: return mpiio.write_parallel(path, schema) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 0cdfa1a8..1f7e6031 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -396,9 +396,11 @@ def __find_datasets_under_group( columns = [ nds_[1] for nds_ in filter( - lambda nds: ds_group_parent in nds[0] - and "header" not in nds[0] - and isinstance(nds[1], h5py.Dataset), + lambda nds: ( + ds_group_parent in nds[0] + and "header" not in nds[0] + and isinstance(nds[1], h5py.Dataset) + ), file_map.items(), ) ] diff --git a/python/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py index 9166227a..cc990ad3 100644 --- a/python/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -80,7 +80,6 @@ def write_parallel(file: Path, file_schema: Schema): has_data = [i for i, state in enumerate(results) if state == CombineState.VALID] if len(has_data) == 0: raise ValueError("No ranks have any data to write!") - print("ASDF") group = comm.Get_group() new_group = group.Incl(has_data) diff --git a/python/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py index 1ad5f40a..a1ee88f3 100644 --- a/python/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -150,7 +150,9 @@ def partition_index(n_partitions: int, counts: h5py.Group, min_level: int): split_level_indices = full_region_indices - return np.array_split(split_level_indices, n_partitions), split_level + return np.array_split( + split_level_indices.astype(np.int64), n_partitions + ), split_level class Tree: @@ -201,8 +203,8 @@ def partition( n_partitions, counts, min_level ) partitions = [] - start = self.__columns[f"level_{split_level}/start"][:] - size = self.__columns[f"level_{split_level}/size"][:] + start = self.__columns[f"level_{split_level}/start"][:].astype(np.int64) + size = self.__columns[f"level_{split_level}/size"][:].astype(np.int64) for index_ in partition_indices: if len(index_) == 0: continue diff --git a/src/index.rs b/src/index.rs index f6742e54..1073b590 100644 --- a/src/index.rs +++ b/src/index.rs @@ -193,4 +193,70 @@ pub(crate) mod index { } Ok(output) } + + #[pyfunction(name = "take_chunked_from_chunked")] + fn take_chunked_from_chunked_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + take_start: &Bound<'_, PyAny>, + take_size: &Bound<'_, PyAny>, + ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let (take_start_arr, take_size_arr) = unpack_chunked_index(take_start, take_size)?; + let result = take_chunked_from_chunked( + start_arr.as_array(), + size_arr.as_array(), + take_start_arr.as_array(), + take_size_arr.as_array(), + )?; + Ok((result.0.into_pyarray(py), result.1.into_pyarray(py))) + } + fn take_chunked_from_chunked( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + take_start: ArrayView1<'_, i64>, + take_size: ArrayView1<'_, i64>, + ) -> Result<(Array1, Array1), PyErr> { + // assumption: everything is sorted + let mut output_start: Vec = Vec::new(); + let mut output_size: Vec = Vec::new(); + if size.sum() < take_start[take_start.len() - 1] + take_size[take_size.len() - 1] { + return Err(PyValueError::new_err( + "You can't take more elements than exist in an index!", + )); + } + let mut chunk_index: usize = 0; + let mut cs = 0; + for (&tstart, &tsize) in zip(take_start, take_size) { + while cs + size[chunk_index] < tstart { + cs += size[chunk_index]; + chunk_index += 1; + } + // chunk_index now points to the chunk that we need to start at + // cs is equal to the cumulative size of all chunks that we've passed + let mut start_in_chunk = tstart - cs; + let mut chunk_taken = 0; + let mut chunk_completed = false; + while !chunk_completed { + let mut size_in_chunk = size[chunk_index] - start_in_chunk; + chunk_completed = size_in_chunk >= (tsize - chunk_taken); + if chunk_completed { + size_in_chunk = tsize - chunk_taken; + } + + output_start.push(start[chunk_index] + start_in_chunk); + output_size.push(size_in_chunk); + chunk_taken += size_in_chunk; + if !chunk_completed { + cs += size[chunk_index]; + chunk_index += 1; + start_in_chunk = 0; + } + } + } + let output_start_arr = Array1::from_vec(output_start); + let output_size_arr = Array1::from_vec(output_size); + Ok((output_start_arr, output_size_arr)) + } } diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 183885b1..f126bc35 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -3,7 +3,6 @@ import astropy.units as u import numpy as np -import opencosmo as oc import pytest from astropy.coordinates import SkyCoord from healpy import pix2ang @@ -11,6 +10,8 @@ from opencosmo.mpi import get_comm_world from pytest_mpi.parallel_assert import parallel_assert +import opencosmo as oc + IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -409,7 +410,6 @@ def test_write_some_missing(core_path_487, core_path_475, per_test_dir): original_data_length = comm.allgather(len(original_data)) ds = ds.with_new_columns(gal_id=np.arange(len(ds))) - print("hi") oc.write(per_test_dir / "lightcone.hdf5", ds) ds = oc.open(per_test_dir / "lightcone.hdf5", synth_cores=True) written_data = ds.select("early_index").get_data("numpy") From 7bf2205490f434ef2081800500ff86473e28f835 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 16:19:06 -0500 Subject: [PATCH 005/139] Fixes for unsorted chunks --- src/index.rs | 78 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/src/index.rs b/src/index.rs index 1073b590..b0c99152 100644 --- a/src/index.rs +++ b/src/index.rs @@ -212,51 +212,71 @@ pub(crate) mod index { )?; Ok((result.0.into_pyarray(py), result.1.into_pyarray(py))) } + fn find_chunk(prefix: &[i64], x: i64) -> usize { + let mut lo = 0usize; + let mut hi = prefix.len() - 1; + while lo + 1 < hi { + let mid = (lo + hi) / 2; + if prefix[mid] <= x { + lo = mid; + } else { + hi = mid; + } + } + lo + } + fn take_chunked_from_chunked( start: ArrayView1<'_, i64>, size: ArrayView1<'_, i64>, take_start: ArrayView1<'_, i64>, take_size: ArrayView1<'_, i64>, ) -> Result<(Array1, Array1), PyErr> { - // assumption: everything is sorted + if take_start.len() == 0 { + return Ok((Array1::::zeros(0), Array1::::zeros(0))); + } let mut output_start: Vec = Vec::new(); let mut output_size: Vec = Vec::new(); - if size.sum() < take_start[take_start.len() - 1] + take_size[take_size.len() - 1] { - return Err(PyValueError::new_err( - "You can't take more elements than exist in an index!", - )); + + let mut prefix = vec![0i64; size.len() + 1]; + for i in 0..size.len() { + prefix[i + 1] = prefix[i] + size[i]; } - let mut chunk_index: usize = 0; - let mut cs = 0; + let total = prefix[size.len()]; + for (&tstart, &tsize) in zip(take_start, take_size) { - while cs + size[chunk_index] < tstart { - cs += size[chunk_index]; - chunk_index += 1; + if tstart + tsize > total { + return Err(PyValueError::new_err( + "You can't take more elements than exist in an index!", + )); } - // chunk_index now points to the chunk that we need to start at - // cs is equal to the cumulative size of all chunks that we've passed + let mut chunk_index = find_chunk(&prefix, tstart); + let mut cs = prefix[chunk_index]; let mut start_in_chunk = tstart - cs; - let mut chunk_taken = 0; - let mut chunk_completed = false; - while !chunk_completed { - let mut size_in_chunk = size[chunk_index] - start_in_chunk; - chunk_completed = size_in_chunk >= (tsize - chunk_taken); - if chunk_completed { - size_in_chunk = tsize - chunk_taken; - } + let mut chunk_taken = 0i64; + + loop { + let size_in_chunk = size[chunk_index] - start_in_chunk; + let remaining = tsize - chunk_taken; + let (take, chunk_completed) = if size_in_chunk >= remaining { + (remaining, true) + } else { + (size_in_chunk, false) + }; output_start.push(start[chunk_index] + start_in_chunk); - output_size.push(size_in_chunk); - chunk_taken += size_in_chunk; - if !chunk_completed { - cs += size[chunk_index]; - chunk_index += 1; - start_in_chunk = 0; + output_size.push(take); + chunk_taken += take; + + if chunk_completed { + break; } + cs += size[chunk_index]; + chunk_index += 1; + start_in_chunk = 0; } } - let output_start_arr = Array1::from_vec(output_start); - let output_size_arr = Array1::from_vec(output_size); - Ok((output_start_arr, output_size_arr)) + + Ok((Array1::from_vec(output_start), Array1::from_vec(output_size))) } } From a5940ba497aa9b8d2739d17253357d9797c672ff Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 16:32:17 -0500 Subject: [PATCH 006/139] Add changelog, fix type-check --- .github/workflows/lint.yaml | 2 +- changes/+80e6e808.improvement.rst | 1 + pyproject.toml | 1 - uv.lock | 54 ------------------------------- 4 files changed, 2 insertions(+), 56 deletions(-) create mode 100644 changes/+80e6e808.improvement.rst diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 55fbfc80..b8c79b1d 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -34,6 +34,6 @@ jobs: run: uv run ruff check . - name: Run mypy type checker - run: uv run mypy src/opencosmo + run: uv run mypy python/opencosmo diff --git a/changes/+80e6e808.improvement.rst b/changes/+80e6e808.improvement.rst new file mode 100644 index 00000000..be818652 --- /dev/null +++ b/changes/+80e6e808.improvement.rst @@ -0,0 +1 @@ +Difficult indexing routines have been rewritten in native code, allowing the removal of numba as a dependency. diff --git a/pyproject.toml b/pyproject.toml index ba37b01d..2aa2f2b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "deprecated>=1.2.38,<2.0.0", "numpy>=2.0,<2.5", "click (>=8.2.1,<9.0.0)", - "numba>=0.64.0", "rustworkx>=0.17.1", ] diff --git a/uv.lock b/uv.lock index 05a5317c..db4a9575 100644 --- a/uv.lock +++ b/uv.lock @@ -715,30 +715,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/c8/d148e041732d631fc76036f8b30fae4e77b027a1e95b7a84bb522481a940/librt-0.8.1-cp314-cp314t-win_arm64.whl", hash = "sha256:bf512a71a23504ed08103a13c941f763db13fb11177beb3d9244c98c29fb4a61", size = 48755, upload-time = "2026-02-17T16:12:47.943Z" }, ] -[[package]] -name = "llvmlite" -version = "0.47.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/88/a8952b6d5c21e74cbf158515b779666f692846502623e9e3c39d8e8ba25f/llvmlite-0.47.0.tar.gz", hash = "sha256:62031ce968ec74e95092184d4b0e857e444f8fdff0b8f9213707699570c33ccc", size = 193614, upload-time = "2026-03-31T18:29:53.497Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/48/4b7fe0e34c169fa2f12532916133e0b219d2823b540733651b34fdac509a/llvmlite-0.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:306a265f408c259067257a732c8e159284334018b4083a9e35f67d19792b164f", size = 37232769, upload-time = "2026-03-31T18:28:43.735Z" }, - { url = "https://files.pythonhosted.org/packages/e6/4b/e3f2cd17822cf772a4a51a0a8080b0032e6d37b2dbe8cfb724eac4e31c52/llvmlite-0.47.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5853bf26160857c0c2573415ff4efe01c4c651e59e2c55c2a088740acfee51cd", size = 56275178, upload-time = "2026-03-31T18:28:48.342Z" }, - { url = "https://files.pythonhosted.org/packages/b6/55/a3b4a543185305a9bdf3d9759d53646ed96e55e7dfd43f53e7a421b8fbae/llvmlite-0.47.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:003bcf7fa579e14db59c1a1e113f93ab8a06b56a4be31c7f08264d1d4072d077", size = 55128632, upload-time = "2026-03-31T18:28:52.901Z" }, - { url = "https://files.pythonhosted.org/packages/2f/f5/d281ae0f79378a5a91f308ea9fdb9f9cc068fddd09629edc0725a5a8fde1/llvmlite-0.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:f3079f25bdc24cd9d27c4b2b5e68f5f60c4fdb7e8ad5ee2b9b006007558f9df7", size = 38138692, upload-time = "2026-03-31T18:28:57.147Z" }, - { url = "https://files.pythonhosted.org/packages/77/6f/4615353e016799f80fa52ccb270a843c413b22361fadda2589b2922fb9b0/llvmlite-0.47.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a3c6a735d4e1041808434f9d440faa3d78d9b4af2ee64d05a66f351883b6ceec", size = 37232771, upload-time = "2026-03-31T18:29:01.324Z" }, - { url = "https://files.pythonhosted.org/packages/31/b8/69f5565f1a280d032525878a86511eebed0645818492feeb169dfb20ae8e/llvmlite-0.47.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2699a74321189e812d476a43d6d7f652f51811e7b5aad9d9bba842a1c7927acb", size = 56275178, upload-time = "2026-03-31T18:29:05.748Z" }, - { url = "https://files.pythonhosted.org/packages/d6/da/b32cafcb926fb0ce2aa25553bf32cb8764af31438f40e2481df08884c947/llvmlite-0.47.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c6951e2b29930227963e53ee152441f0e14be92e9d4231852102d986c761e40", size = 55128632, upload-time = "2026-03-31T18:29:11.235Z" }, - { url = "https://files.pythonhosted.org/packages/46/9f/4898b44e4042c60fafcb1162dfb7014f6f15b1ec19bf29cfea6bf26df90d/llvmlite-0.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2e9adf8698d813a9a5efb2d4370caf344dbc1e145019851fee6a6f319ba760e", size = 38138695, upload-time = "2026-03-31T18:29:15.43Z" }, - { url = "https://files.pythonhosted.org/packages/1c/d4/33c8af00f0bf6f552d74f3a054f648af2c5bc6bece97972f3bfadce4f5ec/llvmlite-0.47.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:de966c626c35c9dff5ae7bf12db25637738d0df83fc370cf793bc94d43d92d14", size = 37232773, upload-time = "2026-03-31T18:29:19.453Z" }, - { url = "https://files.pythonhosted.org/packages/64/1d/a760e993e0c0ba6db38d46b9f48f6c7dceb8ac838824997fb9e25f97bc04/llvmlite-0.47.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ddbccff2aeaff8670368340a158abefc032fe9b3ccf7d9c496639263d00151aa", size = 56275176, upload-time = "2026-03-31T18:29:24.149Z" }, - { url = "https://files.pythonhosted.org/packages/84/3b/e679bc3b29127182a7f4aa2d2e9e5bea42adb93fb840484147d59c236299/llvmlite-0.47.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4a7b778a2e144fc64468fb9bf509ac1226c9813a00b4d7afea5d988c4e22fca", size = 55128631, upload-time = "2026-03-31T18:29:29.536Z" }, - { url = "https://files.pythonhosted.org/packages/be/f7/19e2a09c62809c9e63bbd14ce71fb92c6ff7b7b3045741bb00c781efc3c9/llvmlite-0.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:694e3c2cdc472ed2bd8bd4555ca002eec4310961dd58ef791d508f57b5cc4c94", size = 39153826, upload-time = "2026-03-31T18:29:33.681Z" }, - { url = "https://files.pythonhosted.org/packages/40/a1/581a8c707b5e80efdbbe1dd94527404d33fe50bceb71f39d5a7e11bd57b7/llvmlite-0.47.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:92ec8a169a20b473c1c54d4695e371bde36489fc1efa3688e11e99beba0abf9c", size = 37232772, upload-time = "2026-03-31T18:29:37.952Z" }, - { url = "https://files.pythonhosted.org/packages/11/03/16090dd6f74ba2b8b922276047f15962fbeea0a75d5601607edb301ba945/llvmlite-0.47.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa1cbd800edd3b20bc141521f7fd45a6185a5b84109aa6855134e81397ffe72b", size = 56275178, upload-time = "2026-03-31T18:29:42.58Z" }, - { url = "https://files.pythonhosted.org/packages/f5/cb/0abf1dd4c5286a95ffe0c1d8c67aec06b515894a0dd2ac97f5e27b82ab0b/llvmlite-0.47.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6725179b89f03b17dabe236ff3422cb8291b4c1bf40af152826dfd34e350ae8", size = 55128632, upload-time = "2026-03-31T18:29:46.939Z" }, - { url = "https://files.pythonhosted.org/packages/4f/79/d3bbab197e86e0ff4f9c07122895b66a3e0d024247fcff7f12c473cb36d9/llvmlite-0.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6842cf6f707ec4be3d985a385ad03f72b2d724439e118fcbe99b2929964f0453", size = 39153839, upload-time = "2026-03-31T18:29:51.004Z" }, -] - [[package]] name = "markupsafe" version = "3.0.3" @@ -986,34 +962,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, ] -[[package]] -name = "numba" -version = "0.65.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "llvmlite" }, - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/49/61/7299643b9c18d669e04be7c5bcb64d985070d07553274817b45b049e7bfe/numba-0.65.0.tar.gz", hash = "sha256:edad0d9f6682e93624c00125a471ae4df186175d71fd604c983c377cdc03e68b", size = 2764131, upload-time = "2026-04-01T03:52:01.946Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6c/2f/8bd31a1ea43c01ac215283d83aa5f8d5acbe7a36c85b82f1757bfe9ccb31/numba-0.65.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b27ee4847e1bfb17e9604d100417ee7c1d10f15a6711c6213404b3da13a0b2aa", size = 2680705, upload-time = "2026-04-01T03:51:32.597Z" }, - { url = "https://files.pythonhosted.org/packages/73/36/88406bd58600cc696417b8e5dd6a056478da808f3eaf48d18e2421e0c2d9/numba-0.65.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a52d92ffd297c10364bce60cd1fcb88f99284ab5df085f2c6bcd1cb33b529a6f", size = 3801411, upload-time = "2026-04-01T03:51:34.321Z" }, - { url = "https://files.pythonhosted.org/packages/0c/61/ce753a1d7646dd477e16d15e89473703faebb8995d2f71d7ad69a540b565/numba-0.65.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da8e371e328c06d0010c3d8b44b21858652831b85bcfba78cb22c042e22dbd8e", size = 3501622, upload-time = "2026-04-01T03:51:36.348Z" }, - { url = "https://files.pythonhosted.org/packages/7d/86/db87a5393f1b1fabef53ac3ba4e6b938bb27e40a04ad7cc512098fcae032/numba-0.65.0-cp312-cp312-win_amd64.whl", hash = "sha256:59bb9f2bb9f1238dfd8e927ba50645c18ae769fef4f3d58ea0ea22a2683b91f5", size = 2749979, upload-time = "2026-04-01T03:51:37.88Z" }, - { url = "https://files.pythonhosted.org/packages/8b/f8/eee0f1ff456218db036bfc9023995ec1f85a9dc8f2422f1594f6a87829e0/numba-0.65.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c6334094563a456a695c812e6846288376ca02327cf246cdcc83e1bb27862367", size = 2680679, upload-time = "2026-04-01T03:51:39.491Z" }, - { url = "https://files.pythonhosted.org/packages/1b/8f/3d116e4b8e92f6abace431afa4b2b944f4d65bdee83af886f5c4b263df95/numba-0.65.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b8a9008411615c69d083d1dcf477f75a5aa727b30beb16e139799e2be945cdfd", size = 3809537, upload-time = "2026-04-01T03:51:41.42Z" }, - { url = "https://files.pythonhosted.org/packages/b5/2c/6a3ca4128e253cb67affe06deb47688f51ce968f5111e2a06d010e6f1fa6/numba-0.65.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af96c0cba53664efcb361528b8c75e011a6556c859c7e08424c2715201c6cf7a", size = 3508615, upload-time = "2026-04-01T03:51:43.444Z" }, - { url = "https://files.pythonhosted.org/packages/96/0e/267f9a36fb282c104a971d7eecb685b411c47dce2a740fe69cf5fc2945d9/numba-0.65.0-cp313-cp313-win_amd64.whl", hash = "sha256:6254e73b9c929dc736a1fbd3d6f5680789709a5067cae1fa7198707385129c04", size = 2749938, upload-time = "2026-04-01T03:51:45.218Z" }, - { url = "https://files.pythonhosted.org/packages/56/a4/90edb01e9176053578e343d7a7276bc28356741ee67059aed8ed2c1a4e59/numba-0.65.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:ee336b398a6fca51b1f626034de99f50cb1bd87d537a166275158a3cee744b82", size = 2680878, upload-time = "2026-04-01T03:51:46.91Z" }, - { url = "https://files.pythonhosted.org/packages/24/8d/e12d6ff4b9119db3cbf7b2db1ce257576441bd3c76388c786dea74f20b02/numba-0.65.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:05c0a9fdf75d85f57dee47b719e8d6415707b80aae45d75f63f9dc1b935c29f7", size = 3778456, upload-time = "2026-04-01T03:51:48.552Z" }, - { url = "https://files.pythonhosted.org/packages/17/89/abcd83e76f6a773276fe76244140671bcc5bf820f6e2ae1a15362ae4c8c9/numba-0.65.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:583680e0e8faf124d362df23b4b593f3221a8996341a63d1b664c122401bec2f", size = 3478464, upload-time = "2026-04-01T03:51:50.527Z" }, - { url = "https://files.pythonhosted.org/packages/73/5b/fbce55ce3d933afbc7ade04df826853e4a846aaa47d58d2fbb669b8f2d08/numba-0.65.0-cp314-cp314-win_amd64.whl", hash = "sha256:add297d3e1c08dd884f44100152612fa41e66a51d15fdf91307f9dde31d06830", size = 2752012, upload-time = "2026-04-01T03:51:52.691Z" }, - { url = "https://files.pythonhosted.org/packages/1e/ab/af705f4257d9388fb2fd6d7416573e98b6ca9c786e8b58f02720978557bd/numba-0.65.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:194a243ba53a9157c8538cbb3166ec015d785a8c5d584d06cdd88bee902233c7", size = 2683961, upload-time = "2026-04-01T03:51:54.281Z" }, - { url = "https://files.pythonhosted.org/packages/ff/e5/8267b0adb0c01b52b553df5062fbbb42c30ed5362d08b85cc913a36f838f/numba-0.65.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c7fa502960f7a2f3f5cb025bc7bff888a3551277b92431bfdc5ba2f11a375749", size = 3816373, upload-time = "2026-04-01T03:51:56.18Z" }, - { url = "https://files.pythonhosted.org/packages/b0/f5/b8397ca360971669a93706b9274592b6864e4367a37d498fbbcb62aa2d48/numba-0.65.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5046c63f783ca3eb6195f826a50797465e7c4ce811daa17c9bea47e310c9b964", size = 3532782, upload-time = "2026-04-01T03:51:58.387Z" }, - { url = "https://files.pythonhosted.org/packages/f5/21/1e73fa16bf0393ebb74c5bb208d712152ffdfc84600a8e93a3180317856e/numba-0.65.0-cp314-cp314t-win_amd64.whl", hash = "sha256:46fd679ae4f68c7a5d5721efbd29ecee0b0f3013211591891d79b51bfdf73113", size = 2757611, upload-time = "2026-04-01T03:52:00.083Z" }, -] - [[package]] name = "numpy" version = "2.4.4" @@ -1087,7 +1035,6 @@ dependencies = [ { name = "hdf5plugin" }, { name = "healpy" }, { name = "healsparse" }, - { name = "numba" }, { name = "numpy" }, { name = "pydantic" }, { name = "rustworkx" }, @@ -1138,7 +1085,6 @@ requires-dist = [ { name = "hdf5plugin", specifier = ">=5.0.0,!=5.1" }, { name = "healpy", specifier = ">=1.19.0,<2.0.0" }, { name = "healsparse", specifier = "==1.11.1" }, - { name = "numba", specifier = ">=0.64.0" }, { name = "numpy", specifier = ">=2.0,<2.5" }, { name = "pyarrow", marker = "extra == 'io'", specifier = ">=21.0.0" }, { name = "pydantic", specifier = ">=2.10.6,<3.0.0" }, From c673ec58337773eba76b609693b1fc7484106e1e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 20:16:14 -0500 Subject: [PATCH 007/139] Make index typing more consistent --- python/opencosmo/index/__init__.py | 10 ++++++---- python/opencosmo/index/build.py | 22 +++++++++++++++------- python/opencosmo/index/get.py | 17 +++++++++-------- python/opencosmo/index/in_range.py | 16 ++++++++-------- python/opencosmo/index/mask.py | 12 ++++++++---- python/opencosmo/index/project.py | 14 +++++++++----- python/opencosmo/index/take.py | 15 +++++++++------ python/opencosmo/index/unary.py | 15 +++++++++------ python/opencosmo/spatial/octree.py | 4 +++- python/opencosmo/spatial/tree.py | 17 +++++++---------- 10 files changed, 83 insertions(+), 59 deletions(-) diff --git a/python/opencosmo/index/__init__.py b/python/opencosmo/index/__init__.py index 5d3f64f3..4bba8eab 100644 --- a/python/opencosmo/index/__init__.py +++ b/python/opencosmo/index/__init__.py @@ -1,7 +1,6 @@ import numpy as np -from numpy.typing import NDArray -from .build import concatenate, empty, from_size, single_chunk +from .build import concatenate, empty, from_size, from_start_size_group, single_chunk from .get import get_data from .in_range import n_in_range from .mask import into_array, mask @@ -9,8 +8,10 @@ from .take import take from .unary import get_length, get_range -SimpleIndex = NDArray[np.int_] -ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] +IndexArray = np.ndarray[tuple[int], np.dtype[np.int64]] + +SimpleIndex = IndexArray +ChunkedIndex = tuple[IndexArray, IndexArray] DataIndex = SimpleIndex | ChunkedIndex @@ -32,4 +33,5 @@ "take", "get_length", "get_range", + "from_start_size_group", ] diff --git a/python/opencosmo/index/build.py b/python/opencosmo/index/build.py index 4038bce8..5d831f06 100644 --- a/python/opencosmo/index/build.py +++ b/python/opencosmo/index/build.py @@ -7,25 +7,33 @@ from .mask import into_array if TYPE_CHECKING: - from . import DataIndex + import h5py + from . import ChunkedIndex, DataIndex, SimpleIndex -def from_size(size: int): + +def from_size(size: int) -> ChunkedIndex: return (np.array([0], dtype=np.int64), np.array([size], dtype=np.int64)) -def single_chunk(start: int, size: int): +def single_chunk(start: int, size: int) -> ChunkedIndex: return (np.array([start], dtype=np.int64), np.array([size], np.int64)) -def empty(): +def empty() -> ChunkedIndex: return (np.array([], dtype=np.int64), np.array([], dtype=np.int64)) -def from_range(start: int, end: int): +def from_range(start: int, end: int) -> ChunkedIndex: size = end - start return (np.array([start], dtype=np.int64), np.array([size], np.int64)) -def concatenate(*indices: DataIndex): - np.concatenate(list(map(into_array, indices))) +def concatenate(*indices: DataIndex) -> SimpleIndex: + return np.concatenate(list(map(into_array, indices))) + + +def from_start_size_group(group: h5py) -> ChunkedIndex: + start = group["start"][:].astype(np.int64) + size = group["size"][:].astype(np.int64) + return (start, size) diff --git a/python/opencosmo/index/get.py b/python/opencosmo/index/get.py index 6baa8eb7..98a66905 100644 --- a/python/opencosmo/index/get.py +++ b/python/opencosmo/index/get.py @@ -9,22 +9,22 @@ from .unary import get_length if TYPE_CHECKING: - from numpy.typing import NDArray + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex -def get_data(data: h5py.Dataset | np.ndarray, index: np.ndarray | tuple): +def get_data(data: h5py.Dataset | np.ndarray, index: DataIndex): if get_length(index) == 0: return np.array([]) match index: case np.ndarray(): return get_data_simple(data, index) case (np.ndarray(), np.ndarray()): - return get_data_chunked(data, *index) + return get_data_chunked(data, index) case _: raise ValueError(f"Got invalid index of type {type(index)}") -def get_data_simple(data: h5py.Dataset | np.ndarray, index: NDArray[np.int_]): +def get_data_simple(data: h5py.Dataset | np.ndarray, index: SimpleIndex): if isinstance(data, np.ndarray): return data[index] @@ -41,14 +41,15 @@ def get_data_simple(data: h5py.Dataset | np.ndarray, index: NDArray[np.int_]): return buffer[index - min_] -def get_data_chunked( - data: h5py.Dataset | np.ndarray, starts: NDArray[np.int_], sizes: NDArray[np.int_] -): +def get_data_chunked(data: h5py.Dataset | np.ndarray, index: ChunkedIndex): """ We assume that starts are ordered, and chunks are non-overlapping """ + starts = index[0] + sizes = index[1] + unit = None if isinstance(data, u.Quantity): unit = data.unit @@ -66,7 +67,7 @@ def get_data_chunked( else: storage[dest_slice] = data[source_slice] - running_index += size + running_index += int(size) if unit is not None: storage *= unit diff --git a/python/opencosmo/index/in_range.py b/python/opencosmo/index/in_range.py index e9aa7698..165a6836 100644 --- a/python/opencosmo/index/in_range.py +++ b/python/opencosmo/index/in_range.py @@ -7,28 +7,28 @@ from opencosmo._lib import index as idx if TYPE_CHECKING: - from numpy.typing import NDArray + from opencosmo.index import DataIndex, IndexArray, SimpleIndex def n_in_range( - index: NDArray[np.int_] | tuple, - range_starts: int | NDArray[np.int_], - range_sizes: int | NDArray[np.int_], -): + index: DataIndex, + range_starts: int | IndexArray, + range_sizes: int | IndexArray, +) -> IndexArray: range_starts = np.atleast_1d(range_starts) range_sizes = np.atleast_1d(range_sizes) match index: case np.ndarray(): return __n_in_range_simple(index, range_starts, range_sizes) case (np.ndarray(), np.ndarray()): - return idx.n_in_range_chunked(*index, range_starts, range_sizes) + return idx.n_in_range_chunked(index[0], index[1], range_starts, range_sizes) case _: raise ValueError(f"Unknown index type {type(index)}") def __n_in_range_simple( - index: NDArray[np.int_], start: NDArray[np.int_], size: NDArray[np.int_] -) -> NDArray[np.int_]: + index: SimpleIndex, start: IndexArray, size: IndexArray +) -> IndexArray: if len(start) != len(size): raise ValueError("Start and size arrays must have the same length") if np.any(size < 0): diff --git a/python/opencosmo/index/mask.py b/python/opencosmo/index/mask.py index 7ccdfd92..5f85cc98 100644 --- a/python/opencosmo/index/mask.py +++ b/python/opencosmo/index/mask.py @@ -7,11 +7,13 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex + from opencosmo._lib import index as idxlib from opencosmo.index.unary import get_length -def mask(index, boolean_mask): +def mask(index: DataIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: match index: case np.ndarray(): return __mask_simple(index, boolean_mask) @@ -21,7 +23,7 @@ def mask(index, boolean_mask): raise TypeError(f"Unknown index type {type(index)}") -def __mask_simple(index: NDArray[np.int_], boolean_mask: NDArray[np.bool_]): +def __mask_simple(index: SimpleIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: if (lm := len(boolean_mask)) > len(index): raise ValueError( "Boolean mask must be less than or equal to the length of the index itself" @@ -30,12 +32,12 @@ def __mask_simple(index: NDArray[np.int_], boolean_mask: NDArray[np.bool_]): return index[:lm][boolean_mask] -def __mask_chunked(index: tuple, boolean_mask: NDArray[np.bool_]): +def __mask_chunked(index: ChunkedIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: array = into_array(index) return array[boolean_mask] -def into_array(index: np.ndarray | tuple): +def into_array(index: DataIndex) -> SimpleIndex: if get_length(index) == 0: return np.array([], dtype=np.int64) @@ -47,3 +49,5 @@ def into_array(index: np.ndarray | tuple): return np.arange(index[0][0], index[0][0] + index[1][0]) return idxlib.chunked_into_array(*index) + case _: + raise ValueError(f"Expected a DataIndex, got {type(index)}") diff --git a/python/opencosmo/index/project.py b/python/opencosmo/index/project.py index d2b6d63b..5957be77 100644 --- a/python/opencosmo/index/project.py +++ b/python/opencosmo/index/project.py @@ -10,7 +10,7 @@ from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex -def project(source: DataIndex, other: DataIndex): +def project(source: DataIndex, other: DataIndex) -> DataIndex: match (source, other): case (tuple(), np.ndarray()): return __project_simple_on_chunked(source, other) @@ -20,22 +20,26 @@ def project(source: DataIndex, other: DataIndex): return __project_simple_on_simple(source, other) case (np.ndarray(), tuple()): return __project_chunked_on_simple(source, other) + case _: + raise TypeError(f"Invalid index types: {type(source)}, {type(other)}") -def __project_simple_on_simple(source, other: DataIndex): +def __project_simple_on_simple(source: SimpleIndex, other: SimpleIndex) -> SimpleIndex: isin = np.isin(source, other) return np.where(isin)[0] -def __project_chunked_on_simple(source, other: DataIndex): +def __project_chunked_on_simple(source: SimpleIndex, other: ChunkedIndex) -> DataIndex: return project(source, into_array(other)) -def __project_simple_on_chunked(source: ChunkedIndex, other: SimpleIndex): +def __project_simple_on_chunked(source: ChunkedIndex, other: SimpleIndex) -> DataIndex: return project(into_array(source), other) -def __project_chunked_on_chunked(source: ChunkedIndex, other: ChunkedIndex): +def __project_chunked_on_chunked( + source: ChunkedIndex, other: ChunkedIndex +) -> ChunkedIndex: source_ends = source[0] + source[1] other_ends = other[0] + other[1] diff --git a/python/opencosmo/index/take.py b/python/opencosmo/index/take.py index 2eddd0c0..b4d9e284 100644 --- a/python/opencosmo/index/take.py +++ b/python/opencosmo/index/take.py @@ -1,14 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from opencosmo._lib import index as idxlib -SimpleIndex = np.ndarray -ChunkedIndex = tuple[np.ndarray, np.ndarray] +if TYPE_CHECKING: + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex -def take(from_, by): +def take(from_: DataIndex, by: DataIndex) -> DataIndex: match (from_, by): case (np.ndarray(), np.ndarray()): return __take_simple_from_simple(from_, by) @@ -17,11 +19,12 @@ def take(from_, by): case ((np.ndarray(), np.ndarray()), np.ndarray()): return __take_simple_from_chunked(from_, by) case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): - print("Chunked from chunked") return idxlib.take_chunked_from_chunked(*from_, *by) + case _: + raise TypeError(f"Invalid index types: {type(from_)}, {type(by)}") -def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex): +def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex) -> SimpleIndex: cumulative = np.insert(np.cumsum(from_[1]), 0, 0)[:-1] indices_into_chunks = np.argmax(by[:, np.newaxis] < cumulative, axis=1) - 1 @@ -29,5 +32,5 @@ def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex): return output -def __take_simple_from_simple(from_: np.ndarray, by: np.ndarray): +def __take_simple_from_simple(from_: SimpleIndex, by: SimpleIndex) -> SimpleIndex: return from_[by] diff --git a/python/opencosmo/index/unary.py b/python/opencosmo/index/unary.py index 4989d169..8abe2526 100644 --- a/python/opencosmo/index/unary.py +++ b/python/opencosmo/index/unary.py @@ -1,17 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np -from numpy.typing import NDArray from opencosmo._lib import index as idx +if TYPE_CHECKING: + from opencosmo.index import DataIndex + """ Implementations for unary operations on indices """ -SimpleIndex = NDArray[np.int_] -ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] - -def get_length(index: SimpleIndex | ChunkedIndex): +def get_length(index: DataIndex) -> int: match index: case np.ndarray(): return len(index) @@ -21,7 +24,7 @@ def get_length(index: SimpleIndex | ChunkedIndex): raise TypeError(f"Invalid index type {type(index)}") -def get_range(index: SimpleIndex | ChunkedIndex): +def get_range(index: DataIndex) -> tuple[int, int]: match index: case np.ndarray(): return idx.get_simple_range(index) diff --git a/python/opencosmo/spatial/octree.py b/python/opencosmo/spatial/octree.py index ba87a290..0a00d47e 100644 --- a/python/opencosmo/spatial/octree.py +++ b/python/opencosmo/spatial/octree.py @@ -117,7 +117,9 @@ def __init__(self, root: Octant): self.root = root def get_partition_region(self, index: SimpleIndex, level: int): - octants = [get_octant(idx, level, 2 * self.root.halfwidth) for idx in index] + octants = [ + get_octant(int(idx), level, 2 * self.root.halfwidth) for idx in index + ] return get_region(octants) @classmethod diff --git a/python/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py index a1ee88f3..6ffeed01 100644 --- a/python/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -13,7 +13,7 @@ MPI = None # type: ignore -from opencosmo.index import from_size, get_data, n_in_range +from opencosmo.index import from_size, from_start_size_group, get_data, n_in_range from opencosmo.io.schema import FileEntry, make_schema from opencosmo.io.writer import ( ColumnCombineStrategy, @@ -203,8 +203,7 @@ def partition( n_partitions, counts, min_level ) partitions = [] - start = self.__columns[f"level_{split_level}/start"][:].astype(np.int64) - size = self.__columns[f"level_{split_level}/size"][:].astype(np.int64) + start, size = from_start_size_group(self.__columns[f"level_{split_level}"]) for index_ in partition_indices: if len(index_) == 0: continue @@ -225,8 +224,9 @@ def query(self, region: Region) -> tuple[ChunkedIndex, ChunkedIndex]: intersects = [] for level, (cidx, iidx) in indices.items(): level_key = f"level_{level}" - level_starts = self.__columns[f"{level_key}/start"] - level_sizes = self.__columns[f"{level_key}/size"] + level_starts, level_sizes = from_start_size_group( + self.__columns[f"{level_key}"] + ) c_starts = get_data(level_starts, cidx) c_sizes = get_data(level_sizes, cidx) i_starts = get_data(level_starts, iidx) @@ -241,11 +241,8 @@ def query(self, region: Region) -> tuple[ChunkedIndex, ChunkedIndex]: return (contains_start, contains_size), (intersects_start, intersects_size) def apply_index(self, index: DataIndex, min_counts: int = 100) -> Tree: - max_level_starts = self.__columns[f"level_{self.__max_level}/start"][:].astype( - np.int64 - ) - max_level_sizes = self.__columns[f"level_{self.__max_level}/size"][:].astype( - np.int64 + max_level_starts, max_level_sizes = from_start_size_group( + self.__columns[f"level_{self.__max_level}"] ) n = n_in_range(index, max_level_starts, max_level_sizes) target = h5py.File(f"{uuid1()}.hdf5", "w", driver="core", backing_store=False) From 0caad1d52030db24eff00223b70e18513871f830 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 20:18:52 -0500 Subject: [PATCH 008/139] Add index stubs --- python/opencosmo/_lib/__init__.pyi | 1 + python/opencosmo/_lib/index.pyi | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 python/opencosmo/_lib/__init__.pyi create mode 100644 python/opencosmo/_lib/index.pyi diff --git a/python/opencosmo/_lib/__init__.pyi b/python/opencosmo/_lib/__init__.pyi new file mode 100644 index 00000000..c417140d --- /dev/null +++ b/python/opencosmo/_lib/__init__.pyi @@ -0,0 +1 @@ +from . import index as index diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi new file mode 100644 index 00000000..4a05fbde --- /dev/null +++ b/python/opencosmo/_lib/index.pyi @@ -0,0 +1,14 @@ +from opencosmo.index import IndexArray + +def take_chunked_from_chunked( + start: IndexArray, size: IndexArray, take_start: IndexArray, take_size: IndexArray +) -> tuple[IndexArray, IndexArray]: ... +def get_simple_range(index: IndexArray) -> tuple[int, int]: ... +def get_chunked_range(start: IndexArray, size: IndexArray) -> tuple[int, int]: ... +def n_in_range_chunked( + start: IndexArray, size: IndexArray, range_start: IndexArray, range_size: IndexArray +) -> IndexArray: ... +def chunked_into_array(start: IndexArray, size: IndexArray) -> IndexArray: ... +def take_chunked_from_simple( + simple: IndexArray, start: IndexArray, size: IndexArray +) -> IndexArray: ... From 39947a30c9e06bd8324ab2f3f68fe009d8ed3ee8 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 20:40:59 -0500 Subject: [PATCH 009/139] Update tree handling --- python/opencosmo/io/iopen.py | 14 ++++++++------ python/opencosmo/spatial/tree.py | 6 ++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 1f7e6031..361b573c 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -50,6 +50,7 @@ class DatasetTarget(TypedDict): header: OpenCosmoHeader dataset_group: h5py.Group columns: list[h5py.Dataset] + spatial_index: Optional[h5py.Group] class FileType(Enum): @@ -389,6 +390,7 @@ def __find_datasets_under_group( file_map.keys(), ) ) + for ds_group_name in known_dataset_groups: ds_group_parent = ds_group_name.rsplit("/", maxsplit=1)[0] ds_group_parent += "/" @@ -399,16 +401,19 @@ def __find_datasets_under_group( lambda nds: ( ds_group_parent in nds[0] and "header" not in nds[0] + and "index" not in nds[0] and isinstance(nds[1], h5py.Dataset) ), file_map.items(), ) ] + index_group = file_map.get(f"{ds_group_parent}index") target = DatasetTarget( header=header, dataset_group=file_map[ds_group_name].parent, columns=columns, + spatial_index=index_group, ) if evaluate_load_conditions(target, open_kwargs): known_datasets.append(target) @@ -471,21 +476,18 @@ def open_single_dataset( assert header is not None - index_columns = { - col.name: col for col in columns if col.name.split("/")[-3] == "index" - } try: box_size = header.with_units("scalefree").simulation["box_size"].value except AttributeError: box_size = None - try: + if target["spatial_index"] is not None: tree = open_tree( - index_columns, + target["spatial_index"], box_size, header.file.is_lightcone, ) - except (ValueError, AttributeError): + else: tree = None if header.file.region is not None: diff --git a/python/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py index 6ffeed01..3c1fb05a 100644 --- a/python/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -31,7 +31,7 @@ def open_tree( - tree_columns: dict[str, h5py.Dataset], + tree_group: h5py.Group, box_size: Optional[int], is_lightcone: bool = False, ): @@ -55,9 +55,7 @@ def open_tree( else: spatial_index = OctTreeIndex.from_box_size(box_size) - tree_columns = {name.split("index/")[-1]: col for name, col in tree_columns.items()} - - return Tree(spatial_index, tree_columns) + return Tree(spatial_index, tree_group) def read_tree(file: h5py.File | h5py.Group, box_size: int): From a0c572f17c6f1a6858bff9eab5f7bcae5f9a0d89 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Apr 2026 21:13:28 -0500 Subject: [PATCH 010/139] Fix index exclusion --- python/opencosmo/io/iopen.py | 4 ++-- test/parallel/test_lc_mpi.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 361b573c..eaef0ff7 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -400,9 +400,9 @@ def __find_datasets_under_group( for nds_ in filter( lambda nds: ( ds_group_parent in nds[0] - and "header" not in nds[0] - and "index" not in nds[0] and isinstance(nds[1], h5py.Dataset) + and "header" not in nds[0] + and f"{ds_group_parent}index" not in nds[0] ), file_map.items(), ) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index f126bc35..9f0f12dc 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -3,6 +3,7 @@ import astropy.units as u import numpy as np +import opencosmo as oc import pytest from astropy.coordinates import SkyCoord from healpy import pix2ang @@ -10,8 +11,6 @@ from opencosmo.mpi import get_comm_world from pytest_mpi.parallel_assert import parallel_assert -import opencosmo as oc - IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -403,6 +402,7 @@ def test_diffsky_stack_with_synths(core_path_487, core_path_475, per_test_dir): def test_write_some_missing(core_path_487, core_path_475, per_test_dir): comm = MPI.COMM_WORLD ds = oc.open(core_path_487, core_path_475, synth_cores=False) + assert "early_index" in ds.columns if comm.Get_rank() == 0: ds = ds.with_redshift_range(0, 0.02) assert len(ds.keys()) == 1 From 7ae143135cc2c9392f05f888e8c5310d3d557618 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 7 Apr 2026 08:31:41 -0500 Subject: [PATCH 011/139] Updates to contributing documentation --- CONTRIBUTING.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 66f3ea6a..ebdaa2a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,15 +14,20 @@ If you find that the documentation is not quite up to par, feel free to create a In addition to raising issues on the repository or adding documentation, we welcome PRs that fix bugs, improve performance, or add new features. The remainder of this document describes how to go about setting up a development environment and prepping your PR for merging. -### Step 0: Install UV +### Step 0: Install UV and Build Dependencies We use [uv](https://docs.astral.sh/uv/) to manage dependencies and provide a consistent execution environment for packaging and testing. If you do not already have uv installed on your machine, [follow the documentation](https://docs.astral.sh/uv/getting-started/installation/) for your system. +Becuase `opencosmo` includes some Rust extension modules, it uses [maturin](https://github.com/pyo3/maturin) as its build system and requires a Rust compiler. `uv` should install `maturin`, which is capable of bootstrapping itself. This should all happen automatically in step 2 below, but if for some reason this does not work you may need [install a rust compiler manually](https://rust-lang.org/tools/install/) + + ### Step 1: Create a Fork of the OpenCosmo Repo and Clone You should preform your development in your own personal fork of the repository. You can create one [at this link](https://github.com/ArgonneCPAC/OpenCosmo/fork). Once you have a fork, you can clone it to the machine you will be doing your development on. -### Step 2: Create a Virtual Environment with UV +If you're only planning to make a single contribution, it's fine to commit to the main branch of your fork. If you intend to make many contributions, we recommend creating a branch for each feature. + +### Step 2: Create a Virtual Environment with UV and Install Dependencies In addition to managing dependencies, uv manages a virtual environment with the appropriate versions of all packages. From the root of the repository, run the command: @@ -32,6 +37,8 @@ uv sync This will create a virtual environment in the repository and install all necessary dependencies, as well as extra dependencies required for development work. +As part of the initial `uv sync`, `maturin` will build the rust extension modules and install them in the correct place. This may take some time, but unless you are modifying the rust code it will only have to happen once. + ### Step 2.1: Install Parallel HDF5 (Optional) If you plan to work on features that involve parallel I/O, you will need to install parallel HDF5 to run parallel tests. To start, install the additional packages necessary for developing and testing parallel features @@ -49,7 +56,9 @@ Note that running with uv is required to ensure that the parallel version of HDF ### Step 3: Add a Commit & Create a PR -I know what you're thinking: I haven't actually implemented my changes yet. Why am I already submitting a PR? Two reasons. First, so we know what people are working on so we're not doing duplicate work. Second, so that your PR starts running through the CI pipeline. +I know what you're thinking: I haven't actually implemented my changes yet. Why am I already submitting a PR? Two reasons. First, so we know what people are working on so we're not doing duplicate work. Second, so that your PR starts running through the CI pipeline. This will make it easier down the line! + +For more details of how to do all this, see the [GitHub documentation](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) ### Step 4: Implement Your Changes @@ -59,6 +68,8 @@ This is where the magic happens. Implement your features! If you have questions, Our CI pipeline performs linting with [ruff](https://astral.sh/ruff) and static type checking with [mypy](https://www.mypy-lang.org/). Both should have been installed automatically when you ran `uv sync`. Your PR must pass both to be merged. If you prefer, you can use [pre-commit](https://pre-commit.com/) to automatically run linting and type checking before you commit. Altenatively, you can call them manually on your last commit. +If you're unfamiliar with type hints, go ahead and implement your features without them and we can worry about the details later. + Many libraries do not have full typing support. If this is the case, you can add a `# type: ignore` directive when you import them. If the type stubs exist as a seperate library, you should instead add them as a development dependency in the pyprojet.toml with ```bash @@ -79,7 +90,7 @@ uv run pytest --ignore=test/parallel If you are working on features designed to be used in an MPI context, you can run the parallel tests with: ```bash -uv run mpiexec -n 4 pytest -m parallel test/parallel -x +uv run mpiexec -n 4 pytest -m parallel test/parallel ``` All tests must pass for your PR to be merged. From c6a69220030c733169f86f31caa02ba9da5dc53f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 7 Apr 2026 09:19:06 -0500 Subject: [PATCH 012/139] Update building CI pipeline --- .github/workflows/build.yaml | 51 +++++++++--------- .github/workflows/release.yaml | 99 ++++++++++++++++++++++++++++------ test/parallel/test_lc_mpi.py | 9 +++- 3 files changed, 117 insertions(+), 42 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3aff91fe..1d31beb3 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,5 +1,5 @@ name: build dev -on: +on: workflow_call: inputs: push-docker: @@ -14,34 +14,37 @@ on: required: false jobs: - build-wheel: - runs-on: ubuntu-latest + build-wheels: + name: Build wheels (${{ matrix.target }}) + strategy: + matrix: + include: + - { os: ubuntu-latest, target: x86_64 } + - { os: ubuntu-latest, target: aarch64 } + - { os: macos-latest, target: x86_64-apple-darwin } + - { os: macos-latest, target: aarch64-apple-darwin } + runs-on: ${{ matrix.os }} steps: - name: Checkout repository uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + if: matrix.os != 'ubuntu-latest' + uses: actions/setup-python@v5 with: - python-version: '3.11' - - name: Install uv - uses: astral-sh/setup-uv@v6 + python-version: | + 3.12 + 3.13 + 3.14 + - name: Build wheels + uses: PyO3/maturin-action@v1 with: - version: "0.8.2" - - name: Build package - run: uv build - - build-container: - runs-on: ubuntu-latest - steps: - - id: login - if: ${{ inputs.push-docker }} - uses: docker/login-action@v3 + target: ${{ matrix.target }} + manylinux: auto + args: --release --out dist --interpreter 3.12 3.13 3.14 + - name: Upload wheels + uses: actions/upload-artifact@v4 with: - username: ${{ secrets.docker-username }} - password: ${{ secrets.docker-access-key }} - - uses: docker/setup-buildx-action@v3 - - uses: docker/bake-action@v6 - with: - push: ${{ inputs.push-docker }} - targets: dev + name: wheels-${{ matrix.target }} + path: dist/ + diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index f0d67eec..fa455c9c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -8,7 +8,7 @@ on: type: string jobs: - release: + prepare: runs-on: ubuntu-latest permissions: contents: write @@ -24,21 +24,14 @@ jobs: with: version: "0.8.2" - - name: Set version - run: echo "VERSION=${{ inputs.version }}" >> $GITHUB_ENV - - name: Bump version files - # Uses --new-version so you don't need to specify major/minor/patch. - # commit=false and tag=false are already set in .bumpversion.toml so - # bump-my-version only updates the files. - run: uvx bump-my-version bump --new-version ${{ env.VERSION }} + run: uvx bump-my-version bump --new-version ${{ inputs.version }} - name: Draft changelog (captured for GitHub release body) - run: uv run towncrier build --draft --version ${{ env.VERSION }} > release_notes.rst + run: uv run towncrier build --draft --version ${{ inputs.version }} > release_notes.rst - name: Build changelog - # Writes to docs/source/changelog.rst and removes news fragments in changes/ - run: uv run towncrier build --yes --version ${{ env.VERSION }} + run: uv run towncrier build --yes --version ${{ inputs.version }} - name: Commit, tag, and push release branch run: | @@ -46,21 +39,93 @@ jobs: git config user.email "github-actions[bot]@users.noreply.github.com" git checkout -B release git add --all - git commit -m "Release ${{ env.VERSION }}" - git tag ${{ env.VERSION }} + git commit -m "Release ${{ inputs.version }}" + git tag ${{ inputs.version }} git push origin release --force - git push origin ${{ env.VERSION }} + git push origin ${{ inputs.version }} + + - name: Upload release notes + uses: actions/upload-artifact@v4 + with: + name: release-notes + path: release_notes.rst + + build-wheels: + needs: prepare + name: Build wheels (${{ matrix.target }}) + strategy: + matrix: + include: + - { os: ubuntu-latest, target: x86_64 } + - { os: ubuntu-latest, target: aarch64 } + - { os: macos-latest, target: x86_64-apple-darwin } + - { os: macos-latest, target: aarch64-apple-darwin } + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.version }} + - name: Set up Python + if: matrix.os != 'ubuntu-latest' + uses: actions/setup-python@v5 + with: + python-version: | + 3.12 + 3.13 + 3.14 + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + manylinux: auto + args: --release --out dist --interpreter 3.12 3.13 3.14 + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.target }} + path: dist/ + + publish: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist/ + merge-multiple: true - - name: Build package - run: uv build + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + version: "0.8.2" - name: Publish to PyPI run: uv publish --token ${{ secrets.PYPI_TOKEN }} + github-release: + needs: publish + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist/ + merge-multiple: true + + - name: Download release notes + uses: actions/download-artifact@v4 + with: + name: release-notes + - name: Create GitHub release uses: ncipollo/release-action@v1 with: - tag: ${{ env.VERSION }} + tag: ${{ inputs.version }} bodyFile: release_notes.rst artifacts: dist/* makeLatest: true diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 9f0f12dc..fe53851a 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -121,6 +121,7 @@ def test_healpix_index_chain_failure(haloproperties_600_path): @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parallel(nprocs=4) def test_healpix_write(haloproperties_600_path, per_test_dir): + comm = get_comm_world() ds = oc.open(haloproperties_600_path) pixel = np.random.choice(ds.region.pixels) @@ -137,7 +138,13 @@ def test_healpix_write(haloproperties_600_path, per_test_dir): new_ds = new_ds.bound(region2) ds = ds.bound(region2) - assert set(ds.get_data()["fof_halo_tag"]) == set(new_ds.get_data()["fof_halo_tag"]) + rank_tags = ds.select("fof_halo_tag").get_data() + new_rank_tags = new_ds.select("fof_halo_tag").get_data() + + all_tags = np.concatenate(comm.allgather(rank_tags)) + all_new_tags = np.concatenate(comm.allgather(new_rank_tags)) + + parallel_assert(np.all(np.sort(all_tags) == np.sort(all_new_tags))) @pytest.mark.filterwarnings("ignore::UserWarning") From a357de6438c69e4f142fcace93071141bc3cd413 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 08:10:37 -0500 Subject: [PATCH 013/139] Update healsparse map logic to build from scratch --- .../collection/lightcone/healpix_map.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 80097804..8ab19a95 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -340,12 +340,6 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg table["pixel"] = pixels table.sort("pixel", reverse=False) - if format == "healpix": - if self.__len__() != hp.nside2npix(self.nside): - raise ValueError( - "healpix format chosen but length of dataset doesn't match nside value. Use healsparse" - ) - if len(table.colnames) == 1: table = next(table.itercols()) @@ -355,15 +349,31 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg else: table.remove_columns(self.__hidden) return {name: col.value for name, col in table.items()} + elif format == "healsparse": + pixels = table["pixel"].value + sentinel = np.float32(hp.UNSEEN) + + # Build coverage map once and compute sparse indices once, + # shared across all columns to avoid repeating this work. + cov_map = hsp.HealSparseCoverage.make_empty(self.nside_lr, self.nside) + cov_pix = cov_map.cov_pixels(pixels) + unique_cov_pix = np.unique(cov_pix) + cov_map.initialize_pixels(unique_cov_pix) + sparse_indices = pixels + cov_map[cov_pix] + sparse_map_size = (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov + dict_maps = {} for name, col in table.items(): if name != "pixel": - hsp_out = hsp.HealSparseMap.make_empty( - self.nside_lr, self.nside, dtype=np.float32 + sparse_map = np.full(sparse_map_size, sentinel, dtype=np.float32) + sparse_map[sparse_indices] = col.value.astype(np.float32) + dict_maps[name] = hsp.HealSparseMap( + cov_map=cov_map, + sparse_map=sparse_map, + nside_sparse=self.nside, + sentinel=sentinel, ) - hsp_out[table["pixel"].value] = (col.value).astype(np.float32) - dict_maps[name] = hsp_out return dict_maps @property From ea0d7ad362d9922d4bb73bac1563f35b45fa444d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 08:14:04 -0500 Subject: [PATCH 014/139] Move healsparse map creation logic out of get_data --- .../collection/lightcone/healpix_map.py | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 8ab19a95..db0247a2 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -33,6 +33,30 @@ from opencosmo.spatial import Region +def make_healsparse_map( + pixels: np.ndarray, + values: np.ndarray, + nside: int, + nside_lr: int, +) -> hsp.HealSparseMap: + sentinel = np.float32(hp.UNSEEN) + cov_map = hsp.HealSparseCoverage.make_empty(nside_lr, nside) + cov_pix = cov_map.cov_pixels(pixels) + unique_cov_pix = np.unique(cov_pix) + cov_map.initialize_pixels(unique_cov_pix) + sparse_indices = pixels + cov_map[cov_pix] + sparse_map = np.full( + (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov, sentinel, dtype=np.float32 + ) + sparse_map[sparse_indices] = values.astype(np.float32) + return hsp.HealSparseMap( + cov_map=cov_map, + sparse_map=sparse_map, + nside_sparse=nside, + sentinel=sentinel, + ) + + def take_from_sorted( healpix_map: "HealpixMap", sort_by: str, invert: bool, n: int, at: str | int ): @@ -352,29 +376,11 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg elif format == "healsparse": pixels = table["pixel"].value - sentinel = np.float32(hp.UNSEEN) - - # Build coverage map once and compute sparse indices once, - # shared across all columns to avoid repeating this work. - cov_map = hsp.HealSparseCoverage.make_empty(self.nside_lr, self.nside) - cov_pix = cov_map.cov_pixels(pixels) - unique_cov_pix = np.unique(cov_pix) - cov_map.initialize_pixels(unique_cov_pix) - sparse_indices = pixels + cov_map[cov_pix] - sparse_map_size = (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov - - dict_maps = {} - for name, col in table.items(): - if name != "pixel": - sparse_map = np.full(sparse_map_size, sentinel, dtype=np.float32) - sparse_map[sparse_indices] = col.value.astype(np.float32) - dict_maps[name] = hsp.HealSparseMap( - cov_map=cov_map, - sparse_map=sparse_map, - nside_sparse=self.nside, - sentinel=sentinel, - ) - return dict_maps + return { + name: make_healsparse_map(pixels, col.value, self.nside, self.nside_lr) + for name, col in table.items() + if name != "pixel" + } @property def data(self): From 071f4cd65cfc4b996f8f5d218b627df239e65801 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 08:19:42 -0500 Subject: [PATCH 015/139] Process all columns in healpixmap at once when making healsparse map. --- .../collection/lightcone/healpix_map.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index db0247a2..4a2f863b 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -33,28 +33,34 @@ from opencosmo.spatial import Region -def make_healsparse_map( - pixels: np.ndarray, - values: np.ndarray, +def make_healsparse_maps( + table, nside: int, nside_lr: int, -) -> hsp.HealSparseMap: +) -> dict[str, hsp.HealSparseMap]: sentinel = np.float32(hp.UNSEEN) + pixels = table["pixel"].value + + # Build coverage map once, shared across all columns. cov_map = hsp.HealSparseCoverage.make_empty(nside_lr, nside) cov_pix = cov_map.cov_pixels(pixels) unique_cov_pix = np.unique(cov_pix) cov_map.initialize_pixels(unique_cov_pix) sparse_indices = pixels + cov_map[cov_pix] - sparse_map = np.full( - (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov, sentinel, dtype=np.float32 - ) - sparse_map[sparse_indices] = values.astype(np.float32) - return hsp.HealSparseMap( - cov_map=cov_map, - sparse_map=sparse_map, - nside_sparse=nside, - sentinel=sentinel, - ) + sparse_map_size = (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov + + result = {} + for name, col in table.items(): + if name != "pixel": + sparse_map = np.full(sparse_map_size, sentinel, dtype=np.float32) + sparse_map[sparse_indices] = col.value.astype(np.float32) + result[name] = hsp.HealSparseMap( + cov_map=cov_map, + sparse_map=sparse_map, + nside_sparse=nside, + sentinel=sentinel, + ) + return result def take_from_sorted( @@ -375,12 +381,7 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg return {name: col.value for name, col in table.items()} elif format == "healsparse": - pixels = table["pixel"].value - return { - name: make_healsparse_map(pixels, col.value, self.nside, self.nside_lr) - for name, col in table.items() - if name != "pixel" - } + return make_healsparse_maps(table, self.nside, self.nside_lr) @property def data(self): From 37f253cc9212d00d25a4950809d02b6e68d00fb5 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 08:43:04 -0500 Subject: [PATCH 016/139] Add changelog --- changes/+31323023.improvement.rst | 1 + changes/+91ec5e88.improvement.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changes/+31323023.improvement.rst create mode 100644 changes/+91ec5e88.improvement.rst diff --git a/changes/+31323023.improvement.rst b/changes/+31323023.improvement.rst new file mode 100644 index 00000000..38084d53 --- /dev/null +++ b/changes/+31323023.improvement.rst @@ -0,0 +1 @@ +You can now request HealpixMaps in Healpix format even when the map does not cover the full sky. Maps requested in this format will contain a `pixel` column. diff --git a/changes/+91ec5e88.improvement.rst b/changes/+91ec5e88.improvement.rst new file mode 100644 index 00000000..c8f87405 --- /dev/null +++ b/changes/+91ec5e88.improvement.rst @@ -0,0 +1 @@ +Conversion to healsparse in :py:meth:`HealpixMap.get_data Date: Wed, 8 Apr 2026 11:28:02 -0500 Subject: [PATCH 017/139] Small test fix --- test/parallel/test_lc_mpi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index fe53851a..fd9c2f3f 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -133,7 +133,7 @@ def test_healpix_write(haloproperties_600_path, per_test_dir): oc.write(per_test_dir / "lightcone_test.hdf5", ds) new_ds = oc.open(per_test_dir / "lightcone_test.hdf5") - radius2 = 2 * u.deg + radius2 = 1 * u.deg region2 = oc.make_cone(center, radius2) new_ds = new_ds.bound(region2) ds = ds.bound(region2) From bdbc91bc2a0ab8f1c3a07a9d81c1f0ab02b77e5f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 12:03:56 -0500 Subject: [PATCH 018/139] Implement masked numpy arrays for healpix output --- .../collection/lightcone/healpix_map.py | 21 +++++++++++++++++-- test/test_healpixmap.py | 11 ++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 4a2f863b..37b01543 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -374,11 +374,28 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg table = next(table.itercols()) if format == "healpix": + npix = hp.nside2npix(self.nside) if isinstance(table, (u.Quantity, Column)): - return table.value + vals = np.zeros(npix, dtype=np.float32) + vals[pixels] = table.value + storage = {"vals": vals} + else: table.remove_columns(self.__hidden) - return {name: col.value for name, col in table.items()} + storage = { + name: np.zeros(npix, dtype=np.float32) for name in table.columns + } + for name, arr in storage.items(): + arr[pixels] = table[name].value + if len(pixels) != hp.nside2npix(self.nside): + mask = np.zeros(hp.nside2npix(self.nside), dtype=bool) + mask[pixels] = True + storage = { + name: np.ma.masked_array(arr, mask) for name, arr in storage.items() + } + if len(storage) == 1: + return next(iter(storage.values())) + return storage elif format == "healsparse": return make_healsparse_maps(table, self.nside, self.nside_lr) diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index d94ad778..2e03743d 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -2,10 +2,9 @@ import healpy as hp import healsparse as hsp import numpy as np +import opencosmo as oc import pytest from astropy.coordinates import SkyCoord - -import opencosmo as oc from opencosmo.spatial.healpix import HealpixRegion @@ -301,6 +300,14 @@ def test_healpix_collection_select(healpix_map_path): assert columns_found == to_select +def test_healpix_collection_take_healpix(healpix_map_path): + ds = oc.open(healpix_map_path) + ds = ds.take_range(500, 1000) + data = ds.get_data("healpix") + for value in data.values(): + assert np.all(np.where(value.mask)[0] == np.arange(500, 1000)) + + def test_healpix_collection_select_healsparse(healpix_map_path): ds = oc.open(healpix_map_path) to_select = set(["tsz", "ksz"]) From b96a8028b9e803f5f80de63d2fa54287477ce3c0 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 12:06:06 -0500 Subject: [PATCH 019/139] Update documentation --- python/opencosmo/collection/lightcone/healpix_map.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 37b01543..7d57561d 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -332,7 +332,8 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg You can get the data in two formats, "healsparse" (the default) and "healpix". "healsparse" format will return the data as a healsparse sparse map. - "healpix" will return the data as a dictionary of numpy arrays. For map data, + "healpix" will return the data as a dictionary of numpy arrays. If the map does not + cover the full sky, this wll be a masked numpy array. For map data, due to format requirements, no units will be attached to the data itself, although these will match the units from the data attributes. From 363fef7a699cd1bd77505734fb24b8b4d9837d94 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 8 Apr 2026 12:24:24 -0500 Subject: [PATCH 020/139] Update changelog --- changes/+31323023.improvement.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changes/+31323023.improvement.rst b/changes/+31323023.improvement.rst index 38084d53..0e7104d7 100644 --- a/changes/+31323023.improvement.rst +++ b/changes/+31323023.improvement.rst @@ -1 +1 @@ -You can now request HealpixMaps in Healpix format even when the map does not cover the full sky. Maps requested in this format will contain a `pixel` column. +You can now request HealpixMaps in Healpix format even when the map does not cover the full sky. Maps requested in this format will be returned as masked numpy arrays. From 13f479c10c648292617f74e73efc672db8e5d12d Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 9 Mar 2026 15:07:01 -0400 Subject: [PATCH 021/139] added animate_halo function --- python/opencosmo/analysis/__init__.py | 1 + python/opencosmo/analysis/yt_viz.py | 278 ++++++++++++++++++++++++-- 2 files changed, 260 insertions(+), 19 deletions(-) diff --git a/python/opencosmo/analysis/__init__.py b/python/opencosmo/analysis/__init__.py index e1c7f928..665aea09 100644 --- a/python/opencosmo/analysis/__init__.py +++ b/python/opencosmo/analysis/__init__.py @@ -13,6 +13,7 @@ "PhasePlot", "visualize_halo", "halo_projection_array", + "animate_halo", ] diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index e86c989b..4fcdbaff 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -4,7 +4,9 @@ import numpy as np import yt # type: ignore +import matplotlib.pyplot as plt from matplotlib.colors import LogNorm # type: ignore +from matplotlib.animation import FuncAnimation from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar # type: ignore from unyt import unyt_quantity # type: ignore from yt.visualization.base_plot_types import get_multi_plot # type: ignore @@ -136,7 +138,9 @@ def PhasePlot(*args, **kwargs) -> yt.PhasePlot: def visualize_halo( halo_id: int, data: oc.StructureCollection, + yt_ds: Optional[Any] = None, projection_axis: Optional[str] = "z", + north_vector: Optional[list[float]] = None, length_scale: Optional[str] = "top left", text_color: Optional[str] = "lightgray", width: Optional[float] = None, @@ -251,14 +255,28 @@ def visualize_halo( ) halo_ids: list[int] | tuple[list[int], list[int]] + + if yt_ds is not None: + yt_dataset_provided = True + else: + yt_dataset_provided = False + if len(params["fields"]) == 4: # if 4 fields, make a 2x2 figure halo_ids = ([halo_id, halo_id], [halo_id, halo_id]) + + if yt_dataset_provided: + yt_ds = ([yt_ds, yt_ds],[yt_ds, yt_ds]) + params = {key: (value[:2], value[2:]) for key, value in params.items()} else: # otherwise, do 1xN halo_ids = np.shape(params["fields"])[0] * [halo_id] + + if yt_dataset_provided: + yt_ds = np.shape(params["fields"])[0] * [yt_ds] + params = {key: [value] for key, value in params.items()} return halo_projection_array( @@ -269,15 +287,19 @@ def visualize_halo( width=width, projection_axis=projection_axis, text_color=text_color, + north_vector=north_vector, + yt_ds=yt_ds, ) def halo_projection_array( halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, data: oc.StructureCollection, + yt_ds: Optional[ list[Any] | tuple[list[Any], list[Any]] ] = None, field: Optional[Tuple[str, str]] = ("dm", "particle_mass"), weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", + north_vector: Optional[list[float]] = None, cmap: Optional[str] = "gray", cmap_norm: Optional[Normalize] = None, # type: ignore zlim: Optional[Tuple[float, float]] = None, @@ -317,8 +339,10 @@ def halo_projection_array( weight_field : tuple of str, optional Field to weight by during projection. Follows yt naming conventions. Overridden if ``params["weight_fields"]`` is provided. - projection_axis : str, optional - Data is projected along this axis (``"x"``, ``"y"``, or ``"z"``). + projection_axis : str, int, or 3-element sequence of floats, optional + Data is projected along this axis (``"x"``, ``"y"``, or ``"z"``), or, alternatively, + (0, 1, or 2). ``projection_axis`` is forwarded to the `normal` parameter of `ParticleProjectionPlot`. + An arbitrary projection axis may be provided as a 3-element sequence of floats. Overridden if ``params["projection_axes"]`` is provided cmap : str Matplotlib colormap to use for all panels. Overridden if ``params["cmaps"]`` is provided. @@ -391,9 +415,32 @@ def halo_projection_array( zlim_ = np.full(fig_shape, None) else: zlim_ = np.reshape( - [zlim for _ in range(np.prod(fig_shape))], (fig_shape[0], fig_shape[1], 2) + [zlim for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 2) ) + if isinstance(projection_axis, (str, int)): + projection_axis_ = np.full(fig_shape, projection_axis) + + elif isinstance(projection_axis, (list, tuple, np.ndarray)): + projection_axis_ = np.reshape( + [projection_axis for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 3), + ) + + else: + raise RuntimeError(f"`projection_axis` has unsopported type ({type(projection_axis)}).") + + + if north_vector is None: + north_vector_ = np.full(fig_shape, None) + else: + north_vector_ = np.reshape( + [north_vector for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 3) + ) + + default_params = { "fields": ( np.reshape( @@ -403,7 +450,8 @@ def halo_projection_array( ), "weight_fields": (weight_field_), "zlims": (zlim_), - "projection_axes": (np.full(fig_shape, projection_axis)), + "projection_axes": (projection_axis_), + "north_vectors": (north_vector_), "labels": (np.full(fig_shape, None)), "cmaps": (np.full(fig_shape, cmap)), "cmap_norms": (np.full(fig_shape, None)), @@ -416,6 +464,7 @@ def halo_projection_array( fields = params.get("fields", default_params["fields"]) weight_fields = params.get("weight_fields", default_params["weight_fields"]) projection_axes = params.get("projection_axes", default_params["projection_axes"]) + north_vectors = params.get("north_vectors", default_params["north_vectors"]) zlims = params.get("zlims", default_params["zlims"]) labels = params.get("labels", default_params["labels"]) cmaps = params.get("cmaps", default_params["cmaps"]) @@ -433,6 +482,13 @@ def halo_projection_array( halo_ids = np.array(halo_ids) halo_id_previous = np.inf + if yt_ds is not None: + yt_dataset_provided = True + yt_ds = np.atleast_2d(yt_ds) + else: + yt_dataset_provided = False + + for i in range(nrow): for j in range(ncol): halo_id = halo_ids[i][j] @@ -442,23 +498,30 @@ def halo_projection_array( ax.set_facecolor("black") continue - # retrieve halo particle info if new halo - if (i == 0 and j == 0) or halo_id != halo_id_previous: - # retrieve properties of halo - if len(data) > 1: - data_id = data.filter(oc.col("unique_tag") == halo_id) - else: - if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore - raise RuntimeError(f"Halo ID {halo_id} not in dataset!") - data_id = data - halo_data = next(iter(data_id.objects())) + if yt_dataset_provided: + ds = yt_ds[i][j] + + # sodbighaloparticles holds particle data out to 2*R200 + Rh = ds.domain_width[0] / 4 - # load particles into yt - ds = create_yt_dataset(halo_data) + else: + # retrieve halo particle info if new halo + if (i == 0 and j == 0) or halo_id != halo_id_previous: + # retrieve properties of halo + if len(data) > 1: + data_id = data.filter(oc.col("unique_tag") == halo_id) + else: + if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + raise RuntimeError(f"Halo ID {halo_id} not in dataset!") + data_id = data + halo_data = next(iter(data_id.objects())) - halo_properties = halo_data["halo_properties"] + # load particles into yt + ds = create_yt_dataset(halo_data) - Rh = unyt_quantity.from_astropy(halo_properties["sod_halo_radius"]) + halo_properties = halo_data["halo_properties"] + + Rh = unyt_quantity.from_astropy(halo_properties["sod_halo_radius"]) field, weight_field, zlim, width = ( tuple(fields[i][j]), @@ -475,7 +538,11 @@ def halo_projection_array( label = labels[i][j] proj = ParticleProjectionPlot( - ds, projection_axes[i][j], field, weight_field=weight_field + ds, + projection_axes[i][j], + field, + weight_field=weight_field, + north_vector=north_vectors[i][j], ) proj.set_background_color(field, color="black") @@ -567,3 +634,176 @@ def halo_projection_array( halo_id_previous = halo_id return fig + +def fig_to_rgb(fig): + """ + Render a Matplotlib Figure to an (H, W, 3) uint8 RGB array in memory. + """ + fig.canvas.draw() + buf = np.asarray(fig.canvas.buffer_rgba()) # (H, W, 4) + return buf[..., :3] # drop alpha channel + +def _normalize(v, eps=0): + v = np.asarray(v, dtype=float) + + if eps > 0: + # nudge exact on-axis directions off-axis + if np.allclose(v, [1, 0, 0]): + v = np.array([1.0, eps, 0.0]) + elif np.allclose(v, [-1, 0, 0]): + v = np.array([-1.0, eps, 0.0]) + elif np.allclose(v, [0, 1, 0]): + v = np.array([eps, 1.0, 0.0]) + elif np.allclose(v, [0, -1, 0]): + v = np.array([eps, -1.0, 0.0]) + elif np.allclose(v, [0, 0, 1]): + v = np.array([eps, 0.0, 1.0]) + elif np.allclose(v, [0, 0, -1]): + v = np.array([eps, 0.0, -1.0]) + + n = np.linalg.norm(v) + + return v / n + +def _rodrigues_rotate(v, axis, angle): + """ + Rotate vector v around 'axis' by 'angle' radians (right-hand rule). + """ + v = np.asarray(v, dtype=float) + k = _normalize(axis) + c = np.cos(angle) + s = np.sin(angle) + vrot = v * c + np.cross(k, v) * s + k * np.dot(k, v) * (1 - c) + + return vrot + +def _parse_rotations(rotations: str): + rotations = rotations.replace(" ", "") + parts = rotations.split("+") + factors = [] + axes = [] + for part in parts: + if "*" in part: + if part.count("*") > 1: + raise RuntimeError(f'rotation "{part}" not recognized') + fac_str, axis = part.split("*") + factor = float(fac_str) + else: + factor = 1.0 + axis = part + + if axis not in ("x", "y", "z"): + raise RuntimeError(f'rotation axis "{axis}" not recognized') + + factors.append(factor) + axes.append(axis) + return factors, axes + +def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0)): + normals = [normal0] + norths = [north0] + + factors = [] + axes = [] + + axis_map = {"x": np.array([1.0, 0.0, 0.0]), + "y": np.array([0.0, 1.0, 0.0]), + "z": np.array([0.0, 0.0, 1.0])} + + rotation_list = rotations.replace(" ", "").split("+") + for rotation in rotation_list: + if "*" in rotation: + if rotation.count("*") > 1: + raise RuntimeError(f"rotation \"{rotation}\" not recognized") + factor, axis = rotation.split("*") + factor = float(factor) + else: + factor = 1 + axis = rotation + + factors.append(factor) + axes.append(axis) + + # loop through rotations again, and actually apply them + for i, rotation in enumerate(rotation_list): + + # determine number of frames for this rotation (round up) + frames_i = int(np.ceil( frames * factors[i]/sum(factors) )) + + # angular distance traveled in theta and phi + delta_angle_i = factors[i] * 2*np.pi / frames_i + + axis = axes[i] + + for _ in range(frames_i): + normals.append(_normalize(_rodrigues_rotate(normals[-1], axis_map[axis], delta_angle_i), eps=1e-3)) + norths.append(_normalize(_rodrigues_rotate(norths[-1], axis_map[axis], delta_angle_i), eps=1e-3)) + + return normals, norths + + + +def animate_halo(halo_id, data, rotations="x", frames=30, dpi=100, normal0=(0, 0, 1), north0=(0, 1, 0)): + + # retrieve properties of halo and load into yt + if len(data) > 1: + data_id = data.filter(oc.col("unique_tag") == halo_id) + else: + if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + raise RuntimeError(f"Halo ID {halo_id} not in dataset!") + data_id = data + + halo_data = next(iter(data_id.objects())) + + # load particles into yt + ds = create_yt_dataset(halo_data) + + normals, norths = _get_rotation_vectors(rotations, + frames=frames, + normal0=normal0, + north0=north0 + ) + + fig0 = visualize_halo( + halo_id, + data, + projection_axis=normals[0], + north_vector=norths[0], + yt_ds=ds, + ) + + frame0 = fig_to_rgb(fig0) + plt.close(fig0) + + H, W = frame0.shape[:2] + + # ---- animation "display" figure (single persistent figure) ---- + fig = plt.figure(figsize=(W / dpi, H / dpi), dpi=dpi) + ax = fig.add_axes([0,0,1,1]) + ax.set_axis_off() + ax.set_aspect("auto") + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + #fig.patch.set_facecolor("black") + im = ax.imshow(frame0, interpolation="nearest") + + def update(i): + normal = normals[i] + north = norths[i] + + f = visualize_halo( + halo_id, + data, + projection_axis=normal, + north_vector=north, + yt_ds=ds, + ) + frame = fig_to_rgb(f) + plt.close(f) # close each per-frame figure + + im.set_data(frame) + + return (im,) + + anim = FuncAnimation(fig, update, frames=frames, interval=50, blit=True) + + return anim From f2163380dddbf0358c07c107d1453ecaa1f01415 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 2 Apr 2026 14:36:20 -0400 Subject: [PATCH 022/139] debugging animate_halo function + other things --- python/opencosmo/analysis/yt_viz.py | 131 +++++++++++++++++++--------- 1 file changed, 88 insertions(+), 43 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 4fcdbaff..028975b9 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -10,6 +10,7 @@ from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar # type: ignore from unyt import unyt_quantity # type: ignore from yt.visualization.base_plot_types import get_multi_plot # type: ignore +from yt.visualization.particle_plots import OffAxisParticleProjectionPlot import opencosmo as oc from opencosmo.analysis import create_yt_dataset @@ -144,6 +145,8 @@ def visualize_halo( length_scale: Optional[str] = "top left", text_color: Optional[str] = "lightgray", width: Optional[float] = None, + manual_axis_alignment: Optional[bool] = False, + ) -> Figure: """ Creates a figure showing particle projections of dark matter, stars, gas, and/or gas temperature @@ -307,6 +310,7 @@ def halo_projection_array( length_scale: Optional[str] = None, text_color: Optional[str] = "lightgray", width: Optional[float] = None, + manual_axis_alignment: Optional[bool] = False, ) -> Figure: """ Creates a multipanel figure of projections for different fields and/or halos. @@ -537,13 +541,38 @@ def halo_projection_array( label = labels[i][j] - proj = ParticleProjectionPlot( - ds, - projection_axes[i][j], - field, - weight_field=weight_field, - north_vector=north_vectors[i][j], - ) + # we need to determine of the projection is going to be axis-aligned. + # yt internally checks this within ParticleProjectionPlot and forwards to + # OffAxisParticleProjectionPlot if it is not axis-aligned. We are manually + # calling OffAxisParticleProjectionPlot for more control over the normal/north + # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). + if manual_axis_alignment: + if isinstance(projection_axis, str): + match projection_axis: + case "x": + projection_axis = (1, 0, 0) + case "y": + projection_axis = (0, 1, 0) + case "z": + projection_axis = (0, 0, 1) + + proj = OffAxisParticleProjectionPlot( + ds, + projection_axis, + field, + weight_field=weight_field, + north_vector=north_vectors[i][j], + ) + + else: + proj = ParticleProjectionPlot( + ds, + projection_axis, + field, + weight_field=weight_field, + north_vector=north_vectors[i][j], + ) + proj.set_background_color(field, color="black") @@ -646,24 +675,48 @@ def fig_to_rgb(fig): def _normalize(v, eps=0): v = np.asarray(v, dtype=float) + ''' if eps > 0: # nudge exact on-axis directions off-axis if np.allclose(v, [1, 0, 0]): - v = np.array([1.0, eps, 0.0]) + v = np.array([1.0, eps, eps]) elif np.allclose(v, [-1, 0, 0]): - v = np.array([-1.0, eps, 0.0]) + v = np.array([-1.0, eps, eps]) elif np.allclose(v, [0, 1, 0]): - v = np.array([eps, 1.0, 0.0]) + v = np.array([eps, 1.0, eps]) elif np.allclose(v, [0, -1, 0]): - v = np.array([eps, -1.0, 0.0]) + v = np.array([eps, -1.0, eps]) elif np.allclose(v, [0, 0, 1]): - v = np.array([eps, 0.0, 1.0]) + v = np.array([eps, eps, 1.0]) elif np.allclose(v, [0, 0, -1]): - v = np.array([eps, 0.0, -1.0]) + v = np.array([eps, eps, -1.0]) n = np.linalg.norm(v) return v / n + ''' + + # normalize first + v = v / np.linalg.norm(v) + + if eps > 0: + # count "significant" components + nz = np.sum(np.abs(v) > 1e-12) + + if nz == 1: + # find dominant axis + i = np.argmax(np.abs(v)) + + # add epsilon in a perpendicular direction + if i == 0: + v[1] = eps + elif i == 1: + v[2] = eps + elif i == 2: + v[0] = eps + + v = v / np.linalg.norm(v) + return v def _rodrigues_rotate(v, axis, angle): """ @@ -677,31 +730,13 @@ def _rodrigues_rotate(v, axis, angle): return vrot -def _parse_rotations(rotations: str): - rotations = rotations.replace(" ", "") - parts = rotations.split("+") - factors = [] - axes = [] - for part in parts: - if "*" in part: - if part.count("*") > 1: - raise RuntimeError(f'rotation "{part}" not recognized') - fac_str, axis = part.split("*") - factor = float(fac_str) - else: - factor = 1.0 - axis = part - - if axis not in ("x", "y", "z"): - raise RuntimeError(f'rotation axis "{axis}" not recognized') - - factors.append(factor) - axes.append(axis) - return factors, axes +def _enforce_orthogonality(v1, v2): + # enforce orthogonality of v1, relative to v2 + return v1 - np.dot(v1, v2) * v2 def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0)): - normals = [normal0] - norths = [north0] + normals = [_normalize(normal0, eps=1e-3)] + norths = [_normalize(_enforce_orthogonality(north0, normals[-1]))] factors = [] axes = [] @@ -710,8 +745,8 @@ def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0) "y": np.array([0.0, 1.0, 0.0]), "z": np.array([0.0, 0.0, 1.0])} - rotation_list = rotations.replace(" ", "").split("+") - for rotation in rotation_list: + # get a list of rotations + for rotation in rotations: if "*" in rotation: if rotation.count("*") > 1: raise RuntimeError(f"rotation \"{rotation}\" not recognized") @@ -724,8 +759,8 @@ def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0) factors.append(factor) axes.append(axis) - # loop through rotations again, and actually apply them - for i, rotation in enumerate(rotation_list): + # loop through rotations again and actually apply them + for i, rotation in enumerate(rotations): # determine number of frames for this rotation (round up) frames_i = int(np.ceil( frames * factors[i]/sum(factors) )) @@ -736,8 +771,18 @@ def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0) axis = axes[i] for _ in range(frames_i): - normals.append(_normalize(_rodrigues_rotate(normals[-1], axis_map[axis], delta_angle_i), eps=1e-3)) - norths.append(_normalize(_rodrigues_rotate(norths[-1], axis_map[axis], delta_angle_i), eps=1e-3)) + n = _normalize(_rodrigues_rotate(normals[-1], axis_map[axis], delta_angle_i), eps=1e-3) + u = _normalize(_rodrigues_rotate(norths[-1], axis_map[axis], delta_angle_i)) + + # enforce orthogonality of normal and north vectors + u = _normalize( _enforce_orthogonality(u, n) ) + + # continuity guard: prevent sudden 180-degree flips of north + if np.dot(u, norths[-1]) < 0: + u = -u + + normals.append(n) + norths.append(u) return normals, norths @@ -783,7 +828,6 @@ def animate_halo(halo_id, data, rotations="x", frames=30, dpi=100, normal0=(0, 0 ax.set_axis_off() ax.set_aspect("auto") fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - #fig.patch.set_facecolor("black") im = ax.imshow(frame0, interpolation="nearest") def update(i): @@ -796,6 +840,7 @@ def update(i): projection_axis=normal, north_vector=north, yt_ds=ds, + manual_axis_alignment=True, ) frame = fig_to_rgb(f) plt.close(f) # close each per-frame figure From 7e001e0c06401c210869d9fcb2d9673b8f4385b6 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 2 Apr 2026 16:49:42 -0400 Subject: [PATCH 023/139] can call either visualize_halo or alo --- python/opencosmo/analysis/yt_viz.py | 136 ++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 37 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 028975b9..eda2668f 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -402,6 +402,7 @@ def halo_projection_array( data = data.with_units("comoving") halo_ids = np.atleast_2d(halo_ids) + yt_ds = np.atleast_2d(yt_ds) # determine shape of figure fig_shape = np.shape(halo_ids) @@ -486,13 +487,6 @@ def halo_projection_array( halo_ids = np.array(halo_ids) halo_id_previous = np.inf - if yt_ds is not None: - yt_dataset_provided = True - yt_ds = np.atleast_2d(yt_ds) - else: - yt_dataset_provided = False - - for i in range(nrow): for j in range(ncol): halo_id = halo_ids[i][j] @@ -502,14 +496,14 @@ def halo_projection_array( ax.set_facecolor("black") continue - if yt_dataset_provided: - ds = yt_ds[i][j] - + ds = yt_ds[i][j] + if ds is not None: # sodbighaloparticles holds particle data out to 2*R200 Rh = ds.domain_width[0] / 4 else: - # retrieve halo particle info if new halo + # retrieve halo particle info if new halo, or if yt dataset + # is not already provided if (i == 0 and j == 0) or halo_id != halo_id_previous: # retrieve properties of halo if len(data) > 1: @@ -788,20 +782,50 @@ def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0) -def animate_halo(halo_id, data, rotations="x", frames=30, dpi=100, normal0=(0, 0, 1), north0=(0, 1, 0)): +def animate_halo( + halo_ids, data, + func="visualize_halo", rotations="x", + frames=30, dpi=100, + normal0=(0, 0, 1), north0=(0, 1, 0), + **kwargs, +): - # retrieve properties of halo and load into yt - if len(data) > 1: - data_id = data.filter(oc.col("unique_tag") == halo_id) - else: - if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore - raise RuntimeError(f"Halo ID {halo_id} not in dataset!") - data_id = data + halo_ids = np.atleast_2d(halo_ids) - halo_data = next(iter(data_id.objects())) + fig_shape = np.shape(halo_ids) + yt_ds_arr = np.full(fig_shape, None) + + nrow, ncol = fig_shape + halo_id_previous = np.inf + for i in range(nrow): + for j in range(ncol): + halo_id = halo_ids[i][j] + + if (i == 0 and j == 0) or halo_id != halo_id_previous: + # retrieve properties of halo and load into yt + # this part is skipped if the halo has just been found/loaded in the + # previous iteration + # TODO: make this slightly faster by copying directly yt_ds_arr in cases where + # the halo was loaded into yt more than 1 iteration ago + + if len(data) > 1: + data_id = data.filter(oc.col("unique_tag") == halo_id) + else: + if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + raise RuntimeError(f"Halo ID {halo_id} not in dataset!") + data_id = data + + halo_data = next(iter(data_id.objects())) + + # load particles into yt + ds = create_yt_dataset(halo_data) - # load particles into yt - ds = create_yt_dataset(halo_data) + yt_ds_arr[i][j] = ds + + else: + yt_ds_arr[i][j] = ds + + halo_id_previous = halo_id normals, norths = _get_rotation_vectors(rotations, frames=frames, @@ -809,13 +833,39 @@ def animate_halo(halo_id, data, rotations="x", frames=30, dpi=100, normal0=(0, 0 north0=north0 ) - fig0 = visualize_halo( - halo_id, - data, - projection_axis=normals[0], - north_vector=norths[0], - yt_ds=ds, - ) + call_visualize_halo = False + call_halo_projection_array = False + if func == "visualize_halo": + call_visualize_halo = True + if np.prod(np.shape(halo_ids) != 1): + raise ValueError(f"`visualize_halo` requires a single int for `halo_id`, not an array of values") + + + elif func == "halo_projection_array": + call_halo_projection_array = True + else: + raise RuntimeError(f"\`func\` {func} not recognized ") + + if call_visualize_halo: + fig0 = visualize_halo( + halo_ids[0][0], + data, + projection_axis=normals[0], + north_vector=norths[0], + yt_ds=yt_ds_arr[0][0], + **kwargs, + ) + + elif call_halo_projection_array: + fig0 = halo_projection_array( + halo_ids, + data, + projection_axis=normals[0], + north_vector=norths[0], + yt_ds = yt_ds_arr, + **kwargs, + ) + frame0 = fig_to_rgb(fig0) plt.close(fig0) @@ -834,14 +884,26 @@ def update(i): normal = normals[i] north = norths[i] - f = visualize_halo( - halo_id, - data, - projection_axis=normal, - north_vector=north, - yt_ds=ds, - manual_axis_alignment=True, - ) + if call_visualize_halo: + f = visualize_halo( + halo_ids[0][0], + data, + projection_axis=normal, + north_vector=north, + yt_ds=yt_ds[0][0], + manual_axis_alignment=True, + **kwargs, + ) + elif call_halo_projection_array: + f = halo_projection_array( + halo_ids, + data, + projection_axis=normal, + north_vector=north, + yt_ds=yt_ds_arr, + manual_axis_alignment=True, + **kwargs, + ) frame = fig_to_rgb(f) plt.close(f) # close each per-frame figure From b3f817dfffa4018c45fbe1fc62e42e6a521db418 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 14:32:21 -0400 Subject: [PATCH 024/139] documenting animate_halos --- python/opencosmo/analysis/__init__.py | 2 +- python/opencosmo/analysis/yt_viz.py | 215 +++++++++++++++++--------- 2 files changed, 139 insertions(+), 78 deletions(-) diff --git a/python/opencosmo/analysis/__init__.py b/python/opencosmo/analysis/__init__.py index 665aea09..66de110f 100644 --- a/python/opencosmo/analysis/__init__.py +++ b/python/opencosmo/analysis/__init__.py @@ -13,7 +13,7 @@ "PhasePlot", "visualize_halo", "halo_projection_array", - "animate_halo", + "animate_halos", ] diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index eda2668f..fd1d3aa5 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -19,6 +19,7 @@ from matplotlib.colors import Normalize from matplotlib.figure import Figure from yt.visualization.plot_window import NormalPlot + from yt.data_objects.static_output import Dataset as YT_Dataset # ruff: noqa: E501 @@ -139,7 +140,7 @@ def PhasePlot(*args, **kwargs) -> yt.PhasePlot: def visualize_halo( halo_id: int, data: oc.StructureCollection, - yt_ds: Optional[Any] = None, + yt_ds: Optional[YT_Dataset] = None, projection_axis: Optional[str] = "z", north_vector: Optional[list[float]] = None, length_scale: Optional[str] = "top left", @@ -298,7 +299,7 @@ def visualize_halo( def halo_projection_array( halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, data: oc.StructureCollection, - yt_ds: Optional[ list[Any] | tuple[list[Any], list[Any]] ] = None, + yt_ds: Optional[ list[YT_Dataset] | tuple[list[YT_Dataset], list[YT_Dataset]] ] = None, field: Optional[Tuple[str, str]] = ("dm", "particle_mass"), weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", @@ -337,6 +338,9 @@ def halo_projection_array( data : opencosmo.StructureCollection OpenCosmo StructureCollection dataset containing both halo properties and particle data (e.g., output of ``opencosmo.open([haloproperties, sodbighaloparticles])``). + yt_ds : yt dataset or 2D array of yt datasets, optional + Pre-loaded yt dataset (e.g., output of ``opencosmo.analysis.create_yt_dataset()``). + If ``None``, ``halo_projection_array`` will internally search for the halo ID to create the yt dataset. field : tuple of str, optional Field to plot for all panels. Follows yt naming conventions (e.g., ``("dm", "particle_mass")``, ``("gas", "temperature")``). Overridden if ``params["fields"]`` is provided. @@ -345,9 +349,14 @@ def halo_projection_array( Overridden if ``params["weight_fields"]`` is provided. projection_axis : str, int, or 3-element sequence of floats, optional Data is projected along this axis (``"x"``, ``"y"``, or ``"z"``), or, alternatively, - (0, 1, or 2). ``projection_axis`` is forwarded to the `normal` parameter of `ParticleProjectionPlot`. + (0, 1, or 2). ``projection_axis`` is forwarded to the ``normal`` parameter of `ParticleProjectionPlot`. An arbitrary projection axis may be provided as a 3-element sequence of floats. Overridden if ``params["projection_axes"]`` is provided + north_vector : str, int, or 3-element sequence of floats, optional + Sets the north vector of the projection (i.e. which axis corresponds to "up" in the final image). + Setting ``north_vector`` requires setting a ``projection_axis`` that is perpendicular to ``north_vector``. + If ``north_vector`` is not set, yt will choose a north vector internally. + ``north_vector`` is forwarded to the ``north_vector`` parameter of ParticleProjectionPlot. cmap : str Matplotlib colormap to use for all panels. Overridden if ``params["cmaps"]`` is provided. See https://matplotlib.org/stable/gallery/color/colormap_reference.html for named colormaps. @@ -358,6 +367,10 @@ def halo_projection_array( Colorbar limits for `field`. Overridden if ``params["zlims"]`` is provided. length_scale : str or None, optional Optionally add a horizontal bar denoting length scale in Mpc. + manual_axis_alignment : bool, optional + Generate images by directly calling yt.OffAxisParticleProjectionPlot, + which can give more flexibility for managing image orientation. + If False, ``halo_projection_array`` will use yt.ParticleProjectionPlot. Options: - ``"top left"``: add to top left panel @@ -534,6 +547,9 @@ def halo_projection_array( zlim = tuple(zlim) # type: ignore label = labels[i][j] + + north_vector = _sanitize_input_vector(north_vector) + # we need to determine of the projection is going to be axis-aligned. # yt internally checks this within ParticleProjectionPlot and forwards to @@ -541,14 +557,7 @@ def halo_projection_array( # calling OffAxisParticleProjectionPlot for more control over the normal/north # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). if manual_axis_alignment: - if isinstance(projection_axis, str): - match projection_axis: - case "x": - projection_axis = (1, 0, 0) - case "y": - projection_axis = (0, 1, 0) - case "z": - projection_axis = (0, 0, 1) + projection_axis = _sanitize_input_vector(projection_axis) proj = OffAxisParticleProjectionPlot( ds, @@ -658,7 +667,7 @@ def halo_projection_array( return fig -def fig_to_rgb(fig): +def _fig_to_rgb(fig): """ Render a Matplotlib Figure to an (H, W, 3) uint8 RGB array in memory. """ @@ -666,50 +675,30 @@ def fig_to_rgb(fig): buf = np.asarray(fig.canvas.buffer_rgba()) # (H, W, 4) return buf[..., :3] # drop alpha channel +def _sanitize_input_vector(v): + if isinstance(v, str): + match v: + case "x" | 0: + return (1, 0, 0) + case "y" | 1: + return (0, 1, 0) + case "z" | 2: + return (0, 0, 1) + else: + return v + def _normalize(v, eps=0): v = np.asarray(v, dtype=float) - - ''' - if eps > 0: - # nudge exact on-axis directions off-axis - if np.allclose(v, [1, 0, 0]): - v = np.array([1.0, eps, eps]) - elif np.allclose(v, [-1, 0, 0]): - v = np.array([-1.0, eps, eps]) - elif np.allclose(v, [0, 1, 0]): - v = np.array([eps, 1.0, eps]) - elif np.allclose(v, [0, -1, 0]): - v = np.array([eps, -1.0, eps]) - elif np.allclose(v, [0, 0, 1]): - v = np.array([eps, eps, 1.0]) - elif np.allclose(v, [0, 0, -1]): - v = np.array([eps, eps, -1.0]) - - n = np.linalg.norm(v) - - return v / n - ''' - - # normalize first - v = v / np.linalg.norm(v) + + # normalize + v /= np.linalg.norm(v) if eps > 0: - # count "significant" components - nz = np.sum(np.abs(v) > 1e-12) - - if nz == 1: - # find dominant axis - i = np.argmax(np.abs(v)) - - # add epsilon in a perpendicular direction - if i == 0: - v[1] = eps - elif i == 1: - v[2] = eps - elif i == 2: - v[0] = eps - - v = v / np.linalg.norm(v) + # pad zeros with some non-zero value + v[v==0] = eps + + # normalize again + v /= np.linalg.norm(v) return v def _rodrigues_rotate(v, axis, angle): @@ -729,16 +718,16 @@ def _enforce_orthogonality(v1, v2): return v1 - np.dot(v1, v2) * v2 def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0)): + + normal0 = _sanitize_input_vector(normal0) + north0 = _sanitize_input_vector(north0) + normals = [_normalize(normal0, eps=1e-3)] norths = [_normalize(_enforce_orthogonality(north0, normals[-1]))] factors = [] axes = [] - axis_map = {"x": np.array([1.0, 0.0, 0.0]), - "y": np.array([0.0, 1.0, 0.0]), - "z": np.array([0.0, 0.0, 1.0])} - # get a list of rotations for rotation in rotations: if "*" in rotation: @@ -757,38 +746,111 @@ def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0) for i, rotation in enumerate(rotations): # determine number of frames for this rotation (round up) - frames_i = int(np.ceil( frames * factors[i]/sum(factors) )) + frames_i = int(np.ceil( frames * np.absolute(factors[i])/sum(np.absolute(factors)) )) # angular distance traveled in theta and phi delta_angle_i = factors[i] * 2*np.pi / frames_i - axis = axes[i] + axis = _sanitize_input_vector(axes[i]) for _ in range(frames_i): - n = _normalize(_rodrigues_rotate(normals[-1], axis_map[axis], delta_angle_i), eps=1e-3) - u = _normalize(_rodrigues_rotate(norths[-1], axis_map[axis], delta_angle_i)) + n = _normalize(_rodrigues_rotate(normals[-1], axis, delta_angle_i), eps=1e-3) + u = _normalize(_rodrigues_rotate(norths[-1], axis, delta_angle_i)) # enforce orthogonality of normal and north vectors u = _normalize( _enforce_orthogonality(u, n) ) - # continuity guard: prevent sudden 180-degree flips of north - if np.dot(u, norths[-1]) < 0: - u = -u - normals.append(n) norths.append(u) return normals, norths - -def animate_halo( - halo_ids, data, - func="visualize_halo", rotations="x", - frames=30, dpi=100, - normal0=(0, 0, 1), north0=(0, 1, 0), +def animate_halos( + halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, + data: oc.StructureCollection, + func: str = "visualize_halo", + rotations: str | int | list[str] = "y", + frames: int = 30, + dpi: int = 100, + normal0: str | int | list[int] | tuple[int] = "z", + north0: str | int | list[int] | tuple[int] = "y", **kwargs, ): + """ + Creates an animation of one or more halo projections while rotating the + viewing direction. + + The animation is constructed by repeatedly calling either ``visualize_halo`` or + ``halo_projection_array`` for a sequence of projection orientations and stacking the individual + frames into an animation. The viewing orientation evolves according + to ``rotations``, beginning from the initial projection axis ``normal0`` and + initial "up" direction ``north0``. + + By default, this function animates a single halo using ``visualize_halo`` while + rotating about the y-axis. It can also animate a customizable multipanel projection layout by + setting ``func="halo_projection_array"`` and passing a 2D arrangement of halo IDs. + + Parameters + ---------- + halo_ids : int or array of int + Unique ID of the halo(s) to be animated. `halo_ids` is forwarded to the parameter of the same name + in either `visualize_halo` or `halo_projection_array`, depending on the value of ``func``. + + When ``func="visualize_halo"``, only a single halo ID is allowed. When + ``func="halo_projection_array"``, ``halo_ids`` can be an int, list, or 2D array. + + data : opencosmo.StructureCollection + OpenCosmo StructureCollection containing the halo properties and particle data + needed to create yt datasets for the requested halos. For example, this may be + the output of ``opencosmo.open(["haloproperties.hdf5", "haloparticles.hdf5"])``. + + func : str, optional + Name of the plotting function used to generate each animation frame. + + - ``"visualize_halo"``: animate a single halo using ``visualize_halo``. + - ``"halo_projection_array"``: animate a panel array using + ``halo_projection_array``. + + rotations : str or sequence of str, optional + Specification for how the camera rotates during the animation. + + For example, ``rotations = "x"`` rotates the object once about the x-axis, while + ``rotations = ["x", "y"]`` rotates the object once about the x-axis, then once about the y-axis. + Partial rotations can be defined by prepending the string with a float + (e.g. ``rotations=["0.5*x", "0.25*y"]`` does half a rotation about the x-axis, then a quarter rotation about the y-axis). + Prepending the string with a negative value reverses the rotation direction. + + frames : int, optional + Total number of frames in the animation. + + dpi : int, optional + Resolution of the persistent display figure used to assemble the animation. + This also controls the effective pixel size of the output animation. + + normal0 : str, int, or 3-tuple of float, optional + Projection axis for the initial frame. + + north0 : str, int, or 3-tuple of float, optional + Initial north vector, i.e. the initial "up" direction in the image plane. + ``north0`` muat be perpendicular to ``normal0``. + + **kwargs + Additional keyword arguments passed directly to the selected plotting function + (either ``visualize_halo`` or ``halo_projection_array``). This can be used to + customize the field being projected, color normalization, labels, plot width, + colormap, and so on. + + Note that ``projection_axis``, ``north_vector``, ``yt_ds``, and + ``manual_axis_alignment`` are set internally by ``animate_halo`` and will be + overridden regardless of values passed through ``kwargs``. + + Returns + ------- + matplotlib.animation.FuncAnimation + Matplotlib animation object. + + """ halo_ids = np.atleast_2d(halo_ids) @@ -805,7 +867,7 @@ def animate_halo( # retrieve properties of halo and load into yt # this part is skipped if the halo has just been found/loaded in the # previous iteration - # TODO: make this slightly faster by copying directly yt_ds_arr in cases where + # Can make this slightly faster by copying directly yt_ds_arr in cases where # the halo was loaded into yt more than 1 iteration ago if len(data) > 1: @@ -837,14 +899,13 @@ def animate_halo( call_halo_projection_array = False if func == "visualize_halo": call_visualize_halo = True - if np.prod(np.shape(halo_ids) != 1): + if np.prod(np.shape(halo_ids)) > 1: raise ValueError(f"`visualize_halo` requires a single int for `halo_id`, not an array of values") - elif func == "halo_projection_array": call_halo_projection_array = True else: - raise RuntimeError(f"\`func\` {func} not recognized ") + raise RuntimeError(f"\`func\` {func} not recognized") if call_visualize_halo: fig0 = visualize_halo( @@ -867,7 +928,7 @@ def animate_halo( ) - frame0 = fig_to_rgb(fig0) + frame0 = _fig_to_rgb(fig0) plt.close(fig0) H, W = frame0.shape[:2] @@ -890,7 +951,7 @@ def update(i): data, projection_axis=normal, north_vector=north, - yt_ds=yt_ds[0][0], + yt_ds=yt_ds_arr[0][0], manual_axis_alignment=True, **kwargs, ) @@ -904,7 +965,7 @@ def update(i): manual_axis_alignment=True, **kwargs, ) - frame = fig_to_rgb(f) + frame = _fig_to_rgb(f) plt.close(f) # close each per-frame figure im.set_data(frame) From 5b94bd05a4833c809b168fb9923f0fbfa72b3728 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 14:37:40 -0400 Subject: [PATCH 025/139] changelog --- changes/+07fe4e9.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/+07fe4e9.feature.rst diff --git a/changes/+07fe4e9.feature.rst b/changes/+07fe4e9.feature.rst new file mode 100644 index 00000000..0e8debe3 --- /dev/null +++ b/changes/+07fe4e9.feature.rst @@ -0,0 +1 @@ +Added animate_halos function, which calls either "visualize_halo" or "halo_projection_array" in a loop to create an animated visualization of a given halo or set of halos. From d8361fed58fef78a9e5652b957a481c681c7da4b Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 15:03:05 -0400 Subject: [PATCH 026/139] rendering animate_halo docs --- docs/source/analysis_ref.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/analysis_ref.rst b/docs/source/analysis_ref.rst index 934d9b17..761381a8 100644 --- a/docs/source/analysis_ref.rst +++ b/docs/source/analysis_ref.rst @@ -9,6 +9,8 @@ Analysis .. autofunction:: opencosmo.analysis.halo_projection_array +.. autofunction:: opencosmo.analysis.animate_halos + .. autofunction:: opencosmo.analysis.ParticleProjectionPlot .. autofunction:: opencosmo.analysis.ProjectionPlot From 9abe953f6015d76c95631003769e6e637d33aaa5 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 15:37:29 -0400 Subject: [PATCH 027/139] type checker --- python/opencosmo/analysis/yt_viz.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index fd1d3aa5..2c23a47c 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -547,6 +547,7 @@ def halo_projection_array( zlim = tuple(zlim) # type: ignore label = labels[i][j] + projection_axis = projection_axes[i][j] north_vector = _sanitize_input_vector(north_vector) @@ -900,7 +901,7 @@ def animate_halos( if func == "visualize_halo": call_visualize_halo = True if np.prod(np.shape(halo_ids)) > 1: - raise ValueError(f"`visualize_halo` requires a single int for `halo_id`, not an array of values") + raise ValueError("`visualize_halo` requires a single int for `halo_id`, not an array of values") elif func == "halo_projection_array": call_halo_projection_array = True From 5fce7fd025f41698e7d7dc3367c3cc3dc903903a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 16:08:00 -0400 Subject: [PATCH 028/139] type check --- python/opencosmo/analysis/yt_viz.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 2c23a47c..4c261cec 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -270,7 +270,7 @@ def visualize_halo( halo_ids = ([halo_id, halo_id], [halo_id, halo_id]) if yt_dataset_provided: - yt_ds = ([yt_ds, yt_ds],[yt_ds, yt_ds]) + yt_ds_arr = ([yt_ds, yt_ds],[yt_ds, yt_ds]) params = {key: (value[:2], value[2:]) for key, value in params.items()} @@ -279,7 +279,7 @@ def visualize_halo( halo_ids = np.shape(params["fields"])[0] * [halo_id] if yt_dataset_provided: - yt_ds = np.shape(params["fields"])[0] * [yt_ds] + yt_ds_arr = np.shape(params["fields"])[0] * [yt_ds] params = {key: [value] for key, value in params.items()} @@ -292,14 +292,14 @@ def visualize_halo( projection_axis=projection_axis, text_color=text_color, north_vector=north_vector, - yt_ds=yt_ds, + yt_ds=yt_ds_arr, ) def halo_projection_array( halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, data: oc.StructureCollection, - yt_ds: Optional[ list[YT_Dataset] | tuple[list[YT_Dataset], list[YT_Dataset]] ] = None, + yt_ds: Optional[ YT_Dataset | list[YT_Dataset] | tuple[list[YT_Dataset], list[YT_Dataset]] | np.ndarray ] = None, field: Optional[Tuple[str, str]] = ("dm", "particle_mass"), weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", @@ -557,6 +557,9 @@ def halo_projection_array( # OffAxisParticleProjectionPlot if it is not axis-aligned. We are manually # calling OffAxisParticleProjectionPlot for more control over the normal/north # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). + + proj: OffAxisParticleProjectionPlot | ParticleProjectionPlot + if manual_axis_alignment: projection_axis = _sanitize_input_vector(projection_axis) @@ -936,7 +939,7 @@ def animate_halos( # ---- animation "display" figure (single persistent figure) ---- fig = plt.figure(figsize=(W / dpi, H / dpi), dpi=dpi) - ax = fig.add_axes([0,0,1,1]) + ax = fig.add_axes((0.0, 0.0, 1.0, 1.0)) ax.set_axis_off() ax.set_aspect("auto") fig.subplots_adjust(left=0, right=1, bottom=0, top=1) From d4cbc75072d3a6e21db91f8f1da54574ca64780c Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 16:38:57 -0400 Subject: [PATCH 029/139] type check --- python/opencosmo/analysis/yt_viz.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 4c261cec..f47bcb9b 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -414,8 +414,8 @@ def halo_projection_array( # easily translatable to yt's unit conventions data = data.with_units("comoving") - halo_ids = np.atleast_2d(halo_ids) - yt_ds = np.atleast_2d(yt_ds) + halo_ids = np.atleast_2d(halo_ids) # type: ignore + yt_ds = np.atleast_2d(yt_ds) # type: ignore # determine shape of figure fig_shape = np.shape(halo_ids) @@ -558,7 +558,7 @@ def halo_projection_array( # calling OffAxisParticleProjectionPlot for more control over the normal/north # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). - proj: OffAxisParticleProjectionPlot | ParticleProjectionPlot + proj: type[OffAxisParticleProjectionPlot] | type[ParticleProjectionPlot] if manual_axis_alignment: projection_axis = _sanitize_input_vector(projection_axis) From c1eec634611d5e2ba6141f002f0a68b37803253e Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 16:56:54 -0400 Subject: [PATCH 030/139] type check --- python/opencosmo/analysis/yt_viz.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index f47bcb9b..5938e634 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -414,8 +414,8 @@ def halo_projection_array( # easily translatable to yt's unit conventions data = data.with_units("comoving") - halo_ids = np.atleast_2d(halo_ids) # type: ignore - yt_ds = np.atleast_2d(yt_ds) # type: ignore + halo_ids_2d = np.atleast_2d(halo_ids) + yt_ds_2d = np.atleast_2d(yt_ds) # determine shape of figure fig_shape = np.shape(halo_ids) @@ -509,7 +509,7 @@ def halo_projection_array( ax.set_facecolor("black") continue - ds = yt_ds[i][j] + ds = yt_ds_2d[i][j] if ds is not None: # sodbighaloparticles holds particle data out to 2*R200 Rh = ds.domain_width[0] / 4 @@ -557,9 +557,6 @@ def halo_projection_array( # OffAxisParticleProjectionPlot if it is not axis-aligned. We are manually # calling OffAxisParticleProjectionPlot for more control over the normal/north # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). - - proj: type[OffAxisParticleProjectionPlot] | type[ParticleProjectionPlot] - if manual_axis_alignment: projection_axis = _sanitize_input_vector(projection_axis) @@ -569,7 +566,7 @@ def halo_projection_array( field, weight_field=weight_field, north_vector=north_vectors[i][j], - ) + ) # type: ignore else: proj = ParticleProjectionPlot( @@ -578,7 +575,7 @@ def halo_projection_array( field, weight_field=weight_field, north_vector=north_vectors[i][j], - ) + ) # type: ignore proj.set_background_color(field, color="black") From 671afca7a9e7cf9fc468ed3b19146061cda970a0 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 9 Apr 2026 17:00:03 -0400 Subject: [PATCH 031/139] type check --- python/opencosmo/analysis/yt_viz.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 5938e634..04cd90e6 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -497,12 +497,11 @@ def halo_projection_array( fig, axes, cbars = get_multi_plot(fig_shape[1], fig_shape[0], cbar_padding=0) # are we plotting a single halo multiple times? - halo_ids = np.array(halo_ids) halo_id_previous = np.inf for i in range(nrow): for j in range(ncol): - halo_id = halo_ids[i][j] + halo_id = halo_ids_2d[i][j] ax = axes[i][j] if halo_id is None: From f5f3a4836b8aa20bf3d022b085a500b3adff42c3 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 9 Apr 2026 19:38:55 -0500 Subject: [PATCH 032/139] Move everything to column producers --- python/opencosmo/column/cache.py | 4 +- python/opencosmo/column/column.py | 119 ++++++++++++++++-------- python/opencosmo/dataset/state.py | 146 ++++++++---------------------- 3 files changed, 122 insertions(+), 147 deletions(-) diff --git a/python/opencosmo/column/cache.py b/python/opencosmo/column/cache.py index f388a489..c2be9ffc 100644 --- a/python/opencosmo/column/cache.py +++ b/python/opencosmo/column/cache.py @@ -177,9 +177,9 @@ def __push_up(self, data: dict[str, np.ndarray]): ) self.__cached_data |= {key: data[key] for key in columns_to_keep} - def register_column_group(self, state_id: int, columns: set[str]): + def register_column_group(self, state_id: int, columns: Iterable[str]): assert state_id not in self.__registered_column_groups - self.__registered_column_groups[state_id] = columns + self.__registered_column_groups[state_id] = set(columns) def deregister_column_group(self, state_id: int): assert state_id in self.__registered_column_groups diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 3d0a1ee9..9002dac5 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -31,9 +31,31 @@ from opencosmo import Dataset Comparison = Callable[[float, float], bool] +""" +The structures in this file are used both internally and user-facing. +The Column class represents a reference to a single column in the dataset. If +`is_raw` is set to true, we are requiring that this column is actually instantiated +from data originally in the hdf5 file. This is not intended to be set by the user, +and is only used internally. -def col(column_name: str) -> Column: +A DerivedColumn represents a combination of columns that produces a single new column. +The columns it depends on may or may not be raw columns themselves, but eventually +the dependency graph always points back to raw columns, or columns that were provided +directly by the user as numpy arrays/astropy quantities. + +An evaluted column takes an arbitrary number of columns as input and returns an arbitrary +number of columns as output. These columns HAVE NOT actually been evaluated yet. + +These combinations all form a dependency graph, which can easily be evaluated for validity. + +Raw columns and columns that are in-memory are allowed to be sources. All other columns +must take input. The dependency graph must be a DAG, where only those two types of columns +have no inputs. +""" + + +def col(name: str) -> Column: """ Create a reference to a column with a given name. These references can be combined to produce new columns or express queries that operate on the values in a given @@ -50,7 +72,7 @@ def col(column_name: str) -> Column: For more advanced usage, see :doc:`cols` """ - return Column(column_name) + return Column(name) ColumnOrScalar = Union["Column", "DerivedColumn", int, float] @@ -152,45 +174,30 @@ class Column: """ - def __init__(self, column_name: str): - self.column_name = column_name + def __init__(self, name: str): + self.name = name self.description = None - @property - def requires(self): - return set([self.column_name]) - - @property - def produces(self): - return None - - def get_units(self, column_units: dict[str, u.Unit]): - return column_units[self.column_name] - - def evaluate(self, data: dict[str, np.ndarray], *args): - return data[self.column_name] - - # mypy doesn't reason about eq and neq correctly def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.column_name, other, op.eq) + return ColumnMask(self.name, other, op.eq) def __ne__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.column_name, other, op.ne) + return ColumnMask(self.name, other, op.ne) def __gt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.gt) + return ColumnMask(self.name, other, op.gt) def __ge__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.ge) + return ColumnMask(self.name, other, op.ge) def __lt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.lt) + return ColumnMask(self.name, other, op.lt) def __le__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.le) + return ColumnMask(self.name, other, op.le) def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: - return ColumnMask(self.column_name, other, np.isin) + return ColumnMask(self.name, other, np.isin) def __rmul__(self, other: Any) -> DerivedColumn: match other: @@ -292,6 +299,34 @@ def evaluate( def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... +class RawColumn: + def __init__(self, name, description): + self.__name = name + self.__description = description + + @property + def name(self): + return self.__name + + @property + def requires(self) -> set[str]: + return set() + + @property + def produces(self) -> set[str]: + return set([self.__name]) + + @property + def description(self): + return self.__description + + def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: + return values[self.__name] + + def evaluate(self, data: dict[str, np.ndarray], *args): + return data[self.name] + + class DerivedColumn: """ A derived column represents a combination of multiple columns that already exist in @@ -314,9 +349,11 @@ def __init__( rhs: Optional[ColumnOrScalar], operation: Callable, description: Optional[str] = None, + output_name: Optional[str] = None, ): self.lhs = lhs self.rhs = rhs + self.name = output_name self.operation = operation self.description = description if description is not None else "None" @@ -328,12 +365,12 @@ def requires(self): vals = set() match self.lhs: case Column(): - vals.add(self.lhs.column_name) + vals.add(self.lhs.name) case DerivedColumn(): vals = vals | self.lhs.requires match self.rhs: case Column(): - vals.add(self.rhs.column_name) + vals.add(self.rhs.name) case DerivedColumn(): vals = vals | self.rhs.requires @@ -341,12 +378,12 @@ def requires(self): @property def produces(self): - return None + return None if self.name is None else set([self.name]) def check_parent_existance(self, names: set[str]): match self.rhs: case Column(): - rhs_valid = self.rhs.column_name in names + rhs_valid = self.rhs.name in names case DerivedColumn(): rhs_valid = self.rhs.check_parent_existance(names) case _: @@ -354,7 +391,7 @@ def check_parent_existance(self, names: set[str]): match self.lhs: case Column(): - lhs_valid = self.lhs.column_name in names + lhs_valid = self.lhs.name in names case DerivedColumn(): lhs_valid = self.lhs.check_parent_existance(names) case _: @@ -365,14 +402,14 @@ def check_parent_existance(self, names: set[str]): def get_units(self, units: dict[str, u.Unit]): match self.lhs: case Column(): - lhs_unit = units[self.lhs.column_name] + lhs_unit = units[self.lhs.name] case DerivedColumn(): lhs_unit = self.lhs.get_units(units) case _: lhs_unit = None match self.rhs: case Column(): - rhs_unit = units[self.rhs.column_name] + rhs_unit = units[self.rhs.name] case DerivedColumn(): rhs_unit = self.rhs.get_units(units) case _: @@ -446,14 +483,14 @@ def evaluate(self, data: dict[str, np.ndarray], *args) -> np.ndarray: case DerivedColumn(): lhs = self.lhs.evaluate(data) case Column(): - lhs = data[self.lhs.column_name] + lhs = data[self.lhs.name] case _: lhs = self.lhs match self.rhs: case DerivedColumn(): rhs = self.rhs.evaluate(data) case Column(): - rhs = data[self.rhs.column_name] + rhs = data[self.rhs.name] case _: rhs = self.rhs @@ -498,6 +535,10 @@ def with_kwargs(self, **new_kwargs: Any): **new_kwargs, ) + @property + def name(self): + return self.__func.__name__ + @property def requires(self): return copy(self.__requires) @@ -603,17 +644,17 @@ class ColumnMask: def __init__( self, - column_name: str, + name: str, value: float | u.Quantity, operator: Callable[[table.Column, float | u.Quantity], np.ndarray], ): - self.column_name = column_name + self.name = name self.value = value self.operator = operator @property def requires(self): - return {self.column_name} + return {self.name} def apply(self, column: u.Quantity | np.ndarray) -> np.ndarray: """ @@ -621,7 +662,7 @@ def apply(self, column: u.Quantity | np.ndarray) -> np.ndarray: """ # Astropy's errors are good enough here if isinstance(column, table.Table): - column = column[self.column_name] + column = column[self.name] if isinstance(self.value, u.Quantity) and isinstance(column, u.Quantity): if self.value.unit != column.unit: diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index b47535d9..ba252b6b 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -2,20 +2,18 @@ from copy import copy from functools import reduce -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Iterable, Optional from weakref import finalize import astropy.units as u import numpy as np from opencosmo.column.cache import ColumnCache -from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn +from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn, RawColumn from opencosmo.column.select import get_column_selection -from opencosmo.dataset.derived import ( - build_derived_columns, - validate_derived_columns, -) +from opencosmo.dataset.graph import validate_column_producers from opencosmo.dataset.im import resort, validate_in_memory_columns +from opencosmo.dataset.instantiate import instantiate_dataset from opencosmo.handler.empty import EmptyHandler from opencosmo.handler.hdf5 import Hdf5Handler from opencosmo.index.build import single_chunk @@ -55,21 +53,21 @@ class DatasetState: def __init__( self, + column_producers: list[ConstructedColumn], raw_data_handler: DataHandler, cache: DataCache, - derived_columns: dict[str, ConstructedColumn], unit_handler: UnitHandler, header: OpenCosmoHeader, - columns: set[str], + columns: Iterable[str], region: Region, sort_by: Optional[tuple[str, bool]], ): + self.__producers = column_producers self.__raw_data_handler = raw_data_handler self.__cache = cache - self.__derived_columns = derived_columns self.__unit_handler = unit_handler self.__header = header - self.__columns = columns + self.__columns = set(columns) self.__region = region self.__sort_by = sort_by self.__cache.register_column_group(id(self), self.__columns) @@ -79,7 +77,7 @@ def __rebuild(self, **updates): new = { "raw_data_handler": self.__raw_data_handler, "cache": self.__cache, - "derived_columns": self.__derived_columns, + "column_producers": self.__producers, "unit_handler": self.__unit_handler, "header": self.__header, "columns": self.__columns, @@ -116,13 +114,18 @@ def from_target( unit_handler = make_unit_handler_from_hdf5( target["columns"], target["header"], unit_convention ) + descriptions = handler.descriptions + producers = [ + RawColumn(cname, descriptions.get(cname, "None")) + for cname in handler.columns + ] columns = set(handler.columns) cache = ColumnCache.empty() return DatasetState( + producers, handler, cache, - {}, unit_handler, target["header"], columns, @@ -171,11 +174,10 @@ def __len__(self): @property def descriptions(self): - all_descriptions = ( - {name: col.description for name, col in self.__derived_columns.items()} - | self.__raw_data_handler.descriptions - | self.__cache.descriptions - ) + all_descriptions = {} + for producer in self.__producers: + update = {name: producer.description for name in producer.produces} + all_descriptions |= update return { name: description for name, description in all_descriptions.items() @@ -231,43 +233,13 @@ def get_data( """ Get the data for a given handler. """ - data = self.__build_derived_columns(unit_kwargs) - cached_data = self.__cache.get_data(self.columns) - converted_cached_data = self.__unit_handler.apply_unit_conversions( - cached_data, unit_kwargs - ) - - data |= cached_data - if converted_cached_data: - self.__cache.add_data(converted_cached_data, {}, push_up=False) - data |= converted_cached_data - - raw_columns = ( - set(self.columns) - .intersection(self.__raw_data_handler.columns) - .difference(data.keys()) + data = instantiate_dataset( + self.__producers, + self.__raw_data_handler, + self.__cache, + self.__unit_handler, + unit_kwargs, ) - if ( - self.__sort_by is not None - and self.__sort_by[0] in self.__raw_data_handler.columns - ): - raw_columns.add(self.__sort_by[0]) - - if raw_columns: - raw_data = self.__raw_data_handler.get_data(raw_columns) - raw_data = self.__unit_handler.apply_raw_units(raw_data, unit_kwargs) - if raw_data: - self.__cache.add_data(raw_data, {}) - - updated_data = self.__unit_handler.apply_unit_conversions( - raw_data, unit_kwargs - ) - if updated_data: - self.__cache.add_data(updated_data, {}, push_up=False) - - new_data = raw_data | updated_data - data |= new_data - if missing := set(self.columns).difference(data.keys()): raise RuntimeError( f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" @@ -426,15 +398,22 @@ def with_new_columns( if inter := existing_columns.intersection(new_columns.keys()): raise ValueError(f"Some columns are already in the dataset: {inter}") - new_derived_columns = {} + new_derived_columns = [] new_in_memory_columns = {} new_in_memory_descriptions = {} + new_column_names = self.columns for colname, column in new_columns.items(): match column: - case DerivedColumn() | EvaluatedColumn() | Column(): + case DerivedColumn(): + column.name = colname column.description = descriptions.get(colname, "None") - new_derived_columns[colname] = column + new_derived_columns.append(column) + new_column_names.extend(column.produces) + case EvaluatedColumn() | Column(): + column.description = descriptions.get(colname, "None") + new_derived_columns.append(column) + new_column_names.extend(column.produces) case np.ndarray(): if len(column) != len(self): raise ValueError( @@ -444,14 +423,16 @@ def with_new_columns( colname, "None" ) new_in_memory_columns[colname] = column + new_column_names.append(colname) + continue case _: raise ValueError( f"Got an invalid new column of type {type(column)}" ) new_unit_handler = self.__unit_handler - new_derived = copy(self.__derived_columns) - new_column_names: set[str] = set(self.columns) + new_producers = copy(self.__producers) + new_derived_columns + validate_column_producers(new_producers) if new_in_memory_columns: new_unit_handler = validate_in_memory_columns( new_in_memory_columns, self.__unit_handler, len(self) @@ -462,61 +443,14 @@ def with_new_columns( self.__cache.add_data( new_in_memory_columns, descriptions=new_in_memory_descriptions ) - new_column_names |= set(new_in_memory_columns.keys()) - - if new_derived_columns: - new_units = validate_derived_columns( - self.__derived_columns | new_derived_columns, - existing_columns.union(new_in_memory_columns.keys()).difference( - self.__derived_columns.keys() - ), - new_unit_handler.base_units, - ) - new_derived |= new_derived_columns - for colname, derived in new_derived.items(): - if (prod := derived.produces) is not None: - new_column_names |= prod - else: - new_column_names.add(colname) - - new_unit_handler = new_unit_handler.with_new_columns(**new_units) return self.__rebuild( cache=self.__cache, - derived_columns=new_derived, + column_producers=new_producers, columns=new_column_names, unit_handler=new_unit_handler, ) - def __build_derived_columns(self, unit_kwargs: dict) -> table.Table: - """ - Build any derived columns that are present in this dataset - """ - if not self.__derived_columns: - return {} - - all_derived_columns: set[str] = reduce( - lambda acc, dc: acc.union( - dc[1].produces if dc[1].produces is not None else {dc[0]} - ), - self.__derived_columns.items(), - set(), - ) - derived_names = all_derived_columns.intersection(self.columns) - if self.__sort_by is not None and self.__sort_by[0] in all_derived_columns: - derived_names.add(self.__sort_by[0]) - - dc = build_derived_columns( - self.__derived_columns, - derived_names, - self.__cache, - self.__raw_data_handler, - self.__unit_handler, - unit_kwargs, - self.__raw_data_handler.index, - ) - return dc - def with_region(self, region: Region): """ Return the same dataset but with a different region From 4a52ed25a86cf998398753d44f17cafc53f8a784 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 9 Apr 2026 19:54:01 -0500 Subject: [PATCH 033/139] Basic writes working with producers --- python/opencosmo/dataset/state.py | 52 ++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index ba252b6b..1d8cd0f7 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -336,7 +336,15 @@ def make_schema(self, name: Optional[str] = None): data_schema, metadata_schema = self.__raw_data_handler.make_schema( raw_columns, header ) - derived_names = set(self.__derived_columns.keys()).intersection(self.columns) + derived_names = reduce( + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + self.__producers, + set, + ) + derived_names = derived_names.intersection(self.columns) + derived_data = ( self.select(derived_names) .with_units(self.__unit_handler.base_convention, {}, {}, None, None) @@ -346,22 +354,28 @@ def make_schema(self, name: Optional[str] = None): self.columns + self.meta_columns ) - for colname in derived_names: - if colname in cached_data_schema.columns: + for producer in self.__producers: + if isinstance(producer, RawColumn) or producer.produces.issubset( + cached_data_schema.columns.keys() + ): continue - coldata = derived_data[colname] - unit = "" - if isinstance(coldata, u.Quantity): - unit = str(coldata.unit) - coldata = derived_data[colname].value - - attrs = { - "unit": unit, - "description": str(self.__derived_columns[colname].description), + coldata = {name: derived_data[name] for name in producer.produces} + units = { + name: str(cd.unit) if isinstance(cd, u.Quantity) else "" + for name, cd in coldata.items() } - source = NumpySource(coldata) - writer = ColumnWriter([source], ColumnCombineStrategy.CONCAT, attrs=attrs) - data_schema.columns[colname] = writer + coldata = { + name: cd.value if isinstance(cd, u.Quantity) else cd + for name, cd in coldata.items() + } + for name, cd in coldata.items(): + attrs = {"unit": units[name], "description": producer.description} + + source = NumpySource(cd) + writer = ColumnWriter( + [source], ColumnCombineStrategy.CONCAT, attrs=attrs + ) + data_schema.columns[name] = writer attributes = {} if (load_conditions := self.__raw_data_handler.load_conditions) is not None: @@ -612,10 +626,12 @@ def with_units( cache = self.__cache.create_child() else: all_derived_names: set[str] = reduce( - lambda acc, next: acc.union(next[1].produces or {next[0]}), - self.__derived_columns.items(), + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + self.__producers, set(), - ) + ).intersection(self.columns) columns_to_drop = all_derived_names.union(self.__raw_data_handler.columns) cache = self.__cache.drop(columns_to_drop) return self.__rebuild(unit_handler=new_handler, cache=cache) From c3acea7a7ca6bc779a5d5fa198a56dec027303db Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 9 Apr 2026 20:15:49 -0500 Subject: [PATCH 034/139] Correct raw data caching behavior with column conversions --- python/opencosmo/dataset/state.py | 6 ++++-- python/opencosmo/units/handler.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 1d8cd0f7..39591513 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -178,6 +178,8 @@ def descriptions(self): for producer in self.__producers: update = {name: producer.description for name in producer.produces} all_descriptions |= update + all_descriptions |= self.__cache.descriptions + return { name: description for name, description in all_descriptions.items() @@ -277,9 +279,9 @@ def get_data( def rows(self, metadata_columns: list = [], unit_kwargs: dict = {}): derived_to_collect = ( - set(self.__derived_columns.keys()) - .intersection(self.columns) + set(self.columns) .difference(self.__cache.columns) + .difference(self.__raw_data_handler.columns) ) derived_storage: dict[str, list[np.ndarray]] = { name: [] for name in derived_to_collect diff --git a/python/opencosmo/units/handler.py b/python/opencosmo/units/handler.py index 3b9e53fd..2281ed8a 100644 --- a/python/opencosmo/units/handler.py +++ b/python/opencosmo/units/handler.py @@ -87,6 +87,16 @@ def current_convention(self): def base_convention(self): return self.__base_convention + @property + def columns_with_conversions(self): + return { + name + for name, applicator in self.__applicators.items() + if name in self.__column_conversions + or str(applicator.unit_in_convention(self.current_convention)) + in self.__conversions + } + @cached_property def base_units(self): return {key: app.base_unit for key, app in self.__applicators.items()} From 1449a203ffebaf746483475c7cacf4ab52954a50 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 9 Apr 2026 20:18:12 -0500 Subject: [PATCH 035/139] Add new instantiate file, remove derived file --- python/opencosmo/dataset/derived.py | 226 ------------------------ python/opencosmo/dataset/instantiate.py | 62 +++++++ 2 files changed, 62 insertions(+), 226 deletions(-) delete mode 100644 python/opencosmo/dataset/derived.py create mode 100644 python/opencosmo/dataset/instantiate.py diff --git a/python/opencosmo/dataset/derived.py b/python/opencosmo/dataset/derived.py deleted file mode 100644 index bedafb42..00000000 --- a/python/opencosmo/dataset/derived.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -from functools import reduce -from itertools import product -from typing import TYPE_CHECKING, Mapping, Optional - -import rustworkx as rx - -if TYPE_CHECKING: - import astropy.units as u - import numpy as np - - from opencosmo.column.column import ConstructedColumn - from opencosmo.handler.protocols import DataCache, DataHandler - from opencosmo.index import DataIndex - from opencosmo.units.handler import UnitHandler - - -def build_dependency_graph( - derived_columns: Mapping[str, ConstructedColumn], - names_to_keep: Optional[set[str]] = None, -): - dependency_graph = rx.PyDiGraph() - all_requires: set[str] = reduce( - lambda known, dc: known.union(dc.requires), derived_columns.values(), set() - ) - nodeidx = dependency_graph.add_nodes_from(all_requires) - nodemap = {name: idx for (name, idx) in zip(all_requires, nodeidx)} - - for target, derived_column in derived_columns.items(): - requires = derived_column.requires - produces = derived_column.produces - if produces is None: - produces = set((target,)) - to_add = list(filter(lambda p: p not in nodemap, produces)) - new_map = dependency_graph.add_nodes_from(to_add) - nodemap.update({name: idx for (name, idx) in zip(to_add, new_map)}) - - requires_idx = tuple(nodemap[r] for r in requires) - produces_idx = tuple(nodemap[r] for r in produces) - - dependency_graph.add_edges_from_no_data(product(requires_idx, produces_idx)) - - if names_to_keep is not None: - nodes_to_keep = reduce( - lambda acc, name: acc.union(rx.ancestors(dependency_graph, nodemap[name])), - names_to_keep, - {nodemap[name] for name in names_to_keep}, - ) - names_to_keep = {dependency_graph[n] for n in nodes_to_keep} - dependency_graph = dependency_graph.subgraph(list(nodes_to_keep)) - derived_columns = { - name: dc for name, dc in derived_columns.items() if name in names_to_keep - } - - return dependency_graph - - -def replace_multi_producers( - graph: rx.PyDiGraph, - derived_columns: Mapping[str, ConstructedColumn], - columns_to_get: Optional[set[str]] = None, -): - """ - Some derived columns actually produce multiple outputs. At this stage, the dependency - graph is working solely with actual column names, meaning if any of those columns is - produced by one of these "multi-produces" they will not be in the derived_columns - dictionary and therefore cannot be instantiated. This function replaces such - columns with the name of the derived_column that produces them. - """ - - node_map = {name: i for i, name in enumerate(graph.nodes())} - missing = set(derived_columns.keys()).difference(node_map.keys()) - if not missing: - return graph - for missing_column in missing: - missing_column_produces = derived_columns[missing_column].produces - if missing_column_produces is None or not missing_column_produces.intersection( - columns_to_get or missing_column_produces - ): - continue - outputs = [ - node_map[name] for name in missing_column_produces if name in node_map - ] - graph.contract_nodes(outputs, missing_column) - return graph - - -def validate_derived_columns( - derived_columns: dict[str, ConstructedColumn], - known_raw_columns: set[str], - units: dict[str, u.Unit], -): - """ - Validate the network of derived columns. This - """ - dependency_graph = build_dependency_graph(derived_columns) - if cycle := rx.digraph_find_cycle(dependency_graph): - all_nodes: set[int] = reduce( - lambda known, edge: known.union(edge), cycle, set() - ) - names = [dependency_graph[i] for i in all_nodes] - raise ValueError( - f"Found derived columns that depend on each other! Columns: {names}" - ) - - sources = set( - filter( - lambda i: not dependency_graph.in_degree(i), - range(dependency_graph.num_nodes()), - ) - ) - source_names = map(lambda i: dependency_graph[i], sources) - if missing := set(source_names).difference(known_raw_columns): - raise ValueError(f"Tried to derive columns from unknown columns: {missing}") - - dependency_graph = replace_multi_producers(dependency_graph, derived_columns) - validate_dependency_graph( - dependency_graph, known_raw_columns, set(derived_columns.keys()) - ) - - return validate_derived_units(dependency_graph, derived_columns, units) - - -def validate_dependency_graph( - dependency_graph: rx.PyDiGraph, - known_raw_columns: set[str], - derived_columns: set[str], -): - expected = set(dependency_graph.nodes()) - known = known_raw_columns.union(derived_columns) - missing = expected.difference(known) - assert len(missing) == 0 - - -def validate_derived_units( - dependency_graph: rx.PyDiGraph, - derived_columns: dict[str, ConstructedColumn], - units: dict[str, u.Unit], -): - output_units: dict[str, Optional[u.Unit]] = {} - for node in rx.topological_sort(dependency_graph): - node_name = dependency_graph[node] - if node_name in units: - continue - new_units = derived_columns[node_name].get_units(units) - if not isinstance(new_units, dict): - new_units = {node_name: new_units} - units |= new_units - output_units |= new_units - return output_units - - -def build_derived_columns( - all_derived_columns: dict[str, ConstructedColumn], - derived_columns_to_get: set[str], - cache: DataCache, - data_handler: DataHandler, - unit_handler: UnitHandler, - unit_kwargs: dict, - index: DataIndex, -) -> dict[str, np.ndarray]: - """ - Build any derived columns that are present in this dataset. Also returns any columns that - had to be instantiated in order to build these derived columns. - """ - if not derived_columns_to_get: - return {} - - column_names: set[str] = reduce( - lambda known, dc: known.union(dc[1].produces) - if dc[1].produces is not None - else known.union((dc[0],)), - all_derived_columns.items(), - set(), - ) - - dependency_graph = build_dependency_graph( - all_derived_columns, derived_columns_to_get - ) - cached_data = cache.get_data(dependency_graph.nodes()) - cached_data |= unit_handler.apply_unit_conversions(cached_data, unit_kwargs) - - additional_derived = column_names.difference(cached_data.keys()) - - if not additional_derived: - return cached_data - - columns_to_fetch = ( - set(dependency_graph.nodes()) - .intersection(data_handler.columns) - .difference(cached_data.keys()) - ) - - raw_data = data_handler.get_data(columns_to_fetch) - data = cached_data | unit_handler.apply_units(raw_data, unit_kwargs) - - dependency_graph = replace_multi_producers( - dependency_graph, all_derived_columns, derived_columns_to_get - ) - new_derived: dict[str, np.ndarray] = {} - - for colidx in rx.topological_sort(dependency_graph): - colname = dependency_graph[colidx] - if colname in data: - continue - derived_column = all_derived_columns[colname] - produces = derived_column.produces - if produces is None: - produces = set((colname,)) - if all(name in data for name in produces): - continue - output = derived_column.evaluate( - data, index[1] if isinstance(index, tuple) else None - ) - if isinstance(output, dict): - data |= output - new_derived |= output - else: - data[colname] = output - new_derived[colname] = output - - if new_derived: - cache.add_data(new_derived, {}) - - return data | new_derived diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py new file mode 100644 index 00000000..7aa4c117 --- /dev/null +++ b/python/opencosmo/dataset/instantiate.py @@ -0,0 +1,62 @@ +import rustworkx as rx + +from opencosmo.column.column import RawColumn +from opencosmo.dataset.graph import build_dependency_graph, contract_derived_columns + + +def instantiate_dataset( + column_producers, raw_data_handler, cache, unit_handler, unit_kwargs +): + dependency_graph = build_dependency_graph(column_producers) + cached_data = cache.get_data(dependency_graph.nodes()) + converted_cached_data = unit_handler.apply_unit_conversions( + cached_data, unit_kwargs + ) + + push_up = True + if converted_cached_data: + push_up = False + cache.add_data(converted_cached_data, {}, push_up=push_up) + + raw_columns = filter( + lambda col: ( + isinstance(col, RawColumn) + and not col.requires.intersection(cached_data.keys()) + ), + column_producers, + ) + raw_data = raw_data_handler.get_data([col.name for col in raw_columns]) + raw_data = unit_handler.apply_raw_units(raw_data, unit_kwargs) + new_derived_columns = build_derived_columns( + column_producers, converted_cached_data | raw_data, dependency_graph, None + ) + if raw_data: + cache.add_data(raw_data, {}, push_up=True) + converted_raw_data = unit_handler.apply_unit_conversions(raw_data, unit_kwargs) + if converted_raw_data: + cache.add_data(converted_raw_data, {}, push_up=False) + raw_data |= converted_raw_data + + if new_derived_columns: + cache.add_data(new_derived_columns, {}, push_up=push_up) + + return converted_cached_data | raw_data | new_derived_columns + + +def build_derived_columns(column_producers, data, dependency_graph, index): + dependency_graph = contract_derived_columns(dependency_graph, column_producers) + new_derived = {} + for colidx in rx.topological_sort(dependency_graph): + column = dependency_graph[colidx] + if isinstance(column, str): + assert column in data + continue + produces = column.produces + if all(name in data for name in produces): + continue + output = column.evaluate(data, index[1] if isinstance(index, tuple) else None) + if isinstance(output, dict): + new_derived |= output + else: + new_derived[column.name] = output + return new_derived From 6371e7ae8d3bccc3f1e22f2f5c40edaec9162a55 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 10 Apr 2026 12:30:03 -0500 Subject: [PATCH 036/139] All serial tests working --- python/opencosmo/column/column.py | 15 ++- python/opencosmo/dataset/graph.py | 129 ++++++++++++++++++++++++ python/opencosmo/dataset/instantiate.py | 75 ++++++++++---- python/opencosmo/dataset/state.py | 54 +++++++--- 4 files changed, 233 insertions(+), 40 deletions(-) create mode 100644 python/opencosmo/dataset/graph.py diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 9002dac5..5a0aeb92 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -286,7 +286,7 @@ class ConstructedColumn(Protocol): @property def requires(self) -> set[str]: ... @property - def produces(self) -> Optional[set[str]]: ... + def produces(self) -> set[str]: ... @property def description(self) -> Optional[str]: ... @@ -300,9 +300,10 @@ def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... class RawColumn: - def __init__(self, name, description): + def __init__(self, name, description, alias=None): self.__name = name self.__description = description + self.__alias = alias @property def name(self): @@ -310,11 +311,17 @@ def name(self): @property def requires(self) -> set[str]: + if self.__alias is not None: + return set([self.__name]) return set() + @property + def alias(self) -> str | None: + return self.__alias + @property def produces(self) -> set[str]: - return set([self.__name]) + return set([self.__alias or self.__name]) @property def description(self): @@ -324,7 +331,7 @@ def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: return values[self.__name] def evaluate(self, data: dict[str, np.ndarray], *args): - return data[self.name] + return data[self.__name] class DerivedColumn: diff --git a/python/opencosmo/dataset/graph.py b/python/opencosmo/dataset/graph.py new file mode 100644 index 00000000..6123b923 --- /dev/null +++ b/python/opencosmo/dataset/graph.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from functools import reduce +from itertools import product +from typing import TYPE_CHECKING + +import rustworkx as rx + +from opencosmo.column.column import RawColumn + +if TYPE_CHECKING: + import astropy.units as u + + from opencosmo.column.column import ConstructedColumn + from opencosmo.units.handler import UnitHandler + + +def validate_column_producers( + producers: list[ConstructedColumn], unit_handler: UnitHandler +): + """ + Validate the network of column producers. + """ + raw_columns = set(col.name for col in producers if isinstance(col, RawColumn)) + + dependency_graph = build_dependency_graph(producers) + if cycle := rx.digraph_find_cycle(dependency_graph): + all_nodes: set[int] = reduce( + lambda known, edge: known.union(edge), cycle, set() + ) + names = [dependency_graph[i] for i in all_nodes] + raise ValueError(f"Found columns that depend on each other! Columns: {names}") + + sources = set( + filter( + lambda i: not dependency_graph.in_degree(i), + range(dependency_graph.num_nodes()), + ) + ) + source_names = set(map(lambda i: dependency_graph[i], sources)) + if missing := set(source_names).difference(raw_columns): + raise ValueError(f"Tried to derive columns from unknown columns: {missing}") + + return get_derived_units(dependency_graph, producers, unit_handler.base_units) + + +def build_dependency_graph(producers: list[ConstructedColumn]): + dependency_graph = rx.PyDiGraph() + all_requires: set[str] = reduce( + lambda known, dc: known.union(dc.requires if dc.requires is not None else []), + producers, + set(), + ) + nodeidx = dependency_graph.add_nodes_from(all_requires) + nodemap = {name: idx for (name, idx) in zip(all_requires, nodeidx)} + + for column_producer in producers: + requires = column_producer.requires + produces = column_producer.produces + assert produces is not None + to_add = list(filter(lambda p: p not in nodemap, produces)) + new_map = dependency_graph.add_nodes_from(to_add) + nodemap.update({name: idx for (name, idx) in zip(to_add, new_map)}) + if not requires: + continue + + requires_idx = tuple(nodemap[r] for r in requires) + produces_idx = tuple(nodemap[r] for r in produces) + + dependency_graph.add_edges_from_no_data(product(requires_idx, produces_idx)) + + return dependency_graph + + +def contract_derived_columns( + graph: rx.PyDiGraph, + column_names: set[str], + column_producers: list[ConstructedColumn], +): + """ + Some derived columns actually produce multiple outputs. At this stage, the dependency + graph is working solely with actual column names, meaning if any of those columns is + produced by one of these "multi-produces" they will not be in the derived_columns + dictionary and therefore cannot be instantiated. This function replaces such + columns with the name of the derived_column that produces them. + """ + node_map = {name: i for i, name in enumerate(graph.nodes())} + nodes_to_keep: set[int] = set() + for producer in column_producers: + if producer.produces.intersection(column_names): + produces_idx = {node_map[name] for name in producer.produces} + nodes_to_keep = reduce( + lambda acc, node: acc.union(rx.ancestors(graph, node)), + produces_idx, + nodes_to_keep, + ) + nodes_to_keep.update(produces_idx) + subgraph = graph.subgraph(list(nodes_to_keep)) + node_map = {name: i for i, name in enumerate(subgraph.nodes())} + + for producer in column_producers: + if isinstance(producer, RawColumn): + continue + produces = producer.produces + produces_index = [node_map[name] for name in produces if name in node_map] + if produces_index: + subgraph.contract_nodes(produces_index, producer) + return subgraph + + +def get_derived_units( + dependency_graph: rx.PyDiGraph, + producers: list[ConstructedColumn], + units: dict[str, u.Unit], +): + dependency_graph = contract_derived_columns( + dependency_graph, set(dependency_graph.nodes()), producers + ) + + new_units: dict[str, u.Unit | None] = {} + for node in rx.topological_sort(dependency_graph): + node = dependency_graph[node] + if isinstance(node, str): + continue + column_units = node.get_units(units | new_units) + if not isinstance(column_units, dict): + column_units = {prod: column_units for prod in node.produces} + new_units |= column_units + return new_units diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 7aa4c117..51314e1d 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -4,31 +4,59 @@ from opencosmo.dataset.graph import build_dependency_graph, contract_derived_columns +def get_all_required_columns(column_names: set[str], dependency_graph: rx.PyDiGraph): + required_columns = set() + node_map = {name: i for i, name in enumerate(dependency_graph.nodes())} + for name in column_names: + required_columns.add(name) + ancestors = rx.ancestors(dependency_graph, node_map[name]) + required_columns.update(dependency_graph[i] for i in ancestors) + + return required_columns + + def instantiate_dataset( - column_producers, raw_data_handler, cache, unit_handler, unit_kwargs + column_producers, + column_names, + raw_data_handler, + cache, + unit_handler, + unit_kwargs, ): dependency_graph = build_dependency_graph(column_producers) - cached_data = cache.get_data(dependency_graph.nodes()) + all_required_columns = get_all_required_columns(column_names, dependency_graph) + cached_data = cache.get_data(all_required_columns) converted_cached_data = unit_handler.apply_unit_conversions( cached_data, unit_kwargs ) - push_up = True if converted_cached_data: - push_up = False - cache.add_data(converted_cached_data, {}, push_up=push_up) - - raw_columns = filter( - lambda col: ( - isinstance(col, RawColumn) - and not col.requires.intersection(cached_data.keys()) - ), - column_producers, + cache.add_data(converted_cached_data, {}, push_up=False) + cached_data |= converted_cached_data + + raw_columns = list( + filter( + lambda col: ( + isinstance(col, RawColumn) + and col.name not in cached_data + and col.name in all_required_columns + ), + column_producers, + ) ) - raw_data = raw_data_handler.get_data([col.name for col in raw_columns]) + raw_data = raw_data_handler.get_data(set(col.name for col in raw_columns)) + for column in raw_columns: + if column.alias is None: + continue + raw_data[column.alias] = raw_data[column.name] + raw_data = unit_handler.apply_raw_units(raw_data, unit_kwargs) new_derived_columns = build_derived_columns( - column_producers, converted_cached_data | raw_data, dependency_graph, None + column_producers, + all_required_columns, + cached_data | raw_data, + dependency_graph, + raw_data_handler.index, ) if raw_data: cache.add_data(raw_data, {}, push_up=True) @@ -37,24 +65,27 @@ def instantiate_dataset( cache.add_data(converted_raw_data, {}, push_up=False) raw_data |= converted_raw_data - if new_derived_columns: - cache.add_data(new_derived_columns, {}, push_up=push_up) - - return converted_cached_data | raw_data | new_derived_columns + return cached_data | raw_data | new_derived_columns -def build_derived_columns(column_producers, data, dependency_graph, index): - dependency_graph = contract_derived_columns(dependency_graph, column_producers) +def build_derived_columns( + column_producers, column_names, data, dependency_graph, index +): + dependency_graph = contract_derived_columns( + dependency_graph, column_names, column_producers + ) new_derived = {} for colidx in rx.topological_sort(dependency_graph): column = dependency_graph[colidx] if isinstance(column, str): - assert column in data + assert column in data or column not in column_names continue produces = column.produces if all(name in data for name in produces): continue - output = column.evaluate(data, index[1] if isinstance(index, tuple) else None) + output = column.evaluate( + data | new_derived, index[1] if isinstance(index, tuple) else None + ) if isinstance(output, dict): new_derived |= output else: diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 39591513..0832fca9 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -2,7 +2,7 @@ from copy import copy from functools import reduce -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional, Sequence from weakref import finalize import astropy.units as u @@ -53,7 +53,7 @@ class DatasetState: def __init__( self, - column_producers: list[ConstructedColumn], + column_producers: Sequence[ConstructedColumn], raw_data_handler: DataHandler, cache: DataCache, unit_handler: UnitHandler, @@ -62,7 +62,7 @@ def __init__( region: Region, sort_by: Optional[tuple[str, bool]], ): - self.__producers = column_producers + self.__producers = list(column_producers) self.__raw_data_handler = raw_data_handler self.__cache = cache self.__unit_handler = unit_handler @@ -141,9 +141,10 @@ def in_memory( header: OpenCosmoHeader, unit_convention: UnitConvention, region: Region, - descriptions: Optional[dict[str, str]] = {}, + descriptions: Optional[dict[str, str]] = None, index: Optional[DataIndex] = None, ): + descriptions = descriptions or {} cache = ColumnCache.empty() cache.add_data( data_columns | metadata_columns, descriptions, set(metadata_columns.keys()) @@ -154,12 +155,16 @@ def in_memory( if isinstance(column, u.Quantity): units[name] = column.unit + producers = [ + RawColumn(cname, descriptions.get(cname, "None")) + for cname in data_columns.keys() + ] unit_handler = make_unit_handler_from_units(units, header, unit_convention) return DatasetState( + producers, EmptyHandler(), cache, - {}, unit_handler, header, set(data_columns.keys()), @@ -235,8 +240,13 @@ def get_data( """ Get the data for a given handler. """ + columns_to_get = copy(self.__columns) + if self.__sort_by is not None: + columns_to_get.add(self.__sort_by[0]) + data = instantiate_dataset( self.__producers, + columns_to_get, self.__raw_data_handler, self.__cache, self.__unit_handler, @@ -338,12 +348,12 @@ def make_schema(self, name: Optional[str] = None): data_schema, metadata_schema = self.__raw_data_handler.make_schema( raw_columns, header ) - derived_names = reduce( + derived_names: set[str] = reduce( lambda acc, col: acc.union( col.produces if not isinstance(col, RawColumn) else set() ), self.__producers, - set, + set(), ) derived_names = derived_names.intersection(self.columns) @@ -414,11 +424,12 @@ def with_new_columns( if inter := existing_columns.intersection(new_columns.keys()): raise ValueError(f"Some columns are already in the dataset: {inter}") - new_derived_columns = [] + new_derived_columns: list[ConstructedColumn] = [] new_in_memory_columns = {} new_in_memory_descriptions = {} new_column_names = self.columns + new_static_units = {} for colname, column in new_columns.items(): match column: case DerivedColumn(): @@ -426,10 +437,17 @@ def with_new_columns( column.description = descriptions.get(colname, "None") new_derived_columns.append(column) new_column_names.extend(column.produces) - case EvaluatedColumn() | Column(): + case EvaluatedColumn(): column.description = descriptions.get(colname, "None") new_derived_columns.append(column) new_column_names.extend(column.produces) + case Column(): + producer = RawColumn( + column.name, descriptions.get(colname, None), alias=colname + ) + new_derived_columns.append(producer) + new_column_names.extend(producer.produces) + case np.ndarray(): if len(column) != len(self): raise ValueError( @@ -440,15 +458,22 @@ def with_new_columns( ) new_in_memory_columns[colname] = column new_column_names.append(colname) - continue + new_derived_columns.append(RawColumn(colname, None)) + new_static_units[colname] = ( + column.unit if isinstance(column, u.Quantity) else None + ) + case _: raise ValueError( f"Got an invalid new column of type {type(column)}" ) - new_unit_handler = self.__unit_handler + new_unit_handler = self.__unit_handler.with_static_columns(**new_static_units) + new_producers = copy(self.__producers) + new_derived_columns - validate_column_producers(new_producers) + new_units = validate_column_producers(new_producers, new_unit_handler) + if new_units: + new_unit_handler = new_unit_handler.with_new_columns(**new_units) if new_in_memory_columns: new_unit_handler = validate_in_memory_columns( new_in_memory_columns, self.__unit_handler, len(self) @@ -627,12 +652,13 @@ def with_units( if convention_ == self.__unit_handler.current_convention: cache = self.__cache.create_child() else: - all_derived_names: set[str] = reduce( + all_derived_names: set[str] = set() + all_derived_names = reduce( lambda acc, col: acc.union( col.produces if not isinstance(col, RawColumn) else set() ), self.__producers, - set(), + all_derived_names, ).intersection(self.columns) columns_to_drop = all_derived_names.union(self.__raw_data_handler.columns) cache = self.__cache.drop(columns_to_drop) From b03aa2a0c4d4e7ba690d5e48f99012985326caaf Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 10 Apr 2026 14:35:26 -0500 Subject: [PATCH 037/139] Move rest of dataset instantiation logic to new file --- python/opencosmo/dataset/instantiate.py | 88 +++++++++++++++++++------ python/opencosmo/dataset/state.py | 32 +-------- 2 files changed, 72 insertions(+), 48 deletions(-) diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 51314e1d..d2ebdeb3 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING, Any + +import numpy as np import rustworkx as rx from opencosmo.column.column import RawColumn @@ -15,14 +21,27 @@ def get_all_required_columns(column_names: set[str], dependency_graph: rx.PyDiGr return required_columns +if TYPE_CHECKING: + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache, DataHandler + from opencosmo.index import DataIndex + from opencosmo.units.handler import UnitHandler + + def instantiate_dataset( - column_producers, - column_names, - raw_data_handler, - cache, - unit_handler, - unit_kwargs, + column_producers: list[ConstructedColumn], + column_names: set[str], + raw_data_handler: DataHandler, + cache: DataCache, + unit_handler: UnitHandler, + unit_kwargs: dict[str, Any], + metadata_columns: list[str] | None = None, + sort_by: tuple[str, bool] | None = None, ): + column_names = copy(column_names) + if sort_by is not None: + column_names.add(sort_by[0]) + dependency_graph = build_dependency_graph(column_producers) all_required_columns = get_all_required_columns(column_names, dependency_graph) cached_data = cache.get_data(all_required_columns) @@ -34,16 +53,14 @@ def instantiate_dataset( cache.add_data(converted_cached_data, {}, push_up=False) cached_data |= converted_cached_data - raw_columns = list( - filter( - lambda col: ( - isinstance(col, RawColumn) - and col.name not in cached_data - and col.name in all_required_columns - ), - column_producers, - ) - ) + raw_columns = [ + col + for col in column_producers + if isinstance(col, RawColumn) + and col.name not in cached_data + and col.name in all_required_columns + ] + raw_data = raw_data_handler.get_data(set(col.name for col in raw_columns)) for column in raw_columns: if column.alias is None: @@ -65,16 +82,22 @@ def instantiate_dataset( cache.add_data(converted_raw_data, {}, push_up=False) raw_data |= converted_raw_data - return cached_data | raw_data | new_derived_columns + data = cached_data | raw_data | new_derived_columns + data |= get_metadata_columns(raw_data_handler, cache, metadata_columns) + return sort_data(data, sort_by) def build_derived_columns( - column_producers, column_names, data, dependency_graph, index + column_producers: list[ConstructedColumn], + column_names: set[str], + data: dict[str, np.ndarray], + dependency_graph: rx.PyDiGraph, + index: DataIndex, ): dependency_graph = contract_derived_columns( dependency_graph, column_names, column_producers ) - new_derived = {} + new_derived: dict[str, ConstructedColumn] = {} for colidx in rx.topological_sort(dependency_graph): column = dependency_graph[colidx] if isinstance(column, str): @@ -91,3 +114,30 @@ def build_derived_columns( else: new_derived[column.name] = output return new_derived + + +def get_metadata_columns( + raw_data_handler: DataHandler, cache: DataCache, metadata_columns: list[str] | None +): + if metadata_columns is None: + return {} + metadata = cache.get_data(metadata_columns) + additional_metadata_columns_to_fetch = set(metadata_columns).difference( + metadata.keys() + ) + metadata |= ( + raw_data_handler.get_metadata(additional_metadata_columns_to_fetch) or {} + ) + + return metadata + + +def sort_data(data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None): + if sort_by is None: + return data + sort_column = data[sort_by[0]] + order = np.argsort(sort_column) + if sort_by[1]: + order = order[::-1] + + return {key: value[order] for key, value in data.items()} diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 0832fca9..edd263e0 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -240,47 +240,21 @@ def get_data( """ Get the data for a given handler. """ - columns_to_get = copy(self.__columns) - if self.__sort_by is not None: - columns_to_get.add(self.__sort_by[0]) - data = instantiate_dataset( self.__producers, - columns_to_get, + self.__columns, self.__raw_data_handler, self.__cache, self.__unit_handler, unit_kwargs, + metadata_columns, + None if ignore_sort else self.__sort_by, ) if missing := set(self.columns).difference(data.keys()): raise RuntimeError( f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" ) - # keep ordering - - if metadata_columns: - metadata = self.__cache.get_data(metadata_columns) - additional_metadata_columns_to_fetch = set(metadata_columns).difference( - metadata.keys() - ) - metadata |= ( - self.__raw_data_handler.get_metadata( - additional_metadata_columns_to_fetch - ) - or {} - ) - - data.update(metadata) - - if not ignore_sort and self.__sort_by is not None: - sort_by = data[self.__sort_by[0]] - order = np.argsort(sort_by) - if self.__sort_by[1]: - order = order[::-1] - - data = {key: value[order] for key, value in data.items()} - new_order = [c for c in self.columns] if metadata_columns: new_order.extend(metadata_columns) From b1091f37ad1b3b97977c47125908ac86555d8ed3 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 10 Apr 2026 14:45:59 -0500 Subject: [PATCH 038/139] Move make_schema to its own file --- python/opencosmo/dataset/output.py | 103 +++++++++++++++++++++++++++++ python/opencosmo/dataset/state.py | 85 ++++++------------------ 2 files changed, 122 insertions(+), 66 deletions(-) create mode 100644 python/opencosmo/dataset/output.py diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py new file mode 100644 index 00000000..a645ced1 --- /dev/null +++ b/python/opencosmo/dataset/output.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from functools import reduce +from typing import TYPE_CHECKING, Optional + +import astropy.units as u + +from opencosmo.column.column import RawColumn +from opencosmo.io.schema import ( + FileEntry, + combine_with_cached_schema, + make_schema, +) +from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter, NumpySource + +if TYPE_CHECKING: + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache, DataHandler + from opencosmo.header import OpenCosmoHeader + from opencosmo.io.schema import Schema + from opencosmo.spatial.protocols import Region + + +def get_derived_column_names( + producers: list[ConstructedColumn], columns: set[str] +) -> set[str]: + all_derived: set[str] = reduce( + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + producers, + set(), + ) + return all_derived.intersection(columns) + + +def build_derived_writers( + producers: list[ConstructedColumn], + derived_data: dict, + data_schema: Schema, + cached_data_schema: Schema, +) -> None: + """Add ColumnWriter entries to data_schema for each non-raw, non-cached producer.""" + for producer in producers: + if isinstance(producer, RawColumn) or producer.produces.issubset( + cached_data_schema.columns.keys() + ): + continue + coldata = {name: derived_data[name] for name in producer.produces} + units = { + name: str(cd.unit) if isinstance(cd, u.Quantity) else "" + for name, cd in coldata.items() + } + coldata = { + name: cd.value if isinstance(cd, u.Quantity) else cd + for name, cd in coldata.items() + } + for name, cd in coldata.items(): + attrs = {"unit": units[name], "description": producer.description} + source = NumpySource(cd) + writer = ColumnWriter([source], ColumnCombineStrategy.CONCAT, attrs=attrs) + data_schema.columns[name] = writer + + +def make_dataset_schema( + producers: list[ConstructedColumn], + raw_data_handler: DataHandler, + cache: DataCache, + columns: set[str], + meta_columns: list[str], + header: OpenCosmoHeader, + region: Region, + derived_data: dict, + name: Optional[str] = None, +) -> Schema: + header = header.with_region(region) + raw_columns = columns.intersection(raw_data_handler.columns) + data_schema, metadata_schema = raw_data_handler.make_schema(raw_columns, header) + + cached_data_schema, cached_metadata_schema = cache.make_schema( + list(columns) + meta_columns + ) + + build_derived_writers(producers, derived_data, data_schema, cached_data_schema) + + attributes = {} + if (load_conditions := raw_data_handler.load_conditions) is not None: + attributes["load/if"] = load_conditions + + data_schema = combine_with_cached_schema(data_schema, cached_data_schema) + metadata_schema = combine_with_cached_schema( + metadata_schema, cached_metadata_schema + ) + + children = {"data": data_schema} + if metadata_schema.type != FileEntry.EMPTY: + children[metadata_schema.name] = metadata_schema + if name is None: + name = "" + + return make_schema( + name, FileEntry.DATASET, children=children, attributes=attributes + ) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index edd263e0..1c1b6f26 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -14,13 +14,12 @@ from opencosmo.dataset.graph import validate_column_producers from opencosmo.dataset.im import resort, validate_in_memory_columns from opencosmo.dataset.instantiate import instantiate_dataset +from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema from opencosmo.handler.empty import EmptyHandler from opencosmo.handler.hdf5 import Hdf5Handler from opencosmo.index.build import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range -from opencosmo.io.schema import FileEntry, combine_with_cached_schema, make_schema -from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter, NumpySource from opencosmo.units import UnitConvention from opencosmo.units.handler import ( make_unit_handler_from_hdf5, @@ -316,71 +315,25 @@ def with_mask(self, mask: NDArray[np.bool_]): ) def make_schema(self, name: Optional[str] = None): - header = self.__header.with_region(self.__region) - raw_columns = self.__columns.intersection(self.__raw_data_handler.columns) - - data_schema, metadata_schema = self.__raw_data_handler.make_schema( - raw_columns, header - ) - derived_names: set[str] = reduce( - lambda acc, col: acc.union( - col.produces if not isinstance(col, RawColumn) else set() - ), + derived_names = get_derived_column_names(self.__producers, self.__columns) + if derived_names: + derived_data = ( + self.select(derived_names) + .with_units(self.__unit_handler.base_convention, {}, {}, None, None) + .get_data(ignore_sort=True) + ) + else: + derived_data = {} + return make_dataset_schema( self.__producers, - set(), - ) - derived_names = derived_names.intersection(self.columns) - - derived_data = ( - self.select(derived_names) - .with_units(self.__unit_handler.base_convention, {}, {}, None, None) - .get_data(ignore_sort=True) - ) - cached_data_schema, cached_metadata_schema = self.__cache.make_schema( - self.columns + self.meta_columns - ) - - for producer in self.__producers: - if isinstance(producer, RawColumn) or producer.produces.issubset( - cached_data_schema.columns.keys() - ): - continue - coldata = {name: derived_data[name] for name in producer.produces} - units = { - name: str(cd.unit) if isinstance(cd, u.Quantity) else "" - for name, cd in coldata.items() - } - coldata = { - name: cd.value if isinstance(cd, u.Quantity) else cd - for name, cd in coldata.items() - } - for name, cd in coldata.items(): - attrs = {"unit": units[name], "description": producer.description} - - source = NumpySource(cd) - writer = ColumnWriter( - [source], ColumnCombineStrategy.CONCAT, attrs=attrs - ) - data_schema.columns[name] = writer - - attributes = {} - if (load_conditions := self.__raw_data_handler.load_conditions) is not None: - attributes["load/if"] = load_conditions - - data_schema = combine_with_cached_schema(data_schema, cached_data_schema) - - metadata_schema = combine_with_cached_schema( - metadata_schema, cached_metadata_schema - ) - children = {"data": data_schema} - - if metadata_schema.type != FileEntry.EMPTY: - children[metadata_schema.name] = metadata_schema - if name is None: - name = "" - - return make_schema( - name, FileEntry.DATASET, children=children, attributes=attributes + self.__raw_data_handler, + self.__cache, + self.__columns, + self.meta_columns, + self.__header, + self.__region, + derived_data, + name, ) def with_new_columns( From 1f3f4f2a93d28aca4312f4bb80df2a7f7c24dfac Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 10 Apr 2026 15:28:55 -0500 Subject: [PATCH 039/139] Changelog --- changes/+caf6ec12.improvement.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/+caf6ec12.improvement.rst diff --git a/changes/+caf6ec12.improvement.rst b/changes/+caf6ec12.improvement.rst new file mode 100644 index 00000000..ef170bf4 --- /dev/null +++ b/changes/+caf6ec12.improvement.rst @@ -0,0 +1 @@ +Dataset instantiation and backend process has been reworked to allow for dynamic column updating. From e03f1d88c46be16250bc1ea3175840437b98ca7e Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 13 Apr 2026 15:03:33 -0400 Subject: [PATCH 040/139] type check --- python/opencosmo/analysis/yt_viz.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 04cd90e6..91ab604a 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -264,6 +264,7 @@ def visualize_halo( yt_dataset_provided = True else: yt_dataset_provided = False + yt_ds_arr = None if len(params["fields"]) == 4: # if 4 fields, make a 2x2 figure @@ -299,7 +300,7 @@ def visualize_halo( def halo_projection_array( halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, data: oc.StructureCollection, - yt_ds: Optional[ YT_Dataset | list[YT_Dataset] | tuple[list[YT_Dataset], list[YT_Dataset]] | np.ndarray ] = None, + yt_ds: Optional[ YT_Dataset | list[YT_Dataset|None] | tuple[list[YT_Dataset|None], list[YT_Dataset|None]] | np.ndarray ] = None, field: Optional[Tuple[str, str]] = ("dm", "particle_mass"), weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", @@ -415,7 +416,7 @@ def halo_projection_array( data = data.with_units("comoving") halo_ids_2d = np.atleast_2d(halo_ids) - yt_ds_2d = np.atleast_2d(yt_ds) + yt_ds_2d = np.atleast_2d(np.array(yt_ds)) # type: ignore # determine shape of figure fig_shape = np.shape(halo_ids) From dbaae37125ec8b09b0dda3936d5906fdb147b26d Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 13 Apr 2026 15:14:04 -0400 Subject: [PATCH 041/139] addressing failing tests --- python/opencosmo/analysis/yt_viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 91ab604a..a0699533 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -416,7 +416,7 @@ def halo_projection_array( data = data.with_units("comoving") halo_ids_2d = np.atleast_2d(halo_ids) - yt_ds_2d = np.atleast_2d(np.array(yt_ds)) # type: ignore + yt_ds_2d = np.atleast_2d(yt_ds) # type: ignore # determine shape of figure fig_shape = np.shape(halo_ids) From bba7d98b7edc1cbf5eead1548fa4aa32b517191f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 14 Apr 2026 11:48:15 -0500 Subject: [PATCH 042/139] Move with_new_columns out of state --- python/opencosmo/column/plugin.py | 33 ++++++++++++ python/opencosmo/dataset/im.py | 31 ------------ python/opencosmo/dataset/state.py | 83 +++++-------------------------- 3 files changed, 45 insertions(+), 102 deletions(-) create mode 100644 python/opencosmo/column/plugin.py delete mode 100644 python/opencosmo/dataset/im.py diff --git a/python/opencosmo/column/plugin.py b/python/opencosmo/column/plugin.py new file mode 100644 index 00000000..d1248e1e --- /dev/null +++ b/python/opencosmo/column/plugin.py @@ -0,0 +1,33 @@ +""" +A column plugin is capable of updating the values of column dynamically when it is instantiated. It is a simple function where the first argument is the name of the column. Additional arguments can take: + +1. Additional columns (by name) +2. The dataset index "index" +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from opencosmo.index import DataIndex + + +import numpy as np + +from opencosmo.index import into_array + + +def top_host_idx_updater(top_host_idx: np.ndarray, index: DataIndex): + top_host_idx = top_host_idx.astype(np.int64) + index_array = into_array(index) + + index_array_sort_idx = np.argsort(index_array) + index_sorted = index_array[index_array_sort_idx] + + positions = np.searchsorted(index_sorted, top_host_idx) + valid = (positions < len(index_sorted)) & (index_sorted[positions] == top_host_idx) + + new_top_host_idx = np.full_like(top_host_idx, -1) + new_top_host_idx[valid] = index_sorted[positions[valid]] + return valid diff --git a/python/opencosmo/dataset/im.py b/python/opencosmo/dataset/im.py deleted file mode 100644 index a9ee1f86..00000000 --- a/python/opencosmo/dataset/im.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import astropy.units as u -import numpy as np - -if TYPE_CHECKING: - from opencosmo.units.handler import UnitHandler - - -def resort(columns: dict[str, np.ndarray], sorted_index: Optional[np.ndarray]): - if sorted_index is None or not columns: - return columns - reverse_sort = np.argsort(sorted_index) - return {name: data[reverse_sort] for name, data in columns.items()} - - -def validate_in_memory_columns( - columns: dict[str, np.ndarray], unit_handler: UnitHandler, ds_length: int -): - new_units = {} - for colname, column in columns.items(): - if len(column) != ds_length: - raise ValueError(f"Column {colname} is not the same length as the dataset!") - if isinstance(column, u.Quantity): - new_units[colname] = column.unit - else: - new_units[colname] = None - - return unit_handler.with_static_columns(**new_units) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 1c1b6f26..dd0f916c 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -1,6 +1,5 @@ from __future__ import annotations -from copy import copy from functools import reduce from typing import TYPE_CHECKING, Iterable, Optional, Sequence from weakref import finalize @@ -9,10 +8,9 @@ import numpy as np from opencosmo.column.cache import ColumnCache -from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn, RawColumn +from opencosmo.column.column import RawColumn from opencosmo.column.select import get_column_selection -from opencosmo.dataset.graph import validate_column_producers -from opencosmo.dataset.im import resort, validate_in_memory_columns +from opencosmo.dataset.columns import add_columns, resort from opencosmo.dataset.instantiate import instantiate_dataset from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema from opencosmo.handler.empty import EmptyHandler @@ -345,73 +343,16 @@ def with_new_columns( Add a set of derived columns to the dataset. A derived column is a column that has been created based on the values in another column. """ - - existing_columns = set(self.columns) - - if inter := existing_columns.intersection(new_columns.keys()): - raise ValueError(f"Some columns are already in the dataset: {inter}") - - new_derived_columns: list[ConstructedColumn] = [] - new_in_memory_columns = {} - new_in_memory_descriptions = {} - new_column_names = self.columns - - new_static_units = {} - for colname, column in new_columns.items(): - match column: - case DerivedColumn(): - column.name = colname - column.description = descriptions.get(colname, "None") - new_derived_columns.append(column) - new_column_names.extend(column.produces) - case EvaluatedColumn(): - column.description = descriptions.get(colname, "None") - new_derived_columns.append(column) - new_column_names.extend(column.produces) - case Column(): - producer = RawColumn( - column.name, descriptions.get(colname, None), alias=colname - ) - new_derived_columns.append(producer) - new_column_names.extend(producer.produces) - - case np.ndarray(): - if len(column) != len(self): - raise ValueError( - f"New column {colname} does not have the same length as this dataset!" - ) - new_in_memory_descriptions[colname] = descriptions.get( - colname, "None" - ) - new_in_memory_columns[colname] = column - new_column_names.append(colname) - new_derived_columns.append(RawColumn(colname, None)) - new_static_units[colname] = ( - column.unit if isinstance(column, u.Quantity) else None - ) - - case _: - raise ValueError( - f"Got an invalid new column of type {type(column)}" - ) - - new_unit_handler = self.__unit_handler.with_static_columns(**new_static_units) - - new_producers = copy(self.__producers) + new_derived_columns - new_units = validate_column_producers(new_producers, new_unit_handler) - if new_units: - new_unit_handler = new_unit_handler.with_new_columns(**new_units) - if new_in_memory_columns: - new_unit_handler = validate_in_memory_columns( - new_in_memory_columns, self.__unit_handler, len(self) - ) - new_in_memory_columns = resort( - new_in_memory_columns, self.get_sorted_index() - ) - self.__cache.add_data( - new_in_memory_columns, descriptions=new_in_memory_descriptions - ) - + new_producers, new_column_names, new_unit_handler = add_columns( + self.__producers, + self.__unit_handler, + self.__cache, + self.columns, + self.get_sorted_index(), + descriptions, + new_columns, + len(self), + ) return self.__rebuild( cache=self.__cache, column_producers=new_producers, From d6aac0a3405f0cc201db9cbd8aa576e9c3caf574 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 14 Apr 2026 11:49:50 -0500 Subject: [PATCH 043/139] Add correct file --- python/opencosmo/column/plugin.py | 33 ------- python/opencosmo/dataset/columns.py | 136 ++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 33 deletions(-) delete mode 100644 python/opencosmo/column/plugin.py create mode 100644 python/opencosmo/dataset/columns.py diff --git a/python/opencosmo/column/plugin.py b/python/opencosmo/column/plugin.py deleted file mode 100644 index d1248e1e..00000000 --- a/python/opencosmo/column/plugin.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -A column plugin is capable of updating the values of column dynamically when it is instantiated. It is a simple function where the first argument is the name of the column. Additional arguments can take: - -1. Additional columns (by name) -2. The dataset index "index" -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from opencosmo.index import DataIndex - - -import numpy as np - -from opencosmo.index import into_array - - -def top_host_idx_updater(top_host_idx: np.ndarray, index: DataIndex): - top_host_idx = top_host_idx.astype(np.int64) - index_array = into_array(index) - - index_array_sort_idx = np.argsort(index_array) - index_sorted = index_array[index_array_sort_idx] - - positions = np.searchsorted(index_sorted, top_host_idx) - valid = (positions < len(index_sorted)) & (index_sorted[positions] == top_host_idx) - - new_top_host_idx = np.full_like(top_host_idx, -1) - new_top_host_idx[valid] = index_sorted[positions[valid]] - return valid diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py new file mode 100644 index 00000000..af670e36 --- /dev/null +++ b/python/opencosmo/dataset/columns.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING, Optional + +import astropy.units as u +import numpy as np + +from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn, RawColumn +from opencosmo.dataset.graph import validate_column_producers + +if TYPE_CHECKING: + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache + from opencosmo.units.handler import UnitHandler + + +def resort(columns: dict[str, np.ndarray], sorted_index: Optional[np.ndarray]): + if sorted_index is None or not columns: + return columns + reverse_sort = np.argsort(sorted_index) + return {name: data[reverse_sort] for name, data in columns.items()} + + +def validate_in_memory_columns( + columns: dict[str, np.ndarray], unit_handler: UnitHandler, ds_length: int +) -> UnitHandler: + new_units = {} + for colname, column in columns.items(): + if len(column) != ds_length: + raise ValueError(f"Column {colname} is not the same length as the dataset!") + if isinstance(column, u.Quantity): + new_units[colname] = column.unit + else: + new_units[colname] = None + + return unit_handler.with_static_columns(**new_units) + + +def __categorize_columns( + new_columns: dict, + descriptions: dict[str, str], + ds_length: int, +) -> tuple[list[ConstructedColumn], dict, dict, list[str], dict]: + """ + Classify incoming columns by type and build the producer list, in-memory column + dicts, and static unit map needed by add_columns. + + Returns: + new_derived_columns, new_in_memory_columns, new_in_memory_descriptions, + new_column_names, new_static_units + """ + new_derived_columns: list[ConstructedColumn] = [] + new_in_memory_columns: dict = {} + new_in_memory_descriptions: dict = {} + new_column_names: list[str] = [] + new_static_units: dict = {} + + for colname, column in new_columns.items(): + match column: + case DerivedColumn(): + column.name = colname + column.description = descriptions.get(colname, "None") + new_derived_columns.append(column) + new_column_names.extend(column.produces) + case EvaluatedColumn(): + column.description = descriptions.get(colname, "None") + new_derived_columns.append(column) + new_column_names.extend(column.produces) + case Column(): + producer = RawColumn( + column.name, descriptions.get(colname, None), alias=colname + ) + new_derived_columns.append(producer) + new_column_names.extend(producer.produces) + case np.ndarray(): + if len(column) != ds_length: + raise ValueError( + f"New column {colname} does not have the same length as this dataset!" + ) + new_in_memory_descriptions[colname] = descriptions.get(colname, "None") + new_in_memory_columns[colname] = column + new_column_names.append(colname) + new_derived_columns.append(RawColumn(colname, None)) + new_static_units[colname] = ( + column.unit if isinstance(column, u.Quantity) else None + ) + case _: + raise ValueError(f"Got an invalid new column of type {type(column)}") + + return ( + new_derived_columns, + new_in_memory_columns, + new_in_memory_descriptions, + new_column_names, + new_static_units, + ) + + +def add_columns( + producers: list[ConstructedColumn], + unit_handler: UnitHandler, + cache: DataCache, + existing_column_names: list[str], + sorted_index: np.ndarray | None, + descriptions: dict[str, str], + new_columns: dict, + ds_length: int, +) -> tuple[list[ConstructedColumn], list[str], UnitHandler]: + existing_columns = set(existing_column_names) + if inter := existing_columns.intersection(new_columns.keys()): + raise ValueError(f"Some columns are already in the dataset: {inter}") + + ( + new_derived_columns, + new_in_memory_columns, + new_in_memory_descriptions, + new_column_names, + new_static_units, + ) = __categorize_columns(new_columns, descriptions, ds_length) + + new_unit_handler = unit_handler.with_static_columns(**new_static_units) + new_producers = copy(producers) + new_derived_columns + + new_units = validate_column_producers(new_producers, new_unit_handler) + if new_units: + new_unit_handler = new_unit_handler.with_new_columns(**new_units) + + if new_in_memory_columns: + new_unit_handler = validate_in_memory_columns( + new_in_memory_columns, unit_handler, ds_length + ) + new_in_memory_columns = resort(new_in_memory_columns, sorted_index) + cache.add_data(new_in_memory_columns, descriptions=new_in_memory_descriptions) + + return new_producers, existing_column_names + new_column_names, new_unit_handler From 99953b8e6a738f4e9b8f7c00b16cf6d11c33f833 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 14 Apr 2026 17:00:12 -0500 Subject: [PATCH 044/139] Significant rewrite to ensure column producer uniqueness with UUIDs --- python/opencosmo/column/cache.py | 280 +++++++++++++++--------- python/opencosmo/column/column.py | 143 ++++++++++-- python/opencosmo/dataset/columns.py | 41 +++- python/opencosmo/dataset/evaluate.py | 10 +- python/opencosmo/dataset/graph.py | 116 +++------- python/opencosmo/dataset/instantiate.py | 244 ++++++++++++++------- python/opencosmo/dataset/output.py | 7 +- python/opencosmo/dataset/state.py | 69 +++--- python/opencosmo/handler/protocols.py | 43 +++- test/test_cache.py | 93 ++++---- test/test_collection.py | 9 +- test/test_dataset.py | 35 ++- 12 files changed, 720 insertions(+), 370 deletions(-) diff --git a/python/opencosmo/column/cache.py b/python/opencosmo/column/cache.py index c2be9ffc..ecf53329 100644 --- a/python/opencosmo/column/cache.py +++ b/python/opencosmo/column/cache.py @@ -1,12 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Iterable, Optional +from uuid import UUID from weakref import finalize, ref import astropy.units as u import numpy as np from opencosmo.index import DataIndex +from opencosmo.index.build import from_size from opencosmo.index.get import get_data from opencosmo.index.take import take from opencosmo.index.unary import get_length, get_range @@ -16,12 +18,14 @@ if TYPE_CHECKING: from opencosmo.index import DataIndex +# (producer_uuid, column_name) — the unambiguous key for a cached column. +CacheKey = tuple[UUID, str] ColumnUpdater = Callable[[np.ndarray | u.Quantity], np.ndarray | u.Quantity] def finish( - cached_data: dict[str, np.ndarray], + cached_data: dict[CacheKey, np.ndarray], index: Optional[DataIndex], cache_ref: ref[ColumnCache], ): @@ -29,23 +33,24 @@ def finish( if cache is None: return - columns_to_add = ( - cache.registered_columns.intersection(cached_data.keys()) - cache.columns + # pylint: disable=protected-access + if index is None: + index = from_size(len(cache)) + pairs_to_add = cache.registered_pairs.intersection(cached_data.keys()) - set( + cache.keys() ) - data = {name: cached_data[name] for name in columns_to_add} - if index is not None: - data = {name: get_data(cd, index) for name, cd in data.items()} + data = {key[0]: {key[1]: get_data(cached_data[key], index)} for key in pairs_to_add} if data: cache.add_data(data) -def check_length(cache: ColumnCache, data: dict[str, np.ndarray]): - lengths = set(len(d) for d in data.values()) +def check_length(cache: ColumnCache, data: dict[UUID, dict[str, np.ndarray]]): + lengths = {len(arr) for uuid_data in data.values() for arr in uuid_data.values()} if len(lengths) > 1: raise ValueError( "When adding data to the cache, all columns must be the same length" ) - elif (length := len(cache)) > 0 and length != lengths.pop(): + elif (length := len(cache)) > 0 and lengths and length != lengths.pop(): raise ValueError( "When adding data to the cache, the columns must be the same length as the columns currently in the cache" ) @@ -65,14 +70,17 @@ class ColumnCache: cache. This allows us to preserve the standard "operations create new objects" pattern that is present throughout the library. + Internal storage uses (producer_uuid, column_name) tuples as keys so that multiple + producers that happen to produce a column with the same name are kept separate. """ def __init__( self, - cached_data: dict[str, np.ndarray], - registered_column_groups: dict[int, set[str]], + cached_data: dict[CacheKey, np.ndarray], + registered_column_groups: dict[int, set[CacheKey]], column_descriptions: dict[str, str], metadata_columns: set[str], + metadata_data: dict[str, np.ndarray], derived_index: Optional[DataIndex], parent: Optional[ref[ColumnCache]], children: Optional[list[ref[ColumnCache]]], @@ -80,6 +88,7 @@ def __init__( self.__cached_data = cached_data self.__registered_column_groups = registered_column_groups self.__metadata_columns = metadata_columns + self.__metadata_data = metadata_data self.__descriptions = column_descriptions self.__derived_index = derived_index self.__parent = parent @@ -100,97 +109,123 @@ def __init__( @classmethod def empty(cls): - return ColumnCache({}, {}, {}, set(), None, None, []) + return ColumnCache({}, {}, {}, set(), {}, None, None, []) @property - def columns(self): + def columns(self) -> set[str]: + return {name for (_, name) in self.__cached_data.keys()} + + def keys(self) -> set[CacheKey]: return set(self.__cached_data.keys()) @property - def metadata_columns(self): + def metadata_columns(self) -> set[str]: return self.__metadata_columns @property - def descriptions(self): + def descriptions(self) -> dict[str, str]: return self.__descriptions @property - def registered_columns(self): - return set().union(*list(self.__registered_column_groups.values())) + def registered_pairs(self) -> set[CacheKey]: + if not self.__registered_column_groups: + return set() + return set().union(*self.__registered_column_groups.values()) - def create_child(self): - return ColumnCache({}, {}, {}, self.__metadata_columns, None, ref(self), []) + def create_child(self) -> ColumnCache: + return ColumnCache({}, {}, {}, self.__metadata_columns, {}, None, ref(self), []) - def make_schema(self, columns: Iterable[str]): + def make_schema( + self, columns_to_uuid: dict[str, UUID], meta_columns: list[str] + ) -> tuple: data = {} metadata = {} - columns = set(columns) - cached_data = self.get_data(columns) - if not cached_data: - return ( - make_schema("data", FileEntry.EMPTY), - make_schema("metadata", FileEntry.EMPTY), - ) - for name, coldata in cached_data.items(): + cached = self.get_data({(uuid, name) for name, uuid in columns_to_uuid.items()}) + for name, coldata in _flatten(cached).items(): if isinstance(coldata, u.Quantity): column_data = coldata.value unit_str = str(coldata.unit) else: column_data = coldata unit_str = "" - attrs = {"unit": unit_str} - attrs["description"] = self.descriptions.get(name, "None") + attrs = { + "unit": unit_str, + "description": self.__descriptions.get(name, "None"), + } writer = ColumnWriter.from_numpy_array(column_data, attrs=attrs) - if name in self.metadata_columns: - metadata[name] = writer + data[name] = writer + + for name, coldata in self.get_metadata(meta_columns).items(): + if isinstance(coldata, u.Quantity): + column_data = coldata.value + unit_str = str(coldata.unit) else: - data[name] = writer + column_data = coldata + unit_str = "" + attrs = { + "unit": unit_str, + "description": self.__descriptions.get(name, "None"), + } + writer = ColumnWriter.from_numpy_array(column_data, attrs=attrs) + metadata[name] = writer - data_schema = make_schema("data", FileEntry.COLUMNS, columns=data) - if metadata: - metadata_schema = make_schema( - "metadata", FileEntry.COLUMNS, columns=metadata + if not data and not metadata: + return ( + make_schema("data", FileEntry.EMPTY), + make_schema("metadata", FileEntry.EMPTY), ) - else: - metadata_schema = make_schema("metadata", FileEntry.EMPTY) + data_schema = ( + make_schema("data", FileEntry.COLUMNS, columns=data) + if data + else make_schema("data", FileEntry.EMPTY) + ) + metadata_schema = ( + make_schema("metadata", FileEntry.COLUMNS, columns=metadata) + if metadata + else make_schema("metadata", FileEntry.EMPTY) + ) return data_schema, metadata_schema - def __push_down(self, data: dict[str, np.ndarray]): - columns_to_keep = self.registered_columns.intersection(data.keys()).difference( + def __push_down(self, data: dict[CacheKey, np.ndarray]): + pairs_to_keep = self.registered_pairs.intersection(data.keys()).difference( self.__cached_data.keys() ) - cached_data = {colname: data[colname] for colname in columns_to_keep} + cached_data = {key: data[key] for key in pairs_to_keep} if self.__derived_index is not None: cached_data = { - colname: get_data(coldata, self.__derived_index) - for colname, coldata in cached_data.items() + key: get_data(cd, self.__derived_index) + for key, cd in cached_data.items() } - self.__cached_data |= cached_data - def __push_up(self, data: dict[str, np.ndarray]): + def __push_up(self, data: dict[CacheKey, np.ndarray]): assert len(self) == 0 or all(len(d) == len(self) for d in data.values()) - columns_to_keep = self.registered_columns.intersection(data.keys()).difference( + pairs_to_keep = self.registered_pairs.intersection(data.keys()).difference( self.__cached_data.keys() ) - self.__cached_data |= {key: data[key] for key in columns_to_keep} + self.__cached_data |= {key: data[key] for key in pairs_to_keep} - def register_column_group(self, state_id: int, columns: Iterable[str]): + def register_column_group(self, state_id: int, columns: dict[str, UUID]): assert state_id not in self.__registered_column_groups - self.__registered_column_groups[state_id] = set(columns) + self.__registered_column_groups[state_id] = { + (uuid, name) for name, uuid in columns.items() + } def deregister_column_group(self, state_id: int): assert state_id in self.__registered_column_groups - columns = self.__registered_column_groups.pop(state_id) - remaining_columns = set().union(*list(self.__registered_column_groups.values())) - - to_drop = columns.difference(remaining_columns) + pairs = self.__registered_column_groups.pop(state_id) + remaining = ( + set().union(*self.__registered_column_groups.values()) + if self.__registered_column_groups + else set() + ) + to_drop = pairs.difference(remaining) cached_data = { - name: self.__cached_data.pop(name) - for name in to_drop - if name in self.__cached_data + key: self.__cached_data.pop(key) + for key in to_drop + if key in self.__cached_data } if not cached_data: return @@ -222,58 +257,70 @@ def __len__(self): def add_data( self, - data: dict[str, np.ndarray], + data: dict[UUID, dict[str, np.ndarray]], descriptions: dict[str, str] = {}, - metadata_columns: Optional[set[str]] = None, - push_up=True, + push_up: bool = True, ): - """ - The in-place equivalent of with_data. Should not be used outside the context of this - file. - - """ + """Add UUID-keyed column data to the cache.""" check_length(self, data) - if metadata_columns is not None: - assert metadata_columns.issubset(data.keys()) - self.__metadata_columns = self.__metadata_columns.union(metadata_columns) - self.__descriptions |= descriptions + + flat: dict[CacheKey, np.ndarray] = { + (uuid, name): arr + for uuid, uuid_data in data.items() + for name, arr in uuid_data.items() + } if ( push_up and self.__derived_index is None and self.__parent is not None and (p := self.__parent()) is not None ): - p.__push_up(data) + p.__push_up(flat) - self.__cached_data = self.__cached_data | data + self.__cached_data |= flat - def drop(self, columns: Iterable[str]): - columns = set(columns) - columns_to_drop = set(self.__cached_data.keys()).intersection(columns) + def add_metadata( + self, + data: dict[str, np.ndarray], + descriptions: dict[str, str] = {}, + ): + """Add metadata columns (name-keyed, no producer UUID) to the cache.""" + self.__metadata_columns = self.__metadata_columns.union(data.keys()) + self.__descriptions |= descriptions + self.__metadata_data |= data + + def drop(self, column_names: Iterable[str]) -> ColumnCache: + names_to_drop = set(column_names) data = { - name: data - for name, data in self.__cached_data.items() - if name not in columns_to_drop + key: val + for key, val in self.__cached_data.items() + if key[1] not in names_to_drop } descriptions = { name: desc for name, desc in self.__descriptions.items() - if name not in columns_to_drop + if name not in names_to_drop } - new_meta_columns = self.__metadata_columns.difference(columns) - - return ColumnCache(data, {}, descriptions, new_meta_columns, None, None, []) - - def request(self, column_names: Iterable[str], index: Optional[DataIndex]): - column_names = set(column_names) - columns_in_cache = column_names.intersection(self.__cached_data.keys()) + new_meta_columns = self.__metadata_columns.difference(names_to_drop) + new_meta_data = { + name: val + for name, val in self.__metadata_data.items() + if name not in names_to_drop + } + return ColumnCache( + data, {}, descriptions, new_meta_columns, new_meta_data, None, None, [] + ) - data = {name: self.__cached_data[name] for name in columns_in_cache} + def request( + self, pairs: set[CacheKey], index: Optional[DataIndex] + ) -> dict[CacheKey, np.ndarray]: + pairs_in_cache = pairs.intersection(self.__cached_data.keys()) + data = {key: self.__cached_data[key] for key in pairs_in_cache} if index is not None: - data = {name: get_data(cd, index) for name, cd in data.items()} + data = {key: get_data(cd, index) for key, cd in data.items()} - if self.__parent is None or column_names == columns_in_cache: + if self.__parent is None or pairs == pairs_in_cache: return data parent = self.__parent() @@ -291,9 +338,9 @@ def request(self, column_names: Iterable[str], index: Optional[DataIndex]): assert self.__derived_index is not None and index is not None new_index = take(self.__derived_index, index) - return data | parent.request(column_names, new_index) + return data | parent.request(pairs - pairs_in_cache, new_index) - def take(self, index: DataIndex): + def take(self, index: DataIndex) -> ColumnCache: if len(self) == 0 and not self.columns: return ColumnCache.empty() if get_range(index)[1] > len(self): @@ -301,26 +348,55 @@ def take(self, index: DataIndex): "Tried to take more elements than the length of the cache!" ) new_cache = ColumnCache( - {}, {}, {}, self.__metadata_columns, index, ref(self), [] + {}, {}, {}, self.__metadata_columns, {}, index, ref(self), [] ) self.__children.append(ref(new_cache)) return new_cache - def get_data(self, columns: Iterable[str]): - columns = set(columns) - columns_in_cache = columns.intersection(self.__cached_data.keys()) - missing_columns = columns - columns_in_cache - output = {c: self.__cached_data[c] for c in columns_in_cache} - output |= self.__get_derived_columns(missing_columns) - return output + def get_data(self, pairs: set[CacheKey]) -> dict[UUID, dict[str, np.ndarray]]: + """ + Retrieve data for the requested (uuid, name) pairs. Returns a + UUID-keyed dict so callers can look up each producer's contribution. + """ + pairs_in_cache = pairs.intersection(self.__cached_data.keys()) + missing_pairs = pairs - pairs_in_cache + flat = {key: self.__cached_data[key] for key in pairs_in_cache} + flat |= self.__get_derived_pairs(missing_pairs) + return _unflatten(flat) + + def get_metadata(self, column_names: Iterable[str]) -> dict[str, np.ndarray]: + """Retrieve name-keyed metadata columns.""" + names = set(column_names) + result = { + name: self.__metadata_data[name] + for name in names + if name in self.__metadata_data + } + if self.__parent is not None and (p := self.__parent()) is not None: + missing = names - set(result.keys()) + if missing: + result |= p.get_metadata(missing) + return result - def __get_derived_columns(self, column_names: set[str]): + def __get_derived_pairs(self, pairs: set[CacheKey]) -> dict[CacheKey, np.ndarray]: if self.__parent is None: return {} parent = self.__parent() if parent is None: return {} - result = parent.request(column_names, self.__derived_index) - - self.__cached_data = self.__cached_data | result + result = parent.request(pairs, self.__derived_index) + self.__cached_data |= result return result + + +def _flatten(uuid_data: dict[UUID, dict[str, np.ndarray]]) -> dict[str, np.ndarray]: + """Collapse UUID-keyed data to a flat name-keyed dict. Last writer wins.""" + return {name: arr for d in uuid_data.values() for name, arr in d.items()} + + +def _unflatten(flat: dict[CacheKey, np.ndarray]) -> dict[UUID, dict[str, np.ndarray]]: + """Group (uuid, name) → array back into {uuid: {name: array}}.""" + result: dict[UUID, dict[str, np.ndarray]] = {} + for (uuid, name), arr in flat.items(): + result.setdefault(uuid, {})[name] = arr + return result diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 5a0aeb92..edfb9a66 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -2,7 +2,7 @@ import operator as op from copy import copy -from functools import cached_property, partial, partialmethod +from functools import partial, partialmethod from inspect import signature from typing import ( TYPE_CHECKING, @@ -14,6 +14,7 @@ Self, Union, ) +from uuid import uuid4 import astropy.units as u # type: ignore import numpy as np @@ -28,6 +29,8 @@ from opencosmo.units import UnitsError if TYPE_CHECKING: + from uuid import UUID + from opencosmo import Dataset Comparison = Callable[[float, float], bool] @@ -284,12 +287,22 @@ class ConstructedColumn(Protocol): pass @property - def requires(self) -> set[str]: ... + def uuid(self) -> UUID: ... + + @property + def requires(self) -> set[UUID]: ... + + @property + def dep_map(self) -> dict[str, UUID] | None: ... + @property def produces(self) -> set[str]: ... + @property def description(self) -> Optional[str]: ... + def bind(self, name_to_uuid: dict[str, UUID]) -> Self: ... + def evaluate( self, data: dict[str, np.ndarray], @@ -300,20 +313,48 @@ def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... class RawColumn: - def __init__(self, name, description, alias=None): + def __init__(self, name, description, alias=None, _dep_uuid=None): self.__name = name self.__description = description self.__alias = alias + self.__uuid = uuid4() + self.__dep_uuid: UUID | None = _dep_uuid + + @property + def uuid(self) -> UUID: + return self.__uuid @property def name(self): return self.__name + def bind(self, name_to_uuid: dict[str, UUID]) -> RawColumn: + if self.__alias is None: + return self + dep_uuid = name_to_uuid[self.__name] + return RawColumn( + self.__name, self.__description, alias=self.__alias, _dep_uuid=dep_uuid + ) + @property - def requires(self) -> set[str]: - if self.__alias is not None: - return set([self.__name]) - return set() + def requires(self) -> set[UUID]: + if self.__alias is None: + return set() + if self.__dep_uuid is None: + raise RuntimeError( + f"RawColumn alias '{self.__alias}' has not been bound yet." + ) + return {self.__dep_uuid} + + @property + def dep_map(self) -> dict[str, UUID]: + if self.__alias is None: + return {} + if self.__dep_uuid is None: + raise RuntimeError( + f"RawColumn alias '{self.__alias}' has not been bound yet." + ) + return {self.__name: self.__dep_uuid} @property def alias(self) -> str | None: @@ -357,32 +398,61 @@ def __init__( operation: Callable, description: Optional[str] = None, output_name: Optional[str] = None, + _dep_map: dict[str, UUID] | None = None, ): self.lhs = lhs self.rhs = rhs self.name = output_name self.operation = operation self.description = description if description is not None else "None" + self.__uuid = uuid4() + self.__dep_map: dict[str, UUID] | None = _dep_map - @cached_property - def requires(self): + @property + def uuid(self) -> UUID: + return self.__uuid + + @property + def dep_map(self) -> dict[str, UUID] | None: + return self.__dep_map + + def bind(self, name_to_uuid: dict[str, UUID]) -> DerivedColumn: """ - Return the raw data columns required to make this column + Resolve each dependency column name to the UUID of the producer that was + producing it at the time this column was registered with a dataset. + Returns a new bound DerivedColumn; does not mutate this instance. """ - vals = set() + dep_map = {name: name_to_uuid[name] for name in self._traverse_names()} + return DerivedColumn( + self.lhs, + self.rhs, + self.operation, + self.description, + self.name, + _dep_map=dep_map, + ) + + def _traverse_names(self) -> set[str]: + """Walk the expression tree and collect all Column leaf names.""" + vals: set[str] = set() match self.lhs: case Column(): vals.add(self.lhs.name) case DerivedColumn(): - vals = vals | self.lhs.requires + vals |= self.lhs._traverse_names() match self.rhs: case Column(): vals.add(self.rhs.name) case DerivedColumn(): - vals = vals | self.rhs.requires - + vals |= self.rhs._traverse_names() return vals + @property + def requires(self) -> set[UUID]: + if self.__dep_map is None: + raise RuntimeError(f"DerivedColumn '{self.name}' has not been bound yet.") + return set(self.__dep_map.values()) + @property def produces(self): return None if self.name is None else set([self.name]) @@ -516,6 +586,7 @@ def __init__( strategy: EvaluateStrategy = EvaluateStrategy.ROW_WISE, batch_size: int = -1, description: Optional[str] = None, + _dep_map: dict[str, UUID] | None = None, **kwargs: Any, ): self.__func = func @@ -527,6 +598,43 @@ def __init__( self.__strategy = strategy self.__batch_size = batch_size self.description = description + self.__uuid = uuid4() + self.__dep_map = _dep_map + + @property + def uuid(self) -> UUID: + return self.__uuid + + @property + def dep_map(self) -> dict[str, UUID] | None: + return self.__dep_map + + def bind(self, name_to_uuid: dict[str, UUID]) -> EvaluatedColumn: + """ + Resolve each dependency column name to the UUID of the producer that was + producing it at the time this column was registered with a dataset. + Returns a new bound EvaluatedColumn; does not mutate this instance. + """ + dep_map = {name: name_to_uuid[name] for name in self.__requires} + return EvaluatedColumn( + self.__func, + self.__requires, + self.__produces, + self.__format, + self.__units, + self.__strategy, + self.__batch_size, + self.description, + _dep_map=dep_map, + **self.__kwargs, + ) + + @property + def requires_names(self) -> set[str]: + """Return the required column names (as strings), for use in data lookup.""" + if self.__dep_map is not None: + return set(self.__dep_map.keys()) + return copy(self.__requires) def with_kwargs(self, **new_kwargs: Any): new_kwargs = self.__kwargs | new_kwargs @@ -539,6 +647,7 @@ def with_kwargs(self, **new_kwargs: Any): self.__strategy, self.__batch_size, self.description, + _dep_map=self.__dep_map, **new_kwargs, ) @@ -547,8 +656,10 @@ def name(self): return self.__func.__name__ @property - def requires(self): - return copy(self.__requires) + def requires(self) -> set[UUID]: + if self.__dep_map is None: + raise RuntimeError(f"EvaluatedColumn '{self.name}' has not been bound yet.") + return set(self.__dep_map.values()) @property def produces(self): diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py index af670e36..19290c4a 100644 --- a/python/opencosmo/dataset/columns.py +++ b/python/opencosmo/dataset/columns.py @@ -10,10 +10,14 @@ from opencosmo.dataset.graph import validate_column_producers if TYPE_CHECKING: + from uuid import UUID + from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache from opencosmo.units.handler import UnitHandler + ColumnMap = dict[str, UUID] + def resort(columns: dict[str, np.ndarray], sorted_index: Optional[np.ndarray]): if sorted_index is None or not columns: @@ -101,14 +105,13 @@ def add_columns( producers: list[ConstructedColumn], unit_handler: UnitHandler, cache: DataCache, - existing_column_names: list[str], + name_to_uuid: ColumnMap, sorted_index: np.ndarray | None, descriptions: dict[str, str], new_columns: dict, ds_length: int, -) -> tuple[list[ConstructedColumn], list[str], UnitHandler]: - existing_columns = set(existing_column_names) - if inter := existing_columns.intersection(new_columns.keys()): +) -> tuple[list[ConstructedColumn], ColumnMap, UnitHandler]: + if inter := set(name_to_uuid.keys()).intersection(new_columns.keys()): raise ValueError(f"Some columns are already in the dataset: {inter}") ( @@ -119,6 +122,18 @@ def add_columns( new_static_units, ) = __categorize_columns(new_columns, descriptions, ds_length) + # Extend the name→UUID map with new producers' outputs so that columns added + # in the same with_new_columns call can reference each other. + extended_name_to_uuid = dict(name_to_uuid) + for producer in new_derived_columns: + if producer.produces: + for name in producer.produces: + extended_name_to_uuid[name] = producer.uuid + + new_derived_columns = [ + producer.bind(extended_name_to_uuid) for producer in new_derived_columns + ] + new_unit_handler = unit_handler.with_static_columns(**new_static_units) new_producers = copy(producers) + new_derived_columns @@ -131,6 +146,18 @@ def add_columns( new_in_memory_columns, unit_handler, ds_length ) new_in_memory_columns = resort(new_in_memory_columns, sorted_index) - cache.add_data(new_in_memory_columns, descriptions=new_in_memory_descriptions) - - return new_producers, existing_column_names + new_column_names, new_unit_handler + # Build UUID-keyed data: each in-memory column is owned by its RawColumn producer. + uuid_keyed = { + producer.uuid: {producer.name: new_in_memory_columns[producer.name]} + for producer in new_derived_columns + if isinstance(producer, RawColumn) + and producer.name in new_in_memory_columns + } + cache.add_data(uuid_keyed, descriptions=new_in_memory_descriptions) + + updated_columns = dict(name_to_uuid) + for producer in new_derived_columns: + for name in producer.produces: + updated_columns[name] = producer.uuid + + return new_producers, updated_columns, new_unit_handler diff --git a/python/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py index 9910e22b..38968dfd 100644 --- a/python/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -69,11 +69,12 @@ def visit_dataset( ) -> dict[str, np.ndarray]: if column.batch_size > 0: return visit_dataset_batched(column, dataset) - data = dataset.select(column.requires).get_data(format=column.format) + requires_names = column.requires_names + data = dataset.select(requires_names).get_data(format=column.format) try: data = dict(data) except (TypeError, ValueError): - data = {column.requires.pop(): data} + data = {next(iter(requires_names)): data} output = column.evaluate(data, dataset.index) if not isinstance(output, dict): assert len(column.produces) == 1 @@ -88,16 +89,17 @@ def visit_dataset_batched(column: EvaluatedColumn, dataset: Dataset): output = defaultdict(list) + requires_names = column.requires_names for start, end in np.lib.stride_tricks.sliding_window_view(ranges, 2): batch_data = ( - dataset.select(column.requires) + dataset.select(requires_names) .take_range(start, end) .get_data(format=column.format, unpack=False) ) try: batch_data = dict(batch_data) except TypeError: - batch_data = {column.requires.pop(): batch_data} + batch_data = {next(iter(requires_names)): batch_data} batch_output = column.evaluate(batch_data, None) if batch_output is not None and not isinstance(batch_output, dict): batch_output = {column.produces.pop(): batch_output} diff --git a/python/opencosmo/dataset/graph.py b/python/opencosmo/dataset/graph.py index 6123b923..32ce073b 100644 --- a/python/opencosmo/dataset/graph.py +++ b/python/opencosmo/dataset/graph.py @@ -1,7 +1,6 @@ from __future__ import annotations from functools import reduce -from itertools import product from typing import TYPE_CHECKING import rustworkx as rx @@ -9,6 +8,8 @@ from opencosmo.column.column import RawColumn if TYPE_CHECKING: + from uuid import UUID + import astropy.units as u from opencosmo.column.column import ConstructedColumn @@ -21,106 +22,59 @@ def validate_column_producers( """ Validate the network of column producers. """ - raw_columns = set(col.name for col in producers if isinstance(col, RawColumn)) - dependency_graph = build_dependency_graph(producers) + if cycle := rx.digraph_find_cycle(dependency_graph): all_nodes: set[int] = reduce( lambda known, edge: known.union(edge), cycle, set() ) - names = [dependency_graph[i] for i in all_nodes] + names = [dependency_graph[i].produces for i in all_nodes] raise ValueError(f"Found columns that depend on each other! Columns: {names}") - sources = set( - filter( - lambda i: not dependency_graph.in_degree(i), - range(dependency_graph.num_nodes()), - ) - ) - source_names = set(map(lambda i: dependency_graph[i], sources)) - if missing := set(source_names).difference(raw_columns): - raise ValueError(f"Tried to derive columns from unknown columns: {missing}") - - return get_derived_units(dependency_graph, producers, unit_handler.base_units) - - -def build_dependency_graph(producers: list[ConstructedColumn]): - dependency_graph = rx.PyDiGraph() - all_requires: set[str] = reduce( - lambda known, dc: known.union(dc.requires if dc.requires is not None else []), - producers, - set(), - ) - nodeidx = dependency_graph.add_nodes_from(all_requires) - nodemap = {name: idx for (name, idx) in zip(all_requires, nodeidx)} - - for column_producer in producers: - requires = column_producer.requires - produces = column_producer.produces - assert produces is not None - to_add = list(filter(lambda p: p not in nodemap, produces)) - new_map = dependency_graph.add_nodes_from(to_add) - nodemap.update({name: idx for (name, idx) in zip(to_add, new_map)}) - if not requires: + for i in range(dependency_graph.num_nodes()): + if dependency_graph.in_degree(i): continue + node = dependency_graph[i] + if not isinstance(node, RawColumn): + raise ValueError( + f"Tried to derive columns from unknown columns: {node.produces}" + ) - requires_idx = tuple(nodemap[r] for r in requires) - produces_idx = tuple(nodemap[r] for r in produces) - - dependency_graph.add_edges_from_no_data(product(requires_idx, produces_idx)) - - return dependency_graph + return get_derived_units(dependency_graph, unit_handler.base_units) -def contract_derived_columns( - graph: rx.PyDiGraph, - column_names: set[str], - column_producers: list[ConstructedColumn], -): - """ - Some derived columns actually produce multiple outputs. At this stage, the dependency - graph is working solely with actual column names, meaning if any of those columns is - produced by one of these "multi-produces" they will not be in the derived_columns - dictionary and therefore cannot be instantiated. This function replaces such - columns with the name of the derived_column that produces them. - """ - node_map = {name: i for i, name in enumerate(graph.nodes())} - nodes_to_keep: set[int] = set() - for producer in column_producers: - if producer.produces.intersection(column_names): - produces_idx = {node_map[name] for name in producer.produces} - nodes_to_keep = reduce( - lambda acc, node: acc.union(rx.ancestors(graph, node)), - produces_idx, - nodes_to_keep, +def build_dependency_graph( + producers: list[ConstructedColumn], +) -> rx.PyDiGraph: + graph = rx.PyDiGraph() + uuid_to_node: dict[UUID, int] = {} + + for producer in producers: + node_idx = graph.add_node(producer) + uuid_to_node[producer.uuid] = node_idx + + for producer in producers: + produces_idx = uuid_to_node[producer.uuid] + if not producer.requires.issubset(uuid_to_node.keys()): + raise ValueError( + f"Producer {producer.produces} depends on an unknown producer UUID." ) - nodes_to_keep.update(produces_idx) - subgraph = graph.subgraph(list(nodes_to_keep)) - node_map = {name: i for i, name in enumerate(subgraph.nodes())} + new_edges = ( + (uuid_to_node[dep_uuid], produces_idx) for dep_uuid in producer.requires + ) + graph.add_edges_from_no_data(new_edges) - for producer in column_producers: - if isinstance(producer, RawColumn): - continue - produces = producer.produces - produces_index = [node_map[name] for name in produces if name in node_map] - if produces_index: - subgraph.contract_nodes(produces_index, producer) - return subgraph + return graph def get_derived_units( dependency_graph: rx.PyDiGraph, - producers: list[ConstructedColumn], units: dict[str, u.Unit], ): - dependency_graph = contract_derived_columns( - dependency_graph, set(dependency_graph.nodes()), producers - ) - new_units: dict[str, u.Unit | None] = {} - for node in rx.topological_sort(dependency_graph): - node = dependency_graph[node] - if isinstance(node, str): + for node_idx in rx.topological_sort(dependency_graph): + node = dependency_graph[node_idx] + if isinstance(node, RawColumn): continue column_units = node.get_units(units | new_units) if not isinstance(column_units, dict): diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index d2ebdeb3..d18bd84a 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -1,36 +1,115 @@ from __future__ import annotations -from copy import copy from typing import TYPE_CHECKING, Any import numpy as np import rustworkx as rx from opencosmo.column.column import RawColumn -from opencosmo.dataset.graph import build_dependency_graph, contract_derived_columns - - -def get_all_required_columns(column_names: set[str], dependency_graph: rx.PyDiGraph): - required_columns = set() - node_map = {name: i for i, name in enumerate(dependency_graph.nodes())} - for name in column_names: - required_columns.add(name) - ancestors = rx.ancestors(dependency_graph, node_map[name]) - required_columns.update(dependency_graph[i] for i in ancestors) - - return required_columns - +from opencosmo.dataset.graph import build_dependency_graph if TYPE_CHECKING: + from uuid import UUID + + from opencosmo.column.cache import CacheKey from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache, DataHandler from opencosmo.index import DataIndex from opencosmo.units.handler import UnitHandler +def get_all_required_pairs( + columns_to_uuid: dict[str, UUID], dependency_graph: rx.PyDiGraph +) -> set[CacheKey]: + """ + Return the full set of (producer_uuid, column_name) pairs needed to + produce the requested columns, including all transitive dependencies. + """ + uuid_to_node: dict[UUID, int] = { + dependency_graph[i].uuid: i for i in range(dependency_graph.num_nodes()) + } + required_nodes: set[int] = set() + for uuid in columns_to_uuid.values(): + if uuid in uuid_to_node: + node_idx = uuid_to_node[uuid] + required_nodes.add(node_idx) + required_nodes.update(rx.ancestors(dependency_graph, node_idx)) + + pairs: set[CacheKey] = {(uuid, name) for name, uuid in columns_to_uuid.items()} + for node_idx in required_nodes: + producer = dependency_graph[node_idx] + for name in producer.produces: + pairs.add((producer.uuid, name)) + return pairs + + +def build_initial_uuid_data( + column_producers: list[ConstructedColumn], + raw_data: dict[str, np.ndarray], + cached_data: dict[UUID, dict[str, np.ndarray]], +) -> dict[UUID, dict[str, np.ndarray]]: + """ + Merge cached and freshly-fetched raw data into UUID-keyed storage. + Cached data is the starting point; raw data fills in any gaps. + """ + uuid_data: dict[UUID, dict[str, np.ndarray]] = {**cached_data} + for producer in column_producers: + if not isinstance(producer, RawColumn) or producer.uuid in uuid_data: + continue + output_name = producer.alias or producer.name + if output_name in raw_data: + uuid_data[producer.uuid] = {output_name: raw_data[output_name]} + return uuid_data + + +def build_derived_columns( + columns_to_uuid: dict[str, UUID], + uuid_data: dict[UUID, dict[str, np.ndarray]], + dependency_graph: rx.PyDiGraph, + index: DataIndex, +) -> dict[UUID, dict[str, np.ndarray]]: + """ + Evaluate all derived producers needed to produce the requested columns, + in topological order. Each producer's inputs are resolved by UUID via + dep_map, so column-name shadowing cannot cause a derived column to + receive data from the wrong producer. + """ + uuid_to_node: dict[UUID, int] = { + dependency_graph[i].uuid: i for i in range(dependency_graph.num_nodes()) + } + + required_uuids: set[UUID] = set(columns_to_uuid.values()) + for producer_uuid in list(required_uuids): + if producer_uuid in uuid_to_node: + for node_idx in rx.ancestors(dependency_graph, uuid_to_node[producer_uuid]): + required_uuids.add(dependency_graph[node_idx].uuid) + + new_derived: dict[UUID, dict[str, np.ndarray]] = {} + for node_idx in rx.topological_sort(dependency_graph): + producer = dependency_graph[node_idx] + if isinstance(producer, RawColumn): + continue + if producer.uuid not in required_uuids or producer.uuid in uuid_data: + continue + + all_data = uuid_data | new_derived + input_data = { + name: all_data[dep_uuid][name] + for name, dep_uuid in producer.dep_map.items() + } + output = producer.evaluate( + input_data, index[1] if isinstance(index, tuple) else None + ) + if isinstance(output, dict): + new_derived[producer.uuid] = output + else: + new_derived[producer.uuid] = {next(iter(producer.produces)): output} + return new_derived + + def instantiate_dataset( column_producers: list[ConstructedColumn], - column_names: set[str], + columns_to_uuid: dict[str, UUID], raw_data_handler: DataHandler, cache: DataCache, unit_handler: UnitHandler, @@ -38,29 +117,39 @@ def instantiate_dataset( metadata_columns: list[str] | None = None, sort_by: tuple[str, bool] | None = None, ): - column_names = copy(column_names) - if sort_by is not None: - column_names.add(sort_by[0]) + # Extend working_columns with the sort column if it isn't already included. + working_columns = dict(columns_to_uuid) + if sort_by is not None and sort_by[0] not in working_columns: + sort_name = sort_by[0] + for producer in column_producers: + if sort_name in producer.produces: + working_columns[sort_name] = producer.uuid + break dependency_graph = build_dependency_graph(column_producers) - all_required_columns = get_all_required_columns(column_names, dependency_graph) - cached_data = cache.get_data(all_required_columns) - converted_cached_data = unit_handler.apply_unit_conversions( - cached_data, unit_kwargs - ) - - if converted_cached_data: - cache.add_data(converted_cached_data, {}, push_up=False) - cached_data |= converted_cached_data - + required_pairs = get_all_required_pairs(working_columns, dependency_graph) + + cached_data = cache.get_data(required_pairs) + + # Apply unit conversions to cached data. The unit handler works on flat + # name-keyed dicts; we flatten, convert, then fold results back in. + flat_cached = _flatten(cached_data) + converted_flat = unit_handler.apply_unit_conversions(flat_cached, unit_kwargs) + if converted_flat: + converted_uuid = _apply_uuid_mapping(converted_flat, required_pairs) + cache.add_data(converted_uuid, {}, push_up=False) + for uuid, col_data in converted_uuid.items(): + cached_data.setdefault(uuid, {}).update(col_data) + + # Determine which raw columns still need to be fetched from the handler. + cached_uuids = set(cached_data.keys()) raw_columns = [ col for col in column_producers if isinstance(col, RawColumn) - and col.name not in cached_data - and col.name in all_required_columns + and col.uuid not in cached_uuids + and col.name in {name for (_, name) in required_pairs} ] - raw_data = raw_data_handler.get_data(set(col.name for col in raw_columns)) for column in raw_columns: if column.alias is None: @@ -68,67 +157,50 @@ def instantiate_dataset( raw_data[column.alias] = raw_data[column.name] raw_data = unit_handler.apply_raw_units(raw_data, unit_kwargs) - new_derived_columns = build_derived_columns( - column_producers, - all_required_columns, - cached_data | raw_data, - dependency_graph, - raw_data_handler.index, + + uuid_data = build_initial_uuid_data(column_producers, raw_data, cached_data) + new_derived = build_derived_columns( + working_columns, uuid_data, dependency_graph, raw_data_handler.index ) - if raw_data: - cache.add_data(raw_data, {}, push_up=True) - converted_raw_data = unit_handler.apply_unit_conversions(raw_data, unit_kwargs) - if converted_raw_data: - cache.add_data(converted_raw_data, {}, push_up=False) - raw_data |= converted_raw_data + uuid_data |= new_derived - data = cached_data | raw_data | new_derived_columns + # Write freshly-fetched raw data back to the cache. + if raw_data: + raw_uuid = { + col.uuid: {col.alias or col.name: raw_data[col.alias or col.name]} + for col in raw_columns + if (col.alias or col.name) in raw_data + } + cache.add_data(raw_uuid, {}, push_up=True) + + converted_raw_flat = unit_handler.apply_unit_conversions(raw_data, unit_kwargs) + if converted_raw_flat: + converted_raw_uuid = _apply_uuid_mapping(converted_raw_flat, required_pairs) + cache.add_data(converted_raw_uuid, {}, push_up=False) + for uuid, col_data in converted_raw_uuid.items(): + uuid_data.setdefault(uuid, {}).update(col_data) + + data = { + name: uuid_data[producer_uuid][name] + for name, producer_uuid in working_columns.items() + if producer_uuid in uuid_data and name in uuid_data[producer_uuid] + } data |= get_metadata_columns(raw_data_handler, cache, metadata_columns) return sort_data(data, sort_by) -def build_derived_columns( - column_producers: list[ConstructedColumn], - column_names: set[str], - data: dict[str, np.ndarray], - dependency_graph: rx.PyDiGraph, - index: DataIndex, -): - dependency_graph = contract_derived_columns( - dependency_graph, column_names, column_producers - ) - new_derived: dict[str, ConstructedColumn] = {} - for colidx in rx.topological_sort(dependency_graph): - column = dependency_graph[colidx] - if isinstance(column, str): - assert column in data or column not in column_names - continue - produces = column.produces - if all(name in data for name in produces): - continue - output = column.evaluate( - data | new_derived, index[1] if isinstance(index, tuple) else None - ) - if isinstance(output, dict): - new_derived |= output - else: - new_derived[column.name] = output - return new_derived - - def get_metadata_columns( raw_data_handler: DataHandler, cache: DataCache, metadata_columns: list[str] | None ): if metadata_columns is None: return {} - metadata = cache.get_data(metadata_columns) + metadata = cache.get_metadata(metadata_columns) additional_metadata_columns_to_fetch = set(metadata_columns).difference( metadata.keys() ) metadata |= ( raw_data_handler.get_metadata(additional_metadata_columns_to_fetch) or {} ) - return metadata @@ -139,5 +211,25 @@ def sort_data(data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None): order = np.argsort(sort_column) if sort_by[1]: order = order[::-1] - return {key: value[order] for key, value in data.items()} + + +def _flatten(uuid_data: dict[UUID, dict[str, np.ndarray]]) -> dict[str, np.ndarray]: + """Collapse UUID-keyed data to a flat name-keyed dict (last writer wins).""" + return {name: arr for d in uuid_data.values() for name, arr in d.items()} + + +def _apply_uuid_mapping( + flat_data: dict[str, np.ndarray], + required_pairs: set[CacheKey], +) -> dict[UUID, dict[str, np.ndarray]]: + """ + Map a flat name-keyed dict back to UUID-keyed form using required_pairs to + resolve which UUID owns each column name. + """ + name_to_uuid: dict[str, UUID] = {name: uuid for uuid, name in required_pairs} + result: dict[UUID, dict[str, np.ndarray]] = {} + for name, arr in flat_data.items(): + if name in name_to_uuid: + result.setdefault(name_to_uuid[name], {})[name] = arr + return result diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py index a645ced1..6cd7f86d 100644 --- a/python/opencosmo/dataset/output.py +++ b/python/opencosmo/dataset/output.py @@ -14,6 +14,8 @@ from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter, NumpySource if TYPE_CHECKING: + from uuid import UUID + from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache, DataHandler from opencosmo.header import OpenCosmoHeader @@ -66,19 +68,20 @@ def make_dataset_schema( producers: list[ConstructedColumn], raw_data_handler: DataHandler, cache: DataCache, - columns: set[str], + columns_to_uuid: dict[str, UUID], meta_columns: list[str], header: OpenCosmoHeader, region: Region, derived_data: dict, name: Optional[str] = None, ) -> Schema: + columns = set(columns_to_uuid.keys()) header = header.with_region(region) raw_columns = columns.intersection(raw_data_handler.columns) data_schema, metadata_schema = raw_data_handler.make_schema(raw_columns, header) cached_data_schema, cached_metadata_schema = cache.make_schema( - list(columns) + meta_columns + columns_to_uuid, meta_columns ) build_derived_writers(producers, derived_data, data_schema, cached_data_schema) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index dd0f916c..bb6e1666 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Iterable, Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence from weakref import finalize import astropy.units as u @@ -25,6 +25,8 @@ ) if TYPE_CHECKING: + from uuid import UUID + from astropy import table from astropy.cosmology import Cosmology from numpy.typing import NDArray @@ -55,16 +57,18 @@ def __init__( cache: DataCache, unit_handler: UnitHandler, header: OpenCosmoHeader, - columns: Iterable[str], + columns: dict[str, UUID], region: Region, sort_by: Optional[tuple[str, bool]], ): - self.__producers = list(column_producers) + self.__producers: dict[UUID, ConstructedColumn] = { + p.uuid: p for p in column_producers + } self.__raw_data_handler = raw_data_handler self.__cache = cache self.__unit_handler = unit_handler self.__header = header - self.__columns = set(columns) + self.__columns: dict[str, UUID] = columns self.__region = region self.__sort_by = sort_by self.__cache.register_column_group(id(self), self.__columns) @@ -74,7 +78,7 @@ def __rebuild(self, **updates): new = { "raw_data_handler": self.__raw_data_handler, "cache": self.__cache, - "column_producers": self.__producers, + "column_producers": list(self.__producers.values()), "unit_handler": self.__unit_handler, "header": self.__header, "columns": self.__columns, @@ -117,7 +121,7 @@ def from_target( RawColumn(cname, descriptions.get(cname, "None")) for cname in handler.columns ] - columns = set(handler.columns) + columns = {p.name: p.uuid for p in producers} cache = ColumnCache.empty() return DatasetState( producers, @@ -142,20 +146,27 @@ def in_memory( index: Optional[DataIndex] = None, ): descriptions = descriptions or {} + + # Producers must be created first so their UUIDs are available for the cache. + producers = [ + RawColumn(cname, descriptions.get(cname, "None")) + for cname in data_columns.keys() + ] + columns = {p.name: p.uuid for p in producers} + cache = ColumnCache.empty() - cache.add_data( - data_columns | metadata_columns, descriptions, set(metadata_columns.keys()) - ) + if data_columns: + uuid_data = {p.uuid: {p.name: data_columns[p.name]} for p in producers} + cache.add_data(uuid_data, descriptions) + if metadata_columns: + cache.add_metadata(dict(metadata_columns), {}) + units: dict[str, u.Unit] = {} for name, column in data_columns.items(): units[name] = None if isinstance(column, u.Quantity): units[name] = column.unit - producers = [ - RawColumn(cname, descriptions.get(cname, "None")) - for cname in data_columns.keys() - ] unit_handler = make_unit_handler_from_units(units, header, unit_convention) return DatasetState( @@ -164,7 +175,7 @@ def in_memory( cache, unit_handler, header, - set(data_columns.keys()), + columns, region, None, ) @@ -177,7 +188,7 @@ def __len__(self): @property def descriptions(self): all_descriptions = {} - for producer in self.__producers: + for producer in self.__producers.values(): update = {name: producer.description for name in producer.produces} all_descriptions |= update all_descriptions |= self.__cache.descriptions @@ -219,7 +230,7 @@ def header(self): @property def columns(self) -> list[str]: - return list(self.__columns) + return list(self.__columns.keys()) @property def meta_columns(self) -> list[str]: @@ -238,7 +249,7 @@ def get_data( Get the data for a given handler. """ data = instantiate_dataset( - self.__producers, + list(self.__producers.values()), self.__columns, self.__raw_data_handler, self.__cache, @@ -290,7 +301,11 @@ def rows(self, metadata_columns: list = [], unit_kwargs: dict = {}): } derived_storage = resort(all_derived, self.get_sorted_index()) if derived_storage: - self.__cache.add_data(data, {}) + uuid_keyed: dict = {} + for name, arr in derived_storage.items(): + uuid = self.__columns[name] + uuid_keyed.setdefault(uuid, {})[name] = arr + self.__cache.add_data(uuid_keyed, {}) except GeneratorExit: pass except BaseException: @@ -313,7 +328,9 @@ def with_mask(self, mask: NDArray[np.bool_]): ) def make_schema(self, name: Optional[str] = None): - derived_names = get_derived_column_names(self.__producers, self.__columns) + producers = list(self.__producers.values()) + columns = set(self.__columns.keys()) + derived_names = get_derived_column_names(producers, columns) if derived_names: derived_data = ( self.select(derived_names) @@ -323,7 +340,7 @@ def make_schema(self, name: Optional[str] = None): else: derived_data = {} return make_dataset_schema( - self.__producers, + producers, self.__raw_data_handler, self.__cache, self.__columns, @@ -343,11 +360,11 @@ def with_new_columns( Add a set of derived columns to the dataset. A derived column is a column that has been created based on the values in another column. """ - new_producers, new_column_names, new_unit_handler = add_columns( - self.__producers, + new_producers, new_column_map, new_unit_handler = add_columns( + list(self.__producers.values()), self.__unit_handler, self.__cache, - self.columns, + self.__columns, self.get_sorted_index(), descriptions, new_columns, @@ -356,7 +373,7 @@ def with_new_columns( return self.__rebuild( cache=self.__cache, column_producers=new_producers, - columns=new_column_names, + columns=new_column_map, unit_handler=new_unit_handler, ) @@ -388,7 +405,7 @@ def select(self, columns: set[str], drop=False): if drop: selections = set(self.columns) - selections - return self.__rebuild(columns=selections) + return self.__rebuild(columns={n: self.__columns[n] for n in selections}) def sort_by(self, column_name: str, invert: bool): if column_name not in self.columns: @@ -525,7 +542,7 @@ def with_units( lambda acc, col: acc.union( col.produces if not isinstance(col, RawColumn) else set() ), - self.__producers, + self.__producers.values(), all_derived_names, ).intersection(self.columns) columns_to_drop = all_derived_names.union(self.__raw_data_handler.columns) diff --git a/python/opencosmo/handler/protocols.py b/python/opencosmo/handler/protocols.py index e46335f1..6e674861 100644 --- a/python/opencosmo/handler/protocols.py +++ b/python/opencosmo/handler/protocols.py @@ -3,11 +3,14 @@ from typing import TYPE_CHECKING, Iterable, Optional, Protocol, Self if TYPE_CHECKING: + from uuid import UUID + import numpy as np from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex from opencosmo.io.schema import Schema + from opencosmo.index import DataIndex + class DataHandler(Protocol): def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: @@ -32,17 +35,49 @@ def load_conditions(self) -> Optional[dict]: ... def index(self) -> DataIndex: ... -class DataCache(DataHandler, Protocol): +class DataCache(Protocol): def add_data( self, - data: dict[str, np.ndarray], + data: dict[UUID, dict[str, np.ndarray]], descriptions: dict[str, str], push_up: bool = True, ): ... + def add_metadata( + self, + data: dict[str, np.ndarray], + descriptions: dict[str, str] = {}, + ): ... + + def get_data( + self, pairs: set[tuple[UUID, str]] + ) -> dict[UUID, dict[str, np.ndarray]]: ... + + def get_metadata(self, column_names: Iterable[str]) -> dict[str, np.ndarray]: ... + def __len__(self) -> int: ... + def take(self, index: DataIndex) -> Self: ... + def drop(self, columns: Iterable[str]) -> Self: ... - def register_column_group(self, state_id: int, columns: set[str]) -> None: ... + + def register_column_group( + self, state_id: int, columns: dict[str, UUID] + ) -> None: ... + def deregister_column_group(self, state_id: int) -> None: ... + def create_child(self) -> Self: ... + + @property + def columns(self) -> set[str]: ... + + @property + def metadata_columns(self) -> set[str]: ... + + @property + def descriptions(self) -> dict[str, str]: ... + + def make_schema( + self, columns: dict[str, UUID], meta_columns: list[str] + ) -> tuple[Schema, Schema]: ... diff --git a/test/test_cache.py b/test/test_cache.py index 41932644..592c4f00 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -1,96 +1,115 @@ -import numpy as np +from uuid import uuid4 +import numpy as np from opencosmo.column.cache import ColumnCache +def _make_mapping(names): + """Create a {name: UUID} mapping for testing.""" + return {name: uuid4() for name in names} + + +def _add_raw_data(cache, name_to_uuid, raw_data): + """Add name-keyed data to cache via UUID-keyed interface.""" + uuid_data = {name_to_uuid[name]: {name: raw_data[name]} for name in raw_data} + cache.add_data(uuid_data) + + +def _get_flat_data(cache, name_to_uuid): + """Get data from cache, flattened to a plain name-keyed dict.""" + pairs = {(uuid, name) for name, uuid in name_to_uuid.items()} + uuid_data = cache.get_data(pairs) + return {name: arr for d in uuid_data.values() for name, arr in d.items()} + + def test_cache_take(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 100, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) - cached_data = cache2.get_data("abcdefg") + cache2.register_column_group(0, name_to_uuid) + cached_data = _get_flat_data(cache2, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2]) + assert np.all(column == raw_data[colname][index2]) def test_cache_passthrough(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) + cache2.register_column_group(0, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) assert set(cache3.columns) == set("abcdefg") assert len(cache2.columns) == 0 def test_cache_passthrough_delete(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) - _ = cache2.get_data("abcdefg") + cache2.register_column_group(0, name_to_uuid) + _ = _get_flat_data(cache2, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) assert set(cache3.columns) == set() del cache2 assert set(cache3.columns) == set("abcdefg") - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) def test_cache_passthrough_twice_delete(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) + cache2.register_column_group(0, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) assert set(cache2.columns) == set() del cache assert set(cache2.columns) == set("abcdefg") - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) assert set(cache3.columns) == set("abcdefg") diff --git a/test/test_collection.py b/test/test_collection.py index 75488920..e9a6496e 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1227,7 +1227,10 @@ def test_data_cached_after_objects(halo_paths): pass dataset = ds["dm_particles"] - cache = dataset._Dataset__state._DatasetState__cache - data = cache.get_data(("gpe",)) - assert data.get("gpe") is not None + state = dataset._Dataset__state + cache = state._DatasetState__cache + columns = state._DatasetState__columns # dict[str, UUID] + gpe_uuid = columns["gpe"] + uuid_data = cache.get_data({(gpe_uuid, "gpe")}) + assert uuid_data.get(gpe_uuid, {}).get("gpe") is not None assert dataset.descriptions["gpe"] != "None" diff --git a/test/test_dataset.py b/test/test_dataset.py index f7e90135..ecf4de53 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -543,12 +543,15 @@ def test_rows_cache(input_path): for i, row in enumerate(dataset.rows()): assert row["fof_px"] == row["fof_halo_mass"] * row["fof_halo_com_vx"] - cache = dataset._Dataset__state._DatasetState__cache - cached_data = cache.get_data(["fof_halo_mass", "fof_halo_com_vx", "fof_px"]) - assert np.all( - cached_data["fof_px"] - == cached_data["fof_halo_mass"] * cached_data["fof_halo_com_vx"] - ) + # After iterating rows(), derived columns should be cached. + state = dataset._Dataset__state + cache = state._DatasetState__cache + columns = state._DatasetState__columns + assert "fof_px" in cache.columns + pairs = {(columns["fof_px"], "fof_px")} + uuid_data = cache.get_data(pairs) + flat = {name: arr for d in uuid_data.values() for name, arr in d.items()} + assert len(flat["fof_px"]) == 100 IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -635,16 +638,24 @@ def test_cache_conversion_propogation(input_path): dataset2 = dataset.with_units(conversions={u.Mpc: u.lyr}, fof_halo_center_x=u.km) dataset2.get_data() - cache = dataset._Dataset__state._DatasetState__cache - cache2 = dataset2._Dataset__state._DatasetState__cache + state = dataset._Dataset__state + state2 = dataset2._Dataset__state + cache = state._DatasetState__cache + cache2 = state2._DatasetState__cache + col_to_uuid = state._DatasetState__columns + pairs = {(uuid, name) for name, uuid in col_to_uuid.items()} assert len(cache.columns) == len(dataset2.columns) # just to be safe - cached_columns = cache.get_data(dataset.columns) - cached_columns2 = cache2.get_data(dataset.columns) - for col in cached_columns.values(): + flat = { + name: arr for d in cache.get_data(pairs).values() for name, arr in d.items() + } + flat2 = { + name: arr for d in cache2.get_data(pairs).values() for name, arr in d.items() + } + for col in flat.values(): if isinstance(col, u.Quantity): assert col.unit not in [u.lyr, u.km] - for col in cached_columns2.values(): + for col in flat2.values(): if isinstance(col, u.Quantity): assert col.unit != u.Mpc From a0ba14bf361c9f88dbe4b50e7c02417d9a60b2f2 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 15 Apr 2026 15:45:10 -0400 Subject: [PATCH 045/139] bugfix + updated docs --- docs/source/analysis.rst | 2 +- python/opencosmo/analysis/yt_viz.py | 43 +++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst index d1c96588..24836212 100644 --- a/docs/source/analysis.rst +++ b/docs/source/analysis.rst @@ -105,7 +105,7 @@ The two primary functions for this purpose are: - :func:`opencosmo.analysis.visualize_halo` — a simple 2x2 panel plot for one halo - :func:`opencosmo.analysis.halo_projection_array` — a customizable grid of halos and fields -These use yt under the hood, and are useful for visually inspecting halos with minimal input required. +These use yt under the hood, and are useful for visually inspecting halos with minimal input required. Animated versions of the visualizations outputted by either of these functions can be made using `opencosmo.analysis.animate_halos`. Quick Projections diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index a0699533..258666d6 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -260,10 +260,8 @@ def visualize_halo( halo_ids: list[int] | tuple[list[int], list[int]] - if yt_ds is not None: - yt_dataset_provided = True - else: - yt_dataset_provided = False + yt_dataset_provided = yt_ds is not None + if not yt_dataset_provided: yt_ds_arr = None if len(params["fields"]) == 4: @@ -416,10 +414,16 @@ def halo_projection_array( data = data.with_units("comoving") halo_ids_2d = np.atleast_2d(halo_ids) - yt_ds_2d = np.atleast_2d(yt_ds) # type: ignore # determine shape of figure - fig_shape = np.shape(halo_ids) + fig_shape = np.shape(halo_ids_2d) + + yt_datasets_provided = yt_ds is not None + + if yt_datasets_provided: + yt_ds_2d = np.atleast_2d(yt_ds) # type: ignore + else: + yt_ds_2d = np.full(fig_shape, None) # Default plotting parameters if weight_field is None: @@ -509,22 +513,25 @@ def halo_projection_array( ax.set_facecolor("black") continue - ds = yt_ds_2d[i][j] - if ds is not None: + if yt_datasets_provided: + ds = yt_ds_2d[i][j] + if ds is None: + raise ValueError(f"provided yt dataset cannot be None") + # sodbighaloparticles holds particle data out to 2*R200 Rh = ds.domain_width[0] / 4 else: - # retrieve halo particle info if new halo, or if yt dataset - # is not already provided + # retrieve halo particle info if new halo if (i == 0 and j == 0) or halo_id != halo_id_previous: # retrieve properties of halo if len(data) > 1: data_id = data.filter(oc.col("unique_tag") == halo_id) else: - if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + if halo_id != data["halo_properties"].select("unique_tag").get_data(): # type: ignore raise RuntimeError(f"Halo ID {halo_id} not in dataset!") data_id = data + halo_data = next(iter(data_id.objects())) # load particles into yt @@ -792,6 +799,20 @@ def animate_halos( rotating about the y-axis. It can also animate a customizable multipanel projection layout by setting ``func="halo_projection_array"`` and passing a 2D arrangement of halo IDs. + Example usage for visualizing the most massive halo in a dataset: + .. code-block:: python + + from opencosmo.analysis import animate_halos + + # fetch data and ID for most massive halo + ds = oc.open("haloproperties.hdf5", "haloparticles.hdf5").sort_by("sod_halo_mass").take(1, at="end") + halo_id = ds.select("unique_tag").get_data("numpy") + + # create a 30-frame animation that rotates the object once about the y-axis, then once about the x-axis. + anim = animate_halos(halo_id, ds, rotations=["y", "x"], frames=30) + anim.save("animation.gif", fps=10) + + Parameters ---------- halo_ids : int or array of int From 7e6a4472d73093f1861220103ab0f9b5583339dc Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 15 Apr 2026 15:47:51 -0400 Subject: [PATCH 046/139] type check again... --- python/opencosmo/analysis/yt_viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index 258666d6..e5126caf 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -516,7 +516,7 @@ def halo_projection_array( if yt_datasets_provided: ds = yt_ds_2d[i][j] if ds is None: - raise ValueError(f"provided yt dataset cannot be None") + raise ValueError("provided yt dataset cannot be None") # sodbighaloparticles holds particle data out to 2*R200 Rh = ds.domain_width[0] / 4 From d85944e3bdb7b981dea2e5d7a90b7942d75cb50a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 15 Apr 2026 16:05:20 -0400 Subject: [PATCH 047/139] docs --- docs/source/analysis.rst | 2 +- python/opencosmo/analysis/yt_viz.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst index 24836212..c11f6c9f 100644 --- a/docs/source/analysis.rst +++ b/docs/source/analysis.rst @@ -105,7 +105,7 @@ The two primary functions for this purpose are: - :func:`opencosmo.analysis.visualize_halo` — a simple 2x2 panel plot for one halo - :func:`opencosmo.analysis.halo_projection_array` — a customizable grid of halos and fields -These use yt under the hood, and are useful for visually inspecting halos with minimal input required. Animated versions of the visualizations outputted by either of these functions can be made using `opencosmo.analysis.animate_halos`. +These use yt under the hood, and are useful for visually inspecting halos with minimal input required. Animated versions of the visualizations outputted by either of these functions can be made using :func:`opencosmo.analysis.animate_halos`. Quick Projections diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index e5126caf..cc5f681b 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -800,6 +800,7 @@ def animate_halos( setting ``func="halo_projection_array"`` and passing a 2D arrangement of halo IDs. Example usage for visualizing the most massive halo in a dataset: + .. code-block:: python from opencosmo.analysis import animate_halos From 19121594003dc8f3ae4e6e14ab4a6a0b77ce3016 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 15 Apr 2026 16:16:01 -0400 Subject: [PATCH 048/139] docs --- python/opencosmo/analysis/yt_viz.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py index cc5f681b..053fa4eb 100644 --- a/python/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -803,11 +803,12 @@ def animate_halos( .. code-block:: python + import opencosmo as oc from opencosmo.analysis import animate_halos # fetch data and ID for most massive halo ds = oc.open("haloproperties.hdf5", "haloparticles.hdf5").sort_by("sod_halo_mass").take(1, at="end") - halo_id = ds.select("unique_tag").get_data("numpy") + halo_id = ds["halo_properties"].select("unique_tag").get_data() # create a 30-frame animation that rotates the object once about the y-axis, then once about the x-axis. anim = animate_halos(halo_id, ds, rotations=["y", "x"], frames=30) From b7dd0a428dca3e8f00a72438040aff484e07f00e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 17 Apr 2026 10:07:54 -0500 Subject: [PATCH 049/139] Add no_cache option to derived columns --- python/opencosmo/column/cache.py | 3 ++ python/opencosmo/column/column.py | 18 +++++++++++ python/opencosmo/dataset/instantiate.py | 41 +++++++++++++------------ 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/python/opencosmo/column/cache.py b/python/opencosmo/column/cache.py index ecf53329..9e154a7c 100644 --- a/python/opencosmo/column/cache.py +++ b/python/opencosmo/column/cache.py @@ -262,6 +262,9 @@ def add_data( push_up: bool = True, ): """Add UUID-keyed column data to the cache.""" + + if not data: + return check_length(self, data) self.__descriptions |= descriptions diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index edfb9a66..08ba4a28 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -303,6 +303,8 @@ def description(self) -> Optional[str]: ... def bind(self, name_to_uuid: dict[str, UUID]) -> Self: ... + @property + def no_cache(self) -> bool: ... def evaluate( self, data: dict[str, np.ndarray], @@ -356,6 +358,10 @@ def dep_map(self) -> dict[str, UUID]: ) return {self.__name: self.__dep_uuid} + @property + def no_cache(self): + return False + @property def alias(self) -> str | None: return self.__alias @@ -399,6 +405,7 @@ def __init__( description: Optional[str] = None, output_name: Optional[str] = None, _dep_map: dict[str, UUID] | None = None, + no_cache: bool = False, ): self.lhs = lhs self.rhs = rhs @@ -407,6 +414,7 @@ def __init__( self.description = description if description is not None else "None" self.__uuid = uuid4() self.__dep_map: dict[str, UUID] | None = _dep_map + self.__no_cache = no_cache @property def uuid(self) -> UUID: @@ -457,6 +465,10 @@ def requires(self) -> set[UUID]: def produces(self): return None if self.name is None else set([self.name]) + @property + def no_cache(self): + return self.__no_cache + def check_parent_existance(self, names: set[str]): match self.rhs: case Column(): @@ -587,6 +599,7 @@ def __init__( batch_size: int = -1, description: Optional[str] = None, _dep_map: dict[str, UUID] | None = None, + no_cache: bool = False, **kwargs: Any, ): self.__func = func @@ -597,6 +610,7 @@ def __init__( self.__format = format self.__strategy = strategy self.__batch_size = batch_size + self.__no_cache = no_cache self.description = description self.__uuid = uuid4() self.__dep_map = _dep_map @@ -661,6 +675,10 @@ def requires(self) -> set[UUID]: raise RuntimeError(f"EvaluatedColumn '{self.name}' has not been bound yet.") return set(self.__dep_map.values()) + @property + def no_cache(self): + return self.__no_cache + @property def produces(self): return copy(self.__produces) diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index d18bd84a..ffc5669b 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -67,6 +67,7 @@ def build_derived_columns( uuid_data: dict[UUID, dict[str, np.ndarray]], dependency_graph: rx.PyDiGraph, index: DataIndex, + cache: DataCache, ) -> dict[UUID, dict[str, np.ndarray]]: """ Evaluate all derived producers needed to produce the requested columns, @@ -85,6 +86,7 @@ def build_derived_columns( required_uuids.add(dependency_graph[node_idx].uuid) new_derived: dict[UUID, dict[str, np.ndarray]] = {} + to_cache: dict[UUID, dict[str, np.ndarray]] = {} for node_idx in rx.topological_sort(dependency_graph): producer = dependency_graph[node_idx] if isinstance(producer, RawColumn): @@ -100,10 +102,13 @@ def build_derived_columns( output = producer.evaluate( input_data, index[1] if isinstance(index, tuple) else None ) - if isinstance(output, dict): - new_derived[producer.uuid] = output - else: - new_derived[producer.uuid] = {next(iter(producer.produces)): output} + if not isinstance(output, dict): + output = {next(iter(producer.produces)): output} + new_derived[producer.uuid] = output + if not producer.no_cache: + to_cache[producer.uuid] = output + + cache.add_data(to_cache, {}) return new_derived @@ -160,25 +165,23 @@ def instantiate_dataset( uuid_data = build_initial_uuid_data(column_producers, raw_data, cached_data) new_derived = build_derived_columns( - working_columns, uuid_data, dependency_graph, raw_data_handler.index + working_columns, uuid_data, dependency_graph, raw_data_handler.index, cache ) + uuid_data |= new_derived # Write freshly-fetched raw data back to the cache. - if raw_data: - raw_uuid = { - col.uuid: {col.alias or col.name: raw_data[col.alias or col.name]} - for col in raw_columns - if (col.alias or col.name) in raw_data - } - cache.add_data(raw_uuid, {}, push_up=True) - - converted_raw_flat = unit_handler.apply_unit_conversions(raw_data, unit_kwargs) - if converted_raw_flat: - converted_raw_uuid = _apply_uuid_mapping(converted_raw_flat, required_pairs) - cache.add_data(converted_raw_uuid, {}, push_up=False) - for uuid, col_data in converted_raw_uuid.items(): - uuid_data.setdefault(uuid, {}).update(col_data) + raw_data |= unit_handler.apply_unit_conversions(raw_data, unit_kwargs) + raw_data_uuid = _apply_uuid_mapping(raw_data, required_pairs) + to_add = { + uuid: value + for uuid, value in raw_data_uuid.items() + if uuid in working_columns.values() + } + + cache.add_data(to_add, {}, push_up=False) + for uuid, col_data in raw_data_uuid.items(): + uuid_data.setdefault(uuid, {}).update(col_data) data = { name: uuid_data[producer_uuid][name] From 6022621a5c17337868e49dbc5a476bb4cb3004cd Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 17 Apr 2026 15:57:46 -0500 Subject: [PATCH 050/139] Add reindexing logic for diffsky --- python/opencosmo/_lib/index.pyi | 1 + .../collection/lightcone/healpix_map.py | 4 +- .../collection/lightcone/lightcone.py | 4 +- .../collection/simulation/simulation.py | 4 +- .../collection/structure/structure.py | 4 +- python/opencosmo/column/column.py | 9 +++-- python/opencosmo/column/evaluate.py | 4 +- python/opencosmo/cosmology.py | 2 +- python/opencosmo/dataset/dataset.py | 4 +- python/opencosmo/dataset/instantiate.py | 18 ++++++--- python/opencosmo/dataset/state.py | 4 ++ .../{parameters => dtypes}/__init__.py | 0 .../{parameters => dtypes}/cosmology.py | 0 .../{parameters => dtypes}/diffsky.py | 20 ++++++++++ .../opencosmo/{parameters => dtypes}/dtype.py | 28 ++++++++++++-- .../opencosmo/{parameters => dtypes}/file.py | 0 .../opencosmo/{parameters => dtypes}/hacc.py | 0 .../{parameters => dtypes}/lightcone.py | 0 .../{parameters => dtypes}/origin.py | 2 +- .../{parameters => dtypes}/parameters.py | 0 .../opencosmo/{parameters => dtypes}/units.py | 0 .../opencosmo/{parameters => dtypes}/utils.py | 0 python/opencosmo/header.py | 10 ++--- python/opencosmo/index/ops.py | 16 ++++++++ python/opencosmo/spatial/check.py | 2 +- src/index.rs | 37 ++++++++++++++++++- test/test_diffsky.py | 8 +++- 27 files changed, 146 insertions(+), 35 deletions(-) rename python/opencosmo/{parameters => dtypes}/__init__.py (100%) rename python/opencosmo/{parameters => dtypes}/cosmology.py (100%) rename python/opencosmo/{parameters => dtypes}/diffsky.py (62%) rename python/opencosmo/{parameters => dtypes}/dtype.py (54%) rename python/opencosmo/{parameters => dtypes}/file.py (100%) rename python/opencosmo/{parameters => dtypes}/hacc.py (100%) rename python/opencosmo/{parameters => dtypes}/lightcone.py (100%) rename python/opencosmo/{parameters => dtypes}/origin.py (86%) rename python/opencosmo/{parameters => dtypes}/parameters.py (100%) rename python/opencosmo/{parameters => dtypes}/units.py (100%) rename python/opencosmo/{parameters => dtypes}/utils.py (100%) create mode 100644 python/opencosmo/index/ops.py diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi index 4a05fbde..ca046099 100644 --- a/python/opencosmo/_lib/index.pyi +++ b/python/opencosmo/_lib/index.pyi @@ -12,3 +12,4 @@ def chunked_into_array(start: IndexArray, size: IndexArray) -> IndexArray: ... def take_chunked_from_simple( simple: IndexArray, start: IndexArray, size: IndexArray ) -> IndexArray: ... +def reindex_column(index: IndexArray, index_column: IndexArray) -> IndexArray: ... diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 7d57561d..4aa4fbbb 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -26,10 +26,10 @@ from opencosmo.column.column import ColumnMask, ConstructedColumn from opencosmo.dataset import Dataset from opencosmo.dataset.build import GroupedColumnData + from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters.hacc import HaccSimulationParameters from opencosmo.spatial import Region @@ -306,7 +306,7 @@ def simulation(self) -> HaccSimulationParameters: Returns ------- - parameters: opencosmo.parameters.hacc.HaccSimulationParameters + parameters: opencosmo.dtypes.hacc.HaccSimulationParameters """ return self.__header.simulation diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index e721ee65..8ad88ea6 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -38,10 +38,10 @@ from astropy.table import Table from opencosmo.column.column import ColumnMask, ConstructedColumn + from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters.hacc import HaccSimulationParameters from opencosmo.spatial import Region @@ -420,7 +420,7 @@ def simulation(self) -> HaccSimulationParameters: Returns ------- - parameters: opencosmo.parameters.hacc.HaccSimulationParameters + parameters: opencosmo.dtypes.hacc.HaccSimulationParameters """ return self.__header.simulation diff --git a/python/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py index 19ba8745..f7402ad0 100644 --- a/python/opencosmo/collection/simulation/simulation.py +++ b/python/opencosmo/collection/simulation/simulation.py @@ -14,10 +14,10 @@ from opencosmo.collection.protocols import Collection from opencosmo.column.column import ColumnMask, ConstructedColumn + from opencosmo.dtypes import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region @@ -154,7 +154,7 @@ def simulation(self) -> dict[str, HaccSimulationParameters]: Returns -------- - simulation_parameters: dict[str, opencosmo.parameters.HaccSimulationParameters] + simulation_parameters: dict[str, opencosmo.dtypes.HaccSimulationParameters] """ return self.__map_attribute("simulation") diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 05e9463f..fb9e8ae2 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -21,11 +21,11 @@ import astropy.units as u from opencosmo.column.column import ConstructedColumn + from opencosmo.dtypes import HaccSimulationParameters from opencosmo.index import DataIndex from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema from opencosmo.mpi import MPI - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region @@ -193,7 +193,7 @@ def simulation(self) -> HaccSimulationParameters: Returns ------- - parameters: opencosmo.parameters.HaccSimulationParameters + parameters: opencosmo.dtypes.HaccSimulationParameters """ return self.__header.simulation diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 08ba4a28..c06ceb53 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -32,6 +32,7 @@ from uuid import UUID from opencosmo import Dataset + from opencosmo.index import DataIndex Comparison = Callable[[float, float], bool] """ @@ -308,7 +309,7 @@ def no_cache(self) -> bool: ... def evaluate( self, data: dict[str, np.ndarray], - chunk_sizes: Optional[np.ndarray], + index: DataIndex, ) -> np.ndarray | dict[str, np.ndarray]: ... def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... @@ -640,6 +641,7 @@ def bind(self, name_to_uuid: dict[str, UUID]) -> EvaluatedColumn: self.__batch_size, self.description, _dep_map=dep_map, + no_cache=self.__no_cache, **self.__kwargs, ) @@ -706,8 +708,9 @@ def kwarg_names(self): def get_units(self, units: dict[str, np.ndarray]): return self.__units - def evaluate(self, data: dict[str, np.ndarray], chunk_sizes: Optional[np.ndarray]): + def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): data = {name: data[name] for name in self.__requires} + chunk_sizes = index[1] if isinstance(index, tuple) else None if self.__format != "astropy": data = { name: val.value if isinstance(val, u.Quantity) else val @@ -727,7 +730,7 @@ def evaluate(self, data: dict[str, np.ndarray], chunk_sizes: Optional[np.ndarray match strategy: case EvaluateStrategy.VECTORIZE: - return evaluate_vectorized(data, self.__func, self.__kwargs) + return evaluate_vectorized(data, self.__func, self.__kwargs, index) case EvaluateStrategy.ROW_WISE: return evaluate_rows(data, self.__func, self.__kwargs) case EvaluateStrategy.CHUNKED: diff --git a/python/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py index 7a04a119..ae5a67e0 100644 --- a/python/opencosmo/column/evaluate.py +++ b/python/opencosmo/column/evaluate.py @@ -91,8 +91,8 @@ def __make_chunked_based_output_from_first_values(values, data_length): return storage -def evaluate_vectorized(data, func, kwargs): - return func(**data, **kwargs) +def evaluate_vectorized(data, func, kwargs, index): + return func(**data, **kwargs, index=index) def do_first_evaluation( diff --git a/python/opencosmo/cosmology.py b/python/opencosmo/cosmology.py index b9d23548..3c07e294 100644 --- a/python/opencosmo/cosmology.py +++ b/python/opencosmo/cosmology.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import h5py - from opencosmo.parameters import CosmologyParameters + from opencosmo.dtypes import CosmologyParameters """ Reads cosmology from the header of the file and returns the diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index a3531e2c..24e26d5e 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -29,10 +29,10 @@ from opencosmo.column.column import Column, ColumnMask, ConstructedColumn from opencosmo.dataset.state import DatasetState + from opencosmo.dtypes import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.schema import Schema - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region from opencosmo.spatial.tree import Tree @@ -207,7 +207,7 @@ def simulation(self) -> Optional[HaccSimulationParameters]: Returns ------- - parameters: Optional[opencosmo.parameters.hacc.HaccSimulationParameters] + parameters: Optional[opencosmo.dtypes.hacc.HaccSimulationParameters] """ return getattr(self.__header, "simulation", None) diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index ffc5669b..6a9bb511 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -99,9 +99,7 @@ def build_derived_columns( name: all_data[dep_uuid][name] for name, dep_uuid in producer.dep_map.items() } - output = producer.evaluate( - input_data, index[1] if isinstance(index, tuple) else None - ) + output = producer.evaluate(input_data, index) if not isinstance(output, dict): output = {next(iter(producer.produces)): output} new_derived[producer.uuid] = output @@ -171,15 +169,23 @@ def instantiate_dataset( uuid_data |= new_derived # Write freshly-fetched raw data back to the cache. + # Only map raw columns by UUID — if a non-raw producer shadows a raw column name, + # the set-iteration order in _apply_uuid_mapping would nondeterministically pick one, + # potentially caching unprocessed data under the derived producer's UUID. + raw_producer_uuids = { + col.uuid for col in column_producers if isinstance(col, RawColumn) + } + raw_required_pairs = { + pair for pair in required_pairs if pair[0] in raw_producer_uuids + } raw_data |= unit_handler.apply_unit_conversions(raw_data, unit_kwargs) - raw_data_uuid = _apply_uuid_mapping(raw_data, required_pairs) + raw_data_uuid = _apply_uuid_mapping(raw_data, raw_required_pairs) to_add = { uuid: value for uuid, value in raw_data_uuid.items() if uuid in working_columns.values() } - - cache.add_data(to_add, {}, push_up=False) + cache.add_data(to_add, {}, push_up=True) for uuid, col_data in raw_data_uuid.items(): uuid_data.setdefault(uuid, {}).update(col_data) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index bb6e1666..8820e4cd 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -13,6 +13,7 @@ from opencosmo.dataset.columns import add_columns, resort from opencosmo.dataset.instantiate import instantiate_dataset from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema +from opencosmo.dtypes.dtype import get_dtype_column_plugins from opencosmo.handler.empty import EmptyHandler from opencosmo.handler.hdf5 import Hdf5Handler from opencosmo.index.build import single_chunk @@ -122,6 +123,9 @@ def from_target( for cname in handler.columns ] columns = {p.name: p.uuid for p in producers} + producers, columns = get_dtype_column_plugins( + target["header"], producers, columns + ) cache = ColumnCache.empty() return DatasetState( producers, diff --git a/python/opencosmo/parameters/__init__.py b/python/opencosmo/dtypes/__init__.py similarity index 100% rename from python/opencosmo/parameters/__init__.py rename to python/opencosmo/dtypes/__init__.py diff --git a/python/opencosmo/parameters/cosmology.py b/python/opencosmo/dtypes/cosmology.py similarity index 100% rename from python/opencosmo/parameters/cosmology.py rename to python/opencosmo/dtypes/cosmology.py diff --git a/python/opencosmo/parameters/diffsky.py b/python/opencosmo/dtypes/diffsky.py similarity index 62% rename from python/opencosmo/parameters/diffsky.py rename to python/opencosmo/dtypes/diffsky.py index 35502e6e..c2f3ad78 100644 --- a/python/opencosmo/parameters/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -5,6 +5,9 @@ from pydantic import BaseModel, ConfigDict, field_serializer +from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy +from opencosmo.index.ops import reindex_column + class DiffskyVersionInfo(BaseModel): model_config = ConfigDict(frozen=True) @@ -30,3 +33,20 @@ def serialize_zphot_table(self, value): if value is not None: return list(value) return None + + +def rebuild_top_host_idx(top_host_idx, index): + result = reindex_column(index, top_host_idx) + + return {"top_host_idx": result} + + +top_host_idx = EvaluatedColumn( + rebuild_top_host_idx, + requires=set(["top_host_idx"]), + produces=set(["top_host_idx"]), + format="numpy", + units={"top_host_idx": None}, + strategy=EvaluateStrategy.VECTORIZE, + no_cache=True, +) diff --git a/python/opencosmo/parameters/dtype.py b/python/opencosmo/dtypes/dtype.py similarity index 54% rename from python/opencosmo/parameters/dtype.py rename to python/opencosmo/dtypes/dtype.py index 558c5c22..bb4137c8 100644 --- a/python/opencosmo/parameters/dtype.py +++ b/python/opencosmo/dtypes/dtype.py @@ -2,15 +2,15 @@ from typing import TYPE_CHECKING -from opencosmo.parameters import hacc, lightcone +from opencosmo.dtypes import hacc, lightcone +from opencosmo.dtypes.diffsky import top_host_idx if TYPE_CHECKING: from pydantic import BaseModel - from opencosmo.parameters.file import FileParameters + from opencosmo.dtypes.file import FileParameters -# TODO: think I need to alter this def get_dtype_parameters( file_parameters: FileParameters, ) -> dict[str, dict[str, type[BaseModel]]]: @@ -25,3 +25,25 @@ def get_dtype_parameters( required_dtype_params["lightcone"] = lightcone_parameters dtype_parameters["required"] = required_dtype_params return dtype_parameters + + +def get_dtype_column_plugins( + header, + producers, + columns, +): + plugins = __get_plugins(header) + for name, producer in plugins.items(): + if name not in columns: + continue + producer = producer.bind(columns) + producers.append(producer) + columns[name] = producer.uuid + + return producers, columns + + +def __get_plugins(header): + if header.file.data_type == "synthetic_galaxies": + return {"top_host_idx": top_host_idx} + return {} diff --git a/python/opencosmo/parameters/file.py b/python/opencosmo/dtypes/file.py similarity index 100% rename from python/opencosmo/parameters/file.py rename to python/opencosmo/dtypes/file.py diff --git a/python/opencosmo/parameters/hacc.py b/python/opencosmo/dtypes/hacc.py similarity index 100% rename from python/opencosmo/parameters/hacc.py rename to python/opencosmo/dtypes/hacc.py diff --git a/python/opencosmo/parameters/lightcone.py b/python/opencosmo/dtypes/lightcone.py similarity index 100% rename from python/opencosmo/parameters/lightcone.py rename to python/opencosmo/dtypes/lightcone.py diff --git a/python/opencosmo/parameters/origin.py b/python/opencosmo/dtypes/origin.py similarity index 86% rename from python/opencosmo/parameters/origin.py rename to python/opencosmo/dtypes/origin.py index ad65f431..e62b59ca 100644 --- a/python/opencosmo/parameters/origin.py +++ b/python/opencosmo/dtypes/origin.py @@ -1,4 +1,4 @@ -from opencosmo.parameters import hacc +from opencosmo.dtypes import hacc from .cosmology import CosmologyParameters diff --git a/python/opencosmo/parameters/parameters.py b/python/opencosmo/dtypes/parameters.py similarity index 100% rename from python/opencosmo/parameters/parameters.py rename to python/opencosmo/dtypes/parameters.py diff --git a/python/opencosmo/parameters/units.py b/python/opencosmo/dtypes/units.py similarity index 100% rename from python/opencosmo/parameters/units.py rename to python/opencosmo/dtypes/units.py diff --git a/python/opencosmo/parameters/utils.py b/python/opencosmo/dtypes/utils.py similarity index 100% rename from python/opencosmo/parameters/utils.py rename to python/opencosmo/dtypes/utils.py diff --git a/python/opencosmo/header.py b/python/opencosmo/header.py index ba9b7b27..d8dd742a 100644 --- a/python/opencosmo/header.py +++ b/python/opencosmo/header.py @@ -10,17 +10,17 @@ import numpy as np from pydantic import BaseModel, ValidationError -from opencosmo.file import broadcast_read, file_reader, file_writer -from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter -from opencosmo.parameters import ( +from opencosmo.dtypes import ( FileParameters, dtype, origin, read_header_attributes, write_header_attributes, ) -from opencosmo.parameters.units import apply_units +from opencosmo.dtypes.units import apply_units +from opencosmo.file import broadcast_read, file_reader, file_writer +from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter from opencosmo.units import UnitConvention if TYPE_CHECKING: diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py new file mode 100644 index 00000000..b841d5e4 --- /dev/null +++ b/python/opencosmo/index/ops.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from opencosmo._lib import index as idxlib +from opencosmo.index import into_array + +if TYPE_CHECKING: + from opencosmo.index import DataIndex + + +def reindex_column(index: DataIndex, column: np.ndarray): + column = column.astype(np.int64) + return idxlib.reindex_column(into_array(index), column) diff --git a/python/opencosmo/spatial/check.py b/python/opencosmo/spatial/check.py index 8c1473fe..9553cdf6 100644 --- a/python/opencosmo/spatial/check.py +++ b/python/opencosmo/spatial/check.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from opencosmo.dataset.dataset import Dataset - from opencosmo.parameters import FileParameters + from opencosmo.dtypes import FileParameters from opencosmo.spatial.protocols import Region ALLOWED_COORDINATES_3D = { diff --git a/src/index.rs b/src/index.rs index b0c99152..1d76c58f 100644 --- a/src/index.rs +++ b/src/index.rs @@ -6,6 +6,7 @@ pub(crate) mod index { use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; + use std::collections::HashMap; use std::iter::zip; fn unpack_index_array<'py>(index: &Bound<'py, PyAny>) -> PyResult> { @@ -277,6 +278,40 @@ pub(crate) mod index { } } - Ok((Array1::from_vec(output_start), Array1::from_vec(output_size))) + Ok(( + Array1::from_vec(output_start), + Array1::from_vec(output_size), + )) + } + #[pyfunction(name = "reindex_column")] + fn reindex_columns_py<'py>( + py: Python<'py>, + index: &Bound<'_, PyAny>, + index_column: &Bound<'_, PyAny>, + ) -> PyResult>> { + let index_arr = unpack_index_array(index)?; + let index_column_arr = unpack_index_array(index_column)?; + Ok(reindex_column(index_arr.as_array(), index_column_arr.as_array()).into_pyarray(py)) + } + + fn reindex_column( + index: ArrayView1<'_, i64>, + index_column: ArrayView1<'_, i64>, + ) -> Array1 { + let mut index_map: HashMap = HashMap::new(); + for (i, &index_entry) in index.iter().enumerate() { + index_map.insert(index_entry, i as i64); + } + + let mut output: Vec = Vec::with_capacity(index_column.len()); + for val in index_column.iter() { + let val_index_opt = index_map.get(val); + if let Some(&val_index) = val_index_opt { + output.push(val_index) + } else { + output.push(-1) + } + } + Array1::from_vec(output) } } diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 890ad4ba..252ccf89 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -4,10 +4,10 @@ import astropy.units as u import numpy as np import pytest +from opencosmo.spatial.region import HealpixRegion import opencosmo as oc from opencosmo.column import add_mag_cols -from opencosmo.spatial.region import HealpixRegion @pytest.fixture @@ -277,10 +277,14 @@ def test_add_mag_units_unitless(core_path_475, core_path_487): def test_region(core_path_475, core_path_487): ds = oc.open(core_path_475, core_path_487) assert isinstance(ds.region, HealpixRegion) - print(ds.region.pixels) assert len(ds.region.pixels) == 502 def test_open_bad_data(core_path_475, core_path_487, invalid_data_path): with pytest.raises(ValueError, match=str(invalid_data_path)): oc.open(core_path_475, core_path_487, invalid_data_path) + + +def test_reindex_top_host(core_path_475, core_path_487): + _ = oc.open(core_path_475, core_path_487) + raise NotImplementedError From 8c9d3c7e3f5044d08ab57e1bc0644a40e9047abb Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 21 Apr 2026 17:22:44 -0500 Subject: [PATCH 051/139] Add allow_overwrite to with_new_columns and evalute apis --- .../collection/lightcone/lightcone.py | 8 ++- .../collection/simulation/simulation.py | 12 +++- .../collection/structure/structure.py | 59 +++++++++++++--- python/opencosmo/column/evaluate.py | 5 +- python/opencosmo/dataset/columns.py | 11 ++- python/opencosmo/dataset/dataset.py | 65 ++++++++++++++--- python/opencosmo/dataset/state.py | 2 + test/test_collection.py | 70 +++++++++++++++++++ test/test_dataset.py | 36 ++++++++++ 9 files changed, 243 insertions(+), 25 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 8ad88ea6..39389d34 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -744,6 +744,7 @@ def evaluate( insert=True, format: str = "astropy", batch_size: int = -1, + allow_overwrite: bool = False, **evaluate_kwargs, ): """ @@ -832,6 +833,7 @@ def evaluate( "with_new_columns", mapped_arguments=mapped_evaluated_columns, construct=True, + allow_overwrite=allow_overwrite, ) result = self.__map( @@ -842,6 +844,7 @@ def evaluate( insert=insert, mapped_arguments=mapped_kwargs, batch_size=batch_size, + allow_overwrite=allow_overwrite, construct=insert, **evaluate_kwargs, ) @@ -1162,6 +1165,7 @@ def __take_rows(self, rows: np.ndarray): def with_new_columns( self, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **columns: ConstructedColumn | np.ndarray | u.Quantity, ): """ @@ -1216,7 +1220,9 @@ def with_new_columns( for i, (ds_name, ds) in enumerate(self.items()): raw_columns = {name: arrs[i] for name, arrs in raw_split.items()} columns_input = raw_columns | derived - new_dataset = ds.with_new_columns(descriptions, **columns_input) + new_dataset = ds.with_new_columns( + descriptions, allow_overwrite=allow_overwrite, **columns_input + ) new_datasets[ds_name] = new_dataset return Lightcone(new_datasets, self.z_range, self.__hidden, self.__ordered_by) diff --git a/python/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py index f7402ad0..e3284c16 100644 --- a/python/opencosmo/collection/simulation/simulation.py +++ b/python/opencosmo/collection/simulation/simulation.py @@ -301,6 +301,7 @@ def with_new_columns( *args, datasets: Optional[str | Iterable[str]] = None, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | np.ndarray, ): """ @@ -324,6 +325,9 @@ def with_new_columns( :py:attr:`SimulationCollection(datasets).descriptions `. If a dictionary, should have keys matching the column names. + allow_overwrite: bool, default = False + + ** columns : opencosmo.DerivedColumn | np.ndarray | units.Quantity The new columns """ @@ -336,7 +340,10 @@ def with_new_columns( output = {name: ds for name, ds in self.items()} for ds_name in datasets: output[ds_name] = output[ds_name].with_new_columns( - *args, descriptions=descriptions, **new_columns + *args, + descriptions=descriptions, + allow_overwrite=allow_overwrite, + **new_columns, ) return SimulationCollection(output) @@ -345,6 +352,7 @@ def with_new_columns( *args, descriptions=descriptions, datasets=datasets, + allow_overwrite=allow_overwrite, **new_columns, ) @@ -355,6 +363,7 @@ def evaluate( format: str = "astropy", vectorize: bool = False, insert: bool = False, + allow_overwrite: bool = False, **evaluate_kwargs, ): """ @@ -403,6 +412,7 @@ def evaluate( vectorize=vectorize, insert=insert, format=format, + allow_overwrite=allow_overwrite, construct=insert, **evaluate_kwargs, ) diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index fb9e8ae2..43570924 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -290,6 +290,7 @@ def evaluate( dataset: Optional[str] = None, format: str = "astropy", insert: bool = True, + allow_overwrite: bool = False, **evaluate_kwargs: Any, ): """ @@ -383,7 +384,12 @@ def computation(halo_properties, dm_particles): # If the user sets insert=False, everything is eager. if dataset is not None and dataset == self.__source.dtype: return self.evaluate_on_dataset( - func, dataset=dataset, format=format, insert=insert, **evaluate_kwargs + func, + dataset=dataset, + format=format, + insert=insert, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) if format not in ["astropy", "numpy"]: @@ -397,7 +403,12 @@ def computation(halo_properties, dm_particles): sub_dataset = dataset_path[1] result = self[dataset_path[0]].evaluate( - func, sub_dataset, format, insert, **evaluate_kwargs + func, + sub_dataset, + format, + insert, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) if not insert: return result @@ -427,6 +438,7 @@ def computation(halo_properties, dm_particles): func, insert=insert, format=format, + allow_overwrite=allow_overwrite, strategy="chunked", **evaluate_kwargs, ) @@ -463,6 +475,7 @@ def computation(halo_properties, dm_particles): return self.with_new_columns( **output, dataset=dataset if dataset is not None else self.__source.dtype, + allow_overwrite=allow_overwrite, ) def evaluate_on_dataset( @@ -473,6 +486,7 @@ def evaluate_on_dataset( format: str = "astropy", insert: bool = True, batch_size: int = -1, + allow_overwrite: bool = False, **evaluate_kwargs: Any, ): """ @@ -530,7 +544,12 @@ def evaluate_on_dataset( ds: oc.Dataset | StructureCollection if dataset is None or dataset == self.__source.dtype: result = self.__source.evaluate( - func, vectorize, insert, format, **evaluate_kwargs + func, + vectorize, + insert, + format, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) if not insert: return result @@ -554,7 +573,14 @@ def evaluate_on_dataset( ) if len(ds_path) == 1 and isinstance(ds, oc.Dataset): - result = ds.evaluate(func, vectorize, insert, format, **evaluate_kwargs) + result = ds.evaluate( + func, + vectorize, + insert, + format, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, + ) if not insert: return result assert isinstance(result, oc.Dataset) @@ -570,11 +596,18 @@ def evaluate_on_dataset( self.__derived_columns.union(new_derived_columns_), ) elif len(ds_path) == 1 and isinstance(ds, oc.StructureCollection): - result = ds.evaluate(func, None, format, insert) + result = ds.evaluate( + func, None, format, insert, allow_overwrite=allow_overwrite + ) elif len(ds_path) > 1 and isinstance(ds, oc.StructureCollection): result = ds.evaluate_on_dataset( - func, ".".join(dataset[1:]), vectorize, format, insert + func, + ".".join(ds_path[1:]), + vectorize, + format, + insert, + allow_overwrite=allow_overwrite, ) if not insert: @@ -1054,6 +1087,7 @@ def with_new_columns( self, dataset: str, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | np.ndarray, ): """ @@ -1122,7 +1156,10 @@ def with_new_columns( if not isinstance(new_collection, StructureCollection): raise ValueError(f"{collection_name} is not a collection!") new_collection = new_collection.with_new_columns( - ".".join(path[1:]), descriptions=descriptions, **new_columns + ".".join(path[1:]), + descriptions=descriptions, + allow_overwrite=allow_overwrite, + **new_columns, ) return StructureCollection( self.__source, @@ -1134,7 +1171,9 @@ def with_new_columns( if dataset == self.__source.dtype: new_source = self.__source.with_new_columns( - **new_columns, descriptions=descriptions + **new_columns, + descriptions=descriptions, + allow_overwrite=allow_overwrite, ) return StructureCollection( new_source, @@ -1156,7 +1195,9 @@ def with_new_columns( if not isinstance(ds, oc.Dataset): raise ValueError(f"{dataset} is not a dataset!") - new_ds = ds.with_new_columns(**new_columns, descriptions=descriptions) + new_ds = ds.with_new_columns( + **new_columns, descriptions=descriptions, allow_overwrite=allow_overwrite + ) new_derived_columns = ( set(new_ds.columns).difference(ds.columns).difference(new_im_cols) ) diff --git a/python/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py index ae5a67e0..20decfc1 100644 --- a/python/opencosmo/column/evaluate.py +++ b/python/opencosmo/column/evaluate.py @@ -92,7 +92,10 @@ def __make_chunked_based_output_from_first_values(values, data_length): def evaluate_vectorized(data, func, kwargs, index): - return func(**data, **kwargs, index=index) + try: + return func(**data, **kwargs, index=index) + except TypeError: + return func(**data, **kwargs) def do_first_evaluation( diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py index 19290c4a..052db7ad 100644 --- a/python/opencosmo/dataset/columns.py +++ b/python/opencosmo/dataset/columns.py @@ -110,8 +110,12 @@ def add_columns( descriptions: dict[str, str], new_columns: dict, ds_length: int, + allow_overwrite: bool, ) -> tuple[list[ConstructedColumn], ColumnMap, UnitHandler]: - if inter := set(name_to_uuid.keys()).intersection(new_columns.keys()): + if ( + inter := set(name_to_uuid.keys()).intersection(new_columns.keys()) + and not allow_overwrite + ): raise ValueError(f"Some columns are already in the dataset: {inter}") ( @@ -124,11 +128,14 @@ def add_columns( # Extend the name→UUID map with new producers' outputs so that columns added # in the same with_new_columns call can reference each other. + # For overwritten columns, preserve the OLD UUID so that new producers can + # correctly declare a dependency on the existing data rather than on themselves. extended_name_to_uuid = dict(name_to_uuid) for producer in new_derived_columns: if producer.produces: for name in producer.produces: - extended_name_to_uuid[name] = producer.uuid + if name not in name_to_uuid: + extended_name_to_uuid[name] = producer.uuid new_derived_columns = [ producer.bind(extended_name_to_uuid) for producer in new_derived_columns diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 24e26d5e..666c1397 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -397,6 +397,7 @@ def evaluate( insert=True, format="astropy", batch_size: int = -1, + allow_overwrite: bool = False, _verify: bool = True, **evaluate_kwargs, ) -> Dataset | dict[str, np.ndarray]: @@ -411,18 +412,31 @@ def evaluate( instead of adding it as a column. The function should take in arguments with the same name as the columns in this dataset that - are needed for the computation, and should return a dictionary of output values. - The dataset will automatically selected the needed columns to avoid reading unnecessarily reading - data from disk + are needed for the computation, and should return a dictionary of output values. Any addition + arguments needed by the function can be passed as keyword arguments to :code:`evaluate`. - The new columns will have the same names as the keys of the output dictionary - See :ref:`Evaluating On Datasets` for more details. + The dataset will automatically selected the needed columns to avoid reading unnecessarily reading + data from disk. The new columns will have the same names as the keys of the output dictionary + See :ref:`Evaluating On Datasets` for more details. The keys of this dictionary must be different + from the names of the columns that are already in the dataset, unless allow_overwrite is set + to :code`True` If vectorize is set to True, the full columns will be pased to the dataset. Otherwise, - rows will be passed to the function one at a time. + rows will be passed to the function one at a time. If the function returns None, this method + will also return None as output. + + Keyword arguments can be used to pass in external values that are not columns in the dataset. + For example, we can compute each halo's gas fraction bias — how much gas it retains relative to + the cosmic baryon fraction — by passing the dataset's cosmology object as a keyword argument: + + .. code-block:: python + + def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): + f_gas = sod_halo_mass_gas / sod_halo_mass + f_cosmic = cosmology.Ob0 / cosmology.Om0 + return {"sod_halo_baryon_bias": f_gas / f_cosmic} - If the function returns None, this method will also return None as output. For example, the function - could simply produce plots and save the to files. + ds = ds.evaluate(baryon_fraction_bias, cosmology=ds.cosmology, vectorize=True) Parameters ---------- @@ -440,6 +454,7 @@ def evaluate( format: str, default = astropy Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that this method does not support all the formats available in :py:meth:`get_data ` + allow_overwrite: bool, default = False batch_size: int, default = -1 If set, feed data to the function in batches of the specified size. Default is -1, which disables batching. If @@ -464,7 +479,9 @@ def evaluate( return output return self.with_new_columns( - descriptions={}, **{func.__name__: evaluated_column} + descriptions={}, + allow_overwrite=allow_overwrite, + **{func.__name__: evaluated_column}, ) def filter(self, *masks: ColumnMask) -> Dataset: @@ -591,7 +608,7 @@ def select( new_state = self.__state if derived_columns: - new_state = new_state.with_new_columns({}, **derived_columns) + new_state = new_state.with_new_columns({}, False, **derived_columns) all_columns.update(derived_columns.keys()) new_state = new_state.select(all_columns) @@ -779,6 +796,7 @@ def take_rows(self, rows: np.ndarray | DataIndex): def with_new_columns( self, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | Column | np.ndarray | u.Quantity, ): """ @@ -789,6 +807,25 @@ def with_new_columns( quantities will not change under unit transformations. See :ref:`Adding Custom Columns` for examples. + If allow_overwrite is :code:`True`, the new column may have the same name as + a column that already exists in the dataset. This can be used to transform a column, + for example: + + .. code-block:: python + + log_mass = oc.col("fof_halo_mass").log10() + ds = ds.with_new_columns(fof_halo_mass=log_mass, allow_overwrite=True) + + The "fof_halo_mass" column will now be the log of the original "fof_halo_mass" column. + + Columns will be given the same name as the argument you use when you pass them into the function. + For example, we could do the same as above but name the column "log_fof_halo_mass" with + + .. code-block:: python + + log_mass = oc.col("fof_halo_mass").logo10() + ds = ds.with_new_columns(log_fof_halo_mass = log_mass) + Parameters ---------- @@ -796,8 +833,12 @@ def with_new_columns( A description for the new columns. These descriptions will be accessible through :py:attr:`Dataset.descriptions `. If a dictionary, should have keys matching the column names. + allow_overwrites : bool, default = False + If false, attempting to add a new column with the same name as an existing column will throw an error. + If true, overwrites are allowed. ** new_columns : opencosmo.DerivedColumn | np.ndarray | units.Quantity + The new columns to add. The name of the argument is the name the column will take. Returns ------- @@ -807,7 +848,9 @@ def with_new_columns( """ if isinstance(descriptions, str): descriptions = {key: descriptions for key in new_columns.keys()} - new_state = self.__state.with_new_columns(descriptions, **new_columns) + new_state = self.__state.with_new_columns( + descriptions, allow_overwrite, **new_columns + ) return Dataset(self.__header, new_state, self.__tree) def make_schema( diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 8820e4cd..4587a0f8 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -358,6 +358,7 @@ def make_schema(self, name: Optional[str] = None): def with_new_columns( self, descriptions: dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | np.ndarray | u.Quantity, ): """ @@ -373,6 +374,7 @@ def with_new_columns( descriptions, new_columns, len(self), + allow_overwrite=allow_overwrite, ) return self.__rebuild( cache=self.__cache, diff --git a/test/test_collection.py b/test/test_collection.py index e9a6496e..89f439a4 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -530,6 +530,36 @@ def eval_fn(fof_halo_mass, fof_halo_tag): ) +def test_structure_collection_evaluate_overwrite(halo_paths): + collection = oc.open(*halo_paths).take(20) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + collection.evaluate(fof_halo_mass, dataset="halo_properties") + + collection_overwritten = collection.evaluate( + fof_halo_mass, + dataset="halo_properties", + allow_overwrite=True, + vectorize=True, + ) + original = ( + collection["halo_properties"] + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + overwritten = ( + collection_overwritten["halo_properties"] + .select("fof_halo_mass") + .get_data("numpy") + ) + assert np.all( + np.isclose(overwritten, original["fof_halo_mass"] * original["fof_halo_com_vx"]) + ) + + def test_visit_dataset_in_structure_collection_nochunk(halo_paths): collection = oc.open(*halo_paths) @@ -565,6 +595,22 @@ def offset( assert np.all(offset_vec == offset_loop) +def test_evaluate_on_dataset_nested_path(halo_paths, galaxy_paths): + collection = oc.open(*halo_paths, *galaxy_paths).take(10) + + def particle_id(x, y, z): + return np.arange(len(x)) + + collection = collection.evaluate_on_dataset( + particle_id, dataset="galaxies.star_particles", vectorize=True, insert=True + ) + assert "particle_id" in collection["galaxies"]["star_particles"].columns + for halo in collection.halos(["galaxies"]): + for galaxy in halo["galaxies"].galaxies(["star_particles"]): + pid = galaxy["star_particles"].select("particle_id").get_data("numpy") + assert np.all(pid == np.arange(len(pid))) + + def test_visit_galaxies_in_halo_collection(halo_paths, galaxy_paths): collection = oc.open(*halo_paths, *galaxy_paths).take(10) @@ -1027,6 +1073,30 @@ def fof_px(fof_halo_mass, fof_halo_com_vx, random_value, other_value): ) +def test_simulation_collection_evaluate_overwrite(multi_path): + collection = oc.open(multi_path) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + collection.evaluate(fof_halo_mass, vectorize=True, insert=True) + + collection_overwritten = collection.evaluate( + fof_halo_mass, vectorize=True, insert=True, allow_overwrite=True + ) + for ds_name, ds in collection.items(): + original = ds.select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + overwritten = ( + collection_overwritten[ds_name].select("fof_halo_mass").get_data("numpy") + ) + assert np.all( + np.isclose( + overwritten, original["fof_halo_mass"] * original["fof_halo_com_vx"] + ) + ) + + def test_simulation_collection_add(multi_path): collection = oc.open(multi_path) ds_name = next(iter(collection.keys())) diff --git a/test/test_dataset.py b/test/test_dataset.py index ecf4de53..43aedb30 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -717,3 +717,39 @@ def test_description_with_insert_multiple(input_path, tmp_path): descriptions = ds.descriptions assert descriptions["random_data"] == "random data for a test" assert descriptions["halo_px"] == "halo x momentum" + + +def test_with_new_columns_overwrite(input_path): + ds = oc.open(input_path) + log_mass = oc.col("fof_halo_mass").log10() + + with pytest.raises(ValueError): + ds.with_new_columns(fof_halo_mass=log_mass) + + ds_overwritten = ds.with_new_columns(fof_halo_mass=log_mass, allow_overwrite=True) + assert "fof_halo_mass" in ds_overwritten.columns + + original = ds.select("fof_halo_mass").get_data("numpy") + overwritten = ds_overwritten.select("fof_halo_mass").get_data("numpy") + assert np.all(np.isclose(overwritten, np.log10(original))) + + +def test_evaluate_overwrite(input_path): + ds = oc.open(input_path) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + ds.evaluate(fof_halo_mass, vectorize=True, insert=True) + + ds_overwritten = ds.evaluate( + fof_halo_mass, vectorize=True, insert=True, allow_overwrite=True + ) + assert "fof_halo_mass" in ds_overwritten.columns + + data = ds.select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + overwritten = ds_overwritten.select("fof_halo_mass").get_data("numpy") + assert np.all( + np.isclose(overwritten, data["fof_halo_mass"] * data["fof_halo_com_vx"]) + ) From f6079c46b70b963f80f44a46a77feeaee3af4494 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 23 Apr 2026 14:46:54 -0500 Subject: [PATCH 052/139] Offset plugin for diffsky lightcones --- .../collection/lightcone/lightcone.py | 15 ++- python/opencosmo/dataset/dataset.py | 2 +- python/opencosmo/dataset/instantiate.py | 104 +++++++++--------- python/opencosmo/dataset/output.py | 2 +- python/opencosmo/dtypes/diffsky.py | 27 ++++- python/opencosmo/dtypes/dtype.py | 12 +- python/opencosmo/units/handler.py | 37 ++++--- 7 files changed, 122 insertions(+), 77 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 39389d34..bd1037eb 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -26,6 +26,7 @@ from opencosmo.dataset import Dataset from opencosmo.dataset.evaluate import build_evaluated_column from opencosmo.dataset.formats import convert_data, verify_format +from opencosmo.dtypes.dtype import get_dtype_lightcone_plugins from opencosmo.io.iopen import open_single_dataset from opencosmo.io.mpi import get_all_keys from opencosmo.io.schema import FileEntry, make_schema @@ -224,6 +225,13 @@ def with_redshift_column(dataset: Dataset): ) +def apply_plugins(datasets: Iterable[Dataset | Lightcone], plugins: list[Callable]): + datasets = list(datasets) + if len(plugins) == 0: + return datasets + return reduce(lambda datasets_, plugin: plugin(datasets_), plugins, datasets) + + class Lightcone(dict): """ A lightcone contains two or more datasets that are part of a lightcone. Typically @@ -270,6 +278,7 @@ def __init__( self.__hidden = hidden self.__ordered_by = ordered_by + self.__plugins = get_dtype_lightcone_plugins(self.__header, self.columns) def __repr__(self): """ @@ -284,7 +293,7 @@ def __repr__(self): repr_ds = self.take(10, at="start") table_head = "First 10 rows:\n" - table_repr = repr_ds.data.__repr__() + table_repr = repr_ds.get_data().__repr__() # remove the first line table_repr = table_repr[table_repr.find("\n") + 1 :] z_range = self.z_range @@ -470,8 +479,10 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): ) format = kwargs["output"] verify_format(format) + datasets = apply_plugins(self.values(), self.__plugins) + print(datasets) - data = [ds.get_data(unpack=unpack) for ds in self.values()] + data = [ds.get_data(unpack=unpack) for ds in datasets] data_with_length = [d for d in data if len(d) > 0] if len(data_with_length) == 0: return data[0] diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 666c1397..ee490ca7 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -64,7 +64,7 @@ def __repr__(self): repr_ds = self.take(10, at="start") table_head = "First 10 rows:\n" - table_repr = repr_ds.data.__repr__() + table_repr = repr_ds.get_data().__repr__() # remove the first line table_repr = table_repr[table_repr.find("\n") + 1 :] head = f"OpenCosmo Dataset (length={length})\n" diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 6a9bb511..9335c57d 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from uuid import UUID - from opencosmo.column.cache import CacheKey from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache, DataHandler from opencosmo.index import DataIndex @@ -20,7 +19,7 @@ def get_all_required_pairs( columns_to_uuid: dict[str, UUID], dependency_graph: rx.PyDiGraph -) -> set[CacheKey]: +) -> set[tuple[UUID, str]]: """ Return the full set of (producer_uuid, column_name) pairs needed to produce the requested columns, including all transitive dependencies. @@ -35,7 +34,9 @@ def get_all_required_pairs( required_nodes.add(node_idx) required_nodes.update(rx.ancestors(dependency_graph, node_idx)) - pairs: set[CacheKey] = {(uuid, name) for name, uuid in columns_to_uuid.items()} + pairs: set[tuple[UUID, str]] = { + (uuid, name) for name, uuid in columns_to_uuid.items() + } for node_idx in required_nodes: producer = dependency_graph[node_idx] for name in producer.produces: @@ -110,6 +111,47 @@ def build_derived_columns( return new_derived +def __cache_raw_columns( + raw_columns: list[RawColumn], + raw_data: dict[str, Any], + working_columns: dict[str, UUID], + unit_handler: UnitHandler, + unit_kwargs: dict[str, Any], + cache: DataCache, +) -> dict[UUID, dict[str, Any]]: + """ + Write freshly-fetched raw columns to the cache and return the merged + (pre- and post-conversion) UUID-keyed data for merging into uuid_data. + + Pre-conversion data is pushed up to parent caches. Converted data is kept + local (push_up=False) to avoid propagating dataset-specific unit conversions + to parent caches, which could cause drift through repeated rounding. + """ + raw_by_uuid: dict[UUID, dict[str, Any]] = { + col.uuid: {col.alias or col.name: raw_data[col.alias or col.name]} + for col in raw_columns + if (col.alias or col.name) in raw_data + } + converted_by_uuid = unit_handler.apply_unit_conversions(raw_by_uuid, unit_kwargs) + + cacheable = set(working_columns.values()) + cache.add_data( + {uuid: data for uuid, data in raw_by_uuid.items() if uuid in cacheable}, + {}, + push_up=True, + ) + cache.add_data( + {uuid: data for uuid, data in converted_by_uuid.items() if uuid in cacheable}, + {}, + push_up=False, + ) + + return { + uuid: (data | converted_by_uuid.get(uuid, {})) + for uuid, data in raw_by_uuid.items() + } + + def instantiate_dataset( column_producers: list[ConstructedColumn], columns_to_uuid: dict[str, UUID], @@ -134,14 +176,10 @@ def instantiate_dataset( cached_data = cache.get_data(required_pairs) - # Apply unit conversions to cached data. The unit handler works on flat - # name-keyed dicts; we flatten, convert, then fold results back in. - flat_cached = _flatten(cached_data) - converted_flat = unit_handler.apply_unit_conversions(flat_cached, unit_kwargs) - if converted_flat: - converted_uuid = _apply_uuid_mapping(converted_flat, required_pairs) - cache.add_data(converted_uuid, {}, push_up=False) - for uuid, col_data in converted_uuid.items(): + converted_cached = unit_handler.apply_unit_conversions(cached_data, unit_kwargs) + if converted_cached: + cache.add_data(converted_cached, {}, push_up=False) + for uuid, col_data in converted_cached.items(): cached_data.setdefault(uuid, {}).update(col_data) # Determine which raw columns still need to be fetched from the handler. @@ -168,26 +206,9 @@ def instantiate_dataset( uuid_data |= new_derived - # Write freshly-fetched raw data back to the cache. - # Only map raw columns by UUID — if a non-raw producer shadows a raw column name, - # the set-iteration order in _apply_uuid_mapping would nondeterministically pick one, - # potentially caching unprocessed data under the derived producer's UUID. - raw_producer_uuids = { - col.uuid for col in column_producers if isinstance(col, RawColumn) - } - raw_required_pairs = { - pair for pair in required_pairs if pair[0] in raw_producer_uuids - } - raw_data |= unit_handler.apply_unit_conversions(raw_data, unit_kwargs) - raw_data_uuid = _apply_uuid_mapping(raw_data, raw_required_pairs) - to_add = { - uuid: value - for uuid, value in raw_data_uuid.items() - if uuid in working_columns.values() - } - cache.add_data(to_add, {}, push_up=True) - for uuid, col_data in raw_data_uuid.items(): - uuid_data.setdefault(uuid, {}).update(col_data) + uuid_data |= __cache_raw_columns( + raw_columns, raw_data, working_columns, unit_handler, unit_kwargs, cache + ) data = { name: uuid_data[producer_uuid][name] @@ -221,24 +242,3 @@ def sort_data(data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None): if sort_by[1]: order = order[::-1] return {key: value[order] for key, value in data.items()} - - -def _flatten(uuid_data: dict[UUID, dict[str, np.ndarray]]) -> dict[str, np.ndarray]: - """Collapse UUID-keyed data to a flat name-keyed dict (last writer wins).""" - return {name: arr for d in uuid_data.values() for name, arr in d.items()} - - -def _apply_uuid_mapping( - flat_data: dict[str, np.ndarray], - required_pairs: set[CacheKey], -) -> dict[UUID, dict[str, np.ndarray]]: - """ - Map a flat name-keyed dict back to UUID-keyed form using required_pairs to - resolve which UUID owns each column name. - """ - name_to_uuid: dict[str, UUID] = {name: uuid for uuid, name in required_pairs} - result: dict[UUID, dict[str, np.ndarray]] = {} - for name, arr in flat_data.items(): - if name in name_to_uuid: - result.setdefault(name_to_uuid[name], {})[name] = arr - return result diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py index 6cd7f86d..b535f792 100644 --- a/python/opencosmo/dataset/output.py +++ b/python/opencosmo/dataset/output.py @@ -58,7 +58,7 @@ def build_derived_writers( for name, cd in coldata.items() } for name, cd in coldata.items(): - attrs = {"unit": units[name], "description": producer.description} + attrs = {"unit": units[name], "description": producer.description or "None"} source = NumpySource(cd) writer = ColumnWriter([source], ColumnCombineStrategy.CONCAT, attrs=attrs) data_schema.columns[name] = writer diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index c2f3ad78..56edb948 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -1,13 +1,17 @@ from __future__ import annotations from datetime import datetime # noqa -from typing import ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar, Optional +import numpy as np from pydantic import BaseModel, ConfigDict, field_serializer from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy from opencosmo.index.ops import reindex_column +if TYPE_CHECKING: + from opencosmo import Dataset + class DiffskyVersionInfo(BaseModel): model_config = ConfigDict(frozen=True) @@ -35,6 +39,27 @@ def serialize_zphot_table(self, value): return None +def __offset(top_host_idx, offset): + output = top_host_idx + output[output >= 0] += offset + return {"top_host_idx": output} + + +def offset_top_host_idx(datasets: list[Dataset]): + lengths = [len(ds) for ds in datasets] + offsets = np.cumsum(lengths) + output_datasets = [datasets[0]] + for offset, ds in zip(offsets, datasets[1:]): + output_ds = ds.evaluate( + __offset, + offset=offset, + vectorize=True, + allow_overwrite=True, + ) + output_datasets.append(output_ds) # type: ignore + return output_datasets + + def rebuild_top_host_idx(top_host_idx, index): result = reindex_column(index, top_host_idx) diff --git a/python/opencosmo/dtypes/dtype.py b/python/opencosmo/dtypes/dtype.py index bb4137c8..e86d6bc6 100644 --- a/python/opencosmo/dtypes/dtype.py +++ b/python/opencosmo/dtypes/dtype.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from opencosmo.dtypes import hacc, lightcone -from opencosmo.dtypes.diffsky import top_host_idx +from opencosmo.dtypes.diffsky import offset_top_host_idx, top_host_idx if TYPE_CHECKING: from pydantic import BaseModel @@ -32,7 +32,7 @@ def get_dtype_column_plugins( producers, columns, ): - plugins = __get_plugins(header) + plugins = __get_column_plugins(header) for name, producer in plugins.items(): if name not in columns: continue @@ -43,7 +43,13 @@ def get_dtype_column_plugins( return producers, columns -def __get_plugins(header): +def __get_column_plugins(header): if header.file.data_type == "synthetic_galaxies": return {"top_host_idx": top_host_idx} return {} + + +def get_dtype_lightcone_plugins(header, columns): + if header.file.data_type == "synthetic_galaxies" and "top_host_idx" in columns: + return [offset_top_host_idx] + return [] diff --git a/python/opencosmo/units/handler.py b/python/opencosmo/units/handler.py index 2281ed8a..de582254 100644 --- a/python/opencosmo/units/handler.py +++ b/python/opencosmo/units/handler.py @@ -13,6 +13,8 @@ ) if TYPE_CHECKING: + from uuid import UUID + import h5py import numpy as np from astropy.cosmology import Cosmology @@ -218,25 +220,26 @@ def apply_raw_units(self, data: dict[str, np.ndarray], unit_kwargs): return columns def apply_unit_conversions( - self, data: dict[str, u.Quantity | np.ndarray], unit_kwargs - ): - # Only apply the unit CONVERSIONS. Useful for cached data - # Does not return data that was not updated - output_data = {} + self, + data: dict[UUID, dict[str, u.Quantity | np.ndarray]], + unit_kwargs, + ) -> dict[UUID, dict[str, u.Quantity | np.ndarray]]: + # Only apply the unit CONVERSIONS. Useful for cached data. + # Does not return data that was not updated. if not self.__conversions and not self.__column_conversions: return {} - for colname, column in data.items(): - if not isinstance(column, u.Quantity): - continue - assert isinstance(column, u.Quantity) - column_conversion = self.__column_conversions.get(colname) - unit_conversion = self.__conversions.get(str(column.unit)) - if unit_conversion is not None and column_conversion is None: - output_data[colname] = column.to(unit_conversion) - elif column_conversion is not None: - output_data[colname] = column.to(column_conversion) - - return output_data + output: dict[UUID, dict[str, u.Quantity | np.ndarray]] = {} + for uuid, uuid_data in data.items(): + for colname, column in uuid_data.items(): + if not isinstance(column, u.Quantity): + continue + column_conversion = self.__column_conversions.get(colname) + unit_conversion = self.__conversions.get(str(column.unit)) + if unit_conversion is not None and column_conversion is None: + output.setdefault(uuid, {})[colname] = column.to(unit_conversion) + elif column_conversion is not None: + output.setdefault(uuid, {})[colname] = column.to(column_conversion) + return output def apply_units(self, data: dict[str, np.ndarray], unit_kwargs): if self.__current_convention == UnitConvention.UNITLESS: From 2c090f19d021689ecc7dc89fc259d0808f85d3e9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 23 Apr 2026 15:20:48 -0500 Subject: [PATCH 053/139] Add basic plugin specification --- python/opencosmo/plugins/__init__.py | 0 python/opencosmo/plugins/plugin.py | 60 ++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 python/opencosmo/plugins/__init__.py create mode 100644 python/opencosmo/plugins/plugin.py diff --git a/python/opencosmo/plugins/__init__.py b/python/opencosmo/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py new file mode 100644 index 00000000..0b1f775a --- /dev/null +++ b/python/opencosmo/plugins/plugin.py @@ -0,0 +1,60 @@ +from collections import defaultdict +from enum import StrEnum +from functools import reduce +from typing import Callable, NamedTuple, TypedDict + +import opencosmo as oc + + +class PluginType(StrEnum): + DatasetOpen = "dataset_open" + DatasetInstantiate = "dataset_instantiate" + LightconeLoad = "lightcone_load" + LightconeInstantiate = "lightcone_instantiate" + + +DatasetTransformationPlugin = Callable[[oc.Dataset], oc.Dataset] +LightconeTransformationPlugn = Callable[[oc.Lightcone], oc.Lightcone] + + +type DatasetType[T: (oc.Dataset, oc.Lightcone)] = T +type Verifier[T: DatasetType] = Callable[[T], bool] +type Plugin[T: DatasetType] = Callable[[T], T] + + +class PluginSpec[T: DatasetType](NamedTuple): + plugin_type: PluginType + verifier: Verifier[T] + plugin: Plugin[T] + + +class Plugins(TypedDict): + dataset_open: list[PluginSpec[oc.Dataset]] + dataset_instantiate: list[PluginSpec[oc.Dataset]] + lightcone_load: list[PluginSpec[oc.Lightcone]] + lightcone_instantiate: list[PluginSpec[oc.Lightcone]] + + +KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore + + +def register_plugin[T: DatasetType]( + plugin_type: PluginType, + verifier: Verifier[T], + plugin: Plugin[T], +) -> None: + spec = PluginSpec(plugin_type=plugin_type, verifier=verifier, plugin=plugin) + KNOWN_PLUGINS[str(plugin_type)].append(spec) # type: ignore + + +def apply_plugins[T: DatasetType](plugin_type: PluginType, dataset: T) -> T: + plugins_to_apply = KNOWN_PLUGINS[str(plugin_type)] # type: ignore + return reduce( + lambda ds, spec: apply_single_plugin(spec, ds), plugins_to_apply, dataset + ) + + +def apply_single_plugin[T: DatasetType](spec: PluginSpec[T], dataset: T) -> T: + if spec.verifier(dataset): + return spec.plugin(dataset) + return dataset From b083c4dc38f83f73ebdfd9407fd2465d5d3a0bf8 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 23 Apr 2026 16:51:43 -0500 Subject: [PATCH 054/139] Move top_host_idx reindexing into plugin system --- .../collection/lightcone/lightcone.py | 18 ++--- python/opencosmo/dataset/state.py | 19 +++-- python/opencosmo/dtypes/diffsky.py | 53 +++++++++++--- python/opencosmo/dtypes/dtype.py | 7 +- python/opencosmo/io/iopen.py | 3 +- python/opencosmo/plugins/plugin.py | 39 +++++----- test/test_diffsky.py | 71 ++++++++++++++++++- 7 files changed, 154 insertions(+), 56 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index bd1037eb..60b65926 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -27,10 +27,11 @@ from opencosmo.dataset.evaluate import build_evaluated_column from opencosmo.dataset.formats import convert_data, verify_format from opencosmo.dtypes.dtype import get_dtype_lightcone_plugins -from opencosmo.io.iopen import open_single_dataset +from opencosmo.io import iopen from opencosmo.io.mpi import get_all_keys from opencosmo.io.schema import FileEntry, make_schema from opencosmo.mpi import get_comm_world, get_mpi +from opencosmo.plugins import plugin if TYPE_CHECKING: import astropy.units as u # type: ignore @@ -225,13 +226,6 @@ def with_redshift_column(dataset: Dataset): ) -def apply_plugins(datasets: Iterable[Dataset | Lightcone], plugins: list[Callable]): - datasets = list(datasets) - if len(plugins) == 0: - return datasets - return reduce(lambda datasets_, plugin: plugin(datasets_), plugins, datasets) - - class Lightcone(dict): """ A lightcone contains two or more datasets that are part of a lightcone. Typically @@ -479,10 +473,8 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): ) format = kwargs["output"] verify_format(format) - datasets = apply_plugins(self.values(), self.__plugins) - print(datasets) - - data = [ds.get_data(unpack=unpack) for ds in datasets] + lightcone = plugin.apply_plugins(plugin.PluginType.LightconeInstantiate, self) + data = [ds.get_data(unpack=unpack) for ds in lightcone.values()] data_with_length = [d for d in data if len(d) > 0] if len(data_with_length) == 0: return data[0] @@ -527,7 +519,7 @@ def open(cls, targets: list[FileTarget], **kwargs): for i, ds_target in enumerate(dataset_targets): group_name = ds_target["dataset_group"].name.split("/")[-1] group_name = group_name.lstrip(f"{ds_target['header'].file.step}_") - ds = open_single_dataset(ds_target, bypass_lightcone=True) + ds = iopen.open_single_dataset(ds_target, bypass_lightcone=True) step = ds_target["header"].file.step if step is None: step = i diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 4587a0f8..e2cf810c 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -13,12 +13,12 @@ from opencosmo.dataset.columns import add_columns, resort from opencosmo.dataset.instantiate import instantiate_dataset from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema -from opencosmo.dtypes.dtype import get_dtype_column_plugins from opencosmo.handler.empty import EmptyHandler from opencosmo.handler.hdf5 import Hdf5Handler from opencosmo.index.build import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range +from opencosmo.plugins.plugin import PluginType, apply_plugins from opencosmo.units import UnitConvention from opencosmo.units.handler import ( make_unit_handler_from_hdf5, @@ -123,9 +123,6 @@ def from_target( for cname in handler.columns ] columns = {p.name: p.uuid for p in producers} - producers, columns = get_dtype_column_plugins( - target["header"], producers, columns - ) cache = ColumnCache.empty() return DatasetState( producers, @@ -252,15 +249,17 @@ def get_data( """ Get the data for a given handler. """ + state = apply_plugins(PluginType.DatasetInstantiate, self) + data = instantiate_dataset( - list(self.__producers.values()), - self.__columns, - self.__raw_data_handler, - self.__cache, - self.__unit_handler, + list(state.__producers.values()), + state.__columns, + state.__raw_data_handler, + state.__cache, + state.__unit_handler, unit_kwargs, metadata_columns, - None if ignore_sort else self.__sort_by, + None if ignore_sort else state.__sort_by, ) if missing := set(self.columns).difference(data.keys()): raise RuntimeError( diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 56edb948..7b983c11 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -8,9 +8,11 @@ from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy from opencosmo.index.ops import reindex_column +from opencosmo.plugins.plugin import PluginType, register_plugin if TYPE_CHECKING: - from opencosmo import Dataset + from opencosmo import Dataset, Lightcone + from opencosmo.dataset.state import DatasetState class DiffskyVersionInfo(BaseModel): @@ -66,12 +68,45 @@ def rebuild_top_host_idx(top_host_idx, index): return {"top_host_idx": result} -top_host_idx = EvaluatedColumn( - rebuild_top_host_idx, - requires=set(["top_host_idx"]), - produces=set(["top_host_idx"]), - format="numpy", - units={"top_host_idx": None}, - strategy=EvaluateStrategy.VECTORIZE, - no_cache=True, +def top_host_idx_plugin(dataset: DatasetState): + top_host_idx = EvaluatedColumn( + rebuild_top_host_idx, + requires=set(["top_host_idx"]), + produces=set(["top_host_idx"]), + format="numpy", + units={"top_host_idx": None}, + strategy=EvaluateStrategy.VECTORIZE, + no_cache=True, + ) + return dataset.with_new_columns(updated_host_idx=top_host_idx, allow_overwrite=True) + + +def top_host_idx_offset_plugin(lightcone: Lightcone) -> dict[str, Dataset]: + cs = 0 + output = {} + + def top_host_idx(top_host_idx, offset): + top_host_idx[top_host_idx >= 0] += offset + return top_host_idx + + for key, ds in lightcone.items(): + output[key] = ds.evaluate( + top_host_idx, allow_overwrite=True, vectorize=True, offset=cs + ) + cs += len(ds) + + return output + + +def top_host_idx_verifier[T: (DatasetState, Dataset, Lightcone)](dataset: T) -> bool: + return ( + dataset.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in dataset.columns + ) + + +register_plugin(PluginType.DatasetOpen, top_host_idx_verifier, top_host_idx_plugin) + +register_plugin( # type: ignore + PluginType.LightconeInstantiate, top_host_idx_verifier, top_host_idx_offset_plugin ) diff --git a/python/opencosmo/dtypes/dtype.py b/python/opencosmo/dtypes/dtype.py index e86d6bc6..5e2b4420 100644 --- a/python/opencosmo/dtypes/dtype.py +++ b/python/opencosmo/dtypes/dtype.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING -from opencosmo.dtypes import hacc, lightcone -from opencosmo.dtypes.diffsky import offset_top_host_idx, top_host_idx +from opencosmo.dtypes import diffsky, hacc, lightcone if TYPE_CHECKING: from pydantic import BaseModel @@ -45,11 +44,11 @@ def get_dtype_column_plugins( def __get_column_plugins(header): if header.file.data_type == "synthetic_galaxies": - return {"top_host_idx": top_host_idx} + return {"top_host_idx": diffsky.top_host_idx} return {} def get_dtype_lightcone_plugins(header, columns): if header.file.data_type == "synthetic_galaxies" and "top_host_idx" in columns: - return [offset_top_host_idx] + return [diffsky.offset_top_host_idx] return [] diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index eaef0ff7..f87e3bd5 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -16,6 +16,7 @@ from opencosmo.header import OpenCosmoHeader, read_header from opencosmo.index.build import empty, from_range from opencosmo.mpi import get_comm_world +from opencosmo.plugins.plugin import PluginType, apply_plugins from opencosmo.spatial.builders import from_model from opencosmo.spatial.region import FullSkyRegion, HealpixRegion from opencosmo.spatial.tree import open_tree @@ -547,7 +548,7 @@ def open_single_dataset( {"data": dataset}, header.lightcone["z_range"] ) - return dataset + return apply_plugins(PluginType.DatasetOpen, dataset) def __open_healpix_map(dataset: oc.Dataset, sim_region): diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index 0b1f775a..c2a55584 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -1,9 +1,12 @@ +from __future__ import annotations + from collections import defaultdict from enum import StrEnum from functools import reduce from typing import Callable, NamedTuple, TypedDict -import opencosmo as oc +from opencosmo import dataset as ds +from opencosmo.collection.lightcone import lightcone as lc class PluginType(StrEnum): @@ -13,32 +16,32 @@ class PluginType(StrEnum): LightconeInstantiate = "lightcone_instantiate" -DatasetTransformationPlugin = Callable[[oc.Dataset], oc.Dataset] -LightconeTransformationPlugn = Callable[[oc.Lightcone], oc.Lightcone] - +DatasetTransformationPlugin = Callable[[ds.Dataset], ds.Dataset] +LightconeTransformationPlugin = Callable[[lc.Lightcone], dict[str, ds.Dataset]] -type DatasetType[T: (oc.Dataset, oc.Lightcone)] = T -type Verifier[T: DatasetType] = Callable[[T], bool] -type Plugin[T: DatasetType] = Callable[[T], T] +type Verifier[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)] = Callable[ + [T], bool +] +type Plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)] = Callable[[T], T] -class PluginSpec[T: DatasetType](NamedTuple): +class PluginSpec[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)](NamedTuple): plugin_type: PluginType verifier: Verifier[T] plugin: Plugin[T] class Plugins(TypedDict): - dataset_open: list[PluginSpec[oc.Dataset]] - dataset_instantiate: list[PluginSpec[oc.Dataset]] - lightcone_load: list[PluginSpec[oc.Lightcone]] - lightcone_instantiate: list[PluginSpec[oc.Lightcone]] + dataset_open: list[PluginSpec[ds.Dataset]] + dataset_instantiate: list[PluginSpec[ds.state.DatasetState]] + lightcone_load: list[PluginSpec[lc.Lightcone]] + lightcone_instantiate: list[PluginSpec[lc.Lightcone]] KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore -def register_plugin[T: DatasetType]( +def register_plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( plugin_type: PluginType, verifier: Verifier[T], plugin: Plugin[T], @@ -47,14 +50,18 @@ def register_plugin[T: DatasetType]( KNOWN_PLUGINS[str(plugin_type)].append(spec) # type: ignore -def apply_plugins[T: DatasetType](plugin_type: PluginType, dataset: T) -> T: +def apply_plugins[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( + plugin_type: PluginType, dataset: T +) -> T: plugins_to_apply = KNOWN_PLUGINS[str(plugin_type)] # type: ignore return reduce( - lambda ds, spec: apply_single_plugin(spec, ds), plugins_to_apply, dataset + lambda ds_, spec: apply_single_plugin(spec, ds_), plugins_to_apply, dataset ) -def apply_single_plugin[T: DatasetType](spec: PluginSpec[T], dataset: T) -> T: +def apply_single_plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( + spec: PluginSpec[T], dataset: T +) -> T: if spec.verifier(dataset): return spec.plugin(dataset) return dataset diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 252ccf89..04be421e 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -2,6 +2,7 @@ import shutil import astropy.units as u +import h5py import numpy as np import pytest from opencosmo.spatial.region import HealpixRegion @@ -285,6 +286,70 @@ def test_open_bad_data(core_path_475, core_path_487, invalid_data_path): oc.open(core_path_475, core_path_487, invalid_data_path) -def test_reindex_top_host(core_path_475, core_path_487): - _ = oc.open(core_path_475, core_path_487) - raise NotImplementedError +def get_expected_core_tags(path): + with h5py.File(path) as f: + raw_top_host = f["cores"]["data"]["top_host_idx"][:] + core_tag = f["cores"]["data"]["core_tag"][:] + + top_host_core_tag = core_tag[raw_top_host] + return dict(zip(core_tag, top_host_core_tag)) + + +def test_reindex_top_host_take_none(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + found_top_host_core_tag = data["core_tag"][data["top_host_idx"]] + found_core_map = dict(zip(data["core_tag"], found_top_host_core_tag)) + assert core_map == found_core_map + + +def test_reindex_top_host_take_random(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).take(300) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + has_top_host = data["top_host_idx"] >= 0 + + found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] + found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in data["core_tag"] + } + assert should_have_core_map == found_core_map + + +def test_reindex_top_host_take_range(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).take_range(100, 400) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + has_top_host = data["top_host_idx"] >= 0 + + found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] + found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in data["core_tag"] + } + assert should_have_core_map == found_core_map From 8f1511a1539d46ac0273564ca45e32014122641a Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 23 Apr 2026 17:22:49 -0500 Subject: [PATCH 055/139] Add new test and changelog --- changes/+e334af41.feature.rst | 1 + test/test_diffsky.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 changes/+e334af41.feature.rst diff --git a/changes/+e334af41.feature.rst b/changes/+e334af41.feature.rst new file mode 100644 index 00000000..3c60cba0 --- /dev/null +++ b/changes/+e334af41.feature.rst @@ -0,0 +1 @@ +Calls to :py:meth:`with_new_columns ` and :py:meth:`evaluate ` now accept an `allow_overwrite` flag. In this way you can "transform" a column by creating a derived column that depends on a column of the same name. The input will be the original column, and the output will be the new version. diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 04be421e..be2374be 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -353,3 +353,27 @@ def test_reindex_top_host_take_range(core_path_475, core_path_487): if val in data["core_tag"] and key in data["core_tag"] } assert should_have_core_map == found_core_map + + +def test_reindex_top_host_filter(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).filter(oc.col("logsm_obs") > 10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + has_top_host = data["top_host_idx"] >= 0 + + found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] + found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in data["core_tag"] + } + assert should_have_core_map == found_core_map From c4e3de4f6f5dc319924e9c41cb953935e22b3bde Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 24 Apr 2026 12:14:24 -0500 Subject: [PATCH 056/139] Add fixes for writing with stacking --- .../opencosmo/collection/lightcone/stack.py | 42 +++++++- test/test_diffsky.py | 102 +++++++++--------- 2 files changed, 91 insertions(+), 53 deletions(-) diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index 22ac5be1..ed3ada3b 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -24,6 +24,37 @@ def update_order(data: np.ndarray, comm: Optional[MPI.Comm], order: np.ndarray): return data[order] +def update_top_host_idx( + data: np.ndarray, + comm: Optional[MPI.Comm], + order: np.ndarray, + slice_sizes: list[int], +): + result = data.copy() + + # Add per-slice global offsets so local row references become global + offset = 0 + pos = 0 + for size in slice_sizes: + segment = result[pos : pos + size] + segment[segment >= 0] += offset + offset += size + pos += size + + # Reorder row positions (same as all other columns) + if comm is not None: + result = update_global_order_mpi(result, comm, order) + else: + result = result[order] + + # Remap stored row references through the inverse permutation + inverse_order = np.argsort(order) + output = np.full_like(result, -1) + valid = result >= 0 + output[valid] = inverse_order[result[valid]] + return output + + def update_global_order_mpi(data, comm, order): needs_global_reordering = comm.allgather(np.any((order < 0) | (order > len(order)))) if not np.any(needs_global_reordering): @@ -117,9 +148,16 @@ def stack_lightcone_datasets_in_schema( order = get_stacked_lightcone_order(ds_list, max_level) updater = partial(update_order, order=order) + slice_sizes = [len(dataset) for dataset in ds_list] + top_host_idx_updater = partial( + update_top_host_idx, order=order, slice_sizes=slice_sizes + ) - for column in new_data_group.columns.values(): - column.set_transformation(updater) + for col_name, column in new_data_group.columns.items(): + if col_name == "top_host_idx": + column.set_transformation(top_host_idx_updater) + else: + column.set_transformation(updater) new_index_group = stack_index_groups( [schema.children["index"] for schema in schemas] diff --git a/test/test_diffsky.py b/test/test_diffsky.py index be2374be..90421c3c 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -150,9 +150,13 @@ def offset( def test_open_write_with_synthetics(core_path_475, core_path_487, per_test_dir): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + n = 10_000 ds = oc.open(core_path_487, core_path_475, synth_cores=True) ds = ds.filter(oc.col("lsst_g") < 20).take(n).with_new_columns(galid=np.arange(n)) + original_data = ds.get_data() assert len(original_data) == n assert len(ds) == n @@ -164,13 +168,30 @@ def test_open_write_with_synthetics(core_path_475, core_path_487, per_test_dir): ds = oc.open(per_test_dir / "test.hdf5", synth_cores=True) assert len(ds) == n - written_data = ds.get_data() + # Check top_host_idx correctness on the spatially-sorted written data before + # any reordering, since top_host_idx values are indices into that ordering. + reindex_data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + assert_top_host_idx_correct(reindex_data, core_map) + written_data = ds.get_data() written_data.sort("galid") - columns_to_check = np.random.choice(ds.columns, size=20, replace=False) + other_columns = [c for c in ds.columns if c != "top_host_idx"] + columns_to_check = np.random.choice(other_columns, size=20, replace=False) + for column in columns_to_check: assert np.all(original_data[column] == written_data[column]) + # Check the synthetic cores still point to themselves + synth_core_check_data = ds.select("top_host_idx", "core_tag").get_data("numpy") + synth_core_rows = np.where(synth_core_check_data["core_tag"] == -1)[0] + assert np.all( + synth_core_check_data["top_host_idx"][synth_core_rows] == synth_core_rows + ) + def test_open_write_with_multiple_synthetics( core_path_475, core_path_487, per_test_dir @@ -295,29 +316,20 @@ def get_expected_core_tags(path): return dict(zip(core_tag, top_host_core_tag)) -def test_reindex_top_host_take_none(core_path_475, core_path_487): - core_map = get_expected_core_tags(core_path_475) - core_map |= get_expected_core_tags(core_path_487) - - ds = oc.open(core_path_475, core_path_487) - data = ds.select("top_host_idx", "core_tag").get_data("numpy") - - found_top_host_core_tag = data["core_tag"][data["top_host_idx"]] - found_core_map = dict(zip(data["core_tag"], found_top_host_core_tag)) - assert core_map == found_core_map - - -def test_reindex_top_host_take_random(core_path_475, core_path_487): - core_map = get_expected_core_tags(core_path_475) - core_map |= get_expected_core_tags(core_path_487) - - ds = oc.open(core_path_475, core_path_487).take(300) - data = ds.select("top_host_idx", "core_tag").get_data("numpy") +def assert_top_host_idx_correct(data, core_map): + """ + Verify that top_host_idx correctly references the expected host after any + transformation. For each row with a valid index, checks that the referenced + row has the core_tag the original file associates with that host. Also + asserts that every galaxy whose host is present in the data has a valid index. + data: numpy dict with "top_host_idx" and "core_tag" columns. + core_map: dict mapping core_tag -> host core_tag from the raw HDF5 files. + """ has_top_host = data["top_host_idx"] >= 0 - found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) + filtered_core_map = { key: val for key, val in core_map.items() if key in found_core_map } @@ -331,49 +343,37 @@ def test_reindex_top_host_take_random(core_path_475, core_path_487): assert should_have_core_map == found_core_map -def test_reindex_top_host_take_range(core_path_475, core_path_487): +def test_reindex_top_host_take_none(core_path_475, core_path_487): core_map = get_expected_core_tags(core_path_475) core_map |= get_expected_core_tags(core_path_487) - ds = oc.open(core_path_475, core_path_487).take_range(100, 400) + ds = oc.open(core_path_475, core_path_487) data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) - has_top_host = data["top_host_idx"] >= 0 - found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] - found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) - filtered_core_map = { - key: val for key, val in core_map.items() if key in found_core_map - } - assert filtered_core_map == found_core_map +def test_reindex_top_host_take_random(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) - should_have_core_map = { - key: val - for key, val in core_map.items() - if val in data["core_tag"] and key in data["core_tag"] - } - assert should_have_core_map == found_core_map + ds = oc.open(core_path_475, core_path_487).take(300) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) -def test_reindex_top_host_filter(core_path_475, core_path_487): +def test_reindex_top_host_take_range(core_path_475, core_path_487): core_map = get_expected_core_tags(core_path_475) core_map |= get_expected_core_tags(core_path_487) - ds = oc.open(core_path_475, core_path_487).filter(oc.col("logsm_obs") > 10) + ds = oc.open(core_path_475, core_path_487).take_range(100, 400) data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) - has_top_host = data["top_host_idx"] >= 0 - found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] - found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) - filtered_core_map = { - key: val for key, val in core_map.items() if key in found_core_map - } - assert filtered_core_map == found_core_map +def test_reindex_top_host_filter(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) - should_have_core_map = { - key: val - for key, val in core_map.items() - if val in data["core_tag"] and key in data["core_tag"] - } - assert should_have_core_map == found_core_map + ds = oc.open(core_path_475, core_path_487).filter(oc.col("logsm_obs") > 10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) From 7d22512277b6c86a53b17ac39b672e954973b910 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 24 Apr 2026 12:20:12 -0500 Subject: [PATCH 057/139] Fix docs for renaming --- docs/source/parameters_ref.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/parameters_ref.rst b/docs/source/parameters_ref.rst index c4074a25..9f78f0f5 100644 --- a/docs/source/parameters_ref.rst +++ b/docs/source/parameters_ref.rst @@ -17,7 +17,7 @@ Cosmology Most OpenCosmo files will contain cosmology parameters, which describe the cosmology the simulation was run under. In general you will not interact with this parameter block directly. Instead, requiresting it will return an astropy.cosmology.Cosmology object. Dataset and collections will generally make this object available directly with the :py:attr:`.cosmology ` attribute. -.. autoclass:: opencosmo.parameters.cosmology.CosmologyParameters +.. autoclass:: opencosmo.dtypes.cosmology.CosmologyParameters :members: :undoc-members: :exclude-members: model_config, ACCESS_PATH, ACCESS_TRANSFORMATION @@ -28,14 +28,14 @@ Simulation Parameters Data that was originally produced by HACC will contain the parameters that were used to initialize the simulation. Datasets and collections will generally make these paramters available with the :py:attr:`.simulation ` attribute. -.. autoclass:: opencosmo.parameters.hacc.HaccSimulationParameters +.. autoclass:: opencosmo.dtypes.hacc.HaccSimulationParameters :members: :undoc-members: :exclude-members: model_config,empty_string_to_none,cosmology_parameters,ACCESS_PATH :member-order: bysource -.. autoclass:: opencosmo.parameters.hacc.HaccHydroSimulationParameters +.. autoclass:: opencosmo.dtypes.hacc.HaccHydroSimulationParameters :members: :undoc-members: :exclude-members: model_config From 896aa3e79e3f6dd8cd526d0d1718383884e21a20 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 24 Apr 2026 12:26:57 -0500 Subject: [PATCH 058/139] Skip setting up AWS cli if cache already exists --- .github/workflows/test.yaml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 25311478..d318e67a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,12 +4,6 @@ jobs: get-test-data: runs-on: ubuntu-latest steps: - - name: Setup AWS CLI - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.TEST_DATA_ACCESS_KEY }} - aws-secret-access-key: ${{ secrets.TEST_DATA_SECRET_KEY }} - aws-region: us-west-2 - name: check if cache exists id: check-cache uses: actions/cache@v4 @@ -19,6 +13,13 @@ jobs: lookup-only: true restore-keys: | test-data + - name: Setup AWS CLI + if: steps.check-cache.outputs.cache-hit != 'true' + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.TEST_DATA_ACCESS_KEY }} + aws-secret-access-key: ${{ secrets.TEST_DATA_SECRET_KEY }} + aws-region: us-west-2 - name: Download test data if: steps.check-cache.outputs.cache-hit != 'true' run: aws s3 cp s3://${{ secrets.TEST_DATA_BUCKET }}/test_data.tar.gz test_data.tar.gz From 22abc15261d6f4d7f31ea28705d5be176ce8a861 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 27 Apr 2026 10:03:28 -0500 Subject: [PATCH 059/139] Add some MPI logic, which will be wired up in a later PR --- .../opencosmo/collection/lightcone/stack.py | 34 ++++++++++++- python/opencosmo/io/mpi.py | 1 + test/parallel/test_lc_mpi.py | 49 ++++++++++++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index ed3ada3b..904e3c17 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -24,6 +24,28 @@ def update_order(data: np.ndarray, comm: Optional[MPI.Comm], order: np.ndarray): return data[order] +def _global_inverse_order_mpi(order: np.ndarray, comm) -> np.ndarray: + """ + Build the global inverse permutation across all MPI ranks. + + Each rank's `order` (possibly referencing rows on other ranks) is combined + into a single global permutation of length N. The inverse maps global input + position -> global output position in the written file. + """ + sizes = comm.allgather(len(order)) + ends = np.cumsum(sizes) + starts = np.insert(ends, 0, 0) + N = int(starts[-1]) + + # global_order[i] is the global input position that lands at global output i + global_order = order + starts[comm.Get_rank()] + all_global_orders = np.concatenate(comm.allgather(global_order)) + + global_inv = np.empty(N, dtype=np.int64) + global_inv[all_global_orders] = np.arange(N) + return global_inv + + def update_top_host_idx( data: np.ndarray, comm: Optional[MPI.Comm], @@ -34,6 +56,12 @@ def update_top_host_idx( # Add per-slice global offsets so local row references become global offset = 0 + if comm is not None: + offsets = np.cumsum(comm.allgather(len(data))) + rank = comm.Get_rank() + if rank > 0: + offset = offsets[rank - 1] + pos = 0 for size in slice_sizes: segment = result[pos : pos + size] @@ -44,11 +72,15 @@ def update_top_host_idx( # Reorder row positions (same as all other columns) if comm is not None: result = update_global_order_mpi(result, comm, order) + # After a cross-rank shuffle, result[valid] may contain global indices + # (0..N-1). np.argsort(order) only covers this rank's local portion, so + # we need the full global inverse permutation instead. + inverse_order = _global_inverse_order_mpi(order, comm) else: result = result[order] + inverse_order = np.argsort(order) # Remap stored row references through the inverse permutation - inverse_order = np.argsort(order) output = np.full_like(result, -1) valid = result >= 0 output[valid] = inverse_order[result[valid]] diff --git a/python/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py index cc990ad3..6d2daeb8 100644 --- a/python/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -297,6 +297,7 @@ def __replace_writers_with_updates(schema: Schema, comm: MPI.Comm): child_schema = schema.children.get(cn, make_schema(cn, FileEntry.EMPTY)) new_child_schema = __replace_writers_with_updates(child_schema, comm) schema.children[cn] = new_child_schema + return schema diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index fd9c2f3f..2bc3ba3f 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -2,8 +2,8 @@ import shutil import astropy.units as u +import h5py import numpy as np -import opencosmo as oc import pytest from astropy.coordinates import SkyCoord from healpy import pix2ang @@ -11,6 +11,8 @@ from opencosmo.mpi import get_comm_world from pytest_mpi.parallel_assert import parallel_assert +import opencosmo as oc + IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -528,3 +530,48 @@ def test_lightcone_stacking( assert np.all(np.isin(all_fof_tags, all_fof_tags_new)) assert ds_new.z_range == ds.z_range assert next(iter(ds_new.values())).header.lightcone["z_range"] == ds_new.z_range + + +def _get_expected_core_tags(path): + with h5py.File(path) as f: + raw_top_host = f["cores"]["data"]["top_host_idx"][:] + core_tag = f["cores"]["data"]["core_tag"][:] + top_host_core_tag = core_tag[raw_top_host] + return dict(zip(core_tag, top_host_core_tag)) + + +def _assert_top_host_idx_correct(data, core_map): + """ + Verify top_host_idx is correctly remapped in `data` (a numpy dict with + "top_host_idx" and "core_tag" keys). Works for both local (per-rank) and + global (gathered) data. + + Synthetic cores (core_tag == -1) are checked separately: each must point + to its own row index. Real cores are checked against core_map. + """ + synth_mask = data["core_tag"] == -1 + synth_indices = np.where(synth_mask)[0] + assert np.all(data["top_host_idx"][synth_mask] == synth_indices) + + # Restrict to real cores for the map check, but dereference top_host_idx + # against the full data so that indices into synthetic rows resolve correctly. + real_mask = ~synth_mask + real_top_host_idx = data["top_host_idx"][real_mask] + real_core_tag = data["core_tag"][real_mask] + + has_top_host = real_top_host_idx >= 0 + found_top_host_core_tag = data["core_tag"][real_top_host_idx[has_top_host]] + found_core_map = dict(zip(real_core_tag[has_top_host], found_top_host_core_tag)) + + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in real_core_tag + } + print(len(should_have_core_map)) + assert should_have_core_map == found_core_map From 2dfa8cc6eca2147e6211d08f3bc20ed3023d21b0 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 27 Apr 2026 15:57:55 -0500 Subject: [PATCH 060/139] Add support for index plugins --- pyproject.toml | 4 + .../collection/lightcone/lightcone.py | 7 +- python/opencosmo/dataset/state.py | 15 ++- python/opencosmo/dtypes/diffsky.py | 74 +++++++++++++-- python/opencosmo/io/iopen.py | 13 ++- python/opencosmo/plugins/plugin.py | 93 ++++++++++++------- test/test_diffsky.py | 70 +++++++++++++- 7 files changed, 224 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df1e4247..6e65d592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,10 @@ dependencies = [ "rustworkx>=0.17.1,<1.0", ] +[project.urls] +Repository = "https://github.com/ArgonneCPAC/OpenCosmo" +Documentation = "https://opencosmo.readthedocs.io/en/stable/" + [dependency-groups] dev = [ "ruff>=0.9.4,<1.0.0", diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 60b65926..38c5135a 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -519,7 +519,9 @@ def open(cls, targets: list[FileTarget], **kwargs): for i, ds_target in enumerate(dataset_targets): group_name = ds_target["dataset_group"].name.split("/")[-1] group_name = group_name.lstrip(f"{ds_target['header'].file.step}_") - ds = iopen.open_single_dataset(ds_target, bypass_lightcone=True) + ds = iopen.open_single_dataset( + ds_target, bypass_lightcone=True, open_kwargs=kwargs + ) step = ds_target["header"].file.step if step is None: step = i @@ -538,6 +540,8 @@ def open(cls, targets: list[FileTarget], **kwargs): raise ValueError() result = cls(output) + result = plugin.apply_plugins(plugin.PluginType.LightconeLoad, result, **kwargs) + return make_radec_columns(result) @classmethod @@ -1163,6 +1167,7 @@ def __take_rows(self, rows: np.ndarray): if len(split) > 0: output[name] = dataset.take_rows(split - rs) rs += len(dataset) + return Lightcone(output, self.z_range, self.__hidden, self.__ordered_by) def with_new_columns( diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index e2cf810c..35e48bfc 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -18,7 +18,7 @@ from opencosmo.index.build import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range -from opencosmo.plugins.plugin import PluginType, apply_plugins +from opencosmo.plugins.plugin import PluginType, apply_index_plugins, apply_plugins from opencosmo.units import UnitConvention from opencosmo.units.handler import ( make_unit_handler_from_hdf5, @@ -322,7 +322,10 @@ def get_metadata(self, columns=[]): return metadata def with_mask(self, mask: NDArray[np.bool_]): - index = np.where(mask)[0] + return self.with_index(np.where(mask)[0]) + + def with_index(self, index: DataIndex): + index = apply_index_plugins(self, index) new_raw_handler = self.__raw_data_handler.take(index) new_cache = self.__cache.take(index) return self.__rebuild( @@ -479,16 +482,12 @@ def take_range(self, start: int, end: int): else: take_index = np.sort(sorted[start:end]) - new_raw_handler = self.__raw_data_handler.take(take_index) - new_im = self.__cache.take(take_index) - return self.__rebuild( - raw_data_handler=new_raw_handler, - cache=new_im, - ) + return self.take_rows(take_index) def take_rows(self, rows: DataIndex): if len(self) == 0: return self + rows = apply_index_plugins(self, rows) row_range = get_range(rows) if row_range[1] > len(self) or row_range[0] < 0: diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 7b983c11..221e82a3 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -7,12 +7,19 @@ from pydantic import BaseModel, ConfigDict, field_serializer from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy +from opencosmo.index import into_array from opencosmo.index.ops import reindex_column -from opencosmo.plugins.plugin import PluginType, register_plugin +from opencosmo.plugins.plugin import ( + IndexPluginSpec, + PluginSpec, + PluginType, + register_plugin, +) if TYPE_CHECKING: from opencosmo import Dataset, Lightcone from opencosmo.dataset.state import DatasetState + from opencosmo.index import DataIndex class DiffskyVersionInfo(BaseModel): @@ -68,7 +75,7 @@ def rebuild_top_host_idx(top_host_idx, index): return {"top_host_idx": result} -def top_host_idx_plugin(dataset: DatasetState): +def top_host_idx_plugin(dataset: DatasetState, **kwargs): top_host_idx = EvaluatedColumn( rebuild_top_host_idx, requires=set(["top_host_idx"]), @@ -98,15 +105,70 @@ def top_host_idx(top_host_idx, offset): return output -def top_host_idx_verifier[T: (DatasetState, Dataset, Lightcone)](dataset: T) -> bool: +def top_host_idx_verifier[T: (DatasetState, Dataset, Lightcone)]( + dataset: T, **kwargs +) -> bool: return ( dataset.header.file.data_type == "synthetic_galaxies" and "top_host_idx" in dataset.columns ) -register_plugin(PluginType.DatasetOpen, top_host_idx_verifier, top_host_idx_plugin) +def keep_top_host_idx(dataset: DatasetState, new_index: DataIndex): + index_array = into_array(new_index) + top_host_idx = dataset.select({"top_host_idx"}).get_data()["top_host_idx"] + unique_in_sample = np.unique(top_host_idx[index_array]) + + missing_hosts = np.setdiff1d(unique_in_sample, index_array) + all_satellites = np.where(np.isin(top_host_idx, unique_in_sample))[0] + missing_satellites = np.setdiff1d(all_satellites, index_array) + + if len(missing_hosts) == 0 and len(missing_satellites) == 0: + return new_index + + all_missing = np.sort(np.concatenate((missing_hosts, missing_satellites))) + + insert_idx = np.searchsorted(index_array, all_missing) + result = np.insert(index_array, insert_idx, all_missing) + + return result -register_plugin( # type: ignore - PluginType.LightconeInstantiate, top_host_idx_verifier, top_host_idx_offset_plugin + +def keep_top_host_idx_verifier(dataset: DatasetState): + return "top_host_idx" in dataset.columns + + +def register_keep_top_host_idx(dataset: Dataset, **kwargs): + register_plugin( + IndexPluginSpec( + PluginType.IndexUpdate, keep_top_host_idx_verifier, keep_top_host_idx + ) + ) + return dataset + + +def register_keep_top_host_idx_verifier( + dataset: Dataset, keep_top_host: bool = False, **kwargs +): + return keep_top_host and top_host_idx_verifier(dataset) + + +register_plugin( + PluginSpec(PluginType.DatasetOpen, top_host_idx_verifier, top_host_idx_plugin) +) + +register_plugin( + PluginSpec( + PluginType.LightconeInstantiate, + top_host_idx_verifier, + top_host_idx_offset_plugin, + ) +) + +register_plugin( + PluginSpec( + PluginType.LightconeLoad, + register_keep_top_host_idx_verifier, + register_keep_top_host_idx, + ) ) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index f87e3bd5..af94a17f 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -87,7 +87,7 @@ def open_files(paths: list[Path], open_kwargs: dict[str, Any]): if len(valid_targets) > 1: collection_type = __determine_multi_file_collection_type(valid_targets) return collection_type.open(valid_targets, **open_kwargs) - return __open_single_file(valid_targets[0]) + return __open_single_file(valid_targets[0], open_kwargs) def __make_group_map(group: h5py.File | h5py.Group, prefix: str = ""): @@ -119,14 +119,18 @@ def __make_file_target(path: Path, open_kwargs: dict[str, Any]) -> Optional[File ) -def __open_single_file(target: FileTarget) -> oc.Dataset | oc.collection.Collection: +def __open_single_file( + target: FileTarget, open_kwargs: dict[str, Any] = {} +) -> oc.Dataset | oc.collection.Collection: """ Opens a single file, which may or may not contain several datasets """ if len(target["dataset_targets"]) == 1: # Just one dataset, easy - return open_single_dataset(target["dataset_targets"][0]) + return open_single_dataset( + target["dataset_targets"][0], open_kwargs=open_kwargs + ) elif target["dataset_targets"]: # Multiple datasets, but all grouped together @@ -470,6 +474,7 @@ def open_single_dataset( metadata_group: Optional[str] = None, bypass_lightcone: bool = False, bypass_mpi: bool = False, + open_kwargs: dict[str, Any] = {}, ): header = target["header"] ds_group = target["dataset_group"] @@ -548,7 +553,7 @@ def open_single_dataset( {"data": dataset}, header.lightcone["z_range"] ) - return apply_plugins(PluginType.DatasetOpen, dataset) + return apply_plugins(PluginType.DatasetOpen, dataset, **open_kwargs) def __open_healpix_map(dataset: oc.Dataset, sim_region): diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index c2a55584..60734d5a 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -3,10 +3,17 @@ from collections import defaultdict from enum import StrEnum from functools import reduce -from typing import Callable, NamedTuple, TypedDict +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NamedTuple, + TypedDict, +) -from opencosmo import dataset as ds -from opencosmo.collection.lightcone import lightcone as lc +if TYPE_CHECKING: + from opencosmo.dataset.state import DatasetState + from opencosmo.index import DataIndex class PluginType(StrEnum): @@ -14,54 +21,76 @@ class PluginType(StrEnum): DatasetInstantiate = "dataset_instantiate" LightconeLoad = "lightcone_load" LightconeInstantiate = "lightcone_instantiate" + IndexUpdate = "index_update" -DatasetTransformationPlugin = Callable[[ds.Dataset], ds.Dataset] -LightconeTransformationPlugin = Callable[[lc.Lightcone], dict[str, ds.Dataset]] +type Verifier[T] = Callable[[T], bool] +type Plugin[T] = Callable[[T], T] -type Verifier[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)] = Callable[ - [T], bool -] -type Plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)] = Callable[[T], T] - -class PluginSpec[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)](NamedTuple): +class PluginSpec[T](NamedTuple): plugin_type: PluginType verifier: Verifier[T] plugin: Plugin[T] +class IndexPluginSpec(NamedTuple): + plugin_type: PluginType + verifier: Callable[[DatasetState], bool] + plugin: Callable[[DatasetState, DataIndex], DataIndex] + + class Plugins(TypedDict): - dataset_open: list[PluginSpec[ds.Dataset]] - dataset_instantiate: list[PluginSpec[ds.state.DatasetState]] - lightcone_load: list[PluginSpec[lc.Lightcone]] - lightcone_instantiate: list[PluginSpec[lc.Lightcone]] + dataset_open: list[PluginSpec] + dataset_instantiate: list[PluginSpec] + lightcone_load: list[PluginSpec] + lightcone_instantiate: list[PluginSpec] + index_update: list[IndexPluginSpec] KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore -def register_plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( - plugin_type: PluginType, - verifier: Verifier[T], - plugin: Plugin[T], -) -> None: - spec = PluginSpec(plugin_type=plugin_type, verifier=verifier, plugin=plugin) - KNOWN_PLUGINS[str(plugin_type)].append(spec) # type: ignore +def register_plugin(spec: PluginSpec | IndexPluginSpec) -> None: + KNOWN_PLUGINS[str(spec.plugin_type)].append(spec) # type: ignore + +def apply_plugins[T](plugin_type: PluginType, target: T, **kwargs: Any) -> T: + """Apply all registered plugins of the given type to target. -def apply_plugins[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( - plugin_type: PluginType, dataset: T -) -> T: + kwargs are forwarded to both the verifier and plugin, used to pass + open_kwargs through to DatasetOpen plugins. + """ plugins_to_apply = KNOWN_PLUGINS[str(plugin_type)] # type: ignore return reduce( - lambda ds_, spec: apply_single_plugin(spec, ds_), plugins_to_apply, dataset + lambda t, spec: _apply_single(spec, t, **kwargs), plugins_to_apply, target + ) + + +def apply_index_plugins(state: DatasetState, index: DataIndex) -> DataIndex: + """Apply all registered IndexUpdate plugins to index. + + Each plugin may expand or otherwise modify the position index returned by a + filter or take operation. Plugins run in registration order, each seeing the + output of the previous one. + """ + plugins_to_apply: list[IndexPluginSpec] = KNOWN_PLUGINS[str(PluginType.IndexUpdate)] # type: ignore + return reduce( + lambda idx, spec: _apply_single_index(spec, state, idx), + plugins_to_apply, + index, ) -def apply_single_plugin[T: (ds.Dataset, lc.Lightcone, ds.state.DatasetState)]( - spec: PluginSpec[T], dataset: T -) -> T: - if spec.verifier(dataset): - return spec.plugin(dataset) - return dataset +def _apply_single[T](spec: PluginSpec[T], target: T, **kwargs: Any) -> T: + if spec.verifier(target, **kwargs): + return spec.plugin(target, **kwargs) + return target + + +def _apply_single_index( + spec: IndexPluginSpec, state: DatasetState, index: DataIndex +) -> DataIndex: + if spec.verifier(state): + return spec.plugin(state, index) + return index diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 90421c3c..e41b5db0 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -299,7 +299,7 @@ def test_add_mag_units_unitless(core_path_475, core_path_487): def test_region(core_path_475, core_path_487): ds = oc.open(core_path_475, core_path_487) assert isinstance(ds.region, HealpixRegion) - assert len(ds.region.pixels) == 502 + assert len(ds.region.pixels) == 1610 def test_open_bad_data(core_path_475, core_path_487, invalid_data_path): @@ -316,6 +316,26 @@ def get_expected_core_tags(path): return dict(zip(core_tag, top_host_core_tag)) +def assert_all_group_members_present(data, core_map): + """ + Verify that for every top_host represented in the data, all rows from the + full dataset that point to that top_host are also present. + """ + host_to_members: dict = {} + for ct, host_ct in core_map.items(): + host_to_members.setdefault(host_ct, set()).add(ct) + + present_core_tags = set(data["core_tag"]) + top_host_core_tags = set(data["core_tag"][data["top_host_idx"]]) + + for top_host_ct in top_host_core_tags: + expected_members = host_to_members.get(top_host_ct, set()) + missing = expected_members - present_core_tags + assert not missing, ( + f"top_host {top_host_ct}: {len(missing)} member(s) missing from result" + ) + + def assert_top_host_idx_correct(data, core_map): """ Verify that top_host_idx correctly references the expected host after any @@ -377,3 +397,51 @@ def test_reindex_top_host_filter(core_path_475, core_path_487): ds = oc.open(core_path_475, core_path_487).filter(oc.col("logsm_obs") > 10) data = ds.select("top_host_idx", "core_tag").get_data("numpy") assert_top_host_idx_correct(data, core_map) + + +def test_keep_top_host_take_random(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take(20) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_take_start(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take(20, at="start") + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_take_range(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take_range(20, 60) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_filter(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.filter(oc.col("logsm_obs") < 10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) From 333250e8926ed322f0a768e13f06246711347039 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 08:23:09 -0500 Subject: [PATCH 061/139] Reorganization for lightcone --- .../collection/lightcone/__init__.py | 5 + python/opencosmo/collection/lightcone/io.py | 100 +++++++++ .../collection/lightcone/lightcone.py | 209 ++---------------- .../opencosmo/collection/lightcone/plugins.py | 42 ++++ .../opencosmo/collection/lightcone/utils.py | 77 +++++++ python/opencosmo/plugins/__init__.py | 3 + 6 files changed, 240 insertions(+), 196 deletions(-) create mode 100644 python/opencosmo/collection/lightcone/io.py create mode 100644 python/opencosmo/collection/lightcone/plugins.py create mode 100644 python/opencosmo/collection/lightcone/utils.py diff --git a/python/opencosmo/collection/lightcone/__init__.py b/python/opencosmo/collection/lightcone/__init__.py index 586a29eb..62f96d4a 100644 --- a/python/opencosmo/collection/lightcone/__init__.py +++ b/python/opencosmo/collection/lightcone/__init__.py @@ -1,4 +1,9 @@ +from importlib import import_module + from .healpix_map import HealpixMap from .lightcone import Lightcone +# register plugins +import_module(f"{__name__}.plugins") + __all__ = ["Lightcone", "HealpixMap"] diff --git a/python/opencosmo/collection/lightcone/io.py b/python/opencosmo/collection/lightcone/io.py new file mode 100644 index 00000000..d5ee892f --- /dev/null +++ b/python/opencosmo/collection/lightcone/io.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from collections import OrderedDict, defaultdict + +from opencosmo.collection.lightcone import utils as lcutils +from opencosmo.dataset import dataset as ocds +from opencosmo.io.mpi import get_all_keys +from opencosmo.mpi import get_comm_world, get_mpi + + +def order_by_redshift_range(datasets: dict[str, ocds.Dataset]): + redshift_ranges = { + key: lcutils.get_single_redshift_range(ds) for key, ds in datasets.items() + } + sorted_ranges = sorted(redshift_ranges.items(), key=lambda item: item[1][0]) + output = OrderedDict() + for name, _ in sorted_ranges: + output[name] = datasets[name] + return output + + +def combine_adjacent_datasets_mpi( + ordered_datasets: dict[str, dict[str, ocds.Dataset]], + min_dataset_size, +): + MIN_DATASET_SIZE = 100_000 + comm = get_comm_world() + MPI = get_mpi() + all_dataset_steps = get_all_keys(ordered_datasets, comm) + assert comm is not None and MPI is not None + rs = 0 + output_datasets: dict[str, list[dict[str, ocds.Dataset]]] = OrderedDict() + for step in all_dataset_steps: + if rs == 0: + current_key = step + output_datasets[current_key] = [] + + if step not in ordered_datasets: + rs += comm.allreduce(0, MPI.SUM) + else: + length = sum(len(ds) for ds in ordered_datasets[step].values()) + rs += comm.allreduce(length) + output_datasets[current_key].append(ordered_datasets[step]) + + if rs > MIN_DATASET_SIZE: + rs = 0 + + output = OrderedDict() + for step, datasets in output_datasets.items(): + step_output = defaultdict(list) + for ds_group in datasets: + for ds_type, ds in ds_group.items(): + step_output[ds_type].append(ds) + output[step] = step_output + + return output + + +def combine_adjacent_datasets( + ordered_datasets: dict[str, ocds.Dataset] | dict[str, dict[str, ocds.Dataset]], + min_dataset_size=100_000, +): + is_single = isinstance(next(iter(ordered_datasets.values())), ocds.Dataset) + datasets: dict[str, dict[str, ocds.Dataset]] + if is_single: + assert all(isinstance(ds, ocds.Dataset) for ds in ordered_datasets.values()) + datasets = {key: {"data": ds} for key, ds in ordered_datasets.items()} # type: ignore + else: + assert all(isinstance(ds, dict) for ds in ordered_datasets.values()) + datasets = ordered_datasets # type: ignore + + if get_comm_world() is not None: + return combine_adjacent_datasets_mpi(datasets, min_dataset_size) + + running_sum = 0 + + current_key = next(iter(ordered_datasets.keys())) + output_datasets: dict[str, list[dict[str, ocds.Dataset]]] = OrderedDict( + {current_key: []} + ) + + for key, step_datasets in datasets.items(): + if running_sum < min_dataset_size: + running_sum += sum(len(ds) for ds in step_datasets.values()) + output_datasets[current_key].append(step_datasets) + continue + current_key = key + output_datasets[current_key] = [step_datasets] + running_sum = sum(len(ds) for ds in step_datasets.values()) + + # We have list of dicts, go to dict of lists + output = OrderedDict() + for step, step_datasets_ in output_datasets.items(): + step_output = defaultdict(list) + for ds_group in step_datasets_: + for ds_type, ds in ds_group.items(): + step_output[ds_type].append(ds) + output[step] = step_output + + return output diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 38c5135a..a0b7e982 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections import OrderedDict, defaultdict +from collections import defaultdict from functools import cached_property, reduce from itertools import chain from typing import ( @@ -12,7 +12,6 @@ Mapping, Optional, Self, - Sequence, ) from warnings import warn @@ -20,26 +19,25 @@ from astropy.table import vstack # type: ignore import opencosmo as oc +from opencosmo.collection.lightcone import io as lcio +from opencosmo.collection.lightcone import utils as lcutils from opencosmo.collection.lightcone.coordinates import make_radec_columns from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn -from opencosmo.dataset import Dataset from opencosmo.dataset.evaluate import build_evaluated_column from opencosmo.dataset.formats import convert_data, verify_format from opencosmo.dtypes.dtype import get_dtype_lightcone_plugins from opencosmo.io import iopen -from opencosmo.io.mpi import get_all_keys from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.mpi import get_comm_world, get_mpi from opencosmo.plugins import plugin if TYPE_CHECKING: import astropy.units as u # type: ignore from astropy.coordinates import SkyCoord from astropy.cosmology import Cosmology - from astropy.table import Table from opencosmo.column.column import ColumnMask, ConstructedColumn + from opencosmo.dataset import Dataset from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget @@ -47,185 +45,6 @@ from opencosmo.spatial import Region -def get_redshift_range(datasets: Sequence[Dataset | Lightcone]): - redshift_ranges = list(map(get_single_redshift_range, datasets)) - min_z = min(rr[0] for rr in redshift_ranges) - max_z = max(rr[1] for rr in redshift_ranges) - - return (min_z, max_z) - - -def get_single_redshift_range(dataset: Dataset | Lightcone): - if isinstance(dataset, Lightcone): - return dataset.z_range - redshift_range = dataset.header.lightcone["z_range"] - if redshift_range is not None: - return redshift_range - step_zs = dataset.header.simulation["step_zs"] - step = dataset.header.file.step - assert step is not None - min_redshift = step_zs[step] - max_redshift = step_zs[step - 1] - return (min_redshift, max_redshift) - - -def is_in_range(dataset: Dataset, z_low: float, z_high: float): - z_range = dataset.header.lightcone["z_range"] - if z_range is None: - z_range = get_single_redshift_range(dataset) - if z_high < z_range[0] or z_low > z_range[1]: - return False - return True - - -def sort_table(table: Table, column: str, invert: bool): - column_data = table[column] - if invert: - column_data = -column_data - indices = np.argsort(column_data) - for name in table.columns: - table[name] = table[name][indices] - return table - - -def take_from_sorted( - lightcone: "Lightcone", sort_by: str, invert: bool, n: int, at: str | int -): - column = np.concatenate( - [ds.select(sort_by).get_data("numpy") for ds in lightcone.values()] - ) - if invert: - column = -column - sort_index = np.argsort(column) - if at == "start": - sort_index = sort_index[:n] - elif at == "end": - sort_index = sort_index[-n:] - elif isinstance(at, int): - if at + n > len(sort_index) or at < 0: - raise ValueError( - "Requested a range that is outside the size of this dataset!" - ) - sort_index = sort_index[at : at + n] - - sorted_indices = np.sort(sort_index) - return sorted_indices - - -def order_by_redshift_range(datasets: dict[str, Dataset]): - redshift_ranges = { - key: get_single_redshift_range(ds) for key, ds in datasets.items() - } - sorted_ranges = sorted(redshift_ranges.items(), key=lambda item: item[1][0]) - output = OrderedDict() - for name, _ in sorted_ranges: - output[name] = datasets[name] - return output - - -def combine_adjacent_datasets_mpi( - ordered_datasets: dict[str, dict[str, Dataset]], - min_dataset_size, -): - MIN_DATASET_SIZE = 100_000 - comm = get_comm_world() - MPI = get_mpi() - all_dataset_steps = get_all_keys(ordered_datasets, comm) - assert comm is not None and MPI is not None - rs = 0 - output_datasets: dict[str, list[dict[str, Dataset]]] = OrderedDict() - for step in all_dataset_steps: - if rs == 0: - current_key = step - output_datasets[current_key] = [] - - if step not in ordered_datasets: - rs += comm.allreduce(0, MPI.SUM) - else: - length = sum(len(ds) for ds in ordered_datasets[step].values()) - rs += comm.allreduce(length) - output_datasets[current_key].append(ordered_datasets[step]) - - if rs > MIN_DATASET_SIZE: - rs = 0 - - output = OrderedDict() - for step, datasets in output_datasets.items(): - step_output = defaultdict(list) - for ds_group in datasets: - for ds_type, ds in ds_group.items(): - step_output[ds_type].append(ds) - output[step] = step_output - - return output - - -def combine_adjacent_datasets( - ordered_datasets: dict[str, Dataset] | dict[str, dict[str, Dataset]], - min_dataset_size=100_000, -): - is_single = isinstance(next(iter(ordered_datasets.values())), Dataset) - datasets: dict[str, dict[str, Dataset]] - if is_single: - assert all(isinstance(ds, Dataset) for ds in ordered_datasets.values()) - datasets = {key: {"data": ds} for key, ds in ordered_datasets.items()} # type: ignore - else: - assert all(isinstance(ds, dict) for ds in ordered_datasets.values()) - datasets = ordered_datasets # type: ignore - - if get_comm_world() is not None: - return combine_adjacent_datasets_mpi(datasets, min_dataset_size) - - running_sum = 0 - - current_key = next(iter(ordered_datasets.keys())) - output_datasets: dict[str, list[dict[str, Dataset]]] = OrderedDict( - {current_key: []} - ) - - for key, step_datasets in datasets.items(): - if running_sum < min_dataset_size: - running_sum += sum(len(ds) for ds in step_datasets.values()) - output_datasets[current_key].append(step_datasets) - continue - current_key = key - output_datasets[current_key] = [step_datasets] - running_sum = sum(len(ds) for ds in step_datasets.values()) - - # We have list of dicts, go to dict of lists - output = OrderedDict() - for step, step_datasets_ in output_datasets.items(): - step_output = defaultdict(list) - for ds_group in step_datasets_: - for ds_type, ds in ds_group.items(): - step_output[ds_type].append(ds) - output[step] = step_output - - return output - - -def with_redshift_column(dataset: Dataset): - """ - Ensures a column exists called "redshift" which contains the redshift of the objects - in the lightcone. - """ - if "redshift" in dataset.columns: - return dataset - - elif "fof_halo_center_a" in dataset.columns: - z_col = 1 / oc.col("fof_halo_center_a") - 1 - return dataset.with_new_columns(redshift=z_col) - elif "redshift_true" in dataset.columns: - z_col = oc.col("redshift_true") - return dataset.with_new_columns(redshift=z_col) - elif "zp" in dataset.columns: - z_col = oc.col("zp") - return dataset.with_new_columns(redshift=z_col) - raise ValueError( - "Unable to find a redshift or scale factor column for this lightcone dataset" - ) - - class Lightcone(dict): """ A lightcone contains two or more datasets that are part of a lightcone. Typically @@ -248,15 +67,11 @@ def __init__( hidden: Optional[set[str]] = None, ordered_by: Optional[tuple[str, bool]] = None, ): - datasets = { - k: with_redshift_column(ds) if isinstance(ds, Dataset) else ds - for k, ds in datasets.items() - } self.update(datasets) z_range = ( z_range if z_range is not None - else get_redshift_range(list(datasets.values())) + else lcutils.get_redshift_range(list(datasets.values())) ) columns: set[str] = reduce( @@ -574,7 +389,7 @@ def with_redshift_range(self, z_low: float, z_high: float): raise ValueError("Low and high values of the redshift range are the same!") new_datasets = {} for key, dataset in self.items(): - if not is_in_range(dataset, z_low, z_high): + if not lcutils.is_in_range(dataset, z_low, z_high): continue new_dataset = dataset.filter( oc.col("redshift") > z_low, oc.col("redshift") < z_high @@ -625,11 +440,11 @@ def __map_attribute(self, attribute): return {k: getattr(v, attribute) for k, v in self.items()} def make_schema(self, name: str = "", _min_size=100_000) -> Schema: - datasets = order_by_redshift_range(self) + datasets = lcio.order_by_redshift_range(self) for key in datasets: if isinstance(datasets[key], Lightcone): datasets[key] = dict(datasets[key]) - output_datasets = combine_adjacent_datasets( + output_datasets = lcio.combine_adjacent_datasets( datasets, min_dataset_size=_min_size ) children = {} @@ -640,7 +455,7 @@ def make_schema(self, name: str = "", _min_size=100_000) -> Schema: continue all_datasets = list(chain(*tuple(lst for lst in datasets.values()))) - header_zrange = get_redshift_range(all_datasets) + header_zrange = lcutils.get_redshift_range(all_datasets) my_zrange = self.z_range zrange = ( max(header_zrange[0], my_zrange[0]), @@ -1048,7 +863,7 @@ def take(self, n: int, at: str = "random") -> "Lightcone": return self.__take_rows(indices) elif self.__ordered_by is not None: - index = take_from_sorted(self, *self.__ordered_by, n=n, at=at) + index = lcutils.take_from_sorted(self, *self.__ordered_by, n=n, at=at) return self.__take_rows(index) elif at == "start": return self.take_range(0, n) @@ -1090,7 +905,9 @@ def take_range(self, start: int, end: int): raise ValueError("Got row indices that are out of range!") if self.__ordered_by is not None: - indices = take_from_sorted(self, *self.__ordered_by, end - start, at=start) + indices = lcutils.take_from_sorted( + self, *self.__ordered_by, end - start, at=start + ) return self.__take_rows(indices) ends = np.cumsum(np.fromiter((len(ds) for ds in self.values()), dtype=int)) diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py new file mode 100644 index 00000000..cf1bb623 --- /dev/null +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import opencosmo as oc +from opencosmo.plugins import PluginSpec, PluginType, register_plugin + +if TYPE_CHECKING: + from opencosmo import Lightcone + + +def with_redshift_column(dataset: Lightcone): + """ + Ensures a column exists called "redshift" which contains the redshift of the objects + in the lightcone. + """ + if "redshift" in dataset.columns: + return dataset + + elif "fof_halo_center_a" in dataset.columns: + z_col = 1 / oc.col("fof_halo_center_a") - 1 + return dataset.with_new_columns(redshift=z_col) + elif "redshift_true" in dataset.columns: + z_col = oc.col("redshift_true") + return dataset.with_new_columns(redshift=z_col) + elif "zp" in dataset.columns: + z_col = oc.col("zp") + return dataset.with_new_columns(redshift=z_col) + raise ValueError( + "Unable to find a redshift or scale factor column for this lightcone dataset" + ) + + +register_plugin( + PluginSpec( + PluginType.LightconeLoad, + lambda _: True, + with_redshift_column, + ) +) + +plugins = None diff --git a/python/opencosmo/collection/lightcone/utils.py b/python/opencosmo/collection/lightcone/utils.py new file mode 100644 index 00000000..c5799fa8 --- /dev/null +++ b/python/opencosmo/collection/lightcone/utils.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +import numpy as np + +from opencosmo.collection.lightcone import lightcone as oclc + +if TYPE_CHECKING: + from astropy.table import Table + + from opencosmo import Dataset + + +def get_redshift_range(datasets: Sequence[Dataset | oclc.Lightcone]): + redshift_ranges = list(map(get_single_redshift_range, datasets)) + min_z = min(rr[0] for rr in redshift_ranges) + max_z = max(rr[1] for rr in redshift_ranges) + + return (min_z, max_z) + + +def get_single_redshift_range(dataset: Dataset | oclc.Lightcone): + if isinstance(dataset, oclc.Lightcone): + return dataset.z_range + redshift_range = dataset.header.lightcone["z_range"] + if redshift_range is not None: + return redshift_range + step_zs = dataset.header.simulation["step_zs"] + step = dataset.header.file.step + assert step is not None + min_redshift = step_zs[step] + max_redshift = step_zs[step - 1] + return (min_redshift, max_redshift) + + +def is_in_range(dataset: Dataset, z_low: float, z_high: float): + z_range = dataset.header.lightcone["z_range"] + if z_range is None: + z_range = get_single_redshift_range(dataset) + if z_high < z_range[0] or z_low > z_range[1]: + return False + return True + + +def sort_table(table: Table, column: str, invert: bool): + column_data = table[column] + if invert: + column_data = -column_data + indices = np.argsort(column_data) + for name in table.columns: + table[name] = table[name][indices] + return table + + +def take_from_sorted( + lightcone: "oclc.Lightcone", sort_by: str, invert: bool, n: int, at: str | int +): + column = np.concatenate( + [ds.select(sort_by).get_data("numpy") for ds in lightcone.values()] + ) + if invert: + column = -column + sort_index = np.argsort(column) + if at == "start": + sort_index = sort_index[:n] + elif at == "end": + sort_index = sort_index[-n:] + elif isinstance(at, int): + if at + n > len(sort_index) or at < 0: + raise ValueError( + "Requested a range that is outside the size of this dataset!" + ) + sort_index = sort_index[at : at + n] + + sorted_indices = np.sort(sort_index) + return sorted_indices diff --git a/python/opencosmo/plugins/__init__.py b/python/opencosmo/plugins/__init__.py index e69de29b..928c4461 100644 --- a/python/opencosmo/plugins/__init__.py +++ b/python/opencosmo/plugins/__init__.py @@ -0,0 +1,3 @@ +from .plugin import PluginSpec, PluginType, register_plugin + +__all__ = ["register_plugin", "PluginSpec", "PluginType"] From 5aa07e24a1d89118be98f204d7ae25c8f429bb0e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 08:55:54 -0500 Subject: [PATCH 062/139] More plugin work --- .../collection/lightcone/healpix_map.py | 6 +++- .../collection/lightcone/lightcone.py | 10 +++++-- .../opencosmo/collection/lightcone/plugins.py | 4 +-- python/opencosmo/dataset/columns.py | 3 +- python/opencosmo/dtypes/dtype.py | 30 +------------------ python/opencosmo/io/iopen.py | 2 +- test/parallel/test_lc_mpi.py | 8 +---- test/test_healpixmap.py | 11 +++++-- 8 files changed, 27 insertions(+), 47 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 4aa4fbbb..0bdd7546 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -384,16 +384,20 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg else: table.remove_columns(self.__hidden) storage = { - name: np.zeros(npix, dtype=np.float32) for name in table.columns + name: np.zeros(npix, dtype=np.float32) + for name in table.columns + if name != "pixel" } for name, arr in storage.items(): arr[pixels] = table[name].value + if len(pixels) != hp.nside2npix(self.nside): mask = np.zeros(hp.nside2npix(self.nside), dtype=bool) mask[pixels] = True storage = { name: np.ma.masked_array(arr, mask) for name, arr in storage.items() } + storage["pixel"] = self.pixels if len(storage) == 1: return next(iter(storage.values())) return storage diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index a0b7e982..e6cee885 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -26,7 +26,6 @@ from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn from opencosmo.dataset.evaluate import build_evaluated_column from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.dtypes.dtype import get_dtype_lightcone_plugins from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema from opencosmo.plugins import plugin @@ -87,7 +86,6 @@ def __init__( self.__hidden = hidden self.__ordered_by = ordered_by - self.__plugins = get_dtype_lightcone_plugins(self.__header, self.columns) def __repr__(self): """ @@ -361,9 +359,15 @@ def open(cls, targets: list[FileTarget], **kwargs): @classmethod def from_datasets( - cls, datasets: dict[str, oc.Dataset], z_range: tuple[float, float] + cls, + datasets: dict[str, oc.Dataset], + z_range: tuple[float, float], + **open_kwargs, ): result = cls(datasets, z_range) + result = plugin.apply_plugins( + plugin.PluginType.LightconeLoad, result, **open_kwargs + ) return make_radec_columns(result) def with_redshift_range(self, z_low: float, z_high: float): diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index cf1bb623..5d2a2bb5 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -9,7 +9,7 @@ from opencosmo import Lightcone -def with_redshift_column(dataset: Lightcone): +def with_redshift_column(dataset: Lightcone, *args, **kwargs): """ Ensures a column exists called "redshift" which contains the redshift of the objects in the lightcone. @@ -34,7 +34,7 @@ def with_redshift_column(dataset: Lightcone): register_plugin( PluginSpec( PluginType.LightconeLoad, - lambda _: True, + lambda *args, **kwargs: True, with_redshift_column, ) ) diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py index 052db7ad..8d287680 100644 --- a/python/opencosmo/dataset/columns.py +++ b/python/opencosmo/dataset/columns.py @@ -114,8 +114,7 @@ def add_columns( ) -> tuple[list[ConstructedColumn], ColumnMap, UnitHandler]: if ( inter := set(name_to_uuid.keys()).intersection(new_columns.keys()) - and not allow_overwrite - ): + ) and not allow_overwrite: raise ValueError(f"Some columns are already in the dataset: {inter}") ( diff --git a/python/opencosmo/dtypes/dtype.py b/python/opencosmo/dtypes/dtype.py index 5e2b4420..ee2dc125 100644 --- a/python/opencosmo/dtypes/dtype.py +++ b/python/opencosmo/dtypes/dtype.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from opencosmo.dtypes import diffsky, hacc, lightcone +from opencosmo.dtypes import hacc, lightcone if TYPE_CHECKING: from pydantic import BaseModel @@ -24,31 +24,3 @@ def get_dtype_parameters( required_dtype_params["lightcone"] = lightcone_parameters dtype_parameters["required"] = required_dtype_params return dtype_parameters - - -def get_dtype_column_plugins( - header, - producers, - columns, -): - plugins = __get_column_plugins(header) - for name, producer in plugins.items(): - if name not in columns: - continue - producer = producer.bind(columns) - producers.append(producer) - columns[name] = producer.uuid - - return producers, columns - - -def __get_column_plugins(header): - if header.file.data_type == "synthetic_galaxies": - return {"top_host_idx": diffsky.top_host_idx} - return {} - - -def get_dtype_lightcone_plugins(header, columns): - if header.file.data_type == "synthetic_galaxies" and "top_host_idx" in columns: - return [diffsky.offset_top_host_idx] - return [] diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index af94a17f..870c1a45 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -550,7 +550,7 @@ def open_single_dataset( return __open_healpix_map(dataset, sim_region) elif header.file.is_lightcone and not bypass_lightcone: return occ.Lightcone.from_datasets( - {"data": dataset}, header.lightcone["z_range"] + {"data": dataset}, header.lightcone["z_range"], **open_kwargs ) return apply_plugins(PluginType.DatasetOpen, dataset, **open_kwargs) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 2bc3ba3f..9a52369c 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -125,6 +125,7 @@ def test_healpix_index_chain_failure(haloproperties_600_path): def test_healpix_write(haloproperties_600_path, per_test_dir): comm = get_comm_world() ds = oc.open(haloproperties_600_path) + assert "redshift" in ds.columns pixel = np.random.choice(ds.region.pixels) center = pix2ang(ds.region.nside, pixel, True, True) @@ -418,7 +419,6 @@ def test_write_some_missing(core_path_487, core_path_475, per_test_dir): original_data = ds.select("early_index").get_data("numpy") original_data_length = comm.allgather(len(original_data)) - ds = ds.with_new_columns(gal_id=np.arange(len(ds))) oc.write(per_test_dir / "lightcone.hdf5", ds) ds = oc.open(per_test_dir / "lightcone.hdf5", synth_cores=True) written_data = ds.select("early_index").get_data("numpy") @@ -443,11 +443,6 @@ def test_write_diffsky_some_missing_no_stack( ds.pop(475) assert len(ds.keys()) == 1 - all_lengths = comm.allgather(len(ds)) - all_ends = np.insert(np.cumsum(all_lengths), 0, 0) - rank = comm.Get_rank() - ds = ds.with_new_columns(gal_id=np.arange(all_ends[rank], all_ends[rank + 1])) - columns_to_check = comm.bcast(np.random.choice(ds.columns, 10, replace=False)) columns_to_check = np.insert(columns_to_check, 0, "gal_id") @@ -573,5 +568,4 @@ def _assert_top_host_idx_correct(data, core_map): for key, val in core_map.items() if val in data["core_tag"] and key in real_core_tag } - print(len(should_have_core_map)) assert should_have_core_map == found_core_map diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index 2e03743d..4a51239f 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -2,11 +2,12 @@ import healpy as hp import healsparse as hsp import numpy as np -import opencosmo as oc import pytest from astropy.coordinates import SkyCoord from opencosmo.spatial.healpix import HealpixRegion +import opencosmo as oc + @pytest.fixture def healpix_map_path(map_path): @@ -222,7 +223,13 @@ def test_healpix_write_after_downgrade(healpix_map_path, tmp_path): oc.write(tmp_path / "map_test.hdf5", ds) new_ds = oc.open(tmp_path / "map_test.hdf5") - print(new_ds) + + original_data = ds.get_data("healpix") + written_data = new_ds.get_data("healpix") + + assert np.all(original_data["ksz"] == written_data["ksz"]) + assert np.all(original_data["tsz"] == written_data["tsz"]) + assert np.all(original_data["pixel"] == written_data["pixel"]) def test_healpix_write_after_take_range(healpix_map_path, tmp_path): From f14fd06dfabb1d6f534f8c0fd727a946c82ee1e2 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 09:04:51 -0500 Subject: [PATCH 063/139] Small reworks to guarantee plugin consistency --- python/opencosmo/collection/lightcone/lightcone.py | 4 ++-- python/opencosmo/collection/lightcone/plugins.py | 2 +- python/opencosmo/dataset/state.py | 8 +------- python/opencosmo/dtypes/diffsky.py | 2 +- python/opencosmo/plugins/plugin.py | 2 +- test/test_diffsky.py | 1 + 6 files changed, 7 insertions(+), 12 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index e6cee885..4a318f22 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -353,7 +353,7 @@ def open(cls, targets: list[FileTarget], **kwargs): raise ValueError() result = cls(output) - result = plugin.apply_plugins(plugin.PluginType.LightconeLoad, result, **kwargs) + result = plugin.apply_plugins(plugin.PluginType.LightconeOpen, result, **kwargs) return make_radec_columns(result) @@ -366,7 +366,7 @@ def from_datasets( ): result = cls(datasets, z_range) result = plugin.apply_plugins( - plugin.PluginType.LightconeLoad, result, **open_kwargs + plugin.PluginType.LightconeOpen, result, **open_kwargs ) return make_radec_columns(result) diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 5d2a2bb5..2aabedff 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -33,7 +33,7 @@ def with_redshift_column(dataset: Lightcone, *args, **kwargs): register_plugin( PluginSpec( - PluginType.LightconeLoad, + PluginType.LightconeOpen, lambda *args, **kwargs: True, with_redshift_column, ) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 35e48bfc..53c707e2 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -456,13 +456,7 @@ def take(self, n: int, at: str): else: take_index = np.sort(sorted[row_indices]) - new_handler = self.__raw_data_handler.take(take_index) - new_cache = self.__cache.take(take_index) - - return self.__rebuild( - raw_data_handler=new_handler, - cache=new_cache, - ) + return self.take_rows(take_index) def take_range(self, start: int, end: int): """ diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 221e82a3..a4b35848 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -167,7 +167,7 @@ def register_keep_top_host_idx_verifier( register_plugin( PluginSpec( - PluginType.LightconeLoad, + PluginType.LightconeOpen, register_keep_top_host_idx_verifier, register_keep_top_host_idx, ) diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index 60734d5a..3fdd4ee2 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -19,7 +19,7 @@ class PluginType(StrEnum): DatasetOpen = "dataset_open" DatasetInstantiate = "dataset_instantiate" - LightconeLoad = "lightcone_load" + LightconeOpen = "lightcone_open" LightconeInstantiate = "lightcone_instantiate" IndexUpdate = "index_update" diff --git a/test/test_diffsky.py b/test/test_diffsky.py index e41b5db0..8a96384c 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -409,6 +409,7 @@ def test_keep_top_host_take_random(core_path_475, core_path_487): assert_top_host_idx_correct(data, core_map) assert np.all(data["top_host_idx"] >= 0) assert_all_group_members_present(data, core_map) + print(len(ds)) def test_keep_top_host_take_start(core_path_475, core_path_487): From e6eb426676081c7a5f8bd2fdf94798176afd4197 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 11:02:36 -0500 Subject: [PATCH 064/139] Unify 'take' operations on DatasetState to ensure consistent plugin application --- python/opencosmo/dataset/state.py | 23 +++++------------------ python/opencosmo/handler/hdf5.py | 12 +++++++----- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 53c707e2..d74f1c60 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -440,8 +440,6 @@ def take(self, n: int, at: str): Take rows from the dataset. """ - take_index: DataIndex - if at == "start": return self.take_range(0, n) elif at == "end": @@ -449,14 +447,7 @@ def take(self, n: int, at: str): elif at == "random": row_indices = np.random.choice(len(self), n, replace=False) row_indices.sort() - - sorted = self.get_sorted_index() - if sorted is None: - take_index = row_indices - else: - take_index = np.sort(sorted[row_indices]) - - return self.take_rows(take_index) + return self.take_rows(row_indices) def take_range(self, start: int, end: int): """ @@ -469,13 +460,7 @@ def take_range(self, start: int, end: int): if end > len(self): raise ValueError("end must be less than the length of the dataset.") - sorted = self.get_sorted_index() - take_index: DataIndex - if sorted is None: - take_index = single_chunk(start, end - start) - else: - take_index = np.sort(sorted[start:end]) - + take_index = single_chunk(start, end - start) return self.take_rows(take_index) def take_rows(self, rows: DataIndex): @@ -489,7 +474,9 @@ def take_rows(self, rows: DataIndex): "Row indices must be between 0 and the length of this dataset!" ) sorted = self.get_sorted_index() - new_handler = self.__raw_data_handler.take(rows, sorted) + if sorted is not None: + rows = np.sort(sorted[into_array(rows)]) + new_handler = self.__raw_data_handler.take(rows) new_cache = self.__cache.take(rows) return self.__rebuild( diff --git a/python/opencosmo/handler/hdf5.py b/python/opencosmo/handler/hdf5.py index bd4f44ee..8e63c19e 100644 --- a/python/opencosmo/handler/hdf5.py +++ b/python/opencosmo/handler/hdf5.py @@ -5,6 +5,11 @@ from typing import TYPE_CHECKING, Iterable, Optional import numpy as np +from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.io.writer import ( + ColumnWriter, +) + from opencosmo.index import ( SimpleIndex, from_size, @@ -13,17 +18,14 @@ into_array, take, ) -from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.io.writer import ( - ColumnWriter, -) if TYPE_CHECKING: import h5py from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex from opencosmo.io.schema import Schema + from opencosmo.index import DataIndex + class Hdf5Handler: """ From 74744b474ccc1cc6503e95c2a9b238927424a82a Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 13:07:09 -0500 Subject: [PATCH 065/139] Diffsky plugins now work with sorting --- .../collection/lightcone/lightcone.py | 4 +- python/opencosmo/dataset/instantiate.py | 21 +++------- python/opencosmo/dataset/state.py | 30 ++++++++++++-- python/opencosmo/dtypes/diffsky.py | 15 +++++++ python/opencosmo/plugins/plugin.py | 40 ++++++++++++++++++- test/test_diffsky.py | 21 ++++++++++ 6 files changed, 110 insertions(+), 21 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 4a318f22..b5393f60 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -295,7 +295,9 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): table = vstack(data_with_length, join_type="exact") if self.__ordered_by is not None: - table.sort(self.__ordered_by[0], reverse=self.__ordered_by[1]) + order = table.argsort(self.__ordered_by[0], reverse=self.__ordered_by[1]) + table = table[order] + table = plugin.apply_post_sort_plugins(self, table, np.argsort(order)) to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 9335c57d..2b4db718 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Any -import numpy as np import rustworkx as rx from opencosmo.column.column import RawColumn @@ -11,6 +10,8 @@ if TYPE_CHECKING: from uuid import UUID + import numpy as np + from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache, DataHandler from opencosmo.index import DataIndex @@ -160,12 +161,12 @@ def instantiate_dataset( unit_handler: UnitHandler, unit_kwargs: dict[str, Any], metadata_columns: list[str] | None = None, - sort_by: tuple[str, bool] | None = None, + sort_by: str | None = None, ): # Extend working_columns with the sort column if it isn't already included. working_columns = dict(columns_to_uuid) - if sort_by is not None and sort_by[0] not in working_columns: - sort_name = sort_by[0] + if sort_by is not None and sort_by not in working_columns: + sort_name = sort_by for producer in column_producers: if sort_name in producer.produces: working_columns[sort_name] = producer.uuid @@ -216,7 +217,7 @@ def instantiate_dataset( if producer_uuid in uuid_data and name in uuid_data[producer_uuid] } data |= get_metadata_columns(raw_data_handler, cache, metadata_columns) - return sort_data(data, sort_by) + return data def get_metadata_columns( @@ -232,13 +233,3 @@ def get_metadata_columns( raw_data_handler.get_metadata(additional_metadata_columns_to_fetch) or {} ) return metadata - - -def sort_data(data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None): - if sort_by is None: - return data - sort_column = data[sort_by[0]] - order = np.argsort(sort_column) - if sort_by[1]: - order = order[::-1] - return {key: value[order] for key, value in data.items()} diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index d74f1c60..f46dd2fa 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -18,7 +18,12 @@ from opencosmo.index.build import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range -from opencosmo.plugins.plugin import PluginType, apply_index_plugins, apply_plugins +from opencosmo.plugins.plugin import ( + PluginType, + apply_index_plugins, + apply_plugins, + apply_post_sort_plugins, +) from opencosmo.units import UnitConvention from opencosmo.units.handler import ( make_unit_handler_from_hdf5, @@ -45,6 +50,22 @@ def deregister_state(id: int, cache: DataCache): cache.deregister_column_group(id) +def sort_data( + data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None, state: DatasetState +): + if sort_by is None: + return data + sort_column = data[sort_by[0]] + order = np.argsort(sort_column) + if sort_by[1]: + order = order[::-1] + + data = {key: value[order] for key, value in data.items()} + if sort_by[0] not in state.columns: + data.pop(sort_by[0]) + return apply_post_sort_plugins(state, data, np.argsort(order)) + + class DatasetState: """ Holds mutable state required by the dataset. Cleans up the dataset to mostly focus @@ -250,7 +271,6 @@ def get_data( Get the data for a given handler. """ state = apply_plugins(PluginType.DatasetInstantiate, self) - data = instantiate_dataset( list(state.__producers.values()), state.__columns, @@ -259,13 +279,17 @@ def get_data( state.__unit_handler, unit_kwargs, metadata_columns, - None if ignore_sort else state.__sort_by, + None if (ignore_sort or state.__sort_by is None) else state.__sort_by[0], ) + if missing := set(self.columns).difference(data.keys()): raise RuntimeError( f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" ) + if not ignore_sort: + data = sort_data(data, self.__sort_by, self) + new_order = [c for c in self.columns] if metadata_columns: new_order.extend(metadata_columns) diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index a4b35848..16524af4 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -13,10 +13,13 @@ IndexPluginSpec, PluginSpec, PluginType, + PostSortPluginSpec, register_plugin, ) if TYPE_CHECKING: + from astropy.table import Table + from opencosmo import Dataset, Lightcone from opencosmo.dataset.state import DatasetState from opencosmo.index import DataIndex @@ -138,6 +141,12 @@ def keep_top_host_idx_verifier(dataset: DatasetState): return "top_host_idx" in dataset.columns +def update_top_host_idx_after_sort(data: Table, reverse_index: DataIndex): + mask = data["top_host_idx"] > 0 + data["top_host_idx"][mask] = reverse_index[data["top_host_idx"][mask]] + return data + + def register_keep_top_host_idx(dataset: Dataset, **kwargs): register_plugin( IndexPluginSpec( @@ -172,3 +181,9 @@ def register_keep_top_host_idx_verifier( register_keep_top_host_idx, ) ) + +register_plugin( + PostSortPluginSpec( + PluginType.PostSort, top_host_idx_verifier, update_top_host_idx_after_sort + ) +) diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index 3fdd4ee2..4dc1702b 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -11,9 +11,15 @@ TypedDict, ) +from opencosmo.index import into_array + if TYPE_CHECKING: + import numpy as np + from astropy.table import Table + + from opencosmo import Lightcone from opencosmo.dataset.state import DatasetState - from opencosmo.index import DataIndex + from opencosmo.index import DataIndex, IndexArray class PluginType(StrEnum): @@ -21,6 +27,7 @@ class PluginType(StrEnum): DatasetInstantiate = "dataset_instantiate" LightconeOpen = "lightcone_open" LightconeInstantiate = "lightcone_instantiate" + PostSort = "post_sort" IndexUpdate = "index_update" @@ -40,6 +47,12 @@ class IndexPluginSpec(NamedTuple): plugin: Callable[[DatasetState, DataIndex], DataIndex] +class PostSortPluginSpec[T: (DatasetState, Lightcone)](NamedTuple): + plugin_type: PluginType + verifier: Callable[[T], bool] + plugin: Callable[[Table, IndexArray], dict[str, np.ndarray]] + + class Plugins(TypedDict): dataset_open: list[PluginSpec] dataset_instantiate: list[PluginSpec] @@ -51,7 +64,7 @@ class Plugins(TypedDict): KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore -def register_plugin(spec: PluginSpec | IndexPluginSpec) -> None: +def register_plugin(spec: PluginSpec | IndexPluginSpec | PostSortPluginSpec) -> None: KNOWN_PLUGINS[str(spec.plugin_type)].append(spec) # type: ignore @@ -82,6 +95,29 @@ def apply_index_plugins(state: DatasetState, index: DataIndex) -> DataIndex: ) +def apply_post_sort_plugins[T: (DatasetState, Lightcone)]( + state: T, + data: Table, + index: DataIndex, +) -> T: + plugins_to_apply: list[PostSortPluginSpec] = KNOWN_PLUGINS[str(PluginType.PostSort)] # type: ignore + index_arr = into_array(index) + + return reduce( + lambda data_, spec: _apply_single_post_sort(spec, state, data_, index_arr), + plugins_to_apply, + data, + ) + + +def _apply_single_post_sort[T: (DatasetState, Lightcone)]( + spec: PostSortPluginSpec, state: T, data: Table, index: IndexArray +): + if spec.verifier(state): + return spec.plugin(data, index) + return data + + def _apply_single[T](spec: PluginSpec[T], target: T, **kwargs: Any) -> T: if spec.verifier(target, **kwargs): return spec.plugin(target, **kwargs) diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 8a96384c..4e6261cb 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -381,6 +381,15 @@ def test_reindex_top_host_take_random(core_path_475, core_path_487): assert_top_host_idx_correct(data, core_map) +def test_reindex_top_host_take_range_sorted(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).sort_by("logsm_obs").take(300) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + def test_reindex_top_host_take_range(core_path_475, core_path_487): core_map = get_expected_core_tags(core_path_475) core_map |= get_expected_core_tags(core_path_487) @@ -436,6 +445,18 @@ def test_keep_top_host_take_range(core_path_475, core_path_487): assert_all_group_members_present(data, core_map) +def test_keep_top_host_take_range_after_sort(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True).sort_by("logsm_obs") + ds = ds.take_range(20, 60) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + def test_keep_top_host_filter(core_path_475, core_path_487): core_map = get_expected_core_tags(core_path_475) core_map |= get_expected_core_tags(core_path_487) From 550f303d3c67b96d1cd8527f68528f5191bc1ed0 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 13:22:47 -0500 Subject: [PATCH 066/139] Changelog --- changes/+21d9d2f3.improvement.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changes/+21d9d2f3.improvement.rst diff --git a/changes/+21d9d2f3.improvement.rst b/changes/+21d9d2f3.improvement.rst new file mode 100644 index 00000000..c245a55f --- /dev/null +++ b/changes/+21d9d2f3.improvement.rst @@ -0,0 +1,3 @@ +The OpenCosmo internals now include a basic plugin architecture, which allows for specific data types to introduce +modified behavior at particular points. Currently being used to dynamically re-build the `top_host_idx` column in +diffsky data. From 715763f1373eb334f8c5efd029cef864c160b714 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 13:34:37 -0500 Subject: [PATCH 067/139] Small bugfixes --- changes/+018f30cb.bugfix.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changes/+018f30cb.bugfix.rst diff --git a/changes/+018f30cb.bugfix.rst b/changes/+018f30cb.bugfix.rst new file mode 100644 index 00000000..b236b155 --- /dev/null +++ b/changes/+018f30cb.bugfix.rst @@ -0,0 +1,2 @@ +HealpixMap now correctly unpacks data when there is only a single map, instead of returning a dictionary. Mirrors +behavior in Dataset etc. From 42a8d252500401d93ffad6fc33392914348dd78e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 13:34:46 -0500 Subject: [PATCH 068/139] Small bugfixes --- .../collection/lightcone/healpix_map.py | 3 ++- test/test_healpixmap.py | 22 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py index 0bdd7546..90157547 100644 --- a/python/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -60,6 +60,8 @@ def make_healsparse_maps( nside_sparse=nside, sentinel=sentinel, ) + if len(result) == 1: + return next(iter(result.values())) return result @@ -397,7 +399,6 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg storage = { name: np.ma.masked_array(arr, mask) for name, arr in storage.items() } - storage["pixel"] = self.pixels if len(storage) == 1: return next(iter(storage.values())) return storage diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index 4a51239f..2ad77bb9 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -4,6 +4,7 @@ import numpy as np import pytest from astropy.coordinates import SkyCoord +from healsparse import HealSparseMap from opencosmo.spatial.healpix import HealpixRegion import opencosmo as oc @@ -67,7 +68,7 @@ def test_healpix_downgrade(healpix_map_path): original_data = ds.get_data("healpix") downgraded_data = downgraded_ds.get_data("healpix") - downgraded_data_2 = original_data["tsz"].reshape((-1, 4)).sum(axis=1) / 4 + downgraded_data_2 = original_data.reshape((-1, 4)).sum(axis=1) / 4 center = (0 * u.deg, 0 * u.deg) radius = 1 * u.deg @@ -81,23 +82,21 @@ def test_healpix_downgrade(healpix_map_path): hp.query_disc(downgraded_nside, [1, 0, 0], 1 * (np.pi / 180.0)) ) - assert len(original_data["tsz"]) == original_npix - assert len(downgraded_data["tsz"]) == downgraded_npix + assert len(original_data) == original_npix + assert len(downgraded_data) == downgraded_npix - assert len(data_region_original["tsz"].valid_pixels) == npix_region - assert len(data_region_downgraded["tsz"].valid_pixels) == npix_region_downgraded + assert len(data_region_original.valid_pixels) == npix_region + assert len(data_region_downgraded.valid_pixels) == npix_region_downgraded assert np.all( np.isclose( downgraded_data_2, - downgraded_data["tsz"], + downgraded_data, atol=1.0e-13, ) ) - assert np.isclose( - np.mean(original_data["tsz"]), np.mean(downgraded_data["tsz"]), atol=1.0e-13 - ) + assert np.isclose(np.mean(original_data), np.mean(downgraded_data), atol=1.0e-13) def test_healpix_downgrade_doesnt_have_file_handle(healpix_map_path): @@ -229,7 +228,7 @@ def test_healpix_write_after_downgrade(healpix_map_path, tmp_path): assert np.all(original_data["ksz"] == written_data["ksz"]) assert np.all(original_data["tsz"] == written_data["tsz"]) - assert np.all(original_data["pixel"] == written_data["pixel"]) + assert np.all(ds.pixels == new_ds.pixels) def test_healpix_write_after_take_range(healpix_map_path, tmp_path): @@ -264,9 +263,8 @@ def test_healpix_collection_drop(healpix_map_path): to_drop = set(["tsz"]) ds = ds.drop(to_drop) - columns_found = set(ds.get_data().keys()) - assert not columns_found.intersection(to_drop) + assert isinstance(ds.get_data(), HealSparseMap) def test_healpix_collection_take(healpix_map_path): From a9762a348a7b6ed8fd79cfb8fc37d87a3360672e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 13:48:41 -0500 Subject: [PATCH 069/139] Fixes for healpixmap tests --- test/test_healpixmap.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index 2ad77bb9..97bf215f 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -273,10 +273,10 @@ def test_healpix_collection_take(healpix_map_path): ds_start = ds.take(n_to_take, "start") ds_end = ds.take(n_to_take, "end") ds_random = ds.take(n_to_take, "random") - tags = ds.select("tsz").get_data()["tsz"].valid_pixels - tags_start = ds_start.select("tsz").get_data()["tsz"].valid_pixels - tags_end = ds_end.select("tsz").get_data()["tsz"].valid_pixels - tags_random = ds_random.select("tsz").get_data()["tsz"].valid_pixels + tags = ds.select("tsz").get_data().valid_pixels + tags_start = ds_start.select("tsz").get_data().valid_pixels + tags_end = ds_end.select("tsz").get_data().valid_pixels + tags_random = ds_random.select("tsz").get_data().valid_pixels assert np.all(tags[:n_to_take] == tags_start) assert np.all(tags[-n_to_take:] == tags_end) assert len(tags_random) == n_to_take and len(set(tags_random)) == len(tags_random) @@ -288,8 +288,8 @@ def test_healpix_collection_range(healpix_map_path): end = int(0.75 * len(ds)) ds_range = ds.take_range(start, end) - halo_tags = ds.select("tsz").get_data("healsparse")["tsz"].valid_pixels[start:end] - range_halo_tags = ds_range.select("tsz").get_data("healsparse")["tsz"].valid_pixels + halo_tags = ds.select("tsz").get_data("healsparse").valid_pixels[start:end] + range_halo_tags = ds_range.select("tsz").get_data("healsparse").valid_pixels assert np.all(halo_tags == range_halo_tags) @@ -337,8 +337,14 @@ def test_healpix_collection_derive(healpix_map_path): ds = oc.open(healpix_map_path) sz_sqrd = oc.col("tsz") ** 2 + oc.col("ksz") ** 2 ds = ds.with_new_columns(weird_sz=sz_sqrd) - weird = ds.select("weird_sz").get_data() - assert isinstance(weird, dict) + data = ds.get_data() + assert isinstance(data, dict) + assert np.all(data["weird_sz"].valid_pixels == data["tsz"].valid_pixels) + + tsz = data["tsz"][data["tsz"].valid_pixels] + ksz = data["ksz"][data["ksz"].valid_pixels] + weird_sz = data["weird_sz"][data["weird_sz"].valid_pixels] + assert np.all(weird_sz == tsz**2 + ksz**2) def test_healpix_collection_add(healpix_map_path): @@ -346,7 +352,7 @@ def test_healpix_collection_add(healpix_map_path): map_data = ds.get_data("healsparse")["tsz"].valid_pixels data = np.zeros(len(map_data)) ds = ds.with_new_columns(random=data) - stored_data = ds.select("random").get_data("healpix")["random"] + stored_data = ds.select("random").get_data("healpix") assert np.all(stored_data == data) @@ -355,9 +361,7 @@ def test_healpix_collection_add_sparse(healpix_map_path): map_data = ds.get_data("healsparse")["tsz"].valid_pixels data = np.zeros(len(map_data)) ds = ds.with_new_columns(random=data) - stored_data = ( - ds.select("random").get_data("healsparse")["random"].get_values_pix(map_data) - ) + stored_data = ds.select("random").get_data("healsparse").get_values_pix(map_data) assert np.all(stored_data == data) @@ -379,9 +383,9 @@ def offset(tsz, ksz): ds_vec = ds.evaluate(offset, vectorize=True, insert=True) ds_iter = ds.evaluate(offset, insert=True) - offset_vec = ds_vec.select("offset").get_data("healsparse")["offset"] + offset_vec = ds_vec.select("offset").get_data("healsparse") offset_vec = offset_vec.get_values_pix(offset_vec.valid_pixels) - offset_iter = ds_iter.select("offset").get_data("healsparse")["offset"] + offset_iter = ds_iter.select("offset").get_data("healsparse") offset_iter = offset_iter.get_values_pix(offset_iter.valid_pixels) assert np.all(offset_vec == offset_iter) From b829eb1bc693f36def605d9feaabd42c0b10bae7 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 28 Apr 2026 19:28:36 -0500 Subject: [PATCH 070/139] Whatever you do, do not look at this commit --- python/opencosmo/dtypes/diffsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 16524af4..b9446a9d 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -142,7 +142,7 @@ def keep_top_host_idx_verifier(dataset: DatasetState): def update_top_host_idx_after_sort(data: Table, reverse_index: DataIndex): - mask = data["top_host_idx"] > 0 + mask = data["top_host_idx"] >= 0 data["top_host_idx"][mask] = reverse_index[data["top_host_idx"][mask]] return data From 157631bcfdbb9a6a44b4275d01357d73666aedb4 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 11:24:44 -0500 Subject: [PATCH 071/139] Add diffsky partition plugin --- python/opencosmo/dataset/mpi.py | 16 ++++++-- python/opencosmo/dtypes/diffsky.py | 61 ++++++++++++++++++++++++++++++ python/opencosmo/io/iopen.py | 3 +- python/opencosmo/plugins/plugin.py | 47 +++++++++++++++++++---- 4 files changed, 114 insertions(+), 13 deletions(-) diff --git a/python/opencosmo/dataset/mpi.py b/python/opencosmo/dataset/mpi.py index d5dda2f5..c6ebf8e0 100644 --- a/python/opencosmo/dataset/mpi.py +++ b/python/opencosmo/dataset/mpi.py @@ -4,19 +4,22 @@ from warnings import warn from opencosmo.index.build import single_chunk +from opencosmo.plugins.plugin import apply_partition_plugins from opencosmo.spatial.protocols import TreePartition if TYPE_CHECKING: import h5py from mpi4py import MPI + from opencosmo.header import OpenCosmoHeader from opencosmo.spatial.tree import Tree def partition( comm: MPI.Comm, - length: int, - counts: h5py.Group, + header: OpenCosmoHeader, + index_group: h5py.Group, + data_group: h5py.Group, tree: Optional[Tree], min_level: Optional[int] = None, ) -> Optional[TreePartition]: @@ -25,8 +28,14 @@ def partition( spatial index. In principle this means the number of objects are similar between ranks. """ + partition_plugin_result = apply_partition_plugins( + comm, header, index_group, data_group, tree, min_level + ) + if partition_plugin_result is not None: + return partition_plugin_result + if tree is not None: - partitions = tree.partition(comm.Get_size(), counts, min_level) + partitions = tree.partition(comm.Get_size(), index_group, min_level) try: part = partitions[comm.Get_rank()] except IndexError: @@ -37,6 +46,7 @@ def partition( part = None return part + length = len(next(iter(data_group.values()))) nranks = comm.Get_size() rank = comm.Get_rank() if rank == nranks - 1: diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index b9446a9d..6d0c6635 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -9,20 +9,31 @@ from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy from opencosmo.index import into_array from opencosmo.index.ops import reindex_column +from opencosmo.mpi import get_mpi from opencosmo.plugins.plugin import ( IndexPluginSpec, + PartitionPluginSpec, PluginSpec, PluginType, PostSortPluginSpec, register_plugin, ) +from opencosmo.spatial.tree import TreePartition if TYPE_CHECKING: + import h5py from astropy.table import Table + from mpi4py import MPI from opencosmo import Dataset, Lightcone from opencosmo.dataset.state import DatasetState + from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex + from opencosmo.spatial.tree import Tree + + +else: + MPI = get_mpi() class DiffskyVersionInfo(BaseModel): @@ -162,6 +173,50 @@ def register_keep_top_host_idx_verifier( return keep_top_host and top_host_idx_verifier(dataset) +def partition_plugin( + comm: MPI.Comm, + index_group: h5py.Group, + data_group: h5py.Group, + tree: Optional[Tree] = None, + min_level: Optional[int] = None, +): + top_host_idx = data_group["top_host_idx"][:] + if comm.Get_rank() == 0: + n_ranks = comm.Get_size() + unique_hosts = np.unique(top_host_idx, sorted=True) + ave, res = divmod(unique_hosts.size, n_ranks) + counts = np.array( + [ave + 1 if i < res else ave for i in range(n_ranks)], dtype=np.int64 + ) + displs = np.cumsum(np.concatenate(([0], counts[:-1]))) + counts = comm.bcast(counts) + # Should always be roughly balanced, since data is + # ordered spatially + rank_top_hosts = np.zeros(counts[0], dtype=np.int64) + comm.Scatterv([unique_hosts, counts, displs, MPI.INT64_T], rank_top_hosts) + else: + counts = comm.bcast(None) + rank_top_hosts = np.zeros(counts[comm.Get_rank()], dtype=np.int64) + comm.Scatterv([None, counts, None, MPI.INT64_T], rank_top_hosts) + + is_in_groups = np.isin(top_host_idx, rank_top_hosts) + index = np.where(is_in_groups)[0] + return TreePartition(idx=index, region=None, level=None) + + +def partition_plugin_verifier( + header: OpenCosmoHeader, + index_group: h5py.Group, + data_group: h5py.Group, + tree: Optional[Tree] = None, + min_level: Optional[int] = None, +): + return ( + header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in data_group.keys() + ) + + register_plugin( PluginSpec(PluginType.DatasetOpen, top_host_idx_verifier, top_host_idx_plugin) ) @@ -187,3 +242,9 @@ def register_keep_top_host_idx_verifier( PluginType.PostSort, top_host_idx_verifier, update_top_host_idx_after_sort ) ) + +register_plugin( + PartitionPluginSpec( + PluginType.Partition, partition_plugin_verifier, partition_plugin + ) +) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 870c1a45..8041e762 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -515,8 +515,7 @@ def open_single_dataset( if not bypass_mpi and (comm := get_comm_world()) is not None: assert partition is not None try: - idx_data = ds_group["index"] - part = partition(comm, ds_length, idx_data, tree) + part = partition(comm, header, ds_group["index"], ds_group["data"], tree) if part is None: index = empty() else: diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index 4dc1702b..ee6ffe10 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -3,23 +3,21 @@ from collections import defaultdict from enum import StrEnum from functools import reduce -from typing import ( - TYPE_CHECKING, - Any, - Callable, - NamedTuple, - TypedDict, -) +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, TypedDict from opencosmo.index import into_array if TYPE_CHECKING: + import h5py import numpy as np from astropy.table import Table + from mpi4py import MPI from opencosmo import Lightcone from opencosmo.dataset.state import DatasetState + from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex, IndexArray + from opencosmo.spatial.tree import Tree, TreePartition class PluginType(StrEnum): @@ -29,6 +27,7 @@ class PluginType(StrEnum): LightconeInstantiate = "lightcone_instantiate" PostSort = "post_sort" IndexUpdate = "index_update" + Partition = "partition" type Verifier[T] = Callable[[T], bool] @@ -53,18 +52,30 @@ class PostSortPluginSpec[T: (DatasetState, Lightcone)](NamedTuple): plugin: Callable[[Table, IndexArray], dict[str, np.ndarray]] +class PartitionPluginSpec(NamedTuple): + plugin_type: PluginType + verifier: Callable[[OpenCosmoHeader, h5py.Group, h5py.Group], bool] + plugin: Callable[ + [MPI.Comm, h5py.Group, h5py.Group, Optional[Tree], Optional[int]], + Optional[TreePartition], + ] + + class Plugins(TypedDict): dataset_open: list[PluginSpec] dataset_instantiate: list[PluginSpec] lightcone_load: list[PluginSpec] lightcone_instantiate: list[PluginSpec] index_update: list[IndexPluginSpec] + partition: list[PartitionPluginSpec] KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore -def register_plugin(spec: PluginSpec | IndexPluginSpec | PostSortPluginSpec) -> None: +def register_plugin( + spec: PluginSpec | IndexPluginSpec | PostSortPluginSpec | PartitionPluginSpec, +) -> None: KNOWN_PLUGINS[str(spec.plugin_type)].append(spec) # type: ignore @@ -110,6 +121,26 @@ def apply_post_sort_plugins[T: (DatasetState, Lightcone)]( ) +def apply_partition_plugins( + comm: MPI.Comm, + header: OpenCosmoHeader, + index_group: h5py.Group, + data_group: h5py.Group, + tree: Optional[Tree] = None, + min_level: Optional[int] = None, +) -> Optional[TreePartition]: + partition_plugins = KNOWN_PLUGINS[str(PluginType.Partition)] # type: ignore + if len(partition_plugins) == 0: + return None + if partition_plugins > 1: + raise ValueError("Only one partition plugin is allowed at a time") + plugin_spec = partition_plugins[0] + if not plugin_spec.verifier(header, index_group, data_group): + return None + + return plugin_spec.plugin(comm, index_group, data_group, tree, min_level) + + def _apply_single_post_sort[T: (DatasetState, Lightcone)]( spec: PostSortPluginSpec, state: T, data: Table, index: IndexArray ): From 9d6393f825a9ad04dddf5b6d02148d38cca4b285 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 11:54:33 -0500 Subject: [PATCH 072/139] Add tests for diffsky partition plugin --- python/opencosmo/plugins/plugin.py | 2 +- test/parallel/test_lc_mpi.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py index ee6ffe10..b3fec224 100644 --- a/python/opencosmo/plugins/plugin.py +++ b/python/opencosmo/plugins/plugin.py @@ -132,7 +132,7 @@ def apply_partition_plugins( partition_plugins = KNOWN_PLUGINS[str(PluginType.Partition)] # type: ignore if len(partition_plugins) == 0: return None - if partition_plugins > 1: + if len(partition_plugins) > 1: raise ValueError("Only one partition plugin is allowed at a time") plugin_spec = partition_plugins[0] if not plugin_spec.verifier(header, index_group, data_group): diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 9a52369c..4af5fda8 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -474,6 +474,30 @@ def test_write_diffsky_some_missing_no_stack( ) +@pytest.mark.parallel(nprocs=4) +def test_open_parallel_top_host(core_path_487, core_path_475): + core_map = _get_expected_core_tags(core_path_487) + core_map |= _get_expected_core_tags(core_path_475) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data() + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + +@pytest.mark.parallel(nprocs=4) +def test_keep_top_host_filter(core_path_487, core_path_475): + core_map = _get_expected_core_tags(core_path_487) + core_map |= _get_expected_core_tags(core_path_475) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + data = ds.take(10).select("top_host_idx", "core_tag").get_data() + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + @pytest.mark.parallel(nprocs=4) def test_write_some_missing_no_stack( haloproperties_600_path, haloproperties_601_path, per_test_dir @@ -569,3 +593,30 @@ def _assert_top_host_idx_correct(data, core_map): if val in data["core_tag"] and key in real_core_tag } assert should_have_core_map == found_core_map + + comm = get_comm_world() + all_data_core_maps = comm.allgather(found_core_map) + seen = set() + for m in all_data_core_maps: + assert len(seen.intersection(m.keys())) == 0 + seen |= m.keys() + + +def _assert_all_group_members_present(data, core_map): + """ + Verify that for every top_host represented in the data, all rows from the + full dataset that point to that top_host are also present. + """ + host_to_members: dict = {} + for ct, host_ct in core_map.items(): + host_to_members.setdefault(host_ct, set()).add(ct) + + present_core_tags = set(data["core_tag"]) + top_host_core_tags = set(data["core_tag"][data["top_host_idx"]]) + + for top_host_ct in top_host_core_tags: + expected_members = host_to_members.get(top_host_ct, set()) + missing = expected_members - present_core_tags + assert not missing, ( + f"top_host {top_host_ct}: {len(missing)} member(s) missing from result" + ) From cceaa1b48b633a7e3c10e8887b8b60c5f50c16d8 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 12:56:45 -0500 Subject: [PATCH 073/139] Update diffsky partitioning algorithm. --- python/opencosmo/dtypes/diffsky.py | 45 +++++++++++++++++------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 6d0c6635..7a9eed21 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -181,26 +181,31 @@ def partition_plugin( min_level: Optional[int] = None, ): top_host_idx = data_group["top_host_idx"][:] - if comm.Get_rank() == 0: - n_ranks = comm.Get_size() - unique_hosts = np.unique(top_host_idx, sorted=True) - ave, res = divmod(unique_hosts.size, n_ranks) - counts = np.array( - [ave + 1 if i < res else ave for i in range(n_ranks)], dtype=np.int64 - ) - displs = np.cumsum(np.concatenate(([0], counts[:-1]))) - counts = comm.bcast(counts) - # Should always be roughly balanced, since data is - # ordered spatially - rank_top_hosts = np.zeros(counts[0], dtype=np.int64) - comm.Scatterv([unique_hosts, counts, displs, MPI.INT64_T], rank_top_hosts) - else: - counts = comm.bcast(None) - rank_top_hosts = np.zeros(counts[comm.Get_rank()], dtype=np.int64) - comm.Scatterv([None, counts, None, MPI.INT64_T], rank_top_hosts) - - is_in_groups = np.isin(top_host_idx, rank_top_hosts) - index = np.where(is_in_groups)[0] + n_rows = len(top_host_idx) + n_ranks = comm.Get_size() + rank = comm.Get_rank() + + # Step 1: partition row indices evenly between ranks + ave, res = divmod(n_rows, n_ranks) + counts = np.array( + [ave + 1 if i < res else ave for i in range(n_ranks)], dtype=np.int64 + ) + displs = np.cumsum(np.concatenate(([0], counts[:-1]))) + start = displs[rank] + count = counts[rank] + row_indices = np.arange(start, start + count, dtype=np.int64) + chunk = top_host_idx[start : start + count] + + # Step 2: find top hosts (self-referential rows) and orphans (top_host_idx == -1) + # in this rank's chunk + rank_top_hosts = row_indices[chunk == row_indices] + rank_orphans = row_indices[chunk == -1] + + # Step 3: search the full array for all rows that belong to this rank's top hosts + all_group_rows = np.where(np.isin(top_host_idx, rank_top_hosts))[0] + + # Step 4: combine group members with orphans from this rank's partition + index = np.union1d(all_group_rows, rank_orphans) return TreePartition(idx=index, region=None, level=None) From a00a889a3996951edb9b092813563c8560a5ea25 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 15:32:19 -0500 Subject: [PATCH 074/139] Write tests for mpi write partitioning --- python/opencosmo/io/iopen.py | 3 +- test/parallel/test_lc_mpi.py | 60 +++++++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 8041e762..fb1f16d4 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -545,6 +545,7 @@ def open_single_dataset( state, tree=tree, ) + dataset = apply_plugins(PluginType.DatasetOpen, dataset, **open_kwargs) if header.file.data_type == "healpix_map": return __open_healpix_map(dataset, sim_region) elif header.file.is_lightcone and not bypass_lightcone: @@ -552,7 +553,7 @@ def open_single_dataset( {"data": dataset}, header.lightcone["z_range"], **open_kwargs ) - return apply_plugins(PluginType.DatasetOpen, dataset, **open_kwargs) + return dataset def __open_healpix_map(dataset: oc.Dataset, sim_region): diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 4af5fda8..c4277243 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -486,6 +486,59 @@ def test_open_parallel_top_host(core_path_487, core_path_475): _assert_all_group_members_present(data, core_map) +@pytest.mark.parallel(nprocs=4) +def test_open_write_parallel_top_host(core_path_487, core_path_475, per_test_dir): + with h5py.File(core_path_475) as f: + core_map = _get_expected_core_tags(f["cores"]) + + with h5py.File(core_path_487) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + oc.write(per_test_dir / "test.hdf5", ds) + with h5py.File(per_test_dir / "test.hdf5") as f: + written_core_map = _get_expected_core_tags(f["475_475"]) + assert core_map == written_core_map + + data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + +@pytest.mark.parallel(nprocs=4) +def test_open_write_parallel_top_after_filter( + core_path_487, core_path_475, per_test_dir +): + with h5py.File(core_path_475) as f: + core_map = _get_expected_core_tags(f["cores"]) + + with h5py.File(core_path_487) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True).take(10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + oc.write(per_test_dir / "test.hdf5", ds) + + data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + @pytest.mark.parallel(nprocs=4) def test_keep_top_host_filter(core_path_487, core_path_475): core_map = _get_expected_core_tags(core_path_487) @@ -551,10 +604,9 @@ def test_lightcone_stacking( assert next(iter(ds_new.values())).header.lightcone["z_range"] == ds_new.z_range -def _get_expected_core_tags(path): - with h5py.File(path) as f: - raw_top_host = f["cores"]["data"]["top_host_idx"][:] - core_tag = f["cores"]["data"]["core_tag"][:] +def _get_expected_core_tags(group): + raw_top_host = group["data"]["top_host_idx"][:] + core_tag = group["data"]["core_tag"][:] top_host_core_tag = core_tag[raw_top_host] return dict(zip(core_tag, top_host_core_tag)) From c1d18857b3a4b9809f6ae9f87744531199e4fffe Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 15:40:12 -0500 Subject: [PATCH 075/139] Add changelog, fix tests --- changes/+66861282.feature.rst | 2 ++ test/parallel/test_lc_mpi.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 changes/+66861282.feature.rst diff --git a/changes/+66861282.feature.rst b/changes/+66861282.feature.rst new file mode 100644 index 00000000..b0a1eac9 --- /dev/null +++ b/changes/+66861282.feature.rst @@ -0,0 +1,2 @@ +Diffsky catalog can be forced to keep groups during filtering etc. by setting the `get_top_host = True` flag when +opening. diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index c4277243..adf9a972 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -476,8 +476,10 @@ def test_write_diffsky_some_missing_no_stack( @pytest.mark.parallel(nprocs=4) def test_open_parallel_top_host(core_path_487, core_path_475): - core_map = _get_expected_core_tags(core_path_487) - core_map |= _get_expected_core_tags(core_path_475) + with h5py.File(core_path_487) as f: + core_map = _get_expected_core_tags(f["cores"]) + with h5py.File(core_path_475) as f: + core_map |= _get_expected_core_tags(f["cores"]) ds = oc.open(core_path_475, core_path_487) data = ds.select("top_host_idx", "core_tag").get_data() @@ -541,8 +543,10 @@ def test_open_write_parallel_top_after_filter( @pytest.mark.parallel(nprocs=4) def test_keep_top_host_filter(core_path_487, core_path_475): - core_map = _get_expected_core_tags(core_path_487) - core_map |= _get_expected_core_tags(core_path_475) + with h5py.File(core_path_487) as f: + core_map = _get_expected_core_tags(f["cores"]) + with h5py.File(core_path_475) as f: + core_map |= _get_expected_core_tags(f["cores"]) ds = oc.open(core_path_475, core_path_487, keep_top_host=True) data = ds.take(10).select("top_host_idx", "core_tag").get_data() From 759700b1b81342121e6046825d24a721eeac33fa Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 16:38:17 -0500 Subject: [PATCH 076/139] Add context and hook implementation --- python/opencosmo/plugins/contexts.py | 127 +++++++++++++++++++++++++++ python/opencosmo/plugins/hook.py | 74 ++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 python/opencosmo/plugins/contexts.py create mode 100644 python/opencosmo/plugins/hook.py diff --git a/python/opencosmo/plugins/contexts.py b/python/opencosmo/plugins/contexts.py new file mode 100644 index 00000000..19fee3ad --- /dev/null +++ b/python/opencosmo/plugins/contexts.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + import h5py + import numpy as np + from astropy.table import Table + from mpi4py import MPI + + from opencosmo import Dataset, Lightcone + from opencosmo.dataset.state import DatasetState + from opencosmo.header import OpenCosmoHeader + from opencosmo.index import DataIndex, IndexArray + from opencosmo.spatial.tree import Tree + + +class HookPoint(StrEnum): + DatasetOpen = "dataset_open" + DatasetInstantiate = "dataset_instantiate" + LightconeOpen = "lightcone_open" + LightconeInstantiate = "lightcone_instantiate" + IndexUpdate = "index_update" + PostSort = "post_sort" + Partition = "partition" + + +# --- fold hooks --- +# Plugins receive and return these contexts. Use dataclasses.replace() to +# produce a modified copy rather than mutating in place. + + +@dataclass(frozen=True) +class DatasetOpenCtx: + """Fired once per dataset after it is opened from disk. + + open_kwargs holds any keyword arguments the user passed to opencosmo.open(). + Plugins can inspect them to conditionally modify the dataset. + """ + + dataset: Dataset + open_kwargs: dict[str, Any] + + +@dataclass(frozen=True) +class DatasetInstantiateCtx: + """Fired each time get_data() is called on a DatasetState. + + Plugins may add or modify derived columns before the data is materialised. + """ + + state: DatasetState + + +@dataclass(frozen=True) +class LightconeOpenCtx: + """Fired once per Lightcone after it is opened from disk. + + open_kwargs mirrors DatasetOpenCtx.open_kwargs. + """ + + lightcone: Lightcone + open_kwargs: dict[str, Any] + + +@dataclass(frozen=True) +class LightconeInstantiateCtx: + """Fired each time get_data() is called on a Lightcone. + + Plugins may re-order or re-index sub-datasets before they are stacked. + """ + + lightcone: Lightcone + + +@dataclass(frozen=True) +class IndexUpdateCtx: + """Fired whenever a filter, take, or bound operation produces a new index. + + state is read-only context for the predicate and for reading column data. + index is the value being transformed; return a modified copy via + dataclasses.replace(ctx, index=new_index). + """ + + state: DatasetState + index: DataIndex + + +@dataclass(frozen=True) +class PostSortCtx: + """Fired after a sort operation reorders rows. + + state is read-only context for the predicate. + index is the reverse-sort permutation (i.e. np.argsort of the sort order), + which can be used to remap index-valued columns. + data is the value being transformed; return a modified copy via + dataclasses.replace(ctx, data=new_data). + """ + + state: DatasetState | Lightcone + data: Table | dict[str, np.ndarray] + index: IndexArray + + +# --- query hook --- +# Partition uses query() rather than fold(). The plugin returns a +# TreePartition directly (not a modified context), and the first non-None +# result wins. No plugin responding means the caller falls back to the +# default partitioning strategy. + + +@dataclass(frozen=True) +class PartitionCtx: + """Fired during MPI open to determine how rows are distributed across ranks. + + Plugins return a TreePartition, or None to defer to the default strategy. + This hook uses query() semantics: at most one plugin responds. + """ + + comm: MPI.Comm + header: OpenCosmoHeader + index_group: h5py.Group + data_group: h5py.Group + tree: Optional[Tree] = None + min_level: Optional[int] = None diff --git a/python/opencosmo/plugins/hook.py b/python/opencosmo/plugins/hook.py new file mode 100644 index 00000000..1e6e4b6d --- /dev/null +++ b/python/opencosmo/plugins/hook.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from functools import reduce +from typing import Any, Callable, TypeVar + +Ctx = TypeVar("Ctx") +R = TypeVar("R") + +type Predicate[Ctx] = Callable[[Ctx], bool] + + +@dataclass(frozen=True) +class Registration[Ctx]: + predicate: Predicate[Ctx] + transform: Callable[[Ctx], Any] + + +_registry: dict[str, list[Registration]] = defaultdict(list) + + +def hook(name: str, *, when: Predicate = lambda _: True): + """Decorator that registers a function as a hook implementation. + + The decorated function receives and returns a context dataclass. For fold + hooks, it should return a (possibly modified) copy of the context; use + dataclasses.replace() to do so functionally. For query hooks, it should + return either a result value or None to signal no match. + + Parameters + ---------- + name: + The hook point name. Use a constant from HookPoint. + when: + Predicate called with the context. The hook only fires when this + returns True. + """ + + def decorator(fn: Callable) -> Callable: + _registry[name].append(Registration(when, fn)) + return fn + + return decorator + + +def fold(name: str, ctx: Ctx) -> Ctx: + """Apply all registered hooks for *name* in order, threading ctx through. + + Each matching hook receives the output of the previous one. Non-matching + hooks (predicate returns False) are skipped and ctx passes through unchanged. + """ + return reduce( + lambda c, reg: reg.transform(c) if reg.predicate(c) else c, + _registry[name], + ctx, + ) + + +def query(name: str, ctx: Any) -> Any | None: + """Return the result of the first matching hook for *name*, or None. + + Intended for hook points where at most one plugin should respond — the + first hook whose predicate matches is called, and its return value is + returned immediately. Subsequent hooks are not evaluated. + """ + return next( + ( + result + for reg in _registry[name] + if reg.predicate(ctx) and (result := reg.transform(ctx)) is not None + ), + None, + ) From e9994f63e6fecca952dc5210cbed9e0cd5e00ae3 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 16:52:49 -0500 Subject: [PATCH 077/139] Move plugin registration and calls to new system --- .../collection/lightcone/lightcone.py | 26 +- .../opencosmo/collection/lightcone/plugins.py | 44 ++-- python/opencosmo/dataset/mpi.py | 8 +- python/opencosmo/dataset/state.py | 19 +- python/opencosmo/dtypes/diffsky.py | 226 ++++++++---------- python/opencosmo/io/iopen.py | 5 +- python/opencosmo/plugins/__init__.py | 26 +- python/opencosmo/plugins/plugin.py | 163 ------------- 8 files changed, 174 insertions(+), 343 deletions(-) delete mode 100644 python/opencosmo/plugins/plugin.py diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index b5393f60..f4364a1f 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -28,7 +28,13 @@ from opencosmo.dataset.formats import convert_data, verify_format from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.plugins import plugin +from opencosmo.plugins.contexts import ( + HookPoint, + LightconeInstantiateCtx, + LightconeOpenCtx, + PostSortCtx, +) +from opencosmo.plugins.hook import fold if TYPE_CHECKING: import astropy.units as u # type: ignore @@ -286,7 +292,9 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): ) format = kwargs["output"] verify_format(format) - lightcone = plugin.apply_plugins(plugin.PluginType.LightconeInstantiate, self) + lightcone = fold( + HookPoint.LightconeInstantiate, LightconeInstantiateCtx(self) + ).lightcone data = [ds.get_data(unpack=unpack) for ds in lightcone.values()] data_with_length = [d for d in data if len(d) > 0] if len(data_with_length) == 0: @@ -297,7 +305,9 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): if self.__ordered_by is not None: order = table.argsort(self.__ordered_by[0], reverse=self.__ordered_by[1]) table = table[order] - table = plugin.apply_post_sort_plugins(self, table, np.argsort(order)) + table = fold( + HookPoint.PostSort, PostSortCtx(self, table, np.argsort(order)) + ).data to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) @@ -355,7 +365,9 @@ def open(cls, targets: list[FileTarget], **kwargs): raise ValueError() result = cls(output) - result = plugin.apply_plugins(plugin.PluginType.LightconeOpen, result, **kwargs) + result = fold( + HookPoint.LightconeOpen, LightconeOpenCtx(result, kwargs) + ).lightcone return make_radec_columns(result) @@ -367,9 +379,9 @@ def from_datasets( **open_kwargs, ): result = cls(datasets, z_range) - result = plugin.apply_plugins( - plugin.PluginType.LightconeOpen, result, **open_kwargs - ) + result = fold( + HookPoint.LightconeOpen, LightconeOpenCtx(result, open_kwargs) + ).lightcone return make_radec_columns(result) def with_redshift_range(self, z_low: float, z_high: float): diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 2aabedff..38178ab4 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -1,42 +1,36 @@ from __future__ import annotations +import dataclasses from typing import TYPE_CHECKING import opencosmo as oc -from opencosmo.plugins import PluginSpec, PluginType, register_plugin +from opencosmo.plugins.contexts import HookPoint +from opencosmo.plugins.hook import hook if TYPE_CHECKING: from opencosmo import Lightcone + from opencosmo.plugins.contexts import LightconeOpenCtx -def with_redshift_column(dataset: Lightcone, *args, **kwargs): - """ - Ensures a column exists called "redshift" which contains the redshift of the objects - in the lightcone. - """ - if "redshift" in dataset.columns: - return dataset - - elif "fof_halo_center_a" in dataset.columns: +@hook(HookPoint.LightconeOpen) +def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: + """Ensures a column called 'redshift' exists on every lightcone.""" + lightcone: Lightcone = ctx.lightcone + if "redshift" in lightcone.columns: + return ctx + elif "fof_halo_center_a" in lightcone.columns: z_col = 1 / oc.col("fof_halo_center_a") - 1 - return dataset.with_new_columns(redshift=z_col) - elif "redshift_true" in dataset.columns: + elif "redshift_true" in lightcone.columns: z_col = oc.col("redshift_true") - return dataset.with_new_columns(redshift=z_col) - elif "zp" in dataset.columns: + elif "zp" in lightcone.columns: z_col = oc.col("zp") - return dataset.with_new_columns(redshift=z_col) - raise ValueError( - "Unable to find a redshift or scale factor column for this lightcone dataset" + else: + raise ValueError( + "Unable to find a redshift or scale factor column for this lightcone dataset" + ) + return dataclasses.replace( + ctx, lightcone=lightcone.with_new_columns(redshift=z_col) ) -register_plugin( - PluginSpec( - PluginType.LightconeOpen, - lambda *args, **kwargs: True, - with_redshift_column, - ) -) - plugins = None diff --git a/python/opencosmo/dataset/mpi.py b/python/opencosmo/dataset/mpi.py index c6ebf8e0..67ae301b 100644 --- a/python/opencosmo/dataset/mpi.py +++ b/python/opencosmo/dataset/mpi.py @@ -4,7 +4,8 @@ from warnings import warn from opencosmo.index.build import single_chunk -from opencosmo.plugins.plugin import apply_partition_plugins +from opencosmo.plugins.contexts import HookPoint, PartitionCtx +from opencosmo.plugins.hook import query from opencosmo.spatial.protocols import TreePartition if TYPE_CHECKING: @@ -28,8 +29,9 @@ def partition( spatial index. In principle this means the number of objects are similar between ranks. """ - partition_plugin_result = apply_partition_plugins( - comm, header, index_group, data_group, tree, min_level + partition_plugin_result = query( + HookPoint.Partition, + PartitionCtx(comm, header, index_group, data_group, tree, min_level), ) if partition_plugin_result is not None: return partition_plugin_result diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index f46dd2fa..d964b004 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -18,12 +18,13 @@ from opencosmo.index.build import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range -from opencosmo.plugins.plugin import ( - PluginType, - apply_index_plugins, - apply_plugins, - apply_post_sort_plugins, +from opencosmo.plugins.contexts import ( + DatasetInstantiateCtx, + HookPoint, + IndexUpdateCtx, + PostSortCtx, ) +from opencosmo.plugins.hook import fold from opencosmo.units import UnitConvention from opencosmo.units.handler import ( make_unit_handler_from_hdf5, @@ -63,7 +64,7 @@ def sort_data( data = {key: value[order] for key, value in data.items()} if sort_by[0] not in state.columns: data.pop(sort_by[0]) - return apply_post_sort_plugins(state, data, np.argsort(order)) + return fold(HookPoint.PostSort, PostSortCtx(state, data, np.argsort(order))).data class DatasetState: @@ -270,7 +271,7 @@ def get_data( """ Get the data for a given handler. """ - state = apply_plugins(PluginType.DatasetInstantiate, self) + state = fold(HookPoint.DatasetInstantiate, DatasetInstantiateCtx(self)).state data = instantiate_dataset( list(state.__producers.values()), state.__columns, @@ -349,7 +350,7 @@ def with_mask(self, mask: NDArray[np.bool_]): return self.with_index(np.where(mask)[0]) def with_index(self, index: DataIndex): - index = apply_index_plugins(self, index) + index = fold(HookPoint.IndexUpdate, IndexUpdateCtx(self, index)).index new_raw_handler = self.__raw_data_handler.take(index) new_cache = self.__cache.take(index) return self.__rebuild( @@ -490,7 +491,7 @@ def take_range(self, start: int, end: int): def take_rows(self, rows: DataIndex): if len(self) == 0: return self - rows = apply_index_plugins(self, rows) + rows = fold(HookPoint.IndexUpdate, IndexUpdateCtx(self, rows)).index row_range = get_range(rows) if row_range[1] > len(self) or row_range[0] < 0: diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 7a9eed21..fc9b9eb3 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses from datetime import datetime # noqa from typing import TYPE_CHECKING, ClassVar, Optional @@ -10,27 +11,25 @@ from opencosmo.index import into_array from opencosmo.index.ops import reindex_column from opencosmo.mpi import get_mpi -from opencosmo.plugins.plugin import ( - IndexPluginSpec, - PartitionPluginSpec, - PluginSpec, - PluginType, - PostSortPluginSpec, - register_plugin, -) +from opencosmo.plugins.contexts import HookPoint +from opencosmo.plugins.hook import hook from opencosmo.spatial.tree import TreePartition if TYPE_CHECKING: - import h5py from astropy.table import Table from mpi4py import MPI - from opencosmo import Dataset, Lightcone + from opencosmo import Dataset from opencosmo.dataset.state import DatasetState - from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex - from opencosmo.spatial.tree import Tree - + from opencosmo.plugins.contexts import ( + DatasetOpenCtx, + IndexUpdateCtx, + LightconeInstantiateCtx, + LightconeOpenCtx, + PartitionCtx, + PostSortCtx, + ) else: MPI = get_mpi() @@ -62,6 +61,9 @@ def serialize_zphot_table(self, value): return None +# --- pure logic --- + + def __offset(top_host_idx, offset): output = top_host_idx output[output >= 0] += offset @@ -85,107 +87,115 @@ def offset_top_host_idx(datasets: list[Dataset]): def rebuild_top_host_idx(top_host_idx, index): result = reindex_column(index, top_host_idx) - return {"top_host_idx": result} -def top_host_idx_plugin(dataset: DatasetState, **kwargs): +def keep_top_host_idx(dataset: DatasetState, new_index: DataIndex): + index_array = into_array(new_index) + top_host_idx = dataset.select({"top_host_idx"}).get_data()["top_host_idx"] + unique_in_sample = np.unique(top_host_idx[index_array]) + + missing_hosts = np.setdiff1d(unique_in_sample, index_array) + all_satellites = np.where(np.isin(top_host_idx, unique_in_sample))[0] + missing_satellites = np.setdiff1d(all_satellites, index_array) + + if len(missing_hosts) == 0 and len(missing_satellites) == 0: + return new_index + + all_missing = np.sort(np.concatenate((missing_hosts, missing_satellites))) + insert_idx = np.searchsorted(index_array, all_missing) + return np.insert(index_array, insert_idx, all_missing) + + +def _is_synthetic_galaxies_with_top_host_idx(obj) -> bool: + return ( + obj.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in obj.columns + ) + + +# --- hooks --- + + +@hook( + HookPoint.DatasetOpen, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.dataset), +) +def _attach_top_host_idx_column(ctx: DatasetOpenCtx) -> DatasetOpenCtx: top_host_idx = EvaluatedColumn( rebuild_top_host_idx, - requires=set(["top_host_idx"]), - produces=set(["top_host_idx"]), + requires={"top_host_idx"}, + produces={"top_host_idx"}, format="numpy", units={"top_host_idx": None}, strategy=EvaluateStrategy.VECTORIZE, no_cache=True, ) - return dataset.with_new_columns(updated_host_idx=top_host_idx, allow_overwrite=True) + new_dataset = ctx.dataset.with_new_columns( + updated_host_idx=top_host_idx, allow_overwrite=True + ) + return dataclasses.replace(ctx, dataset=new_dataset) -def top_host_idx_offset_plugin(lightcone: Lightcone) -> dict[str, Dataset]: +@hook( + HookPoint.LightconeInstantiate, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.lightcone), +) +def _offset_top_host_idx(ctx: LightconeInstantiateCtx) -> LightconeInstantiateCtx: cs = 0 output = {} - def top_host_idx(top_host_idx, offset): + def _offset_top_host_idx(top_host_idx, offset): top_host_idx[top_host_idx >= 0] += offset return top_host_idx - for key, ds in lightcone.items(): + for key, ds in ctx.lightcone.items(): output[key] = ds.evaluate( - top_host_idx, allow_overwrite=True, vectorize=True, offset=cs + _offset_top_host_idx, allow_overwrite=True, vectorize=True, offset=cs ) cs += len(ds) - return output - - -def top_host_idx_verifier[T: (DatasetState, Dataset, Lightcone)]( - dataset: T, **kwargs -) -> bool: - return ( - dataset.header.file.data_type == "synthetic_galaxies" - and "top_host_idx" in dataset.columns - ) - - -def keep_top_host_idx(dataset: DatasetState, new_index: DataIndex): - index_array = into_array(new_index) - top_host_idx = dataset.select({"top_host_idx"}).get_data()["top_host_idx"] - unique_in_sample = np.unique(top_host_idx[index_array]) - - missing_hosts = np.setdiff1d(unique_in_sample, index_array) - all_satellites = np.where(np.isin(top_host_idx, unique_in_sample))[0] - missing_satellites = np.setdiff1d(all_satellites, index_array) - - if len(missing_hosts) == 0 and len(missing_satellites) == 0: - return new_index - - all_missing = np.sort(np.concatenate((missing_hosts, missing_satellites))) - - insert_idx = np.searchsorted(index_array, all_missing) - result = np.insert(index_array, insert_idx, all_missing) + return dataclasses.replace(ctx, lightcone=output) # type: ignore[arg-type] - return result +@hook( + HookPoint.LightconeOpen, + when=lambda ctx: ctx.open_kwargs.get("keep_top_host", False) + and _is_synthetic_galaxies_with_top_host_idx(ctx.lightcone), +) +def _register_keep_top_host_idx(ctx: LightconeOpenCtx) -> LightconeOpenCtx: + # Registers an IndexUpdate hook dynamically so that keep_top_host_idx only + # activates when the user explicitly requests it via open(..., keep_top_host=True). + # TODO: store intent in DatasetState to avoid mutating global hook registry. + @hook(HookPoint.IndexUpdate, when=lambda ctx: "top_host_idx" in ctx.state.columns) + def _keep(ctx: IndexUpdateCtx) -> IndexUpdateCtx: + return dataclasses.replace(ctx, index=keep_top_host_idx(ctx.state, ctx.index)) -def keep_top_host_idx_verifier(dataset: DatasetState): - return "top_host_idx" in dataset.columns + return ctx -def update_top_host_idx_after_sort(data: Table, reverse_index: DataIndex): +@hook( + HookPoint.PostSort, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.state), +) +def _remap_top_host_idx_after_sort(ctx: PostSortCtx) -> PostSortCtx: + data: Table = ctx.data # type: ignore[assignment] mask = data["top_host_idx"] >= 0 - data["top_host_idx"][mask] = reverse_index[data["top_host_idx"][mask]] - return data - - -def register_keep_top_host_idx(dataset: Dataset, **kwargs): - register_plugin( - IndexPluginSpec( - PluginType.IndexUpdate, keep_top_host_idx_verifier, keep_top_host_idx - ) - ) - return dataset - - -def register_keep_top_host_idx_verifier( - dataset: Dataset, keep_top_host: bool = False, **kwargs -): - return keep_top_host and top_host_idx_verifier(dataset) + data["top_host_idx"][mask] = ctx.index[data["top_host_idx"][mask]] + return ctx -def partition_plugin( - comm: MPI.Comm, - index_group: h5py.Group, - data_group: h5py.Group, - tree: Optional[Tree] = None, - min_level: Optional[int] = None, -): - top_host_idx = data_group["top_host_idx"][:] +@hook( + HookPoint.Partition, + when=lambda ctx: ctx.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in ctx.data_group.keys(), +) +def _partition_by_top_host_groups(ctx: PartitionCtx) -> Optional[TreePartition]: + top_host_idx = ctx.data_group["top_host_idx"][:] n_rows = len(top_host_idx) - n_ranks = comm.Get_size() - rank = comm.Get_rank() + n_ranks = ctx.comm.Get_size() + rank = ctx.comm.Get_rank() - # Step 1: partition row indices evenly between ranks ave, res = divmod(n_rows, n_ranks) counts = np.array( [ave + 1 if i < res else ave for i in range(n_ranks)], dtype=np.int64 @@ -196,60 +206,12 @@ def partition_plugin( row_indices = np.arange(start, start + count, dtype=np.int64) chunk = top_host_idx[start : start + count] - # Step 2: find top hosts (self-referential rows) and orphans (top_host_idx == -1) - # in this rank's chunk + # find top hosts (self-referential) and orphans (top_host_idx == -1) rank_top_hosts = row_indices[chunk == row_indices] rank_orphans = row_indices[chunk == -1] - # Step 3: search the full array for all rows that belong to this rank's top hosts + # gather all rows belonging to this rank's top hosts all_group_rows = np.where(np.isin(top_host_idx, rank_top_hosts))[0] - # Step 4: combine group members with orphans from this rank's partition index = np.union1d(all_group_rows, rank_orphans) return TreePartition(idx=index, region=None, level=None) - - -def partition_plugin_verifier( - header: OpenCosmoHeader, - index_group: h5py.Group, - data_group: h5py.Group, - tree: Optional[Tree] = None, - min_level: Optional[int] = None, -): - return ( - header.file.data_type == "synthetic_galaxies" - and "top_host_idx" in data_group.keys() - ) - - -register_plugin( - PluginSpec(PluginType.DatasetOpen, top_host_idx_verifier, top_host_idx_plugin) -) - -register_plugin( - PluginSpec( - PluginType.LightconeInstantiate, - top_host_idx_verifier, - top_host_idx_offset_plugin, - ) -) - -register_plugin( - PluginSpec( - PluginType.LightconeOpen, - register_keep_top_host_idx_verifier, - register_keep_top_host_idx, - ) -) - -register_plugin( - PostSortPluginSpec( - PluginType.PostSort, top_host_idx_verifier, update_top_host_idx_after_sort - ) -) - -register_plugin( - PartitionPluginSpec( - PluginType.Partition, partition_plugin_verifier, partition_plugin - ) -) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index fb1f16d4..1d7a1eda 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -16,7 +16,8 @@ from opencosmo.header import OpenCosmoHeader, read_header from opencosmo.index.build import empty, from_range from opencosmo.mpi import get_comm_world -from opencosmo.plugins.plugin import PluginType, apply_plugins +from opencosmo.plugins.contexts import DatasetOpenCtx, HookPoint +from opencosmo.plugins.hook import fold from opencosmo.spatial.builders import from_model from opencosmo.spatial.region import FullSkyRegion, HealpixRegion from opencosmo.spatial.tree import open_tree @@ -545,7 +546,7 @@ def open_single_dataset( state, tree=tree, ) - dataset = apply_plugins(PluginType.DatasetOpen, dataset, **open_kwargs) + dataset = fold(HookPoint.DatasetOpen, DatasetOpenCtx(dataset, open_kwargs)).dataset if header.file.data_type == "healpix_map": return __open_healpix_map(dataset, sim_region) elif header.file.is_lightcone and not bypass_lightcone: diff --git a/python/opencosmo/plugins/__init__.py b/python/opencosmo/plugins/__init__.py index 928c4461..a4d67ff5 100644 --- a/python/opencosmo/plugins/__init__.py +++ b/python/opencosmo/plugins/__init__.py @@ -1,3 +1,25 @@ -from .plugin import PluginSpec, PluginType, register_plugin +from .contexts import ( + DatasetInstantiateCtx, + DatasetOpenCtx, + HookPoint, + IndexUpdateCtx, + LightconeInstantiateCtx, + LightconeOpenCtx, + PartitionCtx, + PostSortCtx, +) +from .hook import fold, hook, query -__all__ = ["register_plugin", "PluginSpec", "PluginType"] +__all__ = [ + "fold", + "hook", + "query", + "HookPoint", + "DatasetOpenCtx", + "DatasetInstantiateCtx", + "LightconeOpenCtx", + "LightconeInstantiateCtx", + "IndexUpdateCtx", + "PostSortCtx", + "PartitionCtx", +] diff --git a/python/opencosmo/plugins/plugin.py b/python/opencosmo/plugins/plugin.py deleted file mode 100644 index b3fec224..00000000 --- a/python/opencosmo/plugins/plugin.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from enum import StrEnum -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, TypedDict - -from opencosmo.index import into_array - -if TYPE_CHECKING: - import h5py - import numpy as np - from astropy.table import Table - from mpi4py import MPI - - from opencosmo import Lightcone - from opencosmo.dataset.state import DatasetState - from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex, IndexArray - from opencosmo.spatial.tree import Tree, TreePartition - - -class PluginType(StrEnum): - DatasetOpen = "dataset_open" - DatasetInstantiate = "dataset_instantiate" - LightconeOpen = "lightcone_open" - LightconeInstantiate = "lightcone_instantiate" - PostSort = "post_sort" - IndexUpdate = "index_update" - Partition = "partition" - - -type Verifier[T] = Callable[[T], bool] -type Plugin[T] = Callable[[T], T] - - -class PluginSpec[T](NamedTuple): - plugin_type: PluginType - verifier: Verifier[T] - plugin: Plugin[T] - - -class IndexPluginSpec(NamedTuple): - plugin_type: PluginType - verifier: Callable[[DatasetState], bool] - plugin: Callable[[DatasetState, DataIndex], DataIndex] - - -class PostSortPluginSpec[T: (DatasetState, Lightcone)](NamedTuple): - plugin_type: PluginType - verifier: Callable[[T], bool] - plugin: Callable[[Table, IndexArray], dict[str, np.ndarray]] - - -class PartitionPluginSpec(NamedTuple): - plugin_type: PluginType - verifier: Callable[[OpenCosmoHeader, h5py.Group, h5py.Group], bool] - plugin: Callable[ - [MPI.Comm, h5py.Group, h5py.Group, Optional[Tree], Optional[int]], - Optional[TreePartition], - ] - - -class Plugins(TypedDict): - dataset_open: list[PluginSpec] - dataset_instantiate: list[PluginSpec] - lightcone_load: list[PluginSpec] - lightcone_instantiate: list[PluginSpec] - index_update: list[IndexPluginSpec] - partition: list[PartitionPluginSpec] - - -KNOWN_PLUGINS: Plugins = defaultdict(list) # type: ignore - - -def register_plugin( - spec: PluginSpec | IndexPluginSpec | PostSortPluginSpec | PartitionPluginSpec, -) -> None: - KNOWN_PLUGINS[str(spec.plugin_type)].append(spec) # type: ignore - - -def apply_plugins[T](plugin_type: PluginType, target: T, **kwargs: Any) -> T: - """Apply all registered plugins of the given type to target. - - kwargs are forwarded to both the verifier and plugin, used to pass - open_kwargs through to DatasetOpen plugins. - """ - plugins_to_apply = KNOWN_PLUGINS[str(plugin_type)] # type: ignore - return reduce( - lambda t, spec: _apply_single(spec, t, **kwargs), plugins_to_apply, target - ) - - -def apply_index_plugins(state: DatasetState, index: DataIndex) -> DataIndex: - """Apply all registered IndexUpdate plugins to index. - - Each plugin may expand or otherwise modify the position index returned by a - filter or take operation. Plugins run in registration order, each seeing the - output of the previous one. - """ - plugins_to_apply: list[IndexPluginSpec] = KNOWN_PLUGINS[str(PluginType.IndexUpdate)] # type: ignore - return reduce( - lambda idx, spec: _apply_single_index(spec, state, idx), - plugins_to_apply, - index, - ) - - -def apply_post_sort_plugins[T: (DatasetState, Lightcone)]( - state: T, - data: Table, - index: DataIndex, -) -> T: - plugins_to_apply: list[PostSortPluginSpec] = KNOWN_PLUGINS[str(PluginType.PostSort)] # type: ignore - index_arr = into_array(index) - - return reduce( - lambda data_, spec: _apply_single_post_sort(spec, state, data_, index_arr), - plugins_to_apply, - data, - ) - - -def apply_partition_plugins( - comm: MPI.Comm, - header: OpenCosmoHeader, - index_group: h5py.Group, - data_group: h5py.Group, - tree: Optional[Tree] = None, - min_level: Optional[int] = None, -) -> Optional[TreePartition]: - partition_plugins = KNOWN_PLUGINS[str(PluginType.Partition)] # type: ignore - if len(partition_plugins) == 0: - return None - if len(partition_plugins) > 1: - raise ValueError("Only one partition plugin is allowed at a time") - plugin_spec = partition_plugins[0] - if not plugin_spec.verifier(header, index_group, data_group): - return None - - return plugin_spec.plugin(comm, index_group, data_group, tree, min_level) - - -def _apply_single_post_sort[T: (DatasetState, Lightcone)]( - spec: PostSortPluginSpec, state: T, data: Table, index: IndexArray -): - if spec.verifier(state): - return spec.plugin(data, index) - return data - - -def _apply_single[T](spec: PluginSpec[T], target: T, **kwargs: Any) -> T: - if spec.verifier(target, **kwargs): - return spec.plugin(target, **kwargs) - return target - - -def _apply_single_index( - spec: IndexPluginSpec, state: DatasetState, index: DataIndex -) -> DataIndex: - if spec.verifier(state): - return spec.plugin(state, index) - return index From 6bfc9e58807f0765dd352041499d49dc56c6313b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 29 Apr 2026 16:57:43 -0500 Subject: [PATCH 078/139] Move lightcone radec generation into plugin --- .../collection/lightcone/coordinates.py | 29 ------------------- .../collection/lightcone/lightcone.py | 10 ++----- .../opencosmo/collection/lightcone/plugins.py | 26 ++++++++++++++++- 3 files changed, 27 insertions(+), 38 deletions(-) delete mode 100644 python/opencosmo/collection/lightcone/coordinates.py diff --git a/python/opencosmo/collection/lightcone/coordinates.py b/python/opencosmo/collection/lightcone/coordinates.py deleted file mode 100644 index 47e21b67..00000000 --- a/python/opencosmo/collection/lightcone/coordinates.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING - -import astropy.units as u -import numpy as np - -if TYPE_CHECKING: - import opencosmo as oc - - -def make_radec_columns(dataset: oc.Lightcone): - if "ra" in dataset.columns and "dec" in dataset.columns: - return dataset - elif "theta" in dataset.columns and "phi" in dataset.columns: - return dataset.evaluate( - radec_from_thetaphi, vectorize=True, insert=True, format="numpy" - ) - else: - warnings.warn( - "Could not find coordinates in this catalog. Spatial queries will not be available" - ) - - -def radec_from_thetaphi(theta, phi): - theta_deg = theta * 180 / np.pi - phi_deg = phi * 180 / np.pi - return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index f4364a1f..db5bc37e 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -21,7 +21,6 @@ import opencosmo as oc from opencosmo.collection.lightcone import io as lcio from opencosmo.collection.lightcone import utils as lcutils -from opencosmo.collection.lightcone.coordinates import make_radec_columns from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn from opencosmo.dataset.evaluate import build_evaluated_column @@ -365,11 +364,7 @@ def open(cls, targets: list[FileTarget], **kwargs): raise ValueError() result = cls(output) - result = fold( - HookPoint.LightconeOpen, LightconeOpenCtx(result, kwargs) - ).lightcone - - return make_radec_columns(result) + return fold(HookPoint.LightconeOpen, LightconeOpenCtx(result, kwargs)).lightcone @classmethod def from_datasets( @@ -379,10 +374,9 @@ def from_datasets( **open_kwargs, ): result = cls(datasets, z_range) - result = fold( + return fold( HookPoint.LightconeOpen, LightconeOpenCtx(result, open_kwargs) ).lightcone - return make_radec_columns(result) def with_redshift_range(self, z_low: float, z_high: float): """ diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 38178ab4..0cb6156d 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -1,8 +1,12 @@ from __future__ import annotations import dataclasses +import warnings from typing import TYPE_CHECKING +import astropy.units as u +import numpy as np + import opencosmo as oc from opencosmo.plugins.contexts import HookPoint from opencosmo.plugins.hook import hook @@ -33,4 +37,24 @@ def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: ) -plugins = None +@hook(HookPoint.LightconeOpen) +def _make_radec_columns(ctx: LightconeOpenCtx): + lightcone: Lightcone = ctx.lightcone + if "ra" in lightcone.columns and "dec" in lightcone.columns: + pass + elif "theta" in lightcone.columns and "phi" in lightcone.columns: + lightcone = lightcone.evaluate( + radec_from_thetaphi, vectorize=True, insert=True, format="numpy" + ) + else: + warnings.warn( + "Could not find coordinates in this catalog. Spatial queries will not be available" + ) + + return dataclasses.replace(ctx, lightcone=lightcone) + + +def radec_from_thetaphi(theta, phi): + theta_deg = theta * 180 / np.pi + phi_deg = phi * 180 / np.pi + return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} From 28276962d3ccb93d2f58da5e0c059f801de1cdbb Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 30 Apr 2026 08:40:26 -0500 Subject: [PATCH 079/139] Fix keep_top_host logic to only fire on single datasets --- python/opencosmo/dataset/build.py | 1 + python/opencosmo/dataset/state.py | 14 ++++++++++++-- python/opencosmo/dtypes/diffsky.py | 28 +++++++++++++--------------- python/opencosmo/io/iopen.py | 1 + 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/python/opencosmo/dataset/build.py b/python/opencosmo/dataset/build.py index cfb8cd74..64473fbc 100644 --- a/python/opencosmo/dataset/build.py +++ b/python/opencosmo/dataset/build.py @@ -57,6 +57,7 @@ def build_dataset_from_data( header, header.file.unit_convention, region, + {}, data_descriptions, ) return Dataset(header, new_state, tree=tree) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index d964b004..050e4316 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional, Sequence from weakref import finalize import astropy.units as u @@ -82,6 +82,7 @@ def __init__( header: OpenCosmoHeader, columns: dict[str, UUID], region: Region, + open_kwargs: dict[str, Any], sort_by: Optional[tuple[str, bool]], ): self.__producers: dict[UUID, ConstructedColumn] = { @@ -95,6 +96,7 @@ def __init__( self.__region = region self.__sort_by = sort_by self.__cache.register_column_group(id(self), self.__columns) + self.__kwargs = open_kwargs finalize(self, deregister_state, id(self), self.__cache) def __rebuild(self, **updates): @@ -107,6 +109,7 @@ def __rebuild(self, **updates): "columns": self.__columns, "region": self.__region, "sort_by": self.__sort_by, + "open_kwargs": self.__kwargs, } | updates return DatasetState(**new) @@ -119,9 +122,9 @@ def from_target( target: DatasetTarget, unit_convention: UnitConvention, region: Region, + open_kwargs: dict[str, Any], index: Optional[DataIndex] = None, metadata_group: Optional[str] = None, - in_memory: bool = False, ): data_group = target["dataset_group"] if "load" in data_group.keys(): @@ -154,6 +157,7 @@ def from_target( target["header"], columns, region, + open_kwargs, None, ) @@ -165,6 +169,7 @@ def in_memory( header: OpenCosmoHeader, unit_convention: UnitConvention, region: Region, + open_kwargs: dict[str, Any], descriptions: Optional[dict[str, str]] = None, index: Optional[DataIndex] = None, ): @@ -200,6 +205,7 @@ def in_memory( header, columns, region, + open_kwargs, None, ) @@ -222,6 +228,10 @@ def descriptions(self): if name in self.columns } + @property + def kwargs(self): + return self.__kwargs + @property def raw_index(self): if (si := self.get_sorted_index()) is not None: diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index fc9b9eb3..17ef116b 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -26,7 +26,6 @@ DatasetOpenCtx, IndexUpdateCtx, LightconeInstantiateCtx, - LightconeOpenCtx, PartitionCtx, PostSortCtx, ) @@ -158,20 +157,17 @@ def _offset_top_host_idx(top_host_idx, offset): return dataclasses.replace(ctx, lightcone=output) # type: ignore[arg-type] +# Registers an IndexUpdate hook dynamically so that keep_top_host_idx only +# activates when the user explicitly requests it via open(..., keep_top_host=True). @hook( - HookPoint.LightconeOpen, - when=lambda ctx: ctx.open_kwargs.get("keep_top_host", False) - and _is_synthetic_galaxies_with_top_host_idx(ctx.lightcone), + HookPoint.IndexUpdate, + when=lambda ctx: ( + "top_host_idx" in ctx.state.columns + and ctx.state.kwargs.get("keep_top_host", False) + ), ) -def _register_keep_top_host_idx(ctx: LightconeOpenCtx) -> LightconeOpenCtx: - # Registers an IndexUpdate hook dynamically so that keep_top_host_idx only - # activates when the user explicitly requests it via open(..., keep_top_host=True). - # TODO: store intent in DatasetState to avoid mutating global hook registry. - @hook(HookPoint.IndexUpdate, when=lambda ctx: "top_host_idx" in ctx.state.columns) - def _keep(ctx: IndexUpdateCtx) -> IndexUpdateCtx: - return dataclasses.replace(ctx, index=keep_top_host_idx(ctx.state, ctx.index)) - - return ctx +def _keep(ctx: IndexUpdateCtx) -> IndexUpdateCtx: + return dataclasses.replace(ctx, index=keep_top_host_idx(ctx.state, ctx.index)) @hook( @@ -187,8 +183,10 @@ def _remap_top_host_idx_after_sort(ctx: PostSortCtx) -> PostSortCtx: @hook( HookPoint.Partition, - when=lambda ctx: ctx.header.file.data_type == "synthetic_galaxies" - and "top_host_idx" in ctx.data_group.keys(), + when=lambda ctx: ( + ctx.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in ctx.data_group.keys() + ), ) def _partition_by_top_host_groups(ctx: PartitionCtx) -> Optional[TreePartition]: top_host_idx = ctx.data_group["top_host_idx"][:] diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 1d7a1eda..e1d11ddd 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -537,6 +537,7 @@ def open_single_dataset( target, UnitConvention.COMOVING, sim_region, + open_kwargs, index, metadata_group, ) From ae6eed4b2325d82c6a69e69d8135c2ae7da06769 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 30 Apr 2026 13:48:55 -0500 Subject: [PATCH 080/139] Move verification code out of the dataset --- python/opencosmo/dataset/dataset.py | 30 ++++++++++----- python/opencosmo/dataset/state.py | 59 ++++++++--------------------- 2 files changed, 36 insertions(+), 53 deletions(-) diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index ee490ca7..d770b0e3 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -20,7 +20,7 @@ from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.index import empty, into_array, mask, project +from opencosmo.index import empty, into_array, mask, project, single_chunk from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -516,7 +516,7 @@ def filter(self, *masks: ColumnMask) -> Dataset: for m in masks: bool_mask &= m.apply(data) - new_state = self.__state.with_mask(bool_mask) + new_state = self.__state.take_rows(np.where(bool_mask)[0]) return Dataset(self.__header, new_state, self.__tree) def rows( @@ -728,14 +728,17 @@ def take( or if 'at' is invalid. """ + if at == "start": + return self.take_range(0, n) + elif at == "end": + return self.take_range(len(self) - n, len(self)) + elif at != "random": + raise ValueError(f"Unknown take type {at}") - new_state = self.__state.take(n, at) - - return Dataset( - self.__header, - new_state, - self.__tree, - ) + row_indices = np.random.choice(len(self), n, replace=False) + row_indices.sort() + new_state = self.__state.take_rows(row_indices) + return Dataset(self.__header, new_state, self.__tree) def take_range(self, start: int, end: int) -> Dataset: """ @@ -761,8 +764,15 @@ def take_range(self, start: int, end: int) -> Dataset: or if end is greater than start. """ + if start < 0 or end < 0: + raise ValueError("start and end must be positive.") + if end < start: + raise ValueError("end must be greater than start.") + if end > len(self): + raise ValueError("end must be less than the length of the dataset.") - new_state = self.__state.take_range(start, end) + take_index = single_chunk(start, end - start) + new_state = self.__state.take_rows(take_index) return Dataset( self.__header, diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 050e4316..fb80b3d6 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from functools import reduce from typing import TYPE_CHECKING, Any, Optional, Sequence from weakref import finalize @@ -15,7 +16,7 @@ from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema from opencosmo.handler.empty import EmptyHandler from opencosmo.handler.hdf5 import Hdf5Handler -from opencosmo.index.build import single_chunk +from opencosmo.index import single_chunk from opencosmo.index.mask import into_array from opencosmo.index.unary import get_range from opencosmo.plugins.contexts import ( @@ -36,7 +37,6 @@ from astropy import table from astropy.cosmology import Cosmology - from numpy.typing import NDArray from opencosmo.column.column import ConstructedColumn from opencosmo.handler.protocols import DataCache, DataHandler @@ -67,6 +67,19 @@ def sort_data( return fold(HookPoint.PostSort, PostSortCtx(state, data, np.argsort(order))).data +@dataclass(frozen=True) +class _DatasetState: + column_producers: Sequence[ConstructedColumn] + raw_data_handler: DataHandler + cache: DataCache + unit_handler: UnitHandler + header: OpenCosmoHeader + columns: dict[str, UUID] + region: Region + open_kwargs: dict[str, Any] + sort_by: Optional[tuple[str, bool]] + + class DatasetState: """ Holds mutable state required by the dataset. Cleans up the dataset to mostly focus @@ -325,7 +338,7 @@ def rows(self, metadata_columns: list = [], unit_kwargs: dict = {}): try: for start, end in chunk_ranges: - chunk = self.take_range(start, end) + chunk = self.take_rows(single_chunk(start, end - start)) data = chunk.get_data( metadata_columns=metadata_columns, unit_kwargs=unit_kwargs ) @@ -356,18 +369,6 @@ def get_metadata(self, columns=[]): metadata = {name: values[sorted_index] for name, values in metadata.items()} return metadata - def with_mask(self, mask: NDArray[np.bool_]): - return self.with_index(np.where(mask)[0]) - - def with_index(self, index: DataIndex): - index = fold(HookPoint.IndexUpdate, IndexUpdateCtx(self, index)).index - new_raw_handler = self.__raw_data_handler.take(index) - new_cache = self.__cache.take(index) - return self.__rebuild( - cache=new_cache, - raw_data_handler=new_raw_handler, - ) - def make_schema(self, name: Optional[str] = None): producers = list(self.__producers.values()) columns = set(self.__columns.keys()) @@ -470,34 +471,6 @@ def get_sorted_index(self): return sorted - def take(self, n: int, at: str): - """ - Take rows from the dataset. - """ - - if at == "start": - return self.take_range(0, n) - elif at == "end": - return self.take_range(len(self) - n, len(self)) - elif at == "random": - row_indices = np.random.choice(len(self), n, replace=False) - row_indices.sort() - return self.take_rows(row_indices) - - def take_range(self, start: int, end: int): - """ - Take a range of rows form the dataset. - """ - if start < 0 or end < 0: - raise ValueError("start and end must be positive.") - if end < start: - raise ValueError("end must be greater than start.") - if end > len(self): - raise ValueError("end must be less than the length of the dataset.") - - take_index = single_chunk(start, end - start) - return self.take_rows(take_index) - def take_rows(self, rows: DataIndex): if len(self) == 0: return self From 0c88fb5f5a871e50abde8870623378fb15258c11 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 1 May 2026 09:44:41 -0500 Subject: [PATCH 081/139] Break DatasetState apart for future changes --- python/opencosmo/dataset/build.py | 4 +- python/opencosmo/dataset/dataset.py | 70 +-- python/opencosmo/dataset/state.py | 798 +++++++++++++-------------- python/opencosmo/dtypes/diffsky.py | 3 +- python/opencosmo/io/iopen.py | 2 +- python/opencosmo/units/converters.py | 5 +- test/test_dataset.py | 18 +- uv.lock | 4 +- 8 files changed, 438 insertions(+), 466 deletions(-) diff --git a/python/opencosmo/dataset/build.py b/python/opencosmo/dataset/build.py index 64473fbc..edcd1d62 100644 --- a/python/opencosmo/dataset/build.py +++ b/python/opencosmo/dataset/build.py @@ -7,8 +7,8 @@ import h5py import numpy as np +import opencosmo.dataset.state as state from opencosmo.dataset import Dataset -from opencosmo.dataset.state import DatasetState from opencosmo.spatial.healpix import HealPixIndex from opencosmo.spatial.tree import Tree from opencosmo.spatial.utils import combine_upwards @@ -51,7 +51,7 @@ def build_dataset_from_data( metadata_group = {} data_descriptions = descriptions.get("data", {}) - new_state = DatasetState.in_memory( + new_state = state.state_in_memory( data_group, metadata_group, header, diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index d770b0e3..4e52dc49 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -17,10 +17,11 @@ from astropy.table import QTable # type: ignore from deprecated.sphinx import deprecated +import opencosmo.dataset.state as st from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.index import empty, into_array, mask, project, single_chunk +from opencosmo.index import empty, get_range, into_array, mask, project, single_chunk from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -83,10 +84,10 @@ def __enter__(self): return self def __exit__(self, *exc_details): - return self.__state.__exit__(*exc_details) + return st.exit_state(self.__state, *exc_details) def close(self): - return self.__state.__exit__() + return st.exit_state(self.__state) @property def header(self) -> OpenCosmoHeader: @@ -236,7 +237,7 @@ def get_metadata(self, columns: str | list[str] = []): if isinstance(columns, str): columns = [columns] - return self.__state.get_metadata(columns) + return st.get_metadata(self.__state, columns) def get_data( self, format="astropy", unpack=True, metadata_columns=[], **kwargs @@ -286,8 +287,8 @@ def get_data( else: unit_kwargs = {} - data = self.__state.get_data( - unit_kwargs=unit_kwargs, metadata_columns=metadata_columns + data = st.get_data( + self.__state, unit_kwargs=unit_kwargs, metadata_columns=metadata_columns ) # dict if unpack: data = { @@ -336,7 +337,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): columns = check.find_coordinates_3d(self, self.dtype) check_region = region.into_base_convention( - self.__state.unit_handler, + self.__state.unit_handler, # type: ignore[arg-type] columns, self.__state.convention, { @@ -349,7 +350,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): check_region = region if not self.__state.region.intersects(check_region): - new_state = self.__state.take_rows(empty()) + new_state = st.take_rows(self.__state, empty()) return Dataset(self.__header, new_state, self.__tree) if not self.__state.region.contains(check_region): @@ -365,7 +366,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): contained_index = project(self.__state.raw_index, contained_index) intersects_index = project(self.__state.raw_index, intersects_index) - check_state = self.__state.take_rows(intersects_index) + check_state = st.take_rows(self.__state, intersects_index) check_dataset = Dataset( self.__header, check_state, @@ -386,7 +387,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): [into_array(contained_index), into_array(new_intersects_index)] ) - new_state = self.__state.take_rows(new_index).with_region(check_region) + new_state = st.with_region(st.take_rows(self.__state, new_index), check_region) return Dataset(self.__header, new_state, self.__tree) @@ -516,7 +517,7 @@ def filter(self, *masks: ColumnMask) -> Dataset: for m in masks: bool_mask &= m.apply(data) - new_state = self.__state.take_rows(np.where(bool_mask)[0]) + new_state = st.take_rows(self.__state, np.where(bool_mask)[0]) return Dataset(self.__header, new_state, self.__tree) def rows( @@ -547,7 +548,7 @@ def rows( else: unit_kwargs = {} - for row in self.__state.rows(metadata_columns, unit_kwargs): + for row in st.iter_rows(self.__state, metadata_columns, unit_kwargs): output_data = row if not isinstance(output_data, dict): output_data = {self.columns[0]: row} @@ -608,10 +609,10 @@ def select( new_state = self.__state if derived_columns: - new_state = new_state.with_new_columns({}, False, **derived_columns) + new_state = st.with_new_columns(new_state, {}, False, **derived_columns) all_columns.update(derived_columns.keys()) - new_state = new_state.select(all_columns) + new_state = st.select(new_state, all_columns) return Dataset( self.__header, new_state, @@ -648,7 +649,7 @@ def drop(self, *columns: str | Iterable[str]) -> Dataset: col_group = {col_group} all_columns.update(col_group) - new_state = self.__state.select(all_columns, drop=True) + new_state = st.select(self.__state, all_columns, drop=True) return Dataset( self.__header, new_state, @@ -688,7 +689,7 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: """ - new_state = self.__state.sort_by(column, invert) + new_state = st.sort_by(self.__state, column, invert) return Dataset( self.__header, new_state, @@ -735,10 +736,12 @@ def take( elif at != "random": raise ValueError(f"Unknown take type {at}") + if n > len(self): + raise ValueError("Cannot take more rows than are in this dataset!") + row_indices = np.random.choice(len(self), n, replace=False) row_indices.sort() - new_state = self.__state.take_rows(row_indices) - return Dataset(self.__header, new_state, self.__tree) + return self.take_rows(row_indices) def take_range(self, start: int, end: int) -> Dataset: """ @@ -772,13 +775,7 @@ def take_range(self, start: int, end: int) -> Dataset: raise ValueError("end must be less than the length of the dataset.") take_index = single_chunk(start, end - start) - new_state = self.__state.take_rows(take_index) - - return Dataset( - self.__header, - new_state, - self.__tree, - ) + return self.take_rows(take_index) def take_rows(self, rows: np.ndarray | DataIndex): """ @@ -800,7 +797,13 @@ def take_rows(self, rows: np.ndarray | DataIndex): dataset. """ - new_state = self.__state.take_rows(rows) + row_range = get_range(rows) + if row_range[0] < 0 or row_range[1] > len(self): + raise ValueError( + "Row indices must be between 0 and the length of this dataset - 1!" + ) + + new_state = st.take_rows(self.__state, rows) return Dataset(self.__header, new_state, self.__tree) def with_new_columns( @@ -858,8 +861,8 @@ def with_new_columns( """ if isinstance(descriptions, str): descriptions = {key: descriptions for key in new_columns.keys()} - new_state = self.__state.with_new_columns( - descriptions, allow_overwrite, **new_columns + new_state = st.with_new_columns( + self.__state, descriptions, allow_overwrite, **new_columns ) return Dataset(self.__header, new_state, self.__tree) @@ -878,7 +881,7 @@ def make_schema( The name of the dataset in the file. The default is "data". """ - schema = self.__state.make_schema(name) + schema = st.make_schema(self.__state, name) if self.__tree is not None: tree = self.__tree.apply_index(self.__state.raw_index) @@ -954,8 +957,13 @@ def with_units( """ - new_state = self.__state.with_units( - convention, conversions, columns, self.cosmology, self.redshift + new_state = st.with_units( + self.__state, + convention, + conversions, + columns, + self.cosmology, + self.redshift, ) if convention is not None: new_header = self.__header.with_units(convention) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index fb80b3d6..1748e0d5 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -1,8 +1,9 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass from functools import reduce -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Generator, Optional from weakref import finalize import astropy.units as u @@ -18,7 +19,6 @@ from opencosmo.handler.hdf5 import Hdf5Handler from opencosmo.index import single_chunk from opencosmo.index.mask import into_array -from opencosmo.index.unary import get_range from opencosmo.plugins.contexts import ( DatasetInstantiateCtx, HookPoint, @@ -43,6 +43,7 @@ from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.iopen import DatasetTarget + from opencosmo.io.schema import Schema from opencosmo.spatial.protocols import Region from opencosmo.units.handler import UnitHandler @@ -68,172 +69,44 @@ def sort_data( @dataclass(frozen=True) -class _DatasetState: - column_producers: Sequence[ConstructedColumn] +class DatasetState: + """ + Main state container for the Dataset. Functions for manipulating it can be found below. The dataclass + itself only exposes basic lookup operations. + """ + + producers: dict[UUID, ConstructedColumn] raw_data_handler: DataHandler cache: DataCache unit_handler: UnitHandler header: OpenCosmoHeader - columns: dict[str, UUID] + column_map: dict[str, UUID] region: Region open_kwargs: dict[str, Any] - sort_by: Optional[tuple[str, bool]] - - -class DatasetState: - """ - Holds mutable state required by the dataset. Cleans up the dataset to mostly focus - on very high-level operations. Not a user facing class. - """ + sort_key: Optional[tuple[str, bool]] - def __init__( - self, - column_producers: Sequence[ConstructedColumn], - raw_data_handler: DataHandler, - cache: DataCache, - unit_handler: UnitHandler, - header: OpenCosmoHeader, - columns: dict[str, UUID], - region: Region, - open_kwargs: dict[str, Any], - sort_by: Optional[tuple[str, bool]], - ): - self.__producers: dict[UUID, ConstructedColumn] = { - p.uuid: p for p in column_producers - } - self.__raw_data_handler = raw_data_handler - self.__cache = cache - self.__unit_handler = unit_handler - self.__header = header - self.__columns: dict[str, UUID] = columns - self.__region = region - self.__sort_by = sort_by - self.__cache.register_column_group(id(self), self.__columns) - self.__kwargs = open_kwargs - finalize(self, deregister_state, id(self), self.__cache) - - def __rebuild(self, **updates): - new = { - "raw_data_handler": self.__raw_data_handler, - "cache": self.__cache, - "column_producers": list(self.__producers.values()), - "unit_handler": self.__unit_handler, - "header": self.__header, - "columns": self.__columns, - "region": self.__region, - "sort_by": self.__sort_by, - "open_kwargs": self.__kwargs, - } | updates - return DatasetState(**new) - - def __exit__(self, *exec_details): - return None - - @classmethod - def from_target( - cls, - target: DatasetTarget, - unit_convention: UnitConvention, - region: Region, - open_kwargs: dict[str, Any], - index: Optional[DataIndex] = None, - metadata_group: Optional[str] = None, - ): - data_group = target["dataset_group"] - if "load" in data_group.keys(): - load_conditions = dict(data_group["load/if"].attrs) - else: - load_conditions = None - - handler = Hdf5Handler.from_columns( - target["columns"], - index, - metadata_group, - load_conditions, - ) - unit_handler = make_unit_handler_from_hdf5( - target["columns"], target["header"], unit_convention - ) - descriptions = handler.descriptions + def __post_init__(self): + self.cache.register_column_group(id(self), self.column_map) + finalize(self, deregister_state, id(self), self.cache) - producers = [ - RawColumn(cname, descriptions.get(cname, "None")) - for cname in handler.columns - ] - columns = {p.name: p.uuid for p in producers} - cache = ColumnCache.empty() - return DatasetState( - producers, - handler, - cache, - unit_handler, - target["header"], - columns, - region, - open_kwargs, - None, - ) - - @classmethod - def in_memory( - cls, - data_columns: dict, - metadata_columns: dict, - header: OpenCosmoHeader, - unit_convention: UnitConvention, - region: Region, - open_kwargs: dict[str, Any], - descriptions: Optional[dict[str, str]] = None, - index: Optional[DataIndex] = None, - ): - descriptions = descriptions or {} + @property + def columns(self) -> list[str]: + return list(self.column_map.keys()) - # Producers must be created first so their UUIDs are available for the cache. - producers = [ - RawColumn(cname, descriptions.get(cname, "None")) - for cname in data_columns.keys() - ] - columns = {p.name: p.uuid for p in producers} - - cache = ColumnCache.empty() - if data_columns: - uuid_data = {p.uuid: {p.name: data_columns[p.name]} for p in producers} - cache.add_data(uuid_data, descriptions) - if metadata_columns: - cache.add_metadata(dict(metadata_columns), {}) - - units: dict[str, u.Unit] = {} - for name, column in data_columns.items(): - units[name] = None - if isinstance(column, u.Quantity): - units[name] = column.unit - - unit_handler = make_unit_handler_from_units(units, header, unit_convention) - - return DatasetState( - producers, - EmptyHandler(), - cache, - unit_handler, - header, - columns, - region, - open_kwargs, - None, + @property + def meta_columns(self) -> list[str]: + columns = set(self.cache.metadata_columns).union( + self.raw_data_handler.metadata_columns ) - - def __len__(self): - if isinstance(self.__raw_data_handler, EmptyHandler): - return len(self.__cache) - return len(self.__raw_data_handler) + return list(columns) @property def descriptions(self): all_descriptions = {} - for producer in self.__producers.values(): + for producer in self.producers.values(): update = {name: producer.description for name in producer.produces} all_descriptions |= update - all_descriptions |= self.__cache.descriptions + all_descriptions |= self.cache.descriptions return { name: description @@ -243,301 +116,390 @@ def descriptions(self): @property def kwargs(self): - return self.__kwargs + return self.open_kwargs @property def raw_index(self): - if (si := self.get_sorted_index()) is not None: - ni = into_array(self.__raw_data_handler.index) + if (si := get_sorted_index(self)) is not None: + ni = into_array(self.raw_data_handler.index) return ni[si] - - return self.__raw_data_handler.index - - @property - def unit_handler(self): - return self.__unit_handler + return self.raw_data_handler.index @property def units(self): - units = self.__unit_handler.current_units + units = self.unit_handler.current_units return {name: units[name] for name in self.columns} @property def convention(self): - return self.__unit_handler.current_convention - - @property - def region(self): - return self.__region - - @property - def header(self): - return self.__header - - @property - def columns(self) -> list[str]: - return list(self.__columns.keys()) - - @property - def meta_columns(self) -> list[str]: - columns = set(self.__cache.metadata_columns).union( - self.__raw_data_handler.metadata_columns - ) - return list(columns) - - def get_data( - self, - ignore_sort: bool = False, - metadata_columns: list = [], - unit_kwargs: dict = {}, - ) -> table.QTable: - """ - Get the data for a given handler. - """ - state = fold(HookPoint.DatasetInstantiate, DatasetInstantiateCtx(self)).state - data = instantiate_dataset( - list(state.__producers.values()), - state.__columns, - state.__raw_data_handler, - state.__cache, - state.__unit_handler, - unit_kwargs, - metadata_columns, - None if (ignore_sort or state.__sort_by is None) else state.__sort_by[0], + return self.unit_handler.current_convention + + def __len__(state: DatasetState) -> int: + if isinstance(state.raw_data_handler, EmptyHandler): + return len(state.cache) + return len(state.raw_data_handler) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Factory functions (replace classmethods) +# --------------------------------------------------------------------------- + + +def state_from_target( + target: DatasetTarget, + unit_convention: UnitConvention, + region: Region, + open_kwargs: dict[str, Any], + index: Optional[DataIndex] = None, + metadata_group: Optional[str] = None, +) -> DatasetState: + data_group = target["dataset_group"] + if "load" in data_group.keys(): + load_conditions = dict(data_group["load/if"].attrs) + else: + load_conditions = None + + handler = Hdf5Handler.from_columns( + target["columns"], + index, + metadata_group, + load_conditions, + ) + unit_handler = make_unit_handler_from_hdf5( + target["columns"], target["header"], unit_convention + ) + descriptions = handler.descriptions + + raw_producers = [ + RawColumn(cname, descriptions.get(cname, "None")) for cname in handler.columns + ] + column_map = {p.name: p.uuid for p in raw_producers} + producers: dict[UUID, ConstructedColumn] = {p.uuid: p for p in raw_producers} + cache = ColumnCache.empty() + return DatasetState( + producers=producers, + raw_data_handler=handler, + cache=cache, + unit_handler=unit_handler, + header=target["header"], + column_map=column_map, + region=region, + open_kwargs=open_kwargs, + sort_key=None, + ) + + +def state_in_memory( + data_columns: dict, + metadata_columns: dict, + header: OpenCosmoHeader, + unit_convention: UnitConvention, + region: Region, + open_kwargs: dict[str, Any], + descriptions: Optional[dict[str, str]] = None, + index: Optional[DataIndex] = None, +) -> DatasetState: + descriptions = descriptions or {} + + raw_producers = [ + RawColumn(cname, descriptions.get(cname, "None")) + for cname in data_columns.keys() + ] + column_map = {p.name: p.uuid for p in raw_producers} + producers: dict[UUID, ConstructedColumn] = {p.uuid: p for p in raw_producers} + + cache = ColumnCache.empty() + if data_columns: + uuid_data = {p.uuid: {p.name: data_columns[p.name]} for p in raw_producers} + cache.add_data(uuid_data, descriptions) + if metadata_columns: + cache.add_metadata(dict(metadata_columns), {}) + + units: dict[str, u.Unit] = {} + for name, column in data_columns.items(): + units[name] = None + if isinstance(column, u.Quantity): + units[name] = column.unit + + unit_handler = make_unit_handler_from_units(units, header, unit_convention) + + return DatasetState( + producers=producers, + raw_data_handler=EmptyHandler(), + cache=cache, + unit_handler=unit_handler, + header=header, + column_map=column_map, + region=region, + open_kwargs=open_kwargs, + sort_key=None, + ) + + +# --------------------------------------------------------------------------- +# Standalone functions (replace methods) +# --------------------------------------------------------------------------- + + +def exit_state(state: DatasetState, *exec_details): + return None + + +def get_data( + state: DatasetState, + ignore_sort: bool = False, + metadata_columns: list = [], + unit_kwargs: dict = {}, +) -> dict: + """ + Use a State to get the associated data. Most of the logic can be found in the + instantiate_dataset method. + """ + state = fold(HookPoint.DatasetInstantiate, DatasetInstantiateCtx(state)).state + data = instantiate_dataset( + list(state.producers.values()), + state.column_map, + state.raw_data_handler, + state.cache, + state.unit_handler, + unit_kwargs, + metadata_columns, + None if (ignore_sort or state.sort_key is None) else state.sort_key[0], + ) + + if missing := set(state.columns).difference(data.keys()): + raise RuntimeError( + f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" ) - if missing := set(self.columns).difference(data.keys()): - raise RuntimeError( - f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" - ) + if not ignore_sort: + data = sort_data(data, state.sort_key, state) - if not ignore_sort: - data = sort_data(data, self.__sort_by, self) + new_order = [c for c in state.columns] + if metadata_columns: + new_order.extend(metadata_columns) - new_order = [c for c in self.columns] - if metadata_columns: - new_order.extend(metadata_columns) + return {name: data[name] for name in new_order} - return {name: data[name] for name in new_order} - def rows(self, metadata_columns: list = [], unit_kwargs: dict = {}): - derived_to_collect = ( - set(self.columns) - .difference(self.__cache.columns) - .difference(self.__raw_data_handler.columns) - ) - derived_storage: dict[str, list[np.ndarray]] = { - name: [] for name in derived_to_collect - } - total_length = len(self) - chunk_ranges = [ - (i, min(i + 1000, total_length)) for i in range(0, total_length, 1000) - ] - if not chunk_ranges: - raise StopIteration - - try: - for start, end in chunk_ranges: - chunk = self.take_rows(single_chunk(start, end - start)) - data = chunk.get_data( - metadata_columns=metadata_columns, unit_kwargs=unit_kwargs - ) - for name in derived_to_collect: - derived_storage[name].append(data[name]) - - for i in range(len(chunk)): - yield {name: column[i] for name, column in data.items()} - all_derived = { - name: np.concatenate(arr) for name, arr in derived_storage.items() - } - derived_storage = resort(all_derived, self.get_sorted_index()) - if derived_storage: - uuid_keyed: dict = {} - for name, arr in derived_storage.items(): - uuid = self.__columns[name] - uuid_keyed.setdefault(uuid, {})[name] = arr - self.__cache.add_data(uuid_keyed, {}) - except GeneratorExit: - pass - except BaseException: - raise - - def get_metadata(self, columns=[]): - metadata = self.__raw_data_handler.get_metadata(columns) - sorted_index = self.get_sorted_index() - if sorted_index is not None: - metadata = {name: values[sorted_index] for name, values in metadata.items()} - return metadata - - def make_schema(self, name: Optional[str] = None): - producers = list(self.__producers.values()) - columns = set(self.__columns.keys()) - derived_names = get_derived_column_names(producers, columns) - if derived_names: - derived_data = ( - self.select(derived_names) - .with_units(self.__unit_handler.base_convention, {}, {}, None, None) - .get_data(ignore_sort=True) +def iter_rows( + state: DatasetState, + metadata_columns: list = [], + unit_kwargs: dict = {}, +) -> Generator: + """ + Iterate over the rows of a given DatasetState + """ + derived_to_collect = ( + set(state.columns) + .difference(state.cache.columns) + .difference(state.raw_data_handler.columns) + ) + derived_storage: dict[str, list[np.ndarray]] = { + name: [] for name in derived_to_collect + } + total_length = len(state) + chunk_ranges = [ + (i, min(i + 1000, total_length)) for i in range(0, total_length, 1000) + ] + if not chunk_ranges: + raise StopIteration + + try: + for start, end in chunk_ranges: + chunk = take_rows(state, single_chunk(start, end - start)) + data = get_data( + chunk, metadata_columns=metadata_columns, unit_kwargs=unit_kwargs ) - else: - derived_data = {} - return make_dataset_schema( - producers, - self.__raw_data_handler, - self.__cache, - self.__columns, - self.meta_columns, - self.__header, - self.__region, - derived_data, - name, - ) + for name in derived_to_collect: + derived_storage[name].append(data[name]) - def with_new_columns( - self, - descriptions: dict[str, str] = {}, - allow_overwrite: bool = False, - **new_columns: ConstructedColumn | np.ndarray | u.Quantity, - ): - """ - Add a set of derived columns to the dataset. A derived column is a column that - has been created based on the values in another column. - """ - new_producers, new_column_map, new_unit_handler = add_columns( - list(self.__producers.values()), - self.__unit_handler, - self.__cache, - self.__columns, - self.get_sorted_index(), - descriptions, - new_columns, - len(self), - allow_overwrite=allow_overwrite, + for i in range(len(chunk)): + yield {name: column[i] for name, column in data.items()} + all_derived = { + name: np.concatenate(arr) for name, arr in derived_storage.items() + } + derived_storage = resort(all_derived, get_sorted_index(state)) + if derived_storage: + uuid_keyed: dict = {} + for name, arr in derived_storage.items(): + uuid = state.column_map[name] + uuid_keyed.setdefault(uuid, {})[name] = arr + state.cache.add_data(uuid_keyed, {}) + except GeneratorExit: + pass + except BaseException: + raise + + +def get_metadata(state: DatasetState, columns: list = []) -> dict: + metadata = state.raw_data_handler.get_metadata(columns) + sorted_index = get_sorted_index(state) + if sorted_index is not None: + metadata = {name: values[sorted_index] for name, values in metadata.items()} + return metadata + + +def make_schema(state: DatasetState, name: Optional[str] = None) -> Schema: + """ + Get metadata columns. + """ + producers = list(state.producers.values()) + columns = set(state.column_map.keys()) + derived_names = get_derived_column_names(producers, columns) + if derived_names: + selected = select(state, derived_names) + converted = with_units( + selected, state.unit_handler.base_convention, {}, {}, None, None ) - return self.__rebuild( - cache=self.__cache, - column_producers=new_producers, - columns=new_column_map, - unit_handler=new_unit_handler, + derived_data = get_data(converted, ignore_sort=True) + else: + derived_data = {} + return make_dataset_schema( + producers, + state.raw_data_handler, + state.cache, + state.column_map, + state.meta_columns, + state.header, + state.region, + derived_data, + name, + ) + + +def with_new_columns( + state: DatasetState, + descriptions: dict[str, str] = {}, + allow_overwrite: bool = False, + **new_columns: ConstructedColumn | np.ndarray | u.Quantity, +) -> DatasetState: + """ + Add columns to a given state + """ + new_producers_list, new_column_map, new_unit_handler = add_columns( + list(state.producers.values()), + state.unit_handler, + state.cache, + state.column_map, + get_sorted_index(state), + descriptions, + new_columns, + len(state), + allow_overwrite=allow_overwrite, + ) + return dataclasses.replace( + state, + producers={p.uuid: p for p in new_producers_list}, + column_map=new_column_map, + unit_handler=new_unit_handler, + ) + + +def with_region(state: DatasetState, region: Region) -> DatasetState: + return dataclasses.replace(state, region=region) + + +def select(state: DatasetState, columns: set[str], drop: bool = False) -> DatasetState: + """ + Select a set of columns + """ + selections, missing = get_column_selection(state.columns, columns) + if missing: + raise ValueError( + f"Columns are included that are not in this dataset: {missing}" ) + elif not selections and columns: + raise ValueError("No columns matched the provided wildcards!") - def with_region(self, region: Region): - """ - Return the same dataset but with a different region - """ - return self.__rebuild(region=region) - - def select(self, columns: set[str], drop=False): - """ - Select a subset of columns from the dataset. It is possible for a user to select - a derived column in the dataset, but not the columns it is derived from. - This class tracks any columns which are required to materialize the dataset but - are not in the final selection in self.__hidden. When the dataset is - materialized, the columns in self.__hidden are removed before the data is - returned to the user. - - """ - - selections, missing = get_column_selection(self.columns, columns) - if missing: - raise ValueError( - f"Columns are included that are not in this dataset: {missing}" - ) - elif not selections and columns: - raise ValueError("No columns matched the provided wildcards!") + if drop: + selections = set(state.columns) - selections - if drop: - selections = set(self.columns) - selections + return dataclasses.replace( + state, column_map={n: state.column_map[n] for n in selections} + ) - return self.__rebuild(columns={n: self.__columns[n] for n in selections}) - def sort_by(self, column_name: str, invert: bool): - if column_name not in self.columns: - raise ValueError(f"This dataset has no column {column_name}") +def sort_by(state: DatasetState, column_name: str, invert: bool) -> DatasetState: + if column_name not in state.columns: + raise ValueError(f"This dataset has no column {column_name}") + return dataclasses.replace(state, sort_key=(column_name, invert)) - return self.__rebuild(sort_by=(column_name, invert)) - def get_sorted_index(self): - if self.__sort_by is not None: - column = self.select({self.__sort_by[0]}).get_data(ignore_sort=True)[ - self.__sort_by[0] - ] - sorted = np.argsort(column) - if self.__sort_by[1]: - sorted = sorted[::-1] - - else: - sorted = None +def get_sorted_index(state: DatasetState) -> np.ndarray | None: + if state.sort_key is not None: + column = get_data(select(state, {state.sort_key[0]}), ignore_sort=True)[ + state.sort_key[0] + ] + sorted_idx = np.argsort(column) + if state.sort_key[1]: + sorted_idx = sorted_idx[::-1] + else: + sorted_idx = None - return sorted + return sorted_idx - def take_rows(self, rows: DataIndex): - if len(self) == 0: - return self - rows = fold(HookPoint.IndexUpdate, IndexUpdateCtx(self, rows)).index - row_range = get_range(rows) - if row_range[1] > len(self) or row_range[0] < 0: - raise ValueError( - "Row indices must be between 0 and the length of this dataset!" - ) - sorted = self.get_sorted_index() - if sorted is not None: - rows = np.sort(sorted[into_array(rows)]) - new_handler = self.__raw_data_handler.take(rows) - new_cache = self.__cache.take(rows) - - return self.__rebuild( - raw_data_handler=new_handler, - cache=new_cache, - ) - - def with_units( - self, - convention: Optional[str], - conversions: dict[u.Unit, u.Unit], - columns: dict[str, u.Unit], - cosmology: Cosmology, - redshift: float | table.Column, +def take_rows(state: DatasetState, rows: DataIndex) -> DatasetState: + """ + Take a set of rows. The associated "take" functions in the + dataset all delegate to this function. + """ + if len(state) == 0: + return state + rows = fold(HookPoint.IndexUpdate, IndexUpdateCtx(state, rows)).index + sorted_idx = get_sorted_index(state) + if sorted_idx is not None: + rows = np.sort(sorted_idx[into_array(rows)]) + new_handler = state.raw_data_handler.take(rows) + new_cache = state.cache.take(rows) + return dataclasses.replace(state, raw_data_handler=new_handler, cache=new_cache) + + +def with_units( + state: DatasetState, + convention: Optional[str], + conversions: dict[u.Unit, u.Unit], + columns: dict[str, u.Unit], + cosmology: Cosmology, + redshift: float | table.Column, +) -> DatasetState: + """ + Update the units of a given state. + """ + if convention is None: + convention_ = state.unit_handler.current_convention + else: + convention_ = UnitConvention(convention) + + if ( + convention_ == UnitConvention.SCALEFREE + and UnitConvention(state.header.file.unit_convention) + != UnitConvention.SCALEFREE ): - """ - Change the unit convention - """ - - if convention is None: - convention_ = self.__unit_handler.current_convention - - else: - convention_ = UnitConvention(convention) - - if ( - convention_ == UnitConvention.SCALEFREE - and UnitConvention(self.header.file.unit_convention) - != UnitConvention.SCALEFREE - ): - raise ValueError( - f"Cannot convert units with convention {self.header.file.unit_convention} to convention scalefree" - ) - column_keys = set(columns.keys()) - missing_columns = column_keys - set(self.columns) - if missing_columns: - raise ValueError(f"Dataset does not have columns {missing_columns}") - - new_handler = self.__unit_handler.with_convention(convention_).with_conversions( - conversions, columns + raise ValueError( + f"Cannot convert units with convention {state.header.file.unit_convention} to convention scalefree" ) - - if convention_ == self.__unit_handler.current_convention: - cache = self.__cache.create_child() - else: - all_derived_names: set[str] = set() - all_derived_names = reduce( - lambda acc, col: acc.union( - col.produces if not isinstance(col, RawColumn) else set() - ), - self.__producers.values(), - all_derived_names, - ).intersection(self.columns) - columns_to_drop = all_derived_names.union(self.__raw_data_handler.columns) - cache = self.__cache.drop(columns_to_drop) - return self.__rebuild(unit_handler=new_handler, cache=cache) + column_keys = set(columns.keys()) + missing_columns = column_keys - set(state.columns) + if missing_columns: + raise ValueError(f"Dataset does not have columns {missing_columns}") + + new_handler = state.unit_handler.with_convention(convention_).with_conversions( + conversions, columns + ) + + if convention_ == state.unit_handler.current_convention: + cache = state.cache.create_child() + else: + all_derived_names: set[str] = set() + all_derived_names = reduce( + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + state.producers.values(), + all_derived_names, + ).intersection(state.columns) + columns_to_drop = all_derived_names.union(state.raw_data_handler.columns) + cache = state.cache.drop(columns_to_drop) + return dataclasses.replace(state, unit_handler=new_handler, cache=cache) diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py index 17ef116b..f40e75c5 100644 --- a/python/opencosmo/dtypes/diffsky.py +++ b/python/opencosmo/dtypes/diffsky.py @@ -7,6 +7,7 @@ import numpy as np from pydantic import BaseModel, ConfigDict, field_serializer +import opencosmo.dataset.state as st from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy from opencosmo.index import into_array from opencosmo.index.ops import reindex_column @@ -91,7 +92,7 @@ def rebuild_top_host_idx(top_host_idx, index): def keep_top_host_idx(dataset: DatasetState, new_index: DataIndex): index_array = into_array(new_index) - top_host_idx = dataset.select({"top_host_idx"}).get_data()["top_host_idx"] + top_host_idx = st.get_data(st.select(dataset, {"top_host_idx"}))["top_host_idx"] unique_in_sample = np.unique(top_host_idx[index_array]) missing_hosts = np.setdiff1d(unique_in_sample, index_array) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index e1d11ddd..a69bb91c 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -533,7 +533,7 @@ def open_single_dataset( rank = comm.Get_rank() index = from_range(chunk_boundaries[rank], chunk_boundaries[rank + 1]) - state = st.DatasetState.from_target( + state = st.state_from_target( target, UnitConvention.COMOVING, sim_region, diff --git a/python/opencosmo/units/converters.py b/python/opencosmo/units/converters.py index 24853806..1388040c 100644 --- a/python/opencosmo/units/converters.py +++ b/python/opencosmo/units/converters.py @@ -6,6 +6,7 @@ import astropy.units as u from astropy.cosmology import units as cu +import opencosmo.dataset.state as st from opencosmo.units import UnitConvention if TYPE_CHECKING: @@ -173,12 +174,12 @@ def get_scale_factor(dataset: "DatasetState", cosmology, redshift): columns = set(dataset.columns) for column in KNOWN_SCALEFACTOR_COLUMNS: if column in columns: - col = dataset.select({column}).get_data()[column] + col = st.get_data(st.select(dataset, {column}))[column] return col for column in KNOWN_REDSHIFT_COLUMNS: if column in columns: - col = dataset.select({column}).get_data()[column] + col = st.get_data(st.select(dataset, {column}))[column] return 1 / (1 + col) return cosmology.scale_factor(redshift) diff --git a/test/test_dataset.py b/test/test_dataset.py index 43aedb30..d2095fa4 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -545,8 +545,8 @@ def test_rows_cache(input_path): # After iterating rows(), derived columns should be cached. state = dataset._Dataset__state - cache = state._DatasetState__cache - columns = state._DatasetState__columns + cache = state.cache + columns = state.column_map assert "fof_px" in cache.columns pairs = {(columns["fof_px"], "fof_px")} uuid_data = cache.get_data(pairs) @@ -578,7 +578,7 @@ def test_cache(input_path): def test_cache_select(input_path): dataset = oc.open(input_path) _ = dataset.get_data() - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert set(dataset.columns) == cache.columns columns = np.random.choice(dataset.columns, 5, replace=False) dataset2 = dataset.select(columns) @@ -595,7 +595,7 @@ def test_cache_filter(input_path): dataset = oc.open(input_path) dataset = dataset.filter(oc.col("fof_halo_mass") > 1e14) - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert ( len(cache.columns) > 0 or cache._ColumnCache__parent() is not None @@ -613,7 +613,7 @@ def test_cache_column_conversion(input_path): dataset2 = dataset.with_units(conversions={u.Mpc: u.lyr}) - cache = dataset2._Dataset__state._DatasetState__cache + cache = dataset2._Dataset__state.cache data2 = dataset2.get_data() assert len(cache.columns) == len(dataset2.columns) @@ -626,7 +626,7 @@ def test_cache_change_units(input_path): dataset.get_data() dataset2 = dataset.with_units("scalefree") - cache = dataset2._Dataset__state._DatasetState__cache + cache = dataset2._Dataset__state.cache assert ( len(cache.columns) == 0 and cache._ColumnCache__parent is None @@ -640,9 +640,9 @@ def test_cache_conversion_propogation(input_path): state = dataset._Dataset__state state2 = dataset2._Dataset__state - cache = state._DatasetState__cache - cache2 = state2._DatasetState__cache - col_to_uuid = state._DatasetState__columns + cache = state.cache + cache2 = state2.cache + col_to_uuid = state.column_map pairs = {(uuid, name) for name, uuid in col_to_uuid.items()} assert len(cache.columns) == len(dataset2.columns) # just to be safe diff --git a/uv.lock b/uv.lock index 07138531..0a16bfce 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12, <3.15" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -1025,7 +1025,7 @@ wheels = [ [[package]] name = "opencosmo" -version = "1.2.5" +version = "1.2.6" source = { editable = "." } dependencies = [ { name = "astropy" }, From 11662c7b9dee8034eefacb1df9662e992020cd15 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 1 May 2026 10:33:44 -0500 Subject: [PATCH 082/139] Update tests and fix changelog --- changes/+19969705.misc.rst | 1 + test/test_collection.py | 4 ++-- test/test_healpixmap.py | 4 ++-- test/test_select.py | 4 +--- 4 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 changes/+19969705.misc.rst diff --git a/changes/+19969705.misc.rst b/changes/+19969705.misc.rst new file mode 100644 index 00000000..db75a8f2 --- /dev/null +++ b/changes/+19969705.misc.rst @@ -0,0 +1 @@ +The DatasetState object has been broken up to allow for more flexibility in instantiating datasets, laying the groundwork for future optimizations. diff --git a/test/test_collection.py b/test/test_collection.py index 89f439a4..5000c4b7 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1298,8 +1298,8 @@ def test_data_cached_after_objects(halo_paths): dataset = ds["dm_particles"] state = dataset._Dataset__state - cache = state._DatasetState__cache - columns = state._DatasetState__columns # dict[str, UUID] + cache = state.cache + columns = state.column_map # dict[str, UUID] gpe_uuid = columns["gpe"] uuid_data = cache.get_data({(gpe_uuid, "gpe")}) assert uuid_data.get(gpe_uuid, {}).get("gpe") is not None diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index 97bf215f..26de7d08 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -112,14 +112,14 @@ def test_healpix_downgrade_doesnt_have_file_handle(healpix_map_path): # First dataset backed by actual file, data should be cached dataset = next(iter(ds.values())) - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert len(cache.columns) > 1 # New dataset entirely in-memory, so no cache output.get_data() downgraded_dataset = next(iter(output.values())) cache = downgraded_dataset._Dataset__state._DatasetState__cache - handler = downgraded_dataset._Dataset__state._DatasetState__raw_data_handler + handler = downgraded_dataset._Dataset__state.raw_data_handler assert len(cache.columns) == len(dataset.columns) assert len(handler) == 0 diff --git a/test/test_select.py b/test/test_select.py index 5f1eac2c..fd9fa58e 100644 --- a/test/test_select.py +++ b/test/test_select.py @@ -88,9 +88,7 @@ def test_select_doesnt_alter_raw(input_path): selected = dataset.select(selected_cols) selected_data = selected.get_data() - raw_data = ( - dataset._Dataset__state._DatasetState__raw_data_handler._Hdf5Handler__columns - ) + raw_data = dataset._Dataset__state.raw_data_handler._Hdf5Handler__columns assert all(isinstance(raw_data[col], h5py.Dataset) for col in cols) assert all(data[col].unit == selected_data[col].unit for col in selected_cols) assert not all(np.all(data[col].value == raw_data[col][:]) for col in selected_cols) From 2fa491a8255e3c46ae55cb620a400db470f550bb Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 1 May 2026 10:56:29 -0500 Subject: [PATCH 083/139] One more test fix --- test/test_healpixmap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index 26de7d08..85178398 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -118,7 +118,7 @@ def test_healpix_downgrade_doesnt_have_file_handle(healpix_map_path): # New dataset entirely in-memory, so no cache output.get_data() downgraded_dataset = next(iter(output.values())) - cache = downgraded_dataset._Dataset__state._DatasetState__cache + cache = downgraded_dataset._Dataset__state.cache handler = downgraded_dataset._Dataset__state.raw_data_handler assert len(cache.columns) == len(dataset.columns) assert len(handler) == 0 From 459cf180974c7c4e090c76c872e83b784aef4e8e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 12 May 2026 10:37:39 -0400 Subject: [PATCH 084/139] Global random take implementation --- python/opencosmo/dataset/dataset.py | 15 +++--- python/opencosmo/dataset/take.py | 72 +++++++++++++++++++++++++++++ test/parallel/test_dataset_mpi.py | 26 +++++++++++ 3 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 python/opencosmo/dataset/take.py create mode 100644 test/parallel/test_dataset_mpi.py diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 4e52dc49..0d476215 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -6,6 +6,7 @@ Callable, Generator, Iterable, + Literal, Mapping, Optional, TypeAlias, @@ -21,6 +22,7 @@ from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format +from opencosmo.dataset.take import get_random_take_index from opencosmo.index import empty, get_range, into_array, mask, project, single_chunk from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -697,9 +699,7 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: ) def take( - self, - n: int, - at: str = "random", + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" ) -> Dataset: """ Create a new dataset from some number of rows from this dataset. @@ -715,6 +715,9 @@ def take( at : str Where to take the rows from. One of "start", "end", or "random". The default is "random". + mode : str, "local" or "global", default = "local" + If working with MPI, whether the `n` is a per-rank or global + number Returns @@ -736,11 +739,7 @@ def take( elif at != "random": raise ValueError(f"Unknown take type {at}") - if n > len(self): - raise ValueError("Cannot take more rows than are in this dataset!") - - row_indices = np.random.choice(len(self), n, replace=False) - row_indices.sort() + row_indices = get_random_take_index(n, len(self), mode) return self.take_rows(row_indices) def take_range(self, start: int, end: int) -> Dataset: diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py new file mode 100644 index 00000000..f7af78ca --- /dev/null +++ b/python/opencosmo/dataset/take.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np + +from opencosmo.mpi import get_comm_world, get_mpi, has_mpi + +if TYPE_CHECKING: + from opencosmo.index import DataIndex, IndexArray + + +def get_random_take_index( + n: int, + ds_length: int, + mode: Literal["local", "global"], +) -> DataIndex: + if mode == "global" and has_mpi(): + return get_random_take_index_mpi(n, ds_length) + + if n > ds_length: + raise ValueError("You cannot take more rows than exist in the dataset!") + + generator = np.random.default_rng() + rows = generator.choice(ds_length, n, replace=False) + return np.sort(rows) + + +def get_random_take_index_mpi(n: int, ds_length: int): + comm = get_comm_world() + assert comm is not None + lengths = comm.allgather(ds_length) + + if (total_length := np.sum(lengths)) < n: + raise ValueError( + f"Tried to take {n} rows but total length of data is {total_length}" + ) + + if comm.Get_rank() == 0: + rng = np.random.default_rng() + rows = np.sort(rng.choice(total_length, n, replace=False)) + else: + rows = None + return get_local_rows_simple(rows, lengths, comm) + + +def get_local_rows_simple(rows: IndexArray | None, lengths, comm): + chunk_ranges = np.zeros(len(lengths) + 1, dtype=np.int64) + chunk_ranges[1:] = np.cumsum(lengths) + if comm.Get_rank() == 0: + assert rows is not None + chunk_ranges_in_index = np.searchsorted(rows, chunk_ranges) + chunk_ranges_in_index = comm.bcast(chunk_ranges_in_index) + else: + chunk_ranges_in_index = comm.bcast(None) + + rank_num = comm.Get_rank() + n_rows_local = chunk_ranges_in_index[rank_num + 1] - chunk_ranges_in_index[rank_num] + + local_rows = np.empty(n_rows_local, dtype=np.int64) + + if comm.Get_rank() == 0: + scatter_lengths = chunk_ranges_in_index[1:] - chunk_ranges_in_index[:-1] + buffer_offsets = np.zeros_like(scatter_lengths) + buffer_offsets[1:] = np.cumsum(scatter_lengths)[:-1] + + buffspec = [rows, scatter_lengths, buffer_offsets, get_mpi().DOUBLE] + comm.Scatterv(buffspec, local_rows) + else: + comm.Scatterv([None, None, None, get_mpi().DOUBLE], local_rows) + + return local_rows - chunk_ranges[rank_num] diff --git a/test/parallel/test_dataset_mpi.py b/test/parallel/test_dataset_mpi.py new file mode 100644 index 00000000..c95b6e58 --- /dev/null +++ b/test/parallel/test_dataset_mpi.py @@ -0,0 +1,26 @@ +import numpy as np +import pytest +from opencosmo.mpi import get_comm_world +from pytest_mpi import parallel_assert + +import opencosmo as oc + + +@pytest.fixture +def input_path(snapshot_path): + return snapshot_path / "haloproperties.hdf5" + + +@pytest.mark.parallel(nprocs=4) +def test_take_global(input_path): + comm = get_comm_world() + + ds = oc.open(input_path) + total_length = sum(comm.allgather(len(ds))) + n_to_take = np.random.randint(total_length // 4, int(total_length * 0.75)) + n_to_take = comm.bcast(n_to_take) + + ds = ds.take(n_to_take, mode="global") + all_lengths = comm.allgather(len(ds)) + + parallel_assert(sum(all_lengths) == n_to_take) From 6a9b8c9d56b935b3a1bf97fba1dfa37bd3bc4764 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 12 May 2026 14:05:23 -0400 Subject: [PATCH 085/139] Global take_range implementation --- python/opencosmo/dataset/dataset.py | 17 +-- python/opencosmo/dataset/take.py | 109 +++++++++++++++++ test/parallel/test_dataset_mpi.py | 175 ++++++++++++++++++++++++++++ 3 files changed, 293 insertions(+), 8 deletions(-) diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 0d476215..c35a4423 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -22,8 +22,8 @@ from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.dataset.take import get_random_take_index -from opencosmo.index import empty, get_range, into_array, mask, project, single_chunk +from opencosmo.dataset.take import get_random_take_index, get_range_take_index +from opencosmo.index import empty, get_range, into_array, mask, project from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -733,16 +733,18 @@ def take( """ if at == "start": - return self.take_range(0, n) + return self.take_range(0, n, mode) elif at == "end": - return self.take_range(len(self) - n, len(self)) + return self.take_range(len(self) - n, len(self), mode) elif at != "random": raise ValueError(f"Unknown take type {at}") row_indices = get_random_take_index(n, len(self), mode) return self.take_rows(row_indices) - def take_range(self, start: int, end: int) -> Dataset: + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ) -> Dataset: """ Create a new dataset from a row range in this dataset. We use standard indexing conventions, so the rows included will be start -> end - 1. @@ -770,10 +772,9 @@ def take_range(self, start: int, end: int) -> Dataset: raise ValueError("start and end must be positive.") if end < start: raise ValueError("end must be greater than start.") - if end > len(self): - raise ValueError("end must be less than the length of the dataset.") - take_index = single_chunk(start, end - start) + take_index = get_range_take_index(self.__state, start, end - start, mode) + return self.take_rows(take_index) def take_rows(self, rows: np.ndarray | DataIndex): diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index f7af78ca..26eec352 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -4,6 +4,8 @@ import numpy as np +import opencosmo.dataset.state as st +from opencosmo.index import single_chunk from opencosmo.mpi import get_comm_world, get_mpi, has_mpi if TYPE_CHECKING: @@ -26,6 +28,113 @@ def get_random_take_index( return np.sort(rows) +def get_range_take_index( + state: st.DatasetState, + start: int, + size: int, + mode: Literal["local", "global"], +): + if mode == "global" and has_mpi(): + return get_range_take_index_mpi(state, start, size) + + if start + size > len(state): + raise ValueError("end must be less than the length of the dataset.") + return single_chunk(start, size) + + +def get_range_take_index_mpi(state: st.DatasetState, start, size): + comm = get_comm_world() + assert comm is not None + lengths = np.array(comm.allgather(len(state)), dtype=np.int64) + total_length = int(np.sum(lengths)) + + if start + size > total_length: + raise ValueError( + f"Tried to take {start + size} rows but total length of data is {total_length}" + ) + + if state.sort_key is not None: + global_sort_order = get_global_sort_order(state) + + if comm.Get_rank() == 0: + assert global_sort_order is not None + n_ranks = comm.Get_size() + chunk_ranges = np.zeros(n_ranks + 1, dtype=np.int64) + chunk_ranges[1:] = np.cumsum(lengths) + + # Map each position in the global sort order to the rank that owns it. + rank_of_element = np.searchsorted( + chunk_ranges[1:], global_sort_order, side="right" + ) + # Count how many sorted elements before `start` belong to each rank; + # this is the local sorted-order start index for each rank's slice. + lo_per_rank = np.bincount(rank_of_element[:start], minlength=n_ranks) + # Count how many sorted elements in [start, start+size) belong to each rank. + count_per_rank = np.bincount( + rank_of_element[start : start + size], minlength=n_ranks + ) + else: + lo_per_rank = None + count_per_rank = None + + lo_per_rank = comm.bcast(lo_per_rank) + count_per_rank = comm.bcast(count_per_rank) + + rank = comm.Get_rank() + local_start = int(lo_per_rank[rank]) + local_size = int(count_per_rank[rank]) + + if local_size == 0: + return np.array([], dtype=np.int64) + return single_chunk(local_start, local_size) + + # Handle the case without sorting: contiguous global range + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + local_start = max(0, start - offset) + local_end = min(int(lengths[rank]), start + size - offset) + + if local_end <= local_start: + return np.array([], dtype=np.int64) + return single_chunk(local_start, local_end - local_start) + + +def get_global_sort_order(state: st.DatasetState): + comm = get_comm_world() + assert comm is not None + + assert state.sort_key is not None + sort_col, sort_desc = state.sort_key + raw = st.get_data(st.select(state, {sort_col}), ignore_sort=True)[sort_col] + local_values = np.asarray( + raw.value if hasattr(raw, "value") else raw, dtype=np.float64 + ) + + lengths = np.array(comm.allgather(len(local_values)), dtype=np.int64) + total_length = int(np.sum(lengths)) + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + + # Use comm.Reduce to get the full catalog on rank 0. + # Each rank writes its values at its global offset; summing gives the full array. + local_contribution = np.zeros(total_length, dtype=np.float64) + local_contribution[offset : offset + len(local_values)] = local_values + recv = np.zeros(total_length, dtype=np.float64) if rank == 0 else None + comm.Reduce(local_contribution, recv, op=get_mpi().SUM, root=0) + + # Other ranks return None + if rank != 0: + return None + + # Determine global sort range + assert recv is not None + global_sorted_phys = np.argsort(recv, kind="stable") + if sort_desc: + global_sorted_phys = global_sorted_phys[::-1] + + return global_sorted_phys + + def get_random_take_index_mpi(n: int, ds_length: int): comm = get_comm_world() assert comm is not None diff --git a/test/parallel/test_dataset_mpi.py b/test/parallel/test_dataset_mpi.py index c95b6e58..71d8e3d1 100644 --- a/test/parallel/test_dataset_mpi.py +++ b/test/parallel/test_dataset_mpi.py @@ -24,3 +24,178 @@ def test_take_global(input_path): all_lengths = comm.allgather(len(ds)) parallel_assert(sum(all_lengths) == n_to_take) + + +# ── take_range global, unsorted ────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_start(input_path): + """First n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + + ds_taken = ds.take_range(0, n, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max(0, min(int(lengths[rank]), n - offset)) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_end(input_path): + """Last n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + ds_taken = ds.take_range(global_start, total, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +# ── take_range global, sorted ───────────────────────────────────────────────── +# +# The sorted tests verify value-level correctness: after a global range take on +# a sorted dataset, every selected value must satisfy the global threshold +# implied by the range position. We derive the expected threshold by gathering +# all values from all ranks before the take, then checking the invariant after. + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_start(input_path): + """Global start on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + # Gather original values to compute the expected threshold before the take. + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + ds_taken = ds.take_range(0, n, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_end(input_path): + """Global end on sorted data selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take_range(total - n, total, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +# ── take_range global, middle ───────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_middle(input_path): + """A middle window of global rows lands on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + global_start = total // 4 + global_end = 3 * total // 4 + + ds_taken = ds.take_range(global_start, global_end, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), global_end - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == global_end - global_start) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_middle(input_path): + """A middle window on sorted data selects the correct globally-ranked values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + global_start = total // 4 + global_end = 3 * total // 4 + size = global_end - global_start + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + sorted_all = np.sort(all_original) + lower_threshold = sorted_all[global_start] + upper_threshold = sorted_all[global_end - 1] + + ds_taken = ds.take_range(global_start, global_end, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == size) + parallel_assert( + np.all(all_selected >= lower_threshold), + "some selected values fall below the lower global threshold", + ) + parallel_assert( + np.all(all_selected <= upper_threshold), + "some selected values exceed the upper global threshold", + ) From 1fdb721d6ab902f6614cc717512b00073b843f02 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 12 May 2026 15:55:06 -0400 Subject: [PATCH 086/139] no longer throw, fix take with --- python/opencosmo/dataset/dataset.py | 9 ++++-- python/opencosmo/dataset/take.py | 45 +++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index c35a4423..eba30011 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -22,7 +22,11 @@ from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.dataset.take import get_random_take_index, get_range_take_index +from opencosmo.dataset.take import ( + get_end_take_index, + get_random_take_index, + get_range_take_index, +) from opencosmo.index import empty, get_range, into_array, mask, project from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -735,7 +739,8 @@ def take( if at == "start": return self.take_range(0, n, mode) elif at == "end": - return self.take_range(len(self) - n, len(self), mode) + take_index = get_end_take_index(n, self.__state, mode) + return self.take_rows(take_index) elif at != "random": raise ValueError(f"Unknown take type {at}") diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index 26eec352..59f262d8 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -5,7 +5,7 @@ import numpy as np import opencosmo.dataset.state as st -from opencosmo.index import single_chunk +from opencosmo.index import empty, from_size, single_chunk from opencosmo.mpi import get_comm_world, get_mpi, has_mpi if TYPE_CHECKING: @@ -21,7 +21,7 @@ def get_random_take_index( return get_random_take_index_mpi(n, ds_length) if n > ds_length: - raise ValueError("You cannot take more rows than exist in the dataset!") + return from_size(ds_length) generator = np.random.default_rng() rows = generator.choice(ds_length, n, replace=False) @@ -37,8 +37,30 @@ def get_range_take_index( if mode == "global" and has_mpi(): return get_range_take_index_mpi(state, start, size) - if start + size > len(state): - raise ValueError("end must be less than the length of the dataset.") + ds_len = len(state) + if start + size > ds_len: + size = len(state) - ds_len + return single_chunk(start, size) + + +def get_end_take_index( + n: int, + state: st.DatasetState, + mode: Literal["local", "global"], +): + ds_length = len(state) + if mode == "global" and has_mpi(): + comm = get_comm_world() + assert comm is not None + total_length = np.sum(comm.allgather(ds_length)) + if n > total_length: + return from_size(ds_length) + + return get_range_take_index_mpi(state, total_length - n, n) + + if n > ds_length: + start = 0 + size = ds_length return single_chunk(start, size) @@ -48,10 +70,11 @@ def get_range_take_index_mpi(state: st.DatasetState, start, size): lengths = np.array(comm.allgather(len(state)), dtype=np.int64) total_length = int(np.sum(lengths)) + if start > total_length: + return empty() + if start + size > total_length: - raise ValueError( - f"Tried to take {start + size} rows but total length of data is {total_length}" - ) + size = total_length - start if state.sort_key is not None: global_sort_order = get_global_sort_order(state) @@ -141,9 +164,7 @@ def get_random_take_index_mpi(n: int, ds_length: int): lengths = comm.allgather(ds_length) if (total_length := np.sum(lengths)) < n: - raise ValueError( - f"Tried to take {n} rows but total length of data is {total_length}" - ) + return from_size(ds_length) if comm.Get_rank() == 0: rng = np.random.default_rng() @@ -173,9 +194,9 @@ def get_local_rows_simple(rows: IndexArray | None, lengths, comm): buffer_offsets = np.zeros_like(scatter_lengths) buffer_offsets[1:] = np.cumsum(scatter_lengths)[:-1] - buffspec = [rows, scatter_lengths, buffer_offsets, get_mpi().DOUBLE] + buffspec = [rows, scatter_lengths, buffer_offsets, get_mpi().INT64_T] comm.Scatterv(buffspec, local_rows) else: - comm.Scatterv([None, None, None, get_mpi().DOUBLE], local_rows) + comm.Scatterv([None, None, None, get_mpi().INT64_T], local_rows) return local_rows - chunk_ranges[rank_num] From 9b089e07abf41c69ff2ef3bd6e1806b6f5fb554f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 12 May 2026 16:07:25 -0400 Subject: [PATCH 087/139] Fix unbound local error, fix test for new behavior --- python/opencosmo/dataset/take.py | 5 +++-- test/test_take.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index 59f262d8..89dc904d 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -58,10 +58,11 @@ def get_end_take_index( return get_range_take_index_mpi(state, total_length - n, n) + start = ds_length - n if n > ds_length: start = 0 - size = ds_length - return single_chunk(start, size) + n = ds_length + return single_chunk(start, n) def get_range_take_index_mpi(state: st.DatasetState, start, size): diff --git a/test/test_take.py b/test/test_take.py index f87a0bce..79524588 100644 --- a/test/test_take.py +++ b/test/test_take.py @@ -58,5 +58,6 @@ def test_take_chain(input_path): def test_take_too_many(input_path): ds = oc.open(input_path) length = len(ds.get_data()) - with pytest.raises(ValueError): - ds.take(length + 1) + + new_ds = ds.take(length + 1) + assert len(new_ds) == len(ds) From aaef87ffe36738f9be7469942cad26a8c35db037 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 13 May 2026 12:45:14 -0400 Subject: [PATCH 088/139] Implement in structurecollection, implement helper index routine for Lightcone --- python/opencosmo/_lib/index.pyi | 3 + .../collection/lightcone/lightcone.py | 85 ++++++++++--------- .../collection/structure/structure.py | 17 +++- python/opencosmo/dataset/dataset.py | 11 ++- python/opencosmo/dataset/take.py | 35 ++++---- python/opencosmo/index/__init__.py | 2 + python/opencosmo/index/ops.py | 6 +- src/index.rs | 53 ++++++++++++ test/parallel/test_dataset_mpi.py | 55 ++++++++++++ test/test_lightcone.py | 2 +- test/test_take.py | 23 +++++ 11 files changed, 226 insertions(+), 66 deletions(-) diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi index ca046099..52e72e1c 100644 --- a/python/opencosmo/_lib/index.pyi +++ b/python/opencosmo/_lib/index.pyi @@ -13,3 +13,6 @@ def take_chunked_from_simple( simple: IndexArray, start: IndexArray, size: IndexArray ) -> IndexArray: ... def reindex_column(index: IndexArray, index_column: IndexArray) -> IndexArray: ... +def rebuild_simple_by_ranges( + index: IndexArray, starts: IndexArray, sizes: IndexArray +) -> IndexArray: ... diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index db5bc37e..068f3c2a 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -9,6 +9,7 @@ Callable, Generator, Iterable, + Literal, Mapping, Optional, Self, @@ -25,6 +26,12 @@ from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn from opencosmo.dataset.evaluate import build_evaluated_column from opencosmo.dataset.formats import convert_data, verify_format +from opencosmo.dataset.take import ( + get_end_take_index, + get_random_take_index, + get_range_take_index, +) +from opencosmo.index import rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema from opencosmo.plugins.contexts import ( @@ -69,7 +76,7 @@ def __init__( datasets: Mapping[Any, Dataset | Lightcone], z_range: Optional[tuple[float, float]] = None, hidden: Optional[set[str]] = None, - ordered_by: Optional[tuple[str, bool]] = None, + sort_key: Optional[tuple[str, bool]] = None, ): self.update(datasets) z_range = ( @@ -90,7 +97,7 @@ def __init__( hidden = set() self.__hidden = hidden - self.__ordered_by = ordered_by + self.__sort_key = sort_key def __repr__(self): """ @@ -301,8 +308,8 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): table = vstack(data_with_length, join_type="exact") - if self.__ordered_by is not None: - order = table.argsort(self.__ordered_by[0], reverse=self.__ordered_by[1]) + if self.__sort_key is not None: + order = table.argsort(self.__sort_key[0], reverse=self.__sort_key[1]) table = table[order] table = fold( HookPoint.PostSort, PostSortCtx(self, table, np.argsort(order)) @@ -408,9 +415,7 @@ def with_redshift_range(self, z_low: float, z_high: float): ) if len(new_dataset) > 0: new_datasets[key] = new_dataset - return Lightcone( - new_datasets, (z_low, z_high), self.__hidden, self.__ordered_by - ) + return Lightcone(new_datasets, (z_low, z_high), self.__hidden, self.__sort_key) def __map( self, @@ -445,7 +450,7 @@ def __map( if not output: output = zero_length_output if construct: - return Lightcone(output, self.z_range, hidden, self.__ordered_by) + return Lightcone(output, self.z_range, hidden, self.__sort_key) return output def __map_attribute(self, attribute): @@ -791,9 +796,9 @@ def select( additional_columns.add("redshift") hidden = hidden.union({"redshift"}) - if self.__ordered_by is not None and self.__ordered_by[0] not in all_columns: - additional_columns.add(self.__ordered_by[0]) - hidden = hidden.union({self.__ordered_by[0]}) + if self.__sort_key is not None and self.__sort_key[0] not in all_columns: + additional_columns.add(self.__sort_key[0]) + hidden = hidden.union({self.__sort_key[0]}) return self.__map( "select", @@ -838,7 +843,9 @@ def drop(self, *columns: str | Iterable[str]) -> Self: kept_columns = current_columns - dropped_columns return self.select(kept_columns) - def take(self, n: int, at: str = "random") -> "Lightcone": + def take( + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" + ) -> "Lightcone": """ Create a new dataset from some number of rows from this dataset. @@ -870,23 +877,21 @@ def take(self, n: int, at: str = "random") -> "Lightcone": "Number of rows to take must be less than number of rows in dataset" ) if at == "random": - indices = np.random.choice(len(self), n, replace=False) - indices = np.sort(indices) - return self.__take_rows(indices) + index = get_random_take_index(n, len(self), mode) - elif self.__ordered_by is not None: - index = lcutils.take_from_sorted(self, *self.__ordered_by, n=n, at=at) - return self.__take_rows(index) elif at == "start": - return self.take_range(0, n) + index = get_range_take_index(self, self.__sort_key, 0, n, mode) elif at == "end": - return self.take_range(len(self) - n, len(self)) + index = get_end_take_index(n, self, self.__sort_key, mode) else: raise ValueError( f'"at" should be one of ("start", "end", "random", got {at}' ) + return self.take_rows(index) - def take_range(self, start: int, end: int): + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ): """ Create a new lightcone from a row range in this lightcone. We use standard indexing conventions, so the rows included will be start -> end - 1. Because @@ -913,12 +918,12 @@ def take_range(self, start: int, end: int): or if end is greater than start. """ - if start < 0 or end > len(self): - raise ValueError("Got row indices that are out of range!") + if start < 0: + raise ValueError("Tried to take negative rows!") - if self.__ordered_by is not None: + if self.__sort_key is not None: indices = lcutils.take_from_sorted( - self, *self.__ordered_by, end - start, at=start + self, *self.__sort_key, end - start, at=start ) return self.__take_rows(indices) @@ -937,7 +942,7 @@ def take_range(self, start: int, end: int): output[name] = dataset.take_range( clipped_starts[i] - starts[i], clipped_ends[i] - starts[i] ) - return Lightcone(output, self.z_range, self.__hidden, self.__ordered_by) + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) def take_rows(self, rows: np.ndarray): """ @@ -965,7 +970,7 @@ def take_rows(self, rows: np.ndarray): raise ValueError( "Rows must be between 0 and the length of this dataset - 1" ) - if self.__ordered_by is not None: + if self.__sort_key is not None: sort_index = self.__make_sort_index() rows = sort_index[rows] rows.sort() @@ -973,12 +978,12 @@ def take_rows(self, rows: np.ndarray): return self.__take_rows(rows) def __make_sort_index(self): - if self.__ordered_by is None: + if self.__sort_key is None: return None data = np.concatenate( - [ds.select(self.__ordered_by[0]).get_data("numpy") for ds in self.values()] + [ds.select(self.__sort_key[0]).get_data("numpy") for ds in self.values()] ) - if self.__ordered_by[1]: + if self.__sort_key[1]: data = -data return np.argsort(data) @@ -987,17 +992,15 @@ def __take_rows(self, rows: np.ndarray): Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorte. For internal use only. """ - ds_ends = np.cumsum(np.fromiter((len(ds) for ds in self.values()), dtype=int)) - partitions = np.searchsorted(rows, ds_ends) - splits = np.split(rows, partitions) - rs = 0 + sizes = np.fromiter((len(ds) for ds in self.values()), dtype=np.int64) + starts = np.zeros_like(sizes) + starts[1:] = np.cumsum(sizes)[:-1] + projected = rebuild_by_ranges(rows, (starts, sizes)) output = {} - for split, (name, dataset) in zip(splits, self.items()): - if len(split) > 0: - output[name] = dataset.take_rows(split - rs) - rs += len(dataset) + for (name, ds), index in zip(self.items(), projected): + output[name] = ds.take_rows(index) - return Lightcone(output, self.z_range, self.__hidden, self.__ordered_by) + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) def with_new_columns( self, @@ -1045,7 +1048,7 @@ def with_new_columns( else: raw[name] = column - if self.__ordered_by is not None: + if self.__sort_key is not None: sort_index = self.__make_sort_index() sort_index = np.argsort(sort_index) raw = {name: raw_data[sort_index] for name, raw_data in raw.items()} @@ -1061,7 +1064,7 @@ def with_new_columns( descriptions, allow_overwrite=allow_overwrite, **columns_input ) new_datasets[ds_name] = new_dataset - return Lightcone(new_datasets, self.z_range, self.__hidden, self.__ordered_by) + return Lightcone(new_datasets, self.z_range, self.__hidden, self.__sort_key) def sort_by(self, column: str, invert: bool = False): """ diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 43570924..8c385e02 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -3,7 +3,16 @@ from collections import defaultdict from functools import partial, reduce from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Literal, + Mapping, + Optional, +) from warnings import warn import numpy as np @@ -989,7 +998,9 @@ def with_units( self.__handler.make_derived(self.__source), ) - def take(self, n: int, at: str = "random"): + def take( + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" + ): """ Take some number of structures from the collection. See :py:meth:`opencosmo.Dataset.take`. @@ -1007,7 +1018,7 @@ def take(self, n: int, at: str = "random"): StructureCollection A new collection with the structures taken from the original. """ - new_source = self.__source.take(n, at) + new_source = self.__source.take(n, at, mode) new_handler = self.__handler.make_derived(self.__source) return StructureCollection( diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index eba30011..5981220d 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -294,7 +294,10 @@ def get_data( unit_kwargs = {} data = st.get_data( - self.__state, unit_kwargs=unit_kwargs, metadata_columns=metadata_columns + self.__state, + unit_kwargs=unit_kwargs, + metadata_columns=metadata_columns, + **kwargs, ) # dict if unpack: data = { @@ -739,7 +742,7 @@ def take( if at == "start": return self.take_range(0, n, mode) elif at == "end": - take_index = get_end_take_index(n, self.__state, mode) + take_index = get_end_take_index(n, self, self.__state.sort_key, mode) return self.take_rows(take_index) elif at != "random": raise ValueError(f"Unknown take type {at}") @@ -778,7 +781,9 @@ def take_range( if end < start: raise ValueError("end must be greater than start.") - take_index = get_range_take_index(self.__state, start, end - start, mode) + take_index = get_range_take_index( + self, self.__state.sort_key, start, end - start, mode + ) return self.take_rows(take_index) diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index 89dc904d..c97540de 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Optional import numpy as np -import opencosmo.dataset.state as st from opencosmo.index import empty, from_size, single_chunk from opencosmo.mpi import get_comm_world, get_mpi, has_mpi @@ -29,26 +28,28 @@ def get_random_take_index( def get_range_take_index( - state: st.DatasetState, + ds, + sort_key: Optional[tuple[str, bool]], start: int, size: int, mode: Literal["local", "global"], ): if mode == "global" and has_mpi(): - return get_range_take_index_mpi(state, start, size) + return get_range_take_index_mpi(ds, sort_key, start, size) - ds_len = len(state) + ds_len = len(ds) if start + size > ds_len: - size = len(state) - ds_len + size = len(ds) - ds_len return single_chunk(start, size) def get_end_take_index( n: int, - state: st.DatasetState, + ds, + sort_key, mode: Literal["local", "global"], ): - ds_length = len(state) + ds_length = len(ds) if mode == "global" and has_mpi(): comm = get_comm_world() assert comm is not None @@ -56,7 +57,7 @@ def get_end_take_index( if n > total_length: return from_size(ds_length) - return get_range_take_index_mpi(state, total_length - n, n) + return get_range_take_index_mpi(ds, sort_key, total_length - n, n) start = ds_length - n if n > ds_length: @@ -65,10 +66,10 @@ def get_end_take_index( return single_chunk(start, n) -def get_range_take_index_mpi(state: st.DatasetState, start, size): +def get_range_take_index_mpi(ds, sort_key, start, size): comm = get_comm_world() assert comm is not None - lengths = np.array(comm.allgather(len(state)), dtype=np.int64) + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) total_length = int(np.sum(lengths)) if start > total_length: @@ -77,8 +78,8 @@ def get_range_take_index_mpi(state: st.DatasetState, start, size): if start + size > total_length: size = total_length - start - if state.sort_key is not None: - global_sort_order = get_global_sort_order(state) + if sort_key is not None: + global_sort_order = get_global_sort_order(ds, sort_key) if comm.Get_rank() == 0: assert global_sort_order is not None @@ -123,13 +124,13 @@ def get_range_take_index_mpi(state: st.DatasetState, start, size): return single_chunk(local_start, local_end - local_start) -def get_global_sort_order(state: st.DatasetState): +def get_global_sort_order(ds, sort_key): comm = get_comm_world() assert comm is not None - assert state.sort_key is not None - sort_col, sort_desc = state.sort_key - raw = st.get_data(st.select(state, {sort_col}), ignore_sort=True)[sort_col] + assert sort_key is not None + sort_col, sort_desc = sort_key + raw = ds.select(sort_col).get_data("numpy", ignore_sort=True) local_values = np.asarray( raw.value if hasattr(raw, "value") else raw, dtype=np.float64 ) diff --git a/python/opencosmo/index/__init__.py b/python/opencosmo/index/__init__.py index 4bba8eab..a640c787 100644 --- a/python/opencosmo/index/__init__.py +++ b/python/opencosmo/index/__init__.py @@ -4,6 +4,7 @@ from .get import get_data from .in_range import n_in_range from .mask import into_array, mask +from .ops import rebuild_by_ranges from .project import project from .take import take from .unary import get_length, get_range @@ -34,4 +35,5 @@ "get_length", "get_range", "from_start_size_group", + "rebuild_by_ranges", ] diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py index b841d5e4..54b44960 100644 --- a/python/opencosmo/index/ops.py +++ b/python/opencosmo/index/ops.py @@ -8,9 +8,13 @@ from opencosmo.index import into_array if TYPE_CHECKING: - from opencosmo.index import DataIndex + from opencosmo.index import ChunkedIndex, DataIndex def reindex_column(index: DataIndex, column: np.ndarray): column = column.astype(np.int64) return idxlib.reindex_column(into_array(index), column) + + +def rebuild_by_ranges(index: DataIndex, ranges: ChunkedIndex): + return idxlib.rebuild_simple_by_ranges(index, *ranges) diff --git a/src/index.rs b/src/index.rs index 1d76c58f..4e773432 100644 --- a/src/index.rs +++ b/src/index.rs @@ -6,6 +6,7 @@ pub(crate) mod index { use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; + use pyo3::types::PyList; use std::collections::HashMap; use std::iter::zip; @@ -314,4 +315,56 @@ pub(crate) mod index { } Array1::from_vec(output) } + + #[pyfunction(name = "rebuild_simple_by_ranges")] + fn rebuild_simple_by_ranges_py<'py>( + py: Python<'py>, + index: &Bound<'_, PyAny>, + range_starts: &Bound<'_, PyAny>, + range_sizes: &Bound<'_, PyAny>, + ) -> PyResult> { + let index_arr = unpack_index_array(index)?; + let (start_arr, size_arr) = unpack_chunked_index(range_starts, range_sizes)?; + let mut output = rebuild_simple_by_ranges( + index_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + PyList::new(py, output.drain(0..).map(|a| a.into_pyarray(py))) + } + + fn rebuild_simple_by_ranges( + index: ArrayView1<'_, i64>, + range_starts: ArrayView1<'_, i64>, + range_sizes: ArrayView1<'_, i64>, + ) -> Vec> { + let mut rs: i64 = 0; + let mut current_chunk_index: usize = 0; + let mut current_chunk_end: i64 = + range_starts[current_chunk_index] + range_sizes[current_chunk_index]; + let mut output_indices = Vec::new(); + let mut current_chunk_indices = Vec::new(); + for &idx in index { + if idx >= current_chunk_end { + output_indices.push(Array1::from_vec(current_chunk_indices)); + rs = current_chunk_end; + current_chunk_index += 1; + current_chunk_end = + range_starts[current_chunk_index] + range_sizes[current_chunk_index]; + current_chunk_indices = Vec::new(); + current_chunk_indices.push(idx - rs); + } else if idx < rs { + continue; + } else { + current_chunk_indices.push(idx - rs); + } + } + output_indices.push(Array1::from_vec(current_chunk_indices)); + if output_indices.len() < range_starts.len() { + for _ in 0..range_starts.len() - output_indices.len() { + output_indices.push(Array1::zeros(0)) + } + } + return output_indices; + } } diff --git a/test/parallel/test_dataset_mpi.py b/test/parallel/test_dataset_mpi.py index 71d8e3d1..956442d4 100644 --- a/test/parallel/test_dataset_mpi.py +++ b/test/parallel/test_dataset_mpi.py @@ -138,6 +138,61 @@ def test_take_range_global_sorted_end(input_path): ) +# ── take global end ─────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_global_end(input_path): + """take(n, at='end', mode='global') selects the last n rows across all ranks.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + ds_taken = ds.take(n, at="end", mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_take_global_end_sorted(input_path): + """take(n, at='end', mode='global') on sorted data selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take(n, at="end", mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + # ── take_range global, middle ───────────────────────────────────────────────── diff --git a/test/test_lightcone.py b/test/test_lightcone.py index b4efc42d..cc89bd2d 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -179,7 +179,7 @@ def test_lc_collection_range( def test_lc_collection_take_rows(haloproperties_600_path, haloproperties_601_path): ds = oc.open(haloproperties_600_path, haloproperties_601_path) - n_to_take = int(0.75 * len(ds)) + n_to_take = int(0.25 * len(ds)) rows = np.random.choice(len(ds), n_to_take, replace=False) rows.sort() ds_rows = ds.take_rows(rows) diff --git a/test/test_take.py b/test/test_take.py index 79524588..8ae08696 100644 --- a/test/test_take.py +++ b/test/test_take.py @@ -61,3 +61,26 @@ def test_take_too_many(input_path): new_ds = ds.take(length + 1) assert len(new_ds) == len(ds) + + +def test_take_end_too_many(input_path): + ds = oc.open(input_path) + length = len(ds) + + new_ds = ds.take(length + 1, at="end") + assert len(new_ds) == length + + +def test_take_end_sorted(input_path): + ds = oc.open(input_path) + cols = ds.columns + sort_col = cols[0] + n = 10 + + all_values = ds.select(sort_col).get_data("numpy") + threshold = np.sort(all_values)[-n] + + taken = ds.sort_by(sort_col).take(n, at="end").select(sort_col).get_data("numpy") + + assert len(taken) == n + assert np.all(taken >= threshold) From 4a7f1a88b32732475970cf1b328d47664e425d29 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 13 May 2026 22:05:47 -0400 Subject: [PATCH 089/139] Takes working for chunked index as well --- .../collection/lightcone/lightcone.py | 45 ++++-------- python/opencosmo/dataset/take.py | 28 ++++++- python/opencosmo/index/ops.py | 7 +- src/index.rs | 73 +++++++++++++++++++ test/test_lightcone.py | 2 +- 5 files changed, 122 insertions(+), 33 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 068f3c2a..387872ad 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -31,7 +31,7 @@ get_random_take_index, get_range_take_index, ) -from opencosmo.index import rebuild_by_ranges +from opencosmo.index import DataIndex, get_range, rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema from opencosmo.plugins.contexts import ( @@ -308,7 +308,7 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): table = vstack(data_with_length, join_type="exact") - if self.__sort_key is not None: + if self.__sort_key is not None and not kwargs.get("ignore_sort", False): order = table.argsort(self.__sort_key[0], reverse=self.__sort_key[1]) table = table[order] table = fold( @@ -887,7 +887,7 @@ def take( raise ValueError( f'"at" should be one of ("start", "end", "random", got {at}' ) - return self.take_rows(index) + return self.__take_rows(index) def take_range( self, start: int, end: int, mode: Literal["local", "global"] = "local" @@ -921,30 +921,10 @@ def take_range( if start < 0: raise ValueError("Tried to take negative rows!") - if self.__sort_key is not None: - indices = lcutils.take_from_sorted( - self, *self.__sort_key, end - start, at=start - ) - return self.__take_rows(indices) - - ends = np.cumsum(np.fromiter((len(ds) for ds in self.values()), dtype=int)) - starts = np.insert(ends, 0, 0)[:-1] - clipped_starts = np.clip(starts, a_min=start, a_max=None) - clipped_ends = np.clip(ends, a_min=None, a_max=end) + index = get_range_take_index(self, self.__sort_key, start, end - start, mode) + return self.__take_rows(index) - output = {} - for i, (name, dataset) in enumerate(self.items()): - if starts[i] == clipped_starts[i] and ends[i] == clipped_ends[i]: - output[name] = dataset - elif clipped_starts[i] >= clipped_ends[i]: - continue - else: - output[name] = dataset.take_range( - clipped_starts[i] - starts[i], clipped_ends[i] - starts[i] - ) - return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) - - def take_rows(self, rows: np.ndarray): + def take_rows(self, rows: DataIndex): """ Take the rows of a lightcone specified by the :code:`rows` argument. :code:`rows` should be an array of integers. @@ -965,8 +945,15 @@ def take_rows(self, rows: np.ndarray): lightcone. """ - rows = np.sort(rows) - if rows[-1] >= len(self) or rows[0] < 0: + if isinstance(rows, np.ndarray): + rows = np.sort(rows) + else: + order = np.argsort(rows[0]) + rows = (rows[0][order], rows[1][order]) + + index_range = get_range(rows) + + if index_range[0] < 0 or index_range[1] > len(self): raise ValueError( "Rows must be between 0 and the length of this dataset - 1" ) @@ -987,7 +974,7 @@ def __make_sort_index(self): data = -data return np.argsort(data) - def __take_rows(self, rows: np.ndarray): + def __take_rows(self, rows: DataIndex): """ Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorte. For internal use only. diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index c97540de..95578fb2 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -4,7 +4,7 @@ import numpy as np -from opencosmo.index import empty, from_size, single_chunk +from opencosmo.index import empty, from_size, into_array, single_chunk from opencosmo.mpi import get_comm_world, get_mpi, has_mpi if TYPE_CHECKING: @@ -27,6 +27,20 @@ def get_random_take_index( return np.sort(rows) +def apply_sort_index(rows: DataIndex, sort_index: np.ndarray) -> np.ndarray: + """Map logical sorted-order positions to physical row positions for either index type.""" + return np.sort(sort_index[into_array(rows)]) + + +def _get_sort_index(ds, sort_key: tuple[str, bool]) -> np.ndarray: + sort_col, sort_desc = sort_key + raw = ds.select(sort_col).get_data(ignore_sort=True) + values = np.asarray(raw.value, dtype=np.float64) + if sort_desc: + values = -values + return np.argsort(values, kind="stable") + + def get_range_take_index( ds, sort_key: Optional[tuple[str, bool]], @@ -39,7 +53,12 @@ def get_range_take_index( ds_len = len(ds) if start + size > ds_len: - size = len(ds) - ds_len + size = ds_len - start + + if sort_key is not None: + sort_index = _get_sort_index(ds, sort_key) + return np.sort(sort_index[start : start + size]) + return single_chunk(start, size) @@ -63,6 +82,11 @@ def get_end_take_index( if n > ds_length: start = 0 n = ds_length + + if sort_key is not None: + sort_index = _get_sort_index(ds, sort_key) + return np.sort(sort_index[start : start + n]) + return single_chunk(start, n) diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py index 54b44960..22353d08 100644 --- a/python/opencosmo/index/ops.py +++ b/python/opencosmo/index/ops.py @@ -17,4 +17,9 @@ def reindex_column(index: DataIndex, column: np.ndarray): def rebuild_by_ranges(index: DataIndex, ranges: ChunkedIndex): - return idxlib.rebuild_simple_by_ranges(index, *ranges) + print(index) + match index: + case np.ndarray(): + return idxlib.rebuild_simple_by_ranges(index, *ranges) + case (np.ndarray(), np.ndarray()): + return idxlib.rebuild_chunked_by_ranges(*index, *ranges) diff --git a/src/index.rs b/src/index.rs index 4e773432..8eb8a268 100644 --- a/src/index.rs +++ b/src/index.rs @@ -316,6 +316,79 @@ pub(crate) mod index { Array1::from_vec(output) } + #[pyfunction(name = "rebuild_chunked_by_ranges")] + fn rebuild_chunked_by_ranges_py<'py>( + py: Python<'py>, + starts: &Bound<'_, PyAny>, + sizes: &Bound<'_, PyAny>, + range_starts: &Bound<'_, PyAny>, + range_sizes: &Bound<'_, PyAny>, + ) -> PyResult> { + let (start_arr, size_arr) = unpack_chunked_index(starts, sizes)?; + let (range_starts_arr, range_sizes_arr) = unpack_chunked_index(range_starts, range_sizes)?; + let mut output = rebuild_chunked_by_ranges( + start_arr.as_array(), + size_arr.as_array(), + range_starts_arr.as_array(), + range_sizes_arr.as_array(), + ); + PyList::new( + py, + output + .drain(0..) + .map(|(st, si)| (st.into_pyarray(py), si.into_pyarray(py))), + ) + } + + fn rebuild_chunked_by_ranges( + starts: ArrayView1<'_, i64>, + sizes: ArrayView1<'_, i64>, + range_starts: ArrayView1<'_, i64>, + range_sizes: ArrayView1<'_, i64>, + ) -> Vec<(Array1, Array1)> { + let n_datasets = range_starts.len(); + let mut outputs: Vec<(Vec, Vec)> = + (0..n_datasets).map(|_| (Vec::new(), Vec::new())).collect(); + + if starts.len() == 0 || n_datasets == 0 { + return outputs + .into_iter() + .map(|(st, si)| (Array1::from_vec(st), Array1::from_vec(si))) + .collect(); + } + + let mut i = 0usize; + let mut j = 0usize; + + while i < starts.len() && j < n_datasets { + let chunk_start = starts[i]; + let chunk_end = chunk_start + sizes[i]; + let ds_start = range_starts[j]; + let ds_end = ds_start + range_sizes[j]; + + if chunk_end <= ds_start { + i += 1; + } else if chunk_start >= ds_end { + j += 1; + } else { + let overlap_start = chunk_start.max(ds_start); + let overlap_end = chunk_end.min(ds_end); + outputs[j].0.push(overlap_start - ds_start); + outputs[j].1.push(overlap_end - overlap_start); + + if chunk_end <= ds_end { + i += 1; + } else { + j += 1; + } + } + } + + outputs + .into_iter() + .map(|(st, si)| (Array1::from_vec(st), Array1::from_vec(si))) + .collect() + } #[pyfunction(name = "rebuild_simple_by_ranges")] fn rebuild_simple_by_ranges_py<'py>( py: Python<'py>, diff --git a/test/test_lightcone.py b/test/test_lightcone.py index cc89bd2d..87c94d3a 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -19,7 +19,7 @@ def haloproperties_601_path(lightcone_path): @pytest.fixture def all_files(): - return ["haloparticles.hdf5", "haloproperties.hdf5", "sodpropertybins.hdf5"] + return ["haloparticles.hdf5", "haloproperties.hdf5", "haloprofiles.hdf5"] @pytest.fixture From b77d875b0fae3ad0aa3df8026d84abee2f94f8fa Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 12:07:43 -0400 Subject: [PATCH 090/139] Updates for lightcone, significant testing --- python/opencosmo/_lib/index.pyi | 6 + .../collection/lightcone/lightcone.py | 56 ++- .../collection/structure/structure.py | 28 +- python/opencosmo/dataset/dataset.py | 29 +- python/opencosmo/dataset/take.py | 48 +- python/opencosmo/index/ops.py | 1 - test/parallel/test_dataset_mpi.py | 53 +++ test/parallel/test_lc_mpi.py | 418 ++++++++++++++++++ test/test_lightcone.py | 24 +- 9 files changed, 613 insertions(+), 50 deletions(-) diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi index 52e72e1c..0662bf7c 100644 --- a/python/opencosmo/_lib/index.pyi +++ b/python/opencosmo/_lib/index.pyi @@ -16,3 +16,9 @@ def reindex_column(index: IndexArray, index_column: IndexArray) -> IndexArray: . def rebuild_simple_by_ranges( index: IndexArray, starts: IndexArray, sizes: IndexArray ) -> IndexArray: ... +def rebuild_chunked_by_ranges( + starts: IndexArray, + sizes: IndexArray, + chunk_starts: IndexArray, + chunk_sizes: IndexArray, +) -> IndexArray: ... diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 387872ad..c667f473 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -30,10 +30,12 @@ get_end_take_index, get_random_take_index, get_range_take_index, + get_rows_take_index, ) -from opencosmo.index import DataIndex, get_range, rebuild_by_ranges +from opencosmo.index import get_range, into_array, rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.mpi import has_mpi from opencosmo.plugins.contexts import ( HookPoint, LightconeInstantiateCtx, @@ -51,6 +53,7 @@ from opencosmo.dataset import Dataset from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader + from opencosmo.index import DataIndex from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema from opencosmo.spatial import Region @@ -859,11 +862,21 @@ def take( at : str Where to take the rows from. One of "start", "end", or "random". The default is "random". + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- - dataset : Dataset - The new dataset with only the selected rows. + lightcone : Lightcone + The new lightcone with only the selected rows. Raises ------ @@ -872,17 +885,18 @@ def take( or if 'at' is invalid. """ - if n > len(self): - raise ValueError( - "Number of rows to take must be less than number of rows in dataset" - ) if at == "random": index = get_random_take_index(n, len(self), mode) - elif at == "start": index = get_range_take_index(self, self.__sort_key, 0, n, mode) + if self.__sort_key is not None and mode == "global" and has_mpi(): + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) elif at == "end": index = get_end_take_index(n, self, self.__sort_key, mode) + if self.__sort_key is not None and mode == "global" and has_mpi(): + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) else: raise ValueError( f'"at" should be one of ("start", "end", "random", got {at}' @@ -902,9 +916,19 @@ def take_range( Parameters ---------- start : int - The beginning of the range + The beginning of the range. end : int - The end of the range + The end of the range (exclusive). + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- @@ -922,6 +946,9 @@ def take_range( raise ValueError("Tried to take negative rows!") index = get_range_take_index(self, self.__sort_key, start, end - start, mode) + if self.__sort_key is not None and mode == "global" and has_mpi(): + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) return self.__take_rows(index) def take_rows(self, rows: DataIndex): @@ -945,22 +972,19 @@ def take_rows(self, rows: DataIndex): lightcone. """ + index_range = get_range(rows) if isinstance(rows, np.ndarray): rows = np.sort(rows) + index_range = (index_range[0], index_range[1] + 1) else: order = np.argsort(rows[0]) rows = (rows[0][order], rows[1][order]) - index_range = get_range(rows) - if index_range[0] < 0 or index_range[1] > len(self): raise ValueError( "Rows must be between 0 and the length of this dataset - 1" ) - if self.__sort_key is not None: - sort_index = self.__make_sort_index() - rows = sort_index[rows] - rows.sort() + rows = get_rows_take_index(self, rows, self.__sort_key) return self.__take_rows(rows) diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 8c385e02..07b29f6e 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1012,6 +1012,17 @@ def take( at : str, optional The method to use to take the structures. One of "random", "first", or "last". Default is "random". + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. + Returns ------- @@ -1030,7 +1041,9 @@ def take( self.__derived_columns, ) - def take_range(self, start: int, end: int): + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ): """ Create a new collection from a row range in this collection. We use standard indexing conventions, so the rows included will be start -> end - 1. @@ -1041,6 +1054,17 @@ def take_range(self, start: int, end: int): The first row to get. end : int The last row to get. + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. + Returns ------- @@ -1054,7 +1078,7 @@ def take_range(self, start: int, end: int): or if end is greater than start. """ - new_source = self.__source.take_range(start, end) + new_source = self.__source.take_range(start, end, mode) return StructureCollection( new_source, self.__header, diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 5981220d..9e2d4f29 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -723,9 +723,15 @@ def take( Where to take the rows from. One of "start", "end", or "random". The default is "random". mode : str, "local" or "global", default = "local" - If working with MPI, whether the `n` is a per-rank or global - number + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- @@ -760,14 +766,25 @@ def take_range( Parameters ---------- start : int - The beginning of the range + The beginning of the range. end : int - The end of the range + The end of the range (exclusive). + + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- - table : astropy.table.Table - The table with only the rows from start to end. + dataset : Dataset + The new dataset with only the rows from start to end. Raises ------ diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index 95578fb2..dcf6e86e 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -8,6 +8,8 @@ from opencosmo.mpi import get_comm_world, get_mpi, has_mpi if TYPE_CHECKING: + from opencosmo.collection.lightcone import Lightcone + from opencosmo.dataset.dataset import Dataset from opencosmo.index import DataIndex, IndexArray @@ -24,25 +26,45 @@ def get_random_take_index( generator = np.random.default_rng() rows = generator.choice(ds_length, n, replace=False) - return np.sort(rows) + return apply_sort_index(rows) -def apply_sort_index(rows: DataIndex, sort_index: np.ndarray) -> np.ndarray: - """Map logical sorted-order positions to physical row positions for either index type.""" - return np.sort(sort_index[into_array(rows)]) +def apply_sort_index( + rows: DataIndex, sort_index: Optional[np.ndarray] = None +) -> np.ndarray: + """Return row positions in ascending physical order. + With sort_index: treats rows as logical sorted-order positions and maps + them to physical row positions. Without sort_index: rows are already + physical positions and are simply sorted. + """ + arr = into_array(rows) + if sort_index is not None: + return np.sort(sort_index[arr]) + return np.sort(arr) -def _get_sort_index(ds, sort_key: tuple[str, bool]) -> np.ndarray: + +def _get_sort_index(ds: Dataset | Lightcone, sort_key: tuple[str, bool]) -> np.ndarray: sort_col, sort_desc = sort_key - raw = ds.select(sort_col).get_data(ignore_sort=True) - values = np.asarray(raw.value, dtype=np.float64) + values = ds.select(sort_col).get_data("numpy", ignore_sort=True) + assert isinstance(values, np.ndarray) if sort_desc: values = -values return np.argsort(values, kind="stable") +def get_rows_take_index( + ds: Dataset | Lightcone, rows: DataIndex, sort_key: Optional[tuple[str, bool]] +) -> DataIndex: + """Map user-provided logical (sorted-order) row positions to physical row positions.""" + if sort_key is None: + return rows + sort_index = _get_sort_index(ds, sort_key) + return apply_sort_index(rows, sort_index) + + def get_range_take_index( - ds, + ds: Dataset | Lightcone, sort_key: Optional[tuple[str, bool]], start: int, size: int, @@ -64,7 +86,7 @@ def get_range_take_index( def get_end_take_index( n: int, - ds, + ds: Dataset | Lightcone, sort_key, mode: Literal["local", "global"], ): @@ -90,7 +112,9 @@ def get_end_take_index( return single_chunk(start, n) -def get_range_take_index_mpi(ds, sort_key, start, size): +def get_range_take_index_mpi( + ds: Dataset | Lightcone, sort_key: Optional[tuple[str, bool]], start: int, size: int +): comm = get_comm_world() assert comm is not None lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) @@ -148,7 +172,7 @@ def get_range_take_index_mpi(ds, sort_key, start, size): return single_chunk(local_start, local_end - local_start) -def get_global_sort_order(ds, sort_key): +def get_global_sort_order(ds: Dataset | Lightcone, sort_key: tuple[str, bool]): comm = get_comm_world() assert comm is not None @@ -200,7 +224,7 @@ def get_random_take_index_mpi(n: int, ds_length: int): return get_local_rows_simple(rows, lengths, comm) -def get_local_rows_simple(rows: IndexArray | None, lengths, comm): +def get_local_rows_simple(rows: IndexArray | None, lengths: list[int], comm): chunk_ranges = np.zeros(len(lengths) + 1, dtype=np.int64) chunk_ranges[1:] = np.cumsum(lengths) if comm.Get_rank() == 0: diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py index 22353d08..9bcd29ae 100644 --- a/python/opencosmo/index/ops.py +++ b/python/opencosmo/index/ops.py @@ -17,7 +17,6 @@ def reindex_column(index: DataIndex, column: np.ndarray): def rebuild_by_ranges(index: DataIndex, ranges: ChunkedIndex): - print(index) match index: case np.ndarray(): return idxlib.rebuild_simple_by_ranges(index, *ranges) diff --git a/test/parallel/test_dataset_mpi.py b/test/parallel/test_dataset_mpi.py index 956442d4..a1ae6970 100644 --- a/test/parallel/test_dataset_mpi.py +++ b/test/parallel/test_dataset_mpi.py @@ -254,3 +254,56 @@ def test_take_range_global_sorted_middle(input_path): np.all(all_selected <= upper_threshold), "some selected values exceed the upper global threshold", ) + + +# ── inverted sort ───────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_inverted_start(input_path): + """Inverted sort: global start selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take_range(0, n, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_inverted_end(input_path): + """Inverted sort: global end selects the n globally smallest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + ds_taken = ds.take_range(total - n, total, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index adf9a972..a4950143 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -608,6 +608,424 @@ def test_lightcone_stacking( assert next(iter(ds_new.values())).header.lightcone["z_range"] == ds_new.z_range +# ── take global ─────────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global(haloproperties_600_path, haloproperties_601_path): + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + total_length = sum(comm.allgather(len(lc))) + n_to_take = np.random.randint(total_length // 4, int(total_length * 0.75)) + n_to_take = comm.bcast(n_to_take) + + lc = lc.take(n_to_take, mode="global") + all_lengths = comm.allgather(len(lc)) + + parallel_assert(sum(all_lengths) == n_to_take) + + +# ── take_range global, unsorted ─────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_start(haloproperties_600_path, haloproperties_601_path): + """First n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + + lc_taken = lc.take_range(0, n, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max(0, min(int(lengths[rank]), n - offset)) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_end(haloproperties_600_path, haloproperties_601_path): + """Last n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + lc_taken = lc.take_range(global_start, total, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_middle(haloproperties_600_path, haloproperties_601_path): + """A middle window of global rows lands on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + global_start = total // 4 + global_end = 3 * total // 4 + + lc_taken = lc.take_range(global_start, global_end, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), global_end - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == global_end - global_start) + + +# ── take_range global, sorted ───────────────────────────────────────────────── +# +# The sorted tests verify value-level correctness: after a global range take on +# a sorted lightcone, every selected value must satisfy the global threshold +# implied by the range position. + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_start( + haloproperties_600_path, haloproperties_601_path +): + """Global start on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_end( + haloproperties_600_path, haloproperties_601_path +): + """Global end on sorted data selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(total - n, total, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_middle( + haloproperties_600_path, haloproperties_601_path +): + """A middle window on sorted data selects the correct globally-ranked values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + global_start = total // 4 + global_end = 3 * total // 4 + size = global_end - global_start + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + sorted_all = np.sort(all_original) + lower_threshold = sorted_all[global_start] + upper_threshold = sorted_all[global_end - 1] + + lc_taken = lc.take_range(global_start, global_end, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == size) + parallel_assert( + np.all(all_selected >= lower_threshold), + "some selected values fall below the lower global threshold", + ) + parallel_assert( + np.all(all_selected <= upper_threshold), + "some selected values exceed the upper global threshold", + ) + + +# ── take global end ─────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_end(haloproperties_600_path, haloproperties_601_path): + """take(n, at='end', mode='global') selects the last n rows across all ranks.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + lc_taken = lc.take(n, at="end", mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_end_sorted(haloproperties_600_path, haloproperties_601_path): + """take(n, at='end', mode='global') on sorted data selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take(n, at="end", mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +# ── take(at="start") global, sorted ────────────────────────────────────────── +# +# take(n, at="start") is a distinct branch from take_range(0, n) in the +# Lightcone implementation; the sort-order → physical conversion lives +# separately in each branch and must be tested independently. + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_start_sorted(haloproperties_600_path, haloproperties_601_path): + """take(n, at='start', mode='global') on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take(n, at="start", mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +# ── inverted sort ───────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_inverted_start( + haloproperties_600_path, haloproperties_601_path +): + """Inverted sort: global start selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass", invert=True + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_inverted_end( + haloproperties_600_path, haloproperties_601_path +): + """Inverted sort: global end selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass", invert=True + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(total - n, total, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +# ── single-step lightcone ───────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_single_step(haloproperties_600_path): + """Global take on a single-step lightcone produces the correct total count.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path) + + total = sum(comm.allgather(len(lc))) + n_to_take = np.random.randint(total // 4, int(total * 0.75)) + n_to_take = comm.bcast(n_to_take) + + lc_taken = lc.take(n_to_take, mode="global") + parallel_assert(sum(comm.allgather(len(lc_taken))) == n_to_take) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_single_step_sorted(haloproperties_600_path): + """Sorted global take_range on a single-step lightcone selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_single_step_sorted_inverted(haloproperties_600_path): + """Inverted sort on a single-step lightcone: global start selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + def _get_expected_core_tags(group): raw_top_host = group["data"]["top_host_idx"][:] core_tag = group["data"]["core_tag"][:] diff --git a/test/test_lightcone.py b/test/test_lightcone.py index 87c94d3a..febfd326 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -197,19 +197,17 @@ def test_lc_collection_take_rows(haloproperties_600_path, haloproperties_601_pat assert np.all( data["fof_halo_mass"][sorted_index][rows] == sorted_tags["fof_halo_mass"] ) - toolkit_sorted_tags_mass = dict( - zip(sorted_tags["fof_halo_tag"], sorted_tags["fof_halo_mass"]) - ) - sorted_tags_mass = dict( - zip( - data["fof_halo_tag"][sorted_index][rows], - data["fof_halo_mass"][sorted_index][rows], - ) - ) - # Exact order is not deterministic, because many low_mass halos have the same mass, - # So we just make sure the tag->mass mapping is the same in the two datasets. - - assert toolkit_sorted_tags_mass == sorted_tags_mass + # Verify each returned (tag, mass) pair is internally consistent with the original + # dataset. We do not compare against a reference sort because sort stability + # determines which specific halo is selected among ties, and we don't require a + # particular choice — only that the returned tag actually belongs to a halo with + # the returned mass. + tag_order = np.argsort(data["fof_halo_tag"]) + all_tags_sorted = data["fof_halo_tag"][tag_order] + all_mass_by_tag = data["fof_halo_mass"][tag_order] + + positions = np.searchsorted(all_tags_sorted, sorted_tags["fof_halo_tag"]) + assert np.all(all_mass_by_tag[positions] == sorted_tags["fof_halo_mass"]) def test_lc_collection_derive( From c1cd8e366ea897e92d384cec98d45535343b2f10 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 13:13:12 -0400 Subject: [PATCH 091/139] Fixes for a few tests --- python/opencosmo/collection/lightcone/lightcone.py | 6 +++--- python/opencosmo/dataset/dataset.py | 1 - python/opencosmo/dataset/take.py | 9 +++++---- test/test_lightcone.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index c667f473..89299886 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -889,12 +889,12 @@ def take( index = get_random_take_index(n, len(self), mode) elif at == "start": index = get_range_take_index(self, self.__sort_key, 0, n, mode) - if self.__sort_key is not None and mode == "global" and has_mpi(): + if self.__sort_key is not None and not (mode == "global" and has_mpi()): sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) elif at == "end": index = get_end_take_index(n, self, self.__sort_key, mode) - if self.__sort_key is not None and mode == "global" and has_mpi(): + if self.__sort_key is not None and not (mode == "global" and has_mpi()): sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) else: @@ -946,7 +946,7 @@ def take_range( raise ValueError("Tried to take negative rows!") index = get_range_take_index(self, self.__sort_key, start, end - start, mode) - if self.__sort_key is not None and mode == "global" and has_mpi(): + if self.__sort_key is not None and not (mode == "global" and has_mpi()): sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) return self.__take_rows(index) diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 9e2d4f29..f8d2d6f6 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -801,7 +801,6 @@ def take_range( take_index = get_range_take_index( self, self.__state.sort_key, start, end - start, mode ) - return self.take_rows(take_index) def take_rows(self, rows: np.ndarray | DataIndex): diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index dcf6e86e..ff847c35 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -78,8 +78,10 @@ def get_range_take_index( size = ds_len - start if sort_key is not None: - sort_index = _get_sort_index(ds, sort_key) - return np.sort(sort_index[start : start + size]) + # Return logical sorted positions; st.take_rows (for Dataset) will map + # them to physical positions via sorted_idx. Lightcone callers must apply + # their own sort mapping after calling this function. + return np.arange(start, start + size, dtype=np.int64) return single_chunk(start, size) @@ -106,8 +108,7 @@ def get_end_take_index( n = ds_length if sort_key is not None: - sort_index = _get_sort_index(ds, sort_key) - return np.sort(sort_index[start : start + n]) + return np.arange(start, start + n, dtype=np.int64) return single_chunk(start, n) diff --git a/test/test_lightcone.py b/test/test_lightcone.py index febfd326..dca0af50 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -40,8 +40,8 @@ def test_create_theta_phi_coords(haloproperties_600_path, haloproperties_601_pat ra = (data["phi"] * u.rad).to(u.deg) dec = ((np.pi / 2 - data["theta"]) * u.rad).to(u.deg) - assert np.allclose(data["ra"], ra, rtol=1e-2) - assert np.allclose(data["dec"], dec, rtol=1e-2) + assert np.allclose(data["ra"], ra, atol=0.01, rtol=1e-2) + assert np.allclose(data["dec"], dec, atol=0.01, rtol=1e-2) def test_lightcone_physical_units(haloproperties_600_path): From 37905e804dd2e7c5555fd7be33d1831a83253a05 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 13:50:55 -0400 Subject: [PATCH 092/139] Small updates, add changelog --- changes/240.bugfix.rst | 1 + changes/240.feature.rst | 1 + .../collection/lightcone/lightcone.py | 18 ++++---- .../collection/structure/structure.py | 19 ++++---- python/opencosmo/dataset/dataset.py | 14 +++--- python/opencosmo/dataset/take.py | 24 +++------- src/index.rs | 46 +++++++++---------- 7 files changed, 56 insertions(+), 67 deletions(-) create mode 100644 changes/240.bugfix.rst create mode 100644 changes/240.feature.rst diff --git a/changes/240.bugfix.rst b/changes/240.bugfix.rst new file mode 100644 index 00000000..958c8149 --- /dev/null +++ b/changes/240.bugfix.rst @@ -0,0 +1 @@ +Requesting more rows than exist via :py:meth:`take ` or :py:meth:`take_range ` no longer raises a ``ValueError``. Instead, all available rows are returned. This applies to :py:class:`Dataset `, :py:class:`Lightcone `, and :py:class:`StructureCollection `. diff --git a/changes/240.feature.rst b/changes/240.feature.rst new file mode 100644 index 00000000..0ef4ab54 --- /dev/null +++ b/changes/240.feature.rst @@ -0,0 +1 @@ +:py:meth:`take `, :py:meth:`take_range `, and their equivalents on :py:class:`Lightcone ` and :py:class:`StructureCollection ` now accept a ``mode`` keyword argument. Setting ``mode="global"`` when running under MPI causes ``n`` (or ``start``/``end``) to be interpreted across all ranks combined rather than per-rank. When the dataset is sorted, ranks coordinate to select from the globally-sorted order, so ``ds.sort_by("fof_halo_mass").take(1000, mode="global")`` returns exactly the 1000 most massive halos distributed across all ranks. diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 89299886..72c31df4 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -32,7 +32,7 @@ get_range_take_index, get_rows_take_index, ) -from opencosmo.index import get_range, into_array, rebuild_by_ranges +from opencosmo.index import get_length, get_range, into_array, rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema from opencosmo.mpi import has_mpi @@ -920,15 +920,15 @@ def take_range( end : int The end of the range (exclusive). mode : str, "local" or "global", default = "local" - Controls how ``n`` is interpreted when running under MPI. Has no - effect if you are not using MPI. + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. - * ``"local"`` (default): ``n`` rows are taken independently on + * ``"local"`` (default): the range is applied independently on each rank. - * ``"global"``: ``n`` is the total number of rows to select across - all ranks combined. Each rank receives the portion of those rows - that it owns. If the dataset is sorted, ranks will coordinate - to take from the globally-sorted dataset. + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the lightcone is sorted, ranks will + coordinate to take from the globally-sorted lightcone. Returns ------- @@ -1009,6 +1009,8 @@ def __take_rows(self, rows: DataIndex): projected = rebuild_by_ranges(rows, (starts, sizes)) output = {} for (name, ds), index in zip(self.items(), projected): + if get_length(index) == 0: + continue output[name] = ds.take_rows(index) return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 07b29f6e..6f8049e9 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1055,21 +1055,20 @@ def take_range( end : int The last row to get. mode : str, "local" or "global", default = "local" - Controls how ``n`` is interpreted when running under MPI. Has no - effect if you are not using MPI. + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. - * ``"local"`` (default): ``n`` rows are taken independently on + * ``"local"`` (default): the range is applied independently on each rank. - * ``"global"``: ``n`` is the total number of rows to select across - all ranks combined. Each rank receives the portion of those rows - that it owns. If the dataset is sorted, ranks will coordinate - to take from the globally-sorted dataset. - + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the collection is sorted, ranks will + coordinate to take from the globally-sorted collection. Returns ------- - table : astropy.table.Table - The table with only the rows from start to end. + collection : StructureCollection + The collection with only the rows from start to end. Raises ------ diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index f8d2d6f6..741fb8eb 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -771,15 +771,15 @@ def take_range( The end of the range (exclusive). mode : str, "local" or "global", default = "local" - Controls how ``n`` is interpreted when running under MPI. Has no - effect if you are not using MPI. + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. - * ``"local"`` (default): ``n`` rows are taken independently on + * ``"local"`` (default): the range is applied independently on each rank. - * ``"global"``: ``n`` is the total number of rows to select across - all ranks combined. Each rank receives the portion of those rows - that it owns. If the dataset is sorted, ranks will coordinate - to take from the globally-sorted dataset. + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the dataset is sorted, ranks will + coordinate to take from the globally-sorted dataset. Returns ------- diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py index ff847c35..645db2ac 100644 --- a/python/opencosmo/dataset/take.py +++ b/python/opencosmo/dataset/take.py @@ -77,19 +77,13 @@ def get_range_take_index( if start + size > ds_len: size = ds_len - start - if sort_key is not None: - # Return logical sorted positions; st.take_rows (for Dataset) will map - # them to physical positions via sorted_idx. Lightcone callers must apply - # their own sort mapping after calling this function. - return np.arange(start, start + size, dtype=np.int64) - return single_chunk(start, size) def get_end_take_index( n: int, ds: Dataset | Lightcone, - sort_key, + sort_key: Optional[tuple[str, bool]], mode: Literal["local", "global"], ): ds_length = len(ds) @@ -107,9 +101,6 @@ def get_end_take_index( start = 0 n = ds_length - if sort_key is not None: - return np.arange(start, start + n, dtype=np.int64) - return single_chunk(start, n) @@ -159,7 +150,7 @@ def get_range_take_index_mpi( local_size = int(count_per_rank[rank]) if local_size == 0: - return np.array([], dtype=np.int64) + return empty() return single_chunk(local_start, local_size) # Handle the case without sorting: contiguous global range @@ -187,14 +178,11 @@ def get_global_sort_order(ds: Dataset | Lightcone, sort_key: tuple[str, bool]): lengths = np.array(comm.allgather(len(local_values)), dtype=np.int64) total_length = int(np.sum(lengths)) rank = comm.Get_rank() - offset = int(np.sum(lengths[:rank])) - # Use comm.Reduce to get the full catalog on rank 0. - # Each rank writes its values at its global offset; summing gives the full array. - local_contribution = np.zeros(total_length, dtype=np.float64) - local_contribution[offset : offset + len(local_values)] = local_values - recv = np.zeros(total_length, dtype=np.float64) if rank == 0 else None - comm.Reduce(local_contribution, recv, op=get_mpi().SUM, root=0) + offsets = np.zeros(len(lengths), dtype=np.int64) + offsets[1:] = np.cumsum(lengths)[:-1] + recv = np.empty(total_length, dtype=np.float64) if rank == 0 else None + comm.Gatherv(local_values, [recv, lengths, offsets, get_mpi().DOUBLE], root=0) # Other ranks return None if rank != 0: diff --git a/src/index.rs b/src/index.rs index 8eb8a268..aa7c118e 100644 --- a/src/index.rs +++ b/src/index.rs @@ -411,33 +411,31 @@ pub(crate) mod index { range_starts: ArrayView1<'_, i64>, range_sizes: ArrayView1<'_, i64>, ) -> Vec> { - let mut rs: i64 = 0; - let mut current_chunk_index: usize = 0; - let mut current_chunk_end: i64 = - range_starts[current_chunk_index] + range_sizes[current_chunk_index]; - let mut output_indices = Vec::new(); - let mut current_chunk_indices = Vec::new(); + let n_ranges = range_starts.len(); + let mut outputs: Vec> = (0..n_ranges).map(|_| Vec::new()).collect(); + + if n_ranges == 0 || index.len() == 0 { + return outputs.into_iter().map(Array1::from_vec).collect(); + } + + let mut j = 0usize; for &idx in index { - if idx >= current_chunk_end { - output_indices.push(Array1::from_vec(current_chunk_indices)); - rs = current_chunk_end; - current_chunk_index += 1; - current_chunk_end = - range_starts[current_chunk_index] + range_sizes[current_chunk_index]; - current_chunk_indices = Vec::new(); - current_chunk_indices.push(idx - rs); - } else if idx < rs { - continue; - } else { - current_chunk_indices.push(idx - rs); + // Advance past all ranges whose end is at or before idx. + // Using a while loop handles the case where idx skips multiple ranges + // and prevents an OOB panic when advancing past the last range. + while j < n_ranges && idx >= range_starts[j] + range_sizes[j] { + j += 1; } - } - output_indices.push(Array1::from_vec(current_chunk_indices)); - if output_indices.len() < range_starts.len() { - for _ in 0..range_starts.len() - output_indices.len() { - output_indices.push(Array1::zeros(0)) + if j >= n_ranges { + break; } + // idx falls before the start of range j (gap in a non-contiguous layout). + if idx < range_starts[j] { + continue; + } + outputs[j].push(idx - range_starts[j]); } - return output_indices; + + outputs.into_iter().map(Array1::from_vec).collect() } } From aebf117f6a6e0a8ef605ca3bc702a113e9eac244 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 14:10:17 -0400 Subject: [PATCH 093/139] Fix for zero-length lightcones --- python/opencosmo/collection/lightcone/lightcone.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 72c31df4..487b8c05 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -32,7 +32,7 @@ get_range_take_index, get_rows_take_index, ) -from opencosmo.index import get_length, get_range, into_array, rebuild_by_ranges +from opencosmo.index import get_range, into_array, rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema from opencosmo.mpi import has_mpi @@ -1009,10 +1009,13 @@ def __take_rows(self, rows: DataIndex): projected = rebuild_by_ranges(rows, (starts, sizes)) output = {} for (name, ds), index in zip(self.items(), projected): - if get_length(index) == 0: - continue output[name] = ds.take_rows(index) + if all(len(ds) == 0 for ds in output.values()): + output = {"data": next(iter(output.values()))} + else: + output = {name: ds for name, ds in output.items() if len(ds) != 0} + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) def with_new_columns( From 8b53edd90bf3e6b969e90a8cb1f6328aa3fb3a4b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 15:58:42 -0400 Subject: [PATCH 094/139] Fix sorted lightcones, disable tests that will be updated in different PR --- python/opencosmo/collection/lightcone/lightcone.py | 7 +++---- test/parallel/test_lc_mpi.py | 2 ++ test/test_lightcone.py | 2 ++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 487b8c05..8ce462a8 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -35,7 +35,6 @@ from opencosmo.index import get_range, into_array, rebuild_by_ranges from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.mpi import has_mpi from opencosmo.plugins.contexts import ( HookPoint, LightconeInstantiateCtx, @@ -889,12 +888,12 @@ def take( index = get_random_take_index(n, len(self), mode) elif at == "start": index = get_range_take_index(self, self.__sort_key, 0, n, mode) - if self.__sort_key is not None and not (mode == "global" and has_mpi()): + if self.__sort_key is not None: sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) elif at == "end": index = get_end_take_index(n, self, self.__sort_key, mode) - if self.__sort_key is not None and not (mode == "global" and has_mpi()): + if self.__sort_key is not None: sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) else: @@ -946,7 +945,7 @@ def take_range( raise ValueError("Tried to take negative rows!") index = get_range_take_index(self, self.__sort_key, start, end - start, mode) - if self.__sort_key is not None and not (mode == "global" and has_mpi()): + if self.__sort_key is not None: sort_index = self.__make_sort_index() index = np.sort(sort_index[into_array(index)]) return self.__take_rows(index) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index a4950143..e7c7a5fa 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -890,6 +890,8 @@ def test_lc_take_global_start_sorted(haloproperties_600_path, haloproperties_601 selected = lc_taken.select("fof_halo_mass").get_data("numpy") all_selected = np.concatenate(comm.allgather(selected)) + print(all_selected) + print(threshold) parallel_assert(len(all_selected) == n) parallel_assert( diff --git a/test/test_lightcone.py b/test/test_lightcone.py index dca0af50..dcbb786e 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -538,11 +538,13 @@ def test_lightcone_stacking_nostack( assert ds_new.z_range == ds.z_range +@pytest.mark.skip def test_lightcone_structure_collection_open(structure_600): c = oc.open(*structure_600) assert isinstance(c, oc.StructureCollection) +@pytest.mark.skip def test_lightcone_structure_collection_open_multiple(structure_600, structure_601): with pytest.raises(NotImplementedError): _ = oc.open(*structure_600, *structure_601) From c57e81ee97f68a4feccd95efbfc18433d40e8253 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 22:29:26 -0400 Subject: [PATCH 095/139] Unify column and metadata column handling in hdf5 handler --- python/opencosmo/handler/hdf5.py | 95 ++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/python/opencosmo/handler/hdf5.py b/python/opencosmo/handler/hdf5.py index 8e63c19e..8a86f3c4 100644 --- a/python/opencosmo/handler/hdf5.py +++ b/python/opencosmo/handler/hdf5.py @@ -4,6 +4,7 @@ from itertools import chain from typing import TYPE_CHECKING, Iterable, Optional +import h5py import numpy as np from opencosmo.io.schema import FileEntry, make_schema from opencosmo.io.writer import ( @@ -20,13 +21,15 @@ ) if TYPE_CHECKING: - import h5py from opencosmo.header import OpenCosmoHeader from opencosmo.io.schema import Schema from opencosmo.index import DataIndex +ColumnSpec = tuple[h5py.Dataset, bool] + + class Hdf5Handler: """ Handler for opencosmo.Dataset @@ -34,15 +37,13 @@ class Hdf5Handler: def __init__( self, - columns: dict[str, h5py.Dataset], + columns: dict[str, ColumnSpec], index: DataIndex, - metadata_columns: dict[str, h5py.Dataset], load_conditions: Optional[dict[str, bool]] = None, ): self.__index = index self.__columns = columns - self.__metadata_columns = metadata_columns - self.__in_memory = next(iter(columns.values())).file.driver == "core" + self.__in_memory = next(iter(columns.values()))[0].file.driver == "core" self.__load_conditions = load_conditions @classmethod @@ -60,10 +61,12 @@ def from_columns( lambda col: col.name.split("/")[-2] == metadata_group, columns ) - data_columns_ = {col.name.split("/")[-1]: col for col in data_columns} - metadata_columns_ = {col.name.split("/")[-1]: col for col in metadata_columns} + data_columns_ = {col.name.split("/")[-1]: (col, True) for col in data_columns} + metadata_columns_ = { + col.name.split("/")[-1]: (col, False) for col in metadata_columns + } lengths = set( - len(col) + len(col[0]) for col in chain(data_columns_.values(), metadata_columns_.values()) ) if len(lengths) > 1: @@ -72,7 +75,7 @@ def from_columns( if index is None: index = from_size(lengths.pop()) - return Hdf5Handler(data_columns_, index, metadata_columns_, load_conditions) + return Hdf5Handler(data_columns_ | metadata_columns_, index, load_conditions) def __len__(self): return get_length(self.__index) @@ -87,17 +90,13 @@ def load_conditions(self) -> Optional[dict[str, bool]]: def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None): if len(other) == 0: - return Hdf5Handler( - self.__columns, other, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, other, self.__load_conditions) if sorted is not None: return self.__take_sorted(other, sorted) new_index = take(self.__index, other) - return Hdf5Handler( - self.__columns, new_index, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, new_index, self.__load_conditions) def __take_sorted(self, other: DataIndex, sorted: np.ndarray): if get_length(sorted) != get_length(self.__index): @@ -107,13 +106,11 @@ def __take_sorted(self, other: DataIndex, sorted: np.ndarray): new_raw_index = into_array(self.__index)[new_indices] new_index = np.sort(new_raw_index) - return Hdf5Handler( - self.__columns, new_index, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, new_index, self.__load_conditions) @property def data(self): - return next(iter(self.__columns.values())).parent + return next(iter(self.__columns.values()))[0].parent @property def index(self): @@ -121,17 +118,24 @@ def index(self): @cached_property def columns(self): - return self.__columns.keys() + return [ + colname for colname in self.__columns.keys() if self.__columns[colname][1] + ] @property def metadata_columns(self): - return self.__metadata_columns.keys() + return [ + colname + for colname in self.__columns.keys() + if not self.__columns[colname][1] + ] @cached_property def descriptions(self): return { - colname: column.attrs.get("description") + colname: column[0].attrs.get("description") for colname, column in self.__columns.items() + if column[1] } def mask(self, mask): @@ -152,28 +156,34 @@ def make_schema( ) -> tuple[Schema, Optional[Schema]]: column_writers = {} for column_name in columns: + column = self.__columns[column_name] + if not column[1]: + continue column_writers[column_name] = ColumnWriter.from_h5_dataset( - self.__columns[column_name], + column[0], self.__index, - attrs=dict(self.__columns[column_name].attrs), + attrs=dict(column[0].attrs), ) data_schema = make_schema("data", FileEntry.COLUMNS, columns=column_writers) - if self.metadata_columns: - assert len(self.__metadata_columns) > 0 - metadata_writers = {} - group_name = next(iter(self.__metadata_columns.values())).parent.name - group_name = group_name.split("/")[-1] - for column_name, column in self.__metadata_columns.items(): - metadata_writers[column_name] = ColumnWriter.from_h5_dataset( - column, self.__index, attrs=dict(column.attrs) - ) - metadata_schema = make_schema( - group_name, FileEntry.COLUMNS, columns=metadata_writers - ) - else: + metadata_columns = { + name: column for name, column in self.__columns.items() if not column[1] + } + if not metadata_columns: metadata_schema = make_schema("metadata", FileEntry.EMPTY) + return data_schema, metadata_schema + + metadata_writers = {} + group_name = next(iter(metadata_columns.values()))[0].parent.name + group_name = group_name.split("/")[-1] + for column_name, column in metadata_columns.items(): + metadata_writers[column_name] = ColumnWriter.from_h5_dataset( + column[0], self.__index, attrs=dict(column[0].attrs) + ) + metadata_schema = make_schema( + group_name, FileEntry.COLUMNS, columns=metadata_writers + ) return data_schema, metadata_schema def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: @@ -183,19 +193,22 @@ def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: data = {} for colname in columns: - data[colname] = get_data(self.__columns[colname], self.__index) + data[colname] = get_data(self.__columns[colname][0], self.__index) # Ensure order is preserved return {name: data[name] for name in columns} def get_metadata(self, columns: Iterable[str]) -> Optional[dict[str, np.ndarray]]: - if len(self.__metadata_columns) == 0: + metadata_columns = { + name: col[0] for name, col in self.__columns.items() if not col[1] + } + if len(metadata_columns) == 0: return None if not columns: - columns = self.metadata_columns + columns = metadata_columns.keys() data = {} for colname in columns: - data[colname] = get_data(self.__metadata_columns[colname], self.__index) + data[colname] = get_data(metadata_columns[colname], self.__index) return data From 372e86d46107a22cf9f1c66aec60886aa38d0e1c Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 23:18:48 -0400 Subject: [PATCH 096/139] Unify column and metadata column in state and instantiation --- python/opencosmo/dataset/instantiate.py | 17 ---- python/opencosmo/dataset/output.py | 5 +- python/opencosmo/dataset/state.py | 55 ++++++++----- python/opencosmo/handler/empty.py | 11 +-- python/opencosmo/handler/hdf5.py | 103 +++++++----------------- python/opencosmo/handler/protocols.py | 9 +-- 6 files changed, 76 insertions(+), 124 deletions(-) diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 2b4db718..60507c80 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -160,7 +160,6 @@ def instantiate_dataset( cache: DataCache, unit_handler: UnitHandler, unit_kwargs: dict[str, Any], - metadata_columns: list[str] | None = None, sort_by: str | None = None, ): # Extend working_columns with the sort column if it isn't already included. @@ -216,20 +215,4 @@ def instantiate_dataset( for name, producer_uuid in working_columns.items() if producer_uuid in uuid_data and name in uuid_data[producer_uuid] } - data |= get_metadata_columns(raw_data_handler, cache, metadata_columns) return data - - -def get_metadata_columns( - raw_data_handler: DataHandler, cache: DataCache, metadata_columns: list[str] | None -): - if metadata_columns is None: - return {} - metadata = cache.get_metadata(metadata_columns) - additional_metadata_columns_to_fetch = set(metadata_columns).difference( - metadata.keys() - ) - metadata |= ( - raw_data_handler.get_metadata(additional_metadata_columns_to_fetch) or {} - ) - return metadata diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py index b535f792..186961f3 100644 --- a/python/opencosmo/dataset/output.py +++ b/python/opencosmo/dataset/output.py @@ -78,7 +78,10 @@ def make_dataset_schema( columns = set(columns_to_uuid.keys()) header = header.with_region(region) raw_columns = columns.intersection(raw_data_handler.columns) - data_schema, metadata_schema = raw_data_handler.make_schema(raw_columns, header) + raw_meta_columns = raw_columns & set(meta_columns) + data_schema, metadata_schema = raw_data_handler.make_schema( + raw_columns, raw_meta_columns, header + ) cached_data_schema, cached_metadata_schema = cache.make_schema( columns_to_uuid, meta_columns diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 1748e0d5..e316e85d 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -84,6 +84,7 @@ class DatasetState: region: Region open_kwargs: dict[str, Any] sort_key: Optional[tuple[str, bool]] + metadata_columns: frozenset[str] def __post_init__(self): self.cache.register_column_group(id(self), self.column_map) @@ -91,14 +92,11 @@ def __post_init__(self): @property def columns(self) -> list[str]: - return list(self.column_map.keys()) + return [c for c in self.column_map if c not in self.metadata_columns] @property def meta_columns(self) -> list[str]: - columns = set(self.cache.metadata_columns).union( - self.raw_data_handler.metadata_columns - ) - return list(columns) + return [c for c in self.column_map if c in self.metadata_columns] @property def descriptions(self): @@ -168,6 +166,11 @@ def state_from_target( unit_handler = make_unit_handler_from_hdf5( target["columns"], target["header"], unit_convention ) + meta_column_names = frozenset( + col.name.split("/")[-1] + for col in target["columns"] + if metadata_group and col.name.split("/")[-2] == metadata_group + ) descriptions = handler.descriptions raw_producers = [ @@ -186,6 +189,7 @@ def state_from_target( region=region, open_kwargs=open_kwargs, sort_key=None, + metadata_columns=meta_column_names, ) @@ -201,22 +205,21 @@ def state_in_memory( ) -> DatasetState: descriptions = descriptions or {} + all_columns = dict(data_columns) | dict(metadata_columns) raw_producers = [ RawColumn(cname, descriptions.get(cname, "None")) - for cname in data_columns.keys() + for cname in all_columns.keys() ] column_map = {p.name: p.uuid for p in raw_producers} producers: dict[UUID, ConstructedColumn] = {p.uuid: p for p in raw_producers} cache = ColumnCache.empty() - if data_columns: - uuid_data = {p.uuid: {p.name: data_columns[p.name]} for p in raw_producers} + if all_columns: + uuid_data = {p.uuid: {p.name: all_columns[p.name]} for p in raw_producers} cache.add_data(uuid_data, descriptions) - if metadata_columns: - cache.add_metadata(dict(metadata_columns), {}) units: dict[str, u.Unit] = {} - for name, column in data_columns.items(): + for name, column in all_columns.items(): units[name] = None if isinstance(column, u.Quantity): units[name] = column.unit @@ -233,6 +236,7 @@ def state_in_memory( region=region, open_kwargs=open_kwargs, sort_key=None, + metadata_columns=frozenset(metadata_columns.keys()), ) @@ -263,7 +267,6 @@ def get_data( state.cache, state.unit_handler, unit_kwargs, - metadata_columns, None if (ignore_sort or state.sort_key is None) else state.sort_key[0], ) @@ -275,9 +278,10 @@ def get_data( if not ignore_sort: data = sort_data(data, state.sort_key, state) - new_order = [c for c in state.columns] - if metadata_columns: - new_order.extend(metadata_columns) + new_order = list(state.columns) + for name in metadata_columns: + if name in state.metadata_columns: + new_order.append(name) return {name: data[name] for name in new_order} @@ -333,11 +337,20 @@ def iter_rows( def get_metadata(state: DatasetState, columns: list = []) -> dict: - metadata = state.raw_data_handler.get_metadata(columns) + names = list(columns) if columns else list(state.metadata_columns) + data = instantiate_dataset( + list(state.producers.values()), + {name: state.column_map[name] for name in names}, + state.raw_data_handler, + state.cache, + state.unit_handler, + {}, + None, + ) sorted_index = get_sorted_index(state) if sorted_index is not None: - metadata = {name: values[sorted_index] for name, values in metadata.items()} - return metadata + data = {name: values[sorted_index] for name, values in data.items()} + return data def make_schema(state: DatasetState, name: Optional[str] = None) -> Schema: @@ -415,9 +428,9 @@ def select(state: DatasetState, columns: set[str], drop: bool = False) -> Datase if drop: selections = set(state.columns) - selections - return dataclasses.replace( - state, column_map={n: state.column_map[n] for n in selections} - ) + new_column_map = {n: state.column_map[n] for n in selections} + new_column_map |= {n: state.column_map[n] for n in state.metadata_columns} + return dataclasses.replace(state, column_map=new_column_map) def sort_by(state: DatasetState, column_name: str, invert: bool) -> DatasetState: diff --git a/python/opencosmo/handler/empty.py b/python/opencosmo/handler/empty.py index f18a4f6f..09d12237 100644 --- a/python/opencosmo/handler/empty.py +++ b/python/opencosmo/handler/empty.py @@ -2,11 +2,13 @@ from typing import TYPE_CHECKING, Iterable, Optional, Self -from opencosmo.index import empty from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.index import empty + if TYPE_CHECKING: import numpy as np + from opencosmo.index import DataIndex @@ -14,9 +16,6 @@ class EmptyHandler: def get_data(self, *args): return {} - def get_metadata(self, columns: Iterable[str]) -> dict[str, np.ndarray]: - return {} - def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None) -> Self: return self @@ -36,10 +35,6 @@ def columns(self) -> Iterable[str]: def load_conditions(self): return None - @property - def metadata_columns(self) -> Iterable[str]: - return set() - @property def descriptions(self): return {} diff --git a/python/opencosmo/handler/hdf5.py b/python/opencosmo/handler/hdf5.py index 8a86f3c4..16b81af3 100644 --- a/python/opencosmo/handler/hdf5.py +++ b/python/opencosmo/handler/hdf5.py @@ -1,10 +1,8 @@ from __future__ import annotations from functools import cached_property -from itertools import chain from typing import TYPE_CHECKING, Iterable, Optional -import h5py import numpy as np from opencosmo.io.schema import FileEntry, make_schema from opencosmo.io.writer import ( @@ -21,15 +19,13 @@ ) if TYPE_CHECKING: + import h5py from opencosmo.header import OpenCosmoHeader from opencosmo.io.schema import Schema from opencosmo.index import DataIndex -ColumnSpec = tuple[h5py.Dataset, bool] - - class Hdf5Handler: """ Handler for opencosmo.Dataset @@ -37,13 +33,13 @@ class Hdf5Handler: def __init__( self, - columns: dict[str, ColumnSpec], + columns: dict[str, h5py.Dataset], index: DataIndex, load_conditions: Optional[dict[str, bool]] = None, ): self.__index = index self.__columns = columns - self.__in_memory = next(iter(columns.values()))[0].file.driver == "core" + self.__in_memory = next(iter(columns.values())).file.driver == "core" self.__load_conditions = load_conditions @classmethod @@ -54,28 +50,24 @@ def from_columns( metadata_group: Optional[str] = None, load_conditions: Optional[dict[str, bool]] = None, ): - data_columns = filter(lambda col: col.name.split("/")[-2] == "data", columns) - metadata_columns: Iterable[h5py.Dataset] = [] + groups = {"data"} if metadata_group: - metadata_columns = filter( - lambda col: col.name.split("/")[-2] == metadata_group, columns - ) + groups.add(metadata_group) - data_columns_ = {col.name.split("/")[-1]: (col, True) for col in data_columns} - metadata_columns_ = { - col.name.split("/")[-1]: (col, False) for col in metadata_columns + all_columns = { + col.name.split("/")[-1]: col + for col in columns + if col.name.split("/")[-2] in groups } - lengths = set( - len(col[0]) - for col in chain(data_columns_.values(), metadata_columns_.values()) - ) + + lengths = set(len(col) for col in all_columns.values()) if len(lengths) > 1: raise ValueError("Not all columns are the same length!") if index is None: index = from_size(lengths.pop()) - return Hdf5Handler(data_columns_ | metadata_columns_, index, load_conditions) + return Hdf5Handler(all_columns, index, load_conditions) def __len__(self): return get_length(self.__index) @@ -110,7 +102,7 @@ def __take_sorted(self, other: DataIndex, sorted: np.ndarray): @property def data(self): - return next(iter(self.__columns.values()))[0].parent + return next(iter(self.__columns.values())).parent @property def index(self): @@ -118,24 +110,13 @@ def index(self): @cached_property def columns(self): - return [ - colname for colname in self.__columns.keys() if self.__columns[colname][1] - ] - - @property - def metadata_columns(self): - return [ - colname - for colname in self.__columns.keys() - if not self.__columns[colname][1] - ] + return list(self.__columns.keys()) @cached_property def descriptions(self): return { - colname: column[0].attrs.get("description") + colname: column.attrs.get("description") for colname, column in self.__columns.items() - if column[1] } def mask(self, mask): @@ -152,39 +133,32 @@ def __exit__(self, *exec_details): def make_schema( self, columns: Iterable[str], + metadata_columns: set[str] = set(), header: Optional[OpenCosmoHeader] = None, ) -> tuple[Schema, Optional[Schema]]: - column_writers = {} - for column_name in columns: + columns = set(columns) + data_writers = {} + for column_name in columns - metadata_columns: column = self.__columns[column_name] - if not column[1]: - continue - column_writers[column_name] = ColumnWriter.from_h5_dataset( - column[0], - self.__index, - attrs=dict(column[0].attrs), + data_writers[column_name] = ColumnWriter.from_h5_dataset( + column, self.__index, attrs=dict(column.attrs) ) + data_schema = make_schema("data", FileEntry.COLUMNS, columns=data_writers) - data_schema = make_schema("data", FileEntry.COLUMNS, columns=column_writers) - - metadata_columns = { - name: column for name, column in self.__columns.items() if not column[1] - } - if not metadata_columns: - metadata_schema = make_schema("metadata", FileEntry.EMPTY) - return data_schema, metadata_schema + raw_meta = columns & metadata_columns + if not raw_meta: + return data_schema, make_schema("metadata", FileEntry.EMPTY) + group_name = self.__columns[next(iter(raw_meta))].parent.name.split("/")[-1] metadata_writers = {} - group_name = next(iter(metadata_columns.values()))[0].parent.name - group_name = group_name.split("/")[-1] - for column_name, column in metadata_columns.items(): + for column_name in raw_meta: + column = self.__columns[column_name] metadata_writers[column_name] = ColumnWriter.from_h5_dataset( - column[0], self.__index, attrs=dict(column[0].attrs) + column, self.__index, attrs=dict(column.attrs) ) - metadata_schema = make_schema( + return data_schema, make_schema( group_name, FileEntry.COLUMNS, columns=metadata_writers ) - return data_schema, metadata_schema def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: """ """ @@ -193,25 +167,10 @@ def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: data = {} for colname in columns: - data[colname] = get_data(self.__columns[colname][0], self.__index) + data[colname] = get_data(self.__columns[colname], self.__index) # Ensure order is preserved return {name: data[name] for name in columns} - def get_metadata(self, columns: Iterable[str]) -> Optional[dict[str, np.ndarray]]: - metadata_columns = { - name: col[0] for name, col in self.__columns.items() if not col[1] - } - if len(metadata_columns) == 0: - return None - if not columns: - columns = metadata_columns.keys() - - data = {} - for colname in columns: - data[colname] = get_data(metadata_columns[colname], self.__index) - - return data - def take_range(self, start: int, end: int, indices: np.ndarray) -> np.ndarray: if start < 0 or end > len(indices): raise ValueError("Indices out of range") diff --git a/python/opencosmo/handler/protocols.py b/python/opencosmo/handler/protocols.py index 6e674861..9a125546 100644 --- a/python/opencosmo/handler/protocols.py +++ b/python/opencosmo/handler/protocols.py @@ -16,19 +16,18 @@ class DataHandler(Protocol): def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: """ """ - def get_metadata(self, columns: Iterable[str]) -> dict[str, np.ndarray]: ... - def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None) -> Self: ... def make_schema( - self, columns: Iterable[str], header: Optional[OpenCosmoHeader] = None + self, + columns: Iterable[str], + metadata_columns: set[str] = set(), + header: Optional[OpenCosmoHeader] = None, ) -> tuple[Schema, Schema]: ... @property def columns(self) -> Iterable[str]: ... @property - def metadata_columns(self) -> Iterable[str]: ... - @property def load_conditions(self) -> Optional[dict]: ... @property From 084b7fcfa7ce29a4ccdf61a8af1baed077dfdaef Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 15 May 2026 13:18:35 -0500 Subject: [PATCH 097/139] Add test for metadata column arithmetic --- changes/+89192e73.bugfix.rst | 1 + docs/source/changelog.rst | 39 +++++++++++++++++++++++++++++++ python/opencosmo/column/column.py | 6 ++--- test/test_collection.py | 19 +++++++++++++++ 4 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 changes/+89192e73.bugfix.rst diff --git a/changes/+89192e73.bugfix.rst b/changes/+89192e73.bugfix.rst new file mode 100644 index 00000000..96750606 --- /dev/null +++ b/changes/+89192e73.bugfix.rst @@ -0,0 +1 @@ +Fixed a bug that could cause column arithmetic to fail with scalars. diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 166ee0c5..a42a26b2 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,3 +1,42 @@ +opencosmo 1.2.6 (2026-05-15) +============================ + +Bugfixes +-------- + +- Requesting more rows than exist via :py:meth:`take ` or :py:meth:`take_range ` no longer raises a ``ValueError``. Instead, all available rows are returned. This applies to :py:class:`Dataset `, :py:class:`Lightcone `, and :py:class:`StructureCollection `. (#240) +- HealpixMap now correctly unpacks data when there is only a single map, instead of returning a dictionary. Mirrors + behavior in Dataset etc. + + +New Features +------------ + +- :py:meth:`take `, :py:meth:`take_range `, and their equivalents on :py:class:`Lightcone ` and :py:class:`StructureCollection ` now accept a ``mode`` keyword argument. Setting ``mode="global"`` when running under MPI causes ``n`` (or ``start``/``end``) to be interpreted across all ranks combined rather than per-rank. When the dataset is sorted, ranks coordinate to select from the globally-sorted order, so ``ds.sort_by("fof_halo_mass").take(1000, mode="global")`` returns exactly the 1000 most massive halos distributed across all ranks. (#240) +- Added animate_halos function, which calls either "visualize_halo" or "halo_projection_array" in a loop to create an animated visualization of a given halo or set of halos. +- Calls to :py:meth:`with_new_columns ` and :py:meth:`evaluate ` now accept an `allow_overwrite` flag. In this way you can "transform" a column by creating a derived column that depends on a column of the same name. The input will be the original column, and the output will be the new version. +- Diffsky catalog can be forced to keep groups during filtering etc. by setting the `get_top_host = True` flag when + opening. + + +Improvements +------------ + +- Conversion to healsparse in :py:meth:`HealpixMap.get_data DerivedColumn: case _: return NotImplemented - @_require_scalar_quantity def __add__(self, other: Any) -> DerivedColumn: match other: - case Column(): + case Column() | int() | float() | u.Quantity(): return DerivedColumn(self, other, op.add) case _: return NotImplemented - @_require_scalar_quantity def __sub__(self, other: Any) -> DerivedColumn: match other: - case Column(): + case Column() | int() | float() | u.Quantity(): return DerivedColumn(self, other, op.sub) case _: return NotImplemented diff --git a/test/test_collection.py b/test/test_collection.py index 5000c4b7..3e03c7ec 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1304,3 +1304,22 @@ def test_data_cached_after_objects(halo_paths): uuid_data = cache.get_data({(gpe_uuid, "gpe")}) assert uuid_data.get(gpe_uuid, {}).get("gpe") is not None assert dataset.descriptions["gpe"] != "None" + + +def test_modify_metadata_column(halo_paths): + ds = oc.open(*halo_paths) + galaxyproperties_start = ds["halo_properties"].get_metadata( + "galaxyproperties_start" + ) + updated_galprops = oc.col("galaxyproperties_start") + 1000 + + ds = ds.with_new_columns( + "halo_properties", galaxyproperties_start=updated_galprops, allow_overwrite=True + ) + updated_galaxyproperties_start = ds["halo_properties"].get_metadata( + "galaxyproperties_start" + ) + assert np.all( + (galaxyproperties_start["galaxyproperties_start"] + 1000) + == updated_galaxyproperties_start["galaxyproperties_start"] + ) From c421be5f89710548024a5b31bc36d2b7d7d8e288 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 15 May 2026 13:25:45 -0500 Subject: [PATCH 098/139] Revert accidental changelog overwrite --- docs/source/changelog.rst | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index a42a26b2..166ee0c5 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,42 +1,3 @@ -opencosmo 1.2.6 (2026-05-15) -============================ - -Bugfixes --------- - -- Requesting more rows than exist via :py:meth:`take ` or :py:meth:`take_range ` no longer raises a ``ValueError``. Instead, all available rows are returned. This applies to :py:class:`Dataset `, :py:class:`Lightcone `, and :py:class:`StructureCollection `. (#240) -- HealpixMap now correctly unpacks data when there is only a single map, instead of returning a dictionary. Mirrors - behavior in Dataset etc. - - -New Features ------------- - -- :py:meth:`take `, :py:meth:`take_range `, and their equivalents on :py:class:`Lightcone ` and :py:class:`StructureCollection ` now accept a ``mode`` keyword argument. Setting ``mode="global"`` when running under MPI causes ``n`` (or ``start``/``end``) to be interpreted across all ranks combined rather than per-rank. When the dataset is sorted, ranks coordinate to select from the globally-sorted order, so ``ds.sort_by("fof_halo_mass").take(1000, mode="global")`` returns exactly the 1000 most massive halos distributed across all ranks. (#240) -- Added animate_halos function, which calls either "visualize_halo" or "halo_projection_array" in a loop to create an animated visualization of a given halo or set of halos. -- Calls to :py:meth:`with_new_columns ` and :py:meth:`evaluate ` now accept an `allow_overwrite` flag. In this way you can "transform" a column by creating a derived column that depends on a column of the same name. The input will be the original column, and the output will be the new version. -- Diffsky catalog can be forced to keep groups during filtering etc. by setting the `get_top_host = True` flag when - opening. - - -Improvements ------------- - -- Conversion to healsparse in :py:meth:`HealpixMap.get_data Date: Fri, 1 May 2026 16:44:42 -0500 Subject: [PATCH 099/139] Initial working implementation of the lightcone structure collection --- .../collection/lightcone/lightcone.py | 51 ++++++++-- .../opencosmo/collection/lightcone/plugins.py | 6 +- .../opencosmo/collection/structure/handler.py | 88 ++++++++++++++--- python/opencosmo/collection/structure/io.py | 98 +++++++++---------- .../collection/structure/structure.py | 40 ++------ python/opencosmo/index/__init__.py | 4 +- python/opencosmo/index/ops.py | 6 ++ 7 files changed, 182 insertions(+), 111 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 8ce462a8..ad780d76 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -17,7 +17,7 @@ from warnings import warn import numpy as np -from astropy.table import vstack # type: ignore +from astropy.table import QTable, vstack # type: ignore import opencosmo as oc from opencosmo.collection.lightcone import io as lcio @@ -167,6 +167,17 @@ def columns(self) -> list[str]: cols = list(filter(lambda col: col not in self.__hidden, cols)) return cols + @property + def meta_columns(self) -> list[str]: + """ + The names of the columns in this dataset. + + Returns + ------- + columns: list[str] + """ + return next(iter(self.values())).meta_columns + @cached_property def descriptions(self) -> dict[str, Optional[str]]: """ @@ -266,7 +277,7 @@ def z_range(self): return self.__header.lightcone["z_range"] - def get_data(self, format="astropy", unpack: bool = False, **kwargs): + def get_data(self, format="astropy", unpack: bool = True, **kwargs): """ Get the data in this dataset as an astropy table/column or as numpy array(s). Note that a dataset does not load data from disk into @@ -303,7 +314,7 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): lightcone = fold( HookPoint.LightconeInstantiate, LightconeInstantiateCtx(self) ).lightcone - data = [ds.get_data(unpack=unpack) for ds in lightcone.values()] + data = [ds.get_data(unpack=False) for ds in lightcone.values()] data_with_length = [d for d in data if len(d) > 0] if len(data_with_length) == 0: return data[0] @@ -319,6 +330,13 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) + if unpack: + data = { + key: value[0] if len(value) == 1 else value + for key, value in table.items() + } + table = QTable(data) + if format != "astropy": return convert_data(dict(table), format) elif len(table.columns) == 1: @@ -326,6 +344,17 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): return table + def get_metadata(self, columns: list[str] = []): + data = [ds.get_metadata(columns) for ds in self.values()] + data_with_length = [d for d in data if len(d) > 0] + if len(data_with_length) == 0: + return data[0] + + output = {} + for key in data[0].keys(): + output[key] = np.concatenate([d[key] for d in data]) + return output + @property def data(self): """ @@ -379,7 +408,7 @@ def open(cls, targets: list[FileTarget], **kwargs): def from_datasets( cls, datasets: dict[str, oc.Dataset], - z_range: tuple[float, float], + z_range: Optional[tuple[float, float]] = None, **open_kwargs, ): result = cls(datasets, z_range) @@ -723,7 +752,9 @@ def filter(self, *masks: ColumnMask, **kwargs) -> Self: """ return self.__map("filter", *masks, **kwargs) - def rows(self) -> Generator[dict[str, float | u.Quantity], None, None]: + def rows( + self, metadata_columns=[] + ) -> Generator[dict[str, float | u.Quantity], None, None]: """ Iterate over the rows in the dataset. Rows are returned as a dictionary For performance, it is recommended to first select the columns you need to @@ -734,7 +765,9 @@ def rows(self) -> Generator[dict[str, float | u.Quantity], None, None]: row : dict A dictionary of values for each row in the dataset with units. """ - yield from chain.from_iterable(v.rows() for v in self.values()) + yield from chain.from_iterable( + v.rows(metadata_columns=metadata_columns) for v in self.values() + ) def select( self, *columns: str | Iterable[str], **derived_columns: ConstructedColumn @@ -794,7 +827,7 @@ def select( hidden = self.__hidden additional_columns = set() - if "redshift" not in all_columns: + if "redshift" not in all_columns and "properties" in self.dtype: additional_columns.add("redshift") hidden = hidden.union({"redshift"}) @@ -984,7 +1017,6 @@ def take_rows(self, rows: DataIndex): "Rows must be between 0 and the length of this dataset - 1" ) rows = get_rows_take_index(self, rows, self.__sort_key) - return self.__take_rows(rows) def __make_sort_index(self): @@ -999,7 +1031,7 @@ def __make_sort_index(self): def __take_rows(self, rows: DataIndex): """ - Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorte. + Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorted. For internal use only. """ sizes = np.fromiter((len(ds) for ds in self.values()), dtype=np.int64) @@ -1009,7 +1041,6 @@ def __take_rows(self, rows: DataIndex): output = {} for (name, ds), index in zip(self.items(), projected): output[name] = ds.take_rows(index) - if all(len(ds) == 0 for ds in output.values()): output = {"data": next(iter(output.values()))} else: diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 0cb6156d..1675fb52 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -20,6 +20,10 @@ def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: """Ensures a column called 'redshift' exists on every lightcone.""" lightcone: Lightcone = ctx.lightcone + if ( + "properties" not in lightcone.dtype + ): # Particles or profiles, redshift handled at structure collection level + return ctx if "redshift" in lightcone.columns: return ctx elif "fof_halo_center_a" in lightcone.columns: @@ -46,7 +50,7 @@ def _make_radec_columns(ctx: LightconeOpenCtx): lightcone = lightcone.evaluate( radec_from_thetaphi, vectorize=True, insert=True, format="numpy" ) - else: + elif "properties" in lightcone.dtype: warnings.warn( "Could not find coordinates in this catalog. Spatial queries will not be available" ) diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 7252c726..93c23df6 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -5,6 +5,7 @@ import numpy as np +from opencosmo.collection.lightcone import lightcone as lc from opencosmo.index import into_array if TYPE_CHECKING: @@ -43,30 +44,46 @@ } -def create_start_size(data, start_name, size_name): +def create_start_size(data, start_name, size_name, offsets): + start = data.pop(start_name, None) size = data.pop(size_name, None) if start is None or size is None: return None + + start = np.atleast_1d(start).astype(np.int64) + size = np.atleast_1d(size).astype(np.int64) valid = size > 0 + if not np.any(valid): + return None - start = start.astype(np.int64) - size = size.astype(np.int64) + if offsets is not None: + ds_rs = 0 + src_rs = 0 + for source_len, ds_len in offsets: + slice = start[src_rs : src_rs + source_len] + slice[slice >= 0] += ds_rs + src_rs += source_len + ds_rs += ds_len - if isinstance(start, np.ndarray): - return (start[valid], size[valid]) - if size == 0: - return None - return (np.atleast_1d(start), np.atleast_1d(size)) + return (start[valid], size[valid]) -def create_idx(data, idx_name): +def create_idx(data, idx_name, offsets): idx = data.pop(idx_name, None) if idx is None: return None idx = idx.astype(np.int64) valid = idx >= 0 + if offsets is not None: + ds_rs = 0 + src_rs = 0 + for source_len, ds_len in offsets: + slice = idx[src_rs : src_rs + source_len] + slice[slice >= 0] += ds_rs + src_rs += source_len + ds_rs += ds_len if isinstance(idx, np.ndarray): return idx[valid] @@ -147,15 +164,23 @@ def from_link_names(cls, names: Iterable[str], rename_galaxies=False): links, columns = make_links(names, rename_galaxies) return LinkHandler(links, columns, None) - def parse(self, data: dict[str, Any]): + def parse( + self, data: dict[str, Any], offsets: Optional[dict[str, list[int]]] = None + ): output = {} for name, handler in self.links.items(): - result = handler(data) + result = handler( + data, offsets=offsets[name] if offsets is not None else None + ) if result is not None: output[name] = result return output - def prep_datasets(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): + def prep_datasets( + self, + source: oc.Dataset | oc.Lightcone, + datasets: dict[str, oc.Dataset | oc.Lightcone], + ): """ Called once when a datasets are opened for the first time. Downstream versions always use rebuild_datsets @@ -165,8 +190,17 @@ def prep_datasets(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) meta = source.get_metadata(all_columns) - indices = self.parse(meta) + offsets = None + if isinstance(source, lc.Lightcone): + offsets = {} + for ds_type, lightcone in datasets.items(): + offsets[ds_type] = [ + (len(source[key]), len(ds)) for key, ds in lightcone.items() + ] + + indices = self.parse(meta, offsets) new_datasets = datasets + for name, index in indices.items(): new_datasets[name] = new_datasets[name].take_rows(index) return new_datasets @@ -184,9 +218,26 @@ def make_derived(self, source: oc.Dataset): return LinkHandler(self.links, self.columns, derived_from) + def rebuild_lightcones( + self, new_source: oc.Lightcone, lightcones: dict[str, oc.Lightcone] + ): + new_datasets = {name: {} for name in lightcones} + + for step, step_source in new_source.items(): + step_datasets = {name: lc[step] for name, lc in lightcones.items()} + new_step_datasets = self.__rebuild_datasets( + self.__derived_from[step], step_source, step_datasets + ) + for name, new_step_ds in new_step_datasets.items(): + new_datasets[name][step] = new_step_ds + return { + name: lc.Lightcone.from_datasets(datasets) + for name, datasets in new_datasets.items() + } + def rebuild_datasets( self, - new_source: oc.Dataset, + new_source: oc.Dataset | oc.Lightcone, datasets: dict[str, oc.Dataset], ): """ @@ -199,7 +250,12 @@ def rebuild_datasets( """ if self.__derived_from is None: return datasets - original_index = into_array(self.__derived_from.index) + elif isinstance(new_source, lc.Lightcone): + return self.rebuild_lightcones(new_source, datasets) + return self.__rebuild_datasets(self.__derived_from, new_source, datasets) + + def __rebuild_datasets(self, derived_from, new_source, datasets): + original_index = into_array(derived_from.index) new_index = into_array(new_source.index) _, index_into_original, index_into_new = np.intersect1d( @@ -209,7 +265,7 @@ def rebuild_datasets( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) index_into_original = index_into_original[np.argsort(index_into_new)] - metadata = self.__derived_from.get_metadata(all_columns) + metadata = derived_from.get_metadata(all_columns) new_datasets = {} for name, dataset in datasets.items(): diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index b2ff41b0..4ca3d34e 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -2,12 +2,13 @@ from collections import defaultdict from functools import reduce +from itertools import chain from typing import TYPE_CHECKING, Optional import numpy as np from opencosmo import io -from opencosmo.collection import lightcone as lc +from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import structure as sc if TYPE_CHECKING: @@ -73,7 +74,7 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): name = target["header"].file.data_type elif name.startswith("halo_properties"): name = name[16:] - link_targets["halo_targets"][name].append(dataset) + link_targets["halo_properties"][name].append(dataset) elif str(target["header"].file.data_type).startswith("galaxy"): dataset = io.iopen.open_single_dataset( target, bypass_lightcone=True, bypass_mpi=True @@ -83,7 +84,7 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): name = target["header"].file.data_type elif name.startswith("galaxy_properties"): name = name[18:] - link_targets["galaxy_targets"][name].append(dataset) + link_targets["halo_properties"][name].append(dataset) else: raise ValueError( f"Unknown data type for structure collection {target['header'].data_type}" @@ -93,30 +94,8 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): len(link_sources["halo_properties"]) > 1 or len(link_sources["galaxy_properties"]) > 1 ): - raise NotImplementedError( - "Opening structure collections that span multiple redshifts is not currently supported" - ) # Potentially a lightcone structure collection - collections = {} - sources_by_step, targets_by_step = __sort_by_step(link_sources, link_targets) - if set(sources_by_step.keys()) != set(targets_by_step.keys()): - raise ValueError("Datasets are not the same across all lightcone steps!") - for step, sources in sources_by_step.items(): - halo_properties = sources.get("halo_properties") - galaxy_properties = sources.get("galaxy_properties") - targets = targets_by_step[step] - collection = __build_structure_collection( - halo_properties, galaxy_properties, targets, ignore_empty - ) - collections[step] = collection - - expected_datasets = set(next(iter(collections.values())).keys()) - for collection in collections.values(): - if set(collection.keys()) != expected_datasets: - raise ValueError( - "All structure collections in a lightcone must have the same set of datasets" - ) - return lc.Lightcone(collections) + return build_lightcone_structure_collection(link_sources, link_targets) halo_properties_target = None galaxy_properties_target = None @@ -143,33 +122,52 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): ) -def __sort_by_step(link_sources: dict[str, list[io.iopen.DatasetTarget]], link_targets): - sources_by_step: dict[int, dict[str, io.iopen.DatasetTarget]] = defaultdict(dict) - targets_by_step: dict[int, dict[str, dict[str, d.Dataset]]] = defaultdict( - lambda: defaultdict(dict) - ) - for source_name, sources in link_sources.items(): - for source in sources: - if not source["header"].file.is_lightcone: +def build_lightcone_structure_collection( + link_sources: dict[str, list[io.iopen.DatasetTarget]], + link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]], +): + found_redshift_steps = set() + for source_type, source_list in link_sources.items(): + if not all(t["header"].file.is_lightcone for t in source_list): + raise ValueError("All sources must be lightcone datasets!") + redshift_steps = set(t["header"].file.step for t in source_list) + if found_redshift_steps and found_redshift_steps != redshift_steps: + raise ValueError( + "All source types must have the same set of redshift steps!" + ) + if not all( + t.header.file.is_lightcone + for t in chain.from_iterable(link_targets[source_type].values()) + ): + raise ValueError("All dataset must be lightcone datasets!") + for targets in link_targets[source_type].values(): + target_redshift_steps = set(t.header.file.step for t in targets) + if target_redshift_steps != redshift_steps: raise ValueError( - "Recived multiple source datasets of a single type, but not all are lightcone datasets!" - ) - if source["header"].file.step is None: - raise ValueError("No step in source!") - - sources_by_step[source["header"].file.step][source_name] = source - for target_type, targets_ in link_targets.items(): - for target_name, targets in targets_.items(): - for target in targets: - if not target.header.file.is_lightcone: - raise ValueError( - "Recived multiple datasets of a single type, but not all are lightcone datasets!" - ) - targets_by_step[target.header.file.step][target_type][target_name] = ( - target + "All datasets must have the same set of redshift steps!" ) + output_sources = {} + output_targets = defaultdict(dict) + for source_type, source_list in link_sources.items(): + if source_type == "galaxy_properties": + raise NotImplementedError + + datasets = [ + io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) + for t in source_list + ] + output_sources[source_type] = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in datasets} + ) + for target_type, targets in link_targets[source_type].items(): + output_targets[source_type][target_type] = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in targets} + ) + return sc.StructureCollection( + output_sources["halo_properties"], output_targets["halo_properties"] + ) - return sources_by_step, targets_by_step + return output_sources, output_targets def __build_structure_collection( diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 6f8049e9..d2356593 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -92,7 +92,6 @@ class StructureCollection: def __init__( self, source: oc.Dataset, - header: oc.header.OpenCosmoHeader, datasets: Mapping[str, oc.Dataset | StructureCollection], hide_source: bool = False, link_handler: Optional[LinkHandler] = None, @@ -104,9 +103,7 @@ def __init__( """ self.__source = source - self.__header = header self.__datasets = dict(datasets) - self.__index = self.__source.index self.__hide_source = hide_source if isinstance(self.__datasets.get("galaxy_properties"), StructureCollection): self.__datasets["galaxies"] = self.__datasets.pop("galaxy_properties") @@ -140,7 +137,7 @@ def __get_datasets(self): return self.__datasets def __repr__(self): - structure_type = self.__header.file.data_type.split("_")[0] + "s" + structure_type = self.__source.header.file.data_type.split("_")[0] + "s" keys = list(self.keys()) if len(keys) == 2: dtype_str = " and ".join(keys) @@ -157,13 +154,9 @@ def open( ) -> StructureCollection: return sio.build_structure_collection(targets, ignore_empty) - @property - def header(self): - return self.__header - @property def dtype(self): - structure_type = self.__header.file.data_type.split("_")[0] + structure_type = self.__source.header.file.dt return structure_type @property @@ -192,7 +185,7 @@ def redshift(self) -> float | tuple[float, float] | None: redshift: float | tuple[float, float] """ - return self.__header.file.redshift + raise NotImplementedError @property def simulation(self) -> HaccSimulationParameters: @@ -204,7 +197,7 @@ def simulation(self) -> HaccSimulationParameters: ------- parameters: opencosmo.dtypes.HaccSimulationParameters """ - return self.__header.simulation + return self.__source.simulation def keys(self) -> list[str]: """ @@ -233,7 +226,7 @@ def __getitem__(self, key: str) -> oc.Dataset | oc.StructureCollection: """ Return the linked dataset with the given key. """ - if key == self.__header.file.data_type: + if key == self.__source.header.file.data_type: return self.__source datasets = self.__get_datasets() if key not in datasets.keys(): @@ -286,7 +279,6 @@ def bound( new_handler = self.__handler.make_derived(self.__source) return StructureCollection( bounded, - self.__header, self.__datasets, self.__hide_source, new_handler, @@ -423,7 +415,6 @@ def computation(halo_properties, dm_particles): return result return StructureCollection( self.__source, - self.__header, self.__get_datasets() | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -461,7 +452,6 @@ def computation(halo_properties, dm_particles): new_derived_columns_ = [f"{dataset}.{col}" for col in new_derived_columns] return StructureCollection( self.__source, - self.__header, self.__get_datasets() | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -565,7 +555,6 @@ def evaluate_on_dataset( assert isinstance(result, oc.Dataset) return StructureCollection( result, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -598,7 +587,6 @@ def evaluate_on_dataset( new_derived_columns_ = [f"{dataset}.{col}" for col in new_derived_columns] return StructureCollection( self.__source, - self.__header, self.__datasets | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -625,7 +613,6 @@ def evaluate_on_dataset( assert isinstance(result, (oc.Dataset, StructureCollection)) return StructureCollection( self.__source, - self.__header, self.__datasets | {ds_path[0]: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -683,7 +670,6 @@ def filter(self, *masks, on_galaxies: bool = False) -> StructureCollection: new_handler = self.__handler.make_derived(self.__source) return StructureCollection( filtered, - self.__header, self.__datasets, self.__hide_source, new_handler, @@ -773,7 +759,7 @@ def select( arg = columns # type: ignore kwargs = {} - if dataset == self.__header.file.data_type: + if dataset == self.__source.file.data_type: new_source = self.__source.select(arg, **kwargs) continue @@ -795,7 +781,6 @@ def select( return StructureCollection( new_source, - self.__header, self.__datasets | new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -837,7 +822,7 @@ def drop(self, **columns_to_drop): new_datasets = {} for dataset_name, columns in columns_to_drop.items(): - if dataset_name == self.__header.file.data_type: + if dataset_name == self.__source.file.data_type: new_source = self.__source.drop(columns) continue @@ -853,7 +838,6 @@ def drop(self, **columns_to_drop): return StructureCollection( new_source, - self.__header, self.__datasets | new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -888,7 +872,6 @@ def sort_by(self, column: str, invert: bool = False) -> StructureCollection: return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -992,7 +975,6 @@ def with_units( return StructureCollection( new_source, - self.__header, new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1034,7 +1016,6 @@ def take( return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, new_handler, @@ -1080,7 +1061,6 @@ def take_range( new_source = self.__source.take_range(start, end, mode) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1110,7 +1090,6 @@ def take_rows(self, rows: np.ndarray | DataIndex): new_source = self.__source.take_rows(rows) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1197,7 +1176,6 @@ def with_new_columns( ) return StructureCollection( self.__source, - self.__header, {**datasets, collection_name: new_collection}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1211,7 +1189,6 @@ def with_new_columns( ) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1239,7 +1216,6 @@ def with_new_columns( return StructureCollection( self.__source, - self.__header, {**datasets, dataset: new_ds}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1287,7 +1263,6 @@ def objects( for column in self.__derived_columns: name_parts = column.split(".") columns_to_collect[name_parts[0]][name_parts[1]] = [] - try: for row in self.__source.rows(metadata_columns=metadata_columns): row = dict(row) @@ -1383,7 +1358,6 @@ def with_datasets(self, datasets: str | Iterable[str]): new_datasets = {name: self.__datasets[name] for name in requested_datasets} return StructureCollection( self.__source, - self.__header, new_datasets, hide_source, self.__handler.make_derived(self.__source), diff --git a/python/opencosmo/index/__init__.py b/python/opencosmo/index/__init__.py index a640c787..29fd1959 100644 --- a/python/opencosmo/index/__init__.py +++ b/python/opencosmo/index/__init__.py @@ -4,7 +4,7 @@ from .get import get_data from .in_range import n_in_range from .mask import into_array, mask -from .ops import rebuild_by_ranges +from .ops import offset, rebuild_by_ranges, reindex_column from .project import project from .take import take from .unary import get_length, get_range @@ -36,4 +36,6 @@ "get_range", "from_start_size_group", "rebuild_by_ranges", + "reindex_column", + "offset", ] diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py index 9bcd29ae..bdb12a42 100644 --- a/python/opencosmo/index/ops.py +++ b/python/opencosmo/index/ops.py @@ -22,3 +22,9 @@ def rebuild_by_ranges(index: DataIndex, ranges: ChunkedIndex): return idxlib.rebuild_simple_by_ranges(index, *ranges) case (np.ndarray(), np.ndarray()): return idxlib.rebuild_chunked_by_ranges(*index, *ranges) + + +def offset(index: DataIndex, offset_amount: int): + if isinstance(index, np.ndarray): + return index + offset_amount + return (index[0] + offset_amount, index[1]) From 580c89ac4d13fa6f2ccb4baaf2d62a6ff9b78762 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 7 May 2026 09:59:28 -0500 Subject: [PATCH 100/139] Generate ra, theta, chi, redshift, from raw coordinates --- .../opencosmo/collection/lightcone/plugins.py | 39 ++++++++ .../opencosmo/collection/structure/handler.py | 3 +- python/opencosmo/collection/structure/io.py | 50 +++++++--- python/opencosmo/column/column.py | 99 ++++++++++++++++++- test/test_diffsky.py | 1 - test/test_lightcone.py | 5 +- 6 files changed, 176 insertions(+), 21 deletions(-) diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 1675fb52..83fc9169 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -8,6 +8,7 @@ import numpy as np import opencosmo as oc +from opencosmo.column import norm_cols from opencosmo.plugins.contexts import HookPoint from opencosmo.plugins.hook import hook @@ -16,6 +17,31 @@ from opencosmo.plugins.contexts import LightconeOpenCtx +@hook( + HookPoint.LightconeOpen, + when=lambda ctx: ctx.lightcone.dtype in ["halo_properties", "galaxy_properties"], +) +def _ensure_coordinates(ctx: LightconeOpenCtx) -> LightconeOpenCtx: + known_columns = set(ctx.lightcone.columns) + if known_columns.issuperset(("phi", "theta")) or known_columns.issuperset( + ("ra", "dec") + ): + return ctx + + prefix = "fof_halo_center" + if ctx.lightcone.dtype == "galaxy_properties": + prefix = "gal_center" + + coord_columns = {coord: f"{prefix}_{coord}" for coord in ["x", "y", "z"]} + if not set(ctx.lightcone.columns).issuperset(coord_columns.values()): + raise ValueError("Unable to find coordinate columns for this lightcone dataset") + chi = norm_cols(*list(coord_columns.values())) + phi = oc.col(coord_columns["y"]).arctan2(oc.col(coord_columns["x"])) + theta = (oc.col(coord_columns["z"]) / oc.col("chi")).arccos() + new_lightcone = ctx.lightcone.with_new_columns(chi=chi, phi=phi, theta=theta) + return dataclasses.replace(ctx, lightcone=new_lightcone) + + @hook(HookPoint.LightconeOpen) def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: """Ensures a column called 'redshift' exists on every lightcone.""" @@ -32,6 +58,12 @@ def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: z_col = oc.col("redshift_true") elif "zp" in lightcone.columns: z_col = oc.col("zp") + elif "chi" in lightcone.columns: + lightcone = lightcone.evaluate( + redshift_from_chi, cosmology=lightcone.cosmology, vectorize=True + ) + return dataclasses.replace(ctx, lightcone=lightcone) + else: raise ValueError( "Unable to find a redshift or scale factor column for this lightcone dataset" @@ -62,3 +94,10 @@ def radec_from_thetaphi(theta, phi): theta_deg = theta * 180 / np.pi phi_deg = phi * 180 / np.pi return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} + + +def redshift_from_chi(chi, cosmology): + import astropy.cosmology.units as cu + + redshift = chi.to(cu.redshift, cu.redshift_distance(cosmology, kind="comoving")) + return {"redshift": redshift} diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 93c23df6..12ac01d9 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -45,7 +45,6 @@ def create_start_size(data, start_name, size_name, offsets): - start = data.pop(start_name, None) size = data.pop(size_name, None) if start is None or size is None: @@ -170,7 +169,7 @@ def parse( output = {} for name, handler in self.links.items(): result = handler( - data, offsets=offsets[name] if offsets is not None else None + data, offsets=offsets.get(name) if offsets is not None else None ) if result is not None: output[name] = result diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index 4ca3d34e..8a49bf8b 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -84,7 +84,7 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): name = target["header"].file.data_type elif name.startswith("galaxy_properties"): name = name[18:] - link_targets["halo_properties"][name].append(dataset) + link_targets["galaxy_properties"][name].append(dataset) else: raise ValueError( f"Unknown data type for structure collection {target['header'].data_type}" @@ -127,6 +127,7 @@ def build_lightcone_structure_collection( link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]], ): found_redshift_steps = set() + print(link_sources.keys(), link_targets.keys()) for source_type, source_list in link_sources.items(): if not all(t["header"].file.is_lightcone for t in source_list): raise ValueError("All sources must be lightcone datasets!") @@ -146,8 +147,36 @@ def build_lightcone_structure_collection( raise ValueError( "All datasets must have the same set of redshift steps!" ) - output_sources = {} - output_targets = defaultdict(dict) + if ( + len(link_sources.get("galaxy_properties", [])) > 0 + and "galaxy_properties" in link_targets + ): + # Galaxy properties and galaxy particles + datasets = [ + io.iopen.open_single_dataset( + t, + "data_linked", + bypass_lightcone=True, + bypass_mpi=len(link_sources.get("halo_properties", [])) > 0, + ) + for t in link_sources["galaxy_properties"] + ] + galaxy_lightcone = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in datasets} + ) + galaxy_target_datasets = {} + for target_type, targets in link_targets[source_type].items(): + galaxy_target_datasets[target_type] = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in targets} + ) + collection = sc.StructureCollection(galaxy_lightcone, galaxy_target_datasets) + if len(link_sources.get("halo_properties", [])) > 0: + link_targets["halo_properties"]["galaxy_properties"] = collection + else: + return collection + + print(link_targets) + assert False for source_type, source_list in link_sources.items(): if source_type == "galaxy_properties": raise NotImplementedError @@ -176,7 +205,7 @@ def __build_structure_collection( link_targets: dict[str, dict[str, d.Dataset | sc.StructureCollection]], ignore_empty: bool, ): - if galaxy_properties_target is not None and "galaxy_targets" in link_targets: + if galaxy_properties_target is not None and "galaxy_properties" in link_targets: # Galaxy properties and galaxy particles source_dataset = io.iopen.open_single_dataset( galaxy_properties_target, @@ -189,25 +218,25 @@ def __build_structure_collection( collection = sc.StructureCollection( source_dataset, source_dataset.header, - link_targets["galaxy_targets"], + link_targets["galaxy_properties"], ) if halo_properties_target is not None: - link_targets["halo_targets"]["galaxy_properties"] = collection + link_targets["halo_properties"]["galaxy_properties"] = collection else: return collection if ( halo_properties_target is not None and galaxy_properties_target is not None - and "galaxy_targets" not in link_targets + and "galaxy_properties" not in link_targets ): # Halo properties and galaxy properties, but no galaxy particles galaxy_properties = io.iopen.open_single_dataset( galaxy_properties_target, bypass_lightcone=True, bypass_mpi=True ) - link_targets["halo_targets"]["galaxy_properties"] = galaxy_properties + link_targets["halo_properties"]["galaxy_properties"] = galaxy_properties - if halo_properties_target is not None and link_targets["halo_targets"]: + if halo_properties_target is not None and link_targets["halo_properties"]: source_dataset = io.iopen.open_single_dataset( halo_properties_target, metadata_group="data_linked", bypass_lightcone=True ) @@ -216,6 +245,5 @@ def __build_structure_collection( return sc.StructureCollection( source_dataset, - source_dataset.header, - link_targets["halo_targets"], + link_targets["halo_properties"], ) diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 580bc8c2..1bbd7997 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -155,6 +155,57 @@ def _sqrt(left: np.ndarray | u.Unit, right: None): return left**0.5 +def _require_dimensionless(unit: u.UnitBase, func_name: str) -> None: + if not unit.is_equivalent(u.dimensionless_unscaled): + raise UnitsError( + f"{func_name} requires a dimensionless input, got unit '{unit}'" + ) + + +def _arcsin(left: Any, right: Any) -> Any: + if isinstance(left, u.UnitBase): + _require_dimensionless(left, "arcsin") + return u.rad + if isinstance(left, u.Quantity): + _require_dimensionless(left.unit, "arcsin") + return np.arcsin(left.value) * u.rad + return np.arcsin(left) + + +def _arccos(left: Any, right: Any) -> Any: + if isinstance(left, u.UnitBase): + _require_dimensionless(left, "arccos") + return u.rad + if isinstance(left, u.Quantity): + _require_dimensionless(left.unit, "arccos") + return np.arccos(left.value) * u.rad + return np.arccos(left) + + +def _arctan2(left: Any, right: Any) -> Any: + left_is_unit = isinstance(left, u.UnitBase) + right_is_unit = isinstance(right, u.UnitBase) + if left_is_unit or right_is_unit: + if not (left_is_unit and right_is_unit) or not left.is_equivalent(right): + raise UnitsError( + "arctan2 requires both inputs to have equivalent units or both to be unitless" + ) + return u.rad + left_is_qty = isinstance(left, u.Quantity) + right_is_qty = isinstance(right, u.Quantity) + if left_is_qty != right_is_qty: + raise UnitsError( + "arctan2 requires both inputs to have equivalent units or both to be unitless" + ) + if left_is_qty: + if not left.unit.is_equivalent(right.unit): + raise UnitsError( + f"arctan2 inputs have incompatible units: '{left.unit}' and '{right.unit}'" + ) + return np.arctan2(left.value, right.value) * u.rad + return np.arctan2(left, right) + + class Column: """ Represents a reference to a column with a given name. Column reference @@ -303,6 +354,27 @@ def sqrt(self) -> DerivedColumn: """ return DerivedColumn(self, None, _sqrt) + def arcsin(self) -> DerivedColumn: + """ + Create a derived column containing the arcsine of this column (in radians). + The column must be dimensionless. + """ + return DerivedColumn(self, None, _arcsin) + + def arccos(self) -> DerivedColumn: + """ + Create a derived column containing the arccosine of this column (in radians). + The column must be dimensionless. + """ + return DerivedColumn(self, None, _arccos) + + def arctan2(self, other: ColumnOrScalar) -> DerivedColumn: + """ + Create a derived column containing arctan2(self, other) in radians. + Both columns must be dimensionless. + """ + return DerivedColumn(self, other, _arctan2) + class ConstructedColumn(Protocol): pass @@ -336,11 +408,11 @@ def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... class RawColumn: - def __init__(self, name, description, alias=None, _dep_uuid=None): + def __init__(self, name, description, alias=None, _dep_uuid=None, _uuid=None): self.__name = name self.__description = description self.__alias = alias - self.__uuid = uuid4() + self.__uuid = _uuid if _uuid is not None else uuid4() self.__dep_uuid: UUID | None = _dep_uuid @property @@ -356,7 +428,11 @@ def bind(self, name_to_uuid: dict[str, UUID]) -> RawColumn: return self dep_uuid = name_to_uuid[self.__name] return RawColumn( - self.__name, self.__description, alias=self.__alias, _dep_uuid=dep_uuid + self.__name, + self.__description, + alias=self.__alias, + _dep_uuid=dep_uuid, + _uuid=self.__uuid, ) @property @@ -427,13 +503,14 @@ def __init__( output_name: Optional[str] = None, _dep_map: dict[str, UUID] | None = None, no_cache: bool = False, + _uuid: UUID | None = None, ): self.lhs = lhs self.rhs = rhs self.name = output_name self.operation = operation self.description = description if description is not None else "None" - self.__uuid = uuid4() + self.__uuid = _uuid if _uuid is not None else uuid4() self.__dep_map: dict[str, UUID] | None = _dep_map self.__no_cache = no_cache @@ -459,6 +536,7 @@ def bind(self, name_to_uuid: dict[str, UUID]) -> DerivedColumn: self.description, self.name, _dep_map=dep_map, + _uuid=self.__uuid, ) def _traverse_names(self) -> set[str]: @@ -600,6 +678,15 @@ def exp10(self, expected_unit_container: u.LogUnit = u.DexUnit): def sqrt(self): return DerivedColumn(self, None, _sqrt) + def arcsin(self) -> DerivedColumn: + return DerivedColumn(self, None, _arcsin) + + def arccos(self) -> DerivedColumn: + return DerivedColumn(self, None, _arccos) + + def arctan2(self, other: ColumnOrScalar) -> DerivedColumn: + return DerivedColumn(self, other, _arctan2) + def evaluate(self, data: dict[str, np.ndarray], *args) -> np.ndarray: lhs: np.typing.ArrayLike rhs: Optional[np.typing.ArrayLike] @@ -635,6 +722,7 @@ def __init__( description: Optional[str] = None, _dep_map: dict[str, UUID] | None = None, no_cache: bool = False, + _uuid: UUID | None = None, **kwargs: Any, ): self.__func = func @@ -647,7 +735,7 @@ def __init__( self.__batch_size = batch_size self.__no_cache = no_cache self.description = description - self.__uuid = uuid4() + self.__uuid = _uuid if _uuid is not None else uuid4() self.__dep_map = _dep_map @property @@ -676,6 +764,7 @@ def bind(self, name_to_uuid: dict[str, UUID]) -> EvaluatedColumn: self.description, _dep_map=dep_map, no_cache=self.__no_cache, + _uuid=self.__uuid, **self.__kwargs, ) diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 4e6261cb..42b95905 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -418,7 +418,6 @@ def test_keep_top_host_take_random(core_path_475, core_path_487): assert_top_host_idx_correct(data, core_map) assert np.all(data["top_host_idx"] >= 0) assert_all_group_members_present(data, core_map) - print(len(ds)) def test_keep_top_host_take_start(core_path_475, core_path_487): diff --git a/test/test_lightcone.py b/test/test_lightcone.py index dcbb786e..83177aee 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -40,8 +40,9 @@ def test_create_theta_phi_coords(haloproperties_600_path, haloproperties_601_pat ra = (data["phi"] * u.rad).to(u.deg) dec = ((np.pi / 2 - data["theta"]) * u.rad).to(u.deg) - assert np.allclose(data["ra"], ra, atol=0.01, rtol=1e-2) - assert np.allclose(data["dec"], dec, atol=0.01, rtol=1e-2) + + assert np.allclose(data["ra"], ra, atol=0.0001, rtol=1e-2) + assert np.allclose(data["dec"], dec, atol=0.0001, rtol=1e-2) def test_lightcone_physical_units(haloproperties_600_path): From 7d9d7dd93ff32b15edfab54f8733655f2750c7a4 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 7 May 2026 13:01:13 -0500 Subject: [PATCH 101/139] Linking to galaxies in a lightcone structure collection is working --- .../opencosmo/collection/lightcone/plugins.py | 3 +- .../opencosmo/collection/structure/handler.py | 36 +++++++++++--- python/opencosmo/collection/structure/io.py | 48 ++++++++----------- .../collection/structure/structure.py | 6 +-- 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 83fc9169..bcfdb1c2 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -4,6 +4,7 @@ import warnings from typing import TYPE_CHECKING +import astropy.cosmology.units as cu import astropy.units as u import numpy as np @@ -97,7 +98,5 @@ def radec_from_thetaphi(theta, phi): def redshift_from_chi(chi, cosmology): - import astropy.cosmology.units as cu - redshift = chi.to(cu.redshift, cu.redshift_distance(cosmology, kind="comoving")) return {"redshift": redshift} diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 12ac01d9..b7be48cf 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -1,11 +1,13 @@ from __future__ import annotations +from collections import defaultdict from functools import partial, reduce -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, cast import numpy as np from opencosmo.collection.lightcone import lightcone as lc +from opencosmo.collection.structure import structure as sc from opencosmo.index import into_array if TYPE_CHECKING: @@ -152,7 +154,7 @@ def __init__( self, links, columns, - derived_from: Optional[oc.Dataset], + derived_from: Optional[oc.Dataset | oc.Lightcone], ): self.__derived_from = derived_from self.links = links @@ -164,7 +166,9 @@ def from_link_names(cls, names: Iterable[str], rename_galaxies=False): return LinkHandler(links, columns, None) def parse( - self, data: dict[str, Any], offsets: Optional[dict[str, list[int]]] = None + self, + data: dict[str, Any], + offsets: Optional[dict[str, list[tuple[int, int]]]] = None, ): output = {} for name, handler in self.links.items(): @@ -184,7 +188,6 @@ def prep_datasets( Called once when a datasets are opened for the first time. Downstream versions always use rebuild_datsets """ - all_columns: list[str] = reduce( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) @@ -192,7 +195,12 @@ def prep_datasets( offsets = None if isinstance(source, lc.Lightcone): offsets = {} + for ds_type, lightcone in datasets.items(): + if isinstance(lightcone, sc.StructureCollection): + lightcone = lightcone["galaxy_properties"] + + assert isinstance(lightcone, lc.Lightcone) offsets[ds_type] = [ (len(source[key]), len(ds)) for key, ds in lightcone.items() ] @@ -218,12 +226,19 @@ def make_derived(self, source: oc.Dataset): return LinkHandler(self.links, self.columns, derived_from) def rebuild_lightcones( - self, new_source: oc.Lightcone, lightcones: dict[str, oc.Lightcone] + self, + new_source: oc.Lightcone, + lightcones: dict[str, oc.Lightcone | sc.StructureCollection], ): - new_datasets = {name: {} for name in lightcones} + new_datasets: dict[str, lc.Lightcone | sc.StructureCollection] = defaultdict( + dict + ) + if "galaxies" in lightcones: + raise NotImplementedError() for step, step_source in new_source.items(): step_datasets = {name: lc[step] for name, lc in lightcones.items()} + assert isinstance(self.__derived_from, lc.Lightcone) new_step_datasets = self.__rebuild_datasets( self.__derived_from[step], step_source, step_datasets ) @@ -237,7 +252,7 @@ def rebuild_lightcones( def rebuild_datasets( self, new_source: oc.Dataset | oc.Lightcone, - datasets: dict[str, oc.Dataset], + datasets: Mapping[str, oc.Dataset | oc.Lightcone | sc.StructureCollection], ): """ We have a few guarantees here: @@ -250,6 +265,13 @@ def rebuild_datasets( if self.__derived_from is None: return datasets elif isinstance(new_source, lc.Lightcone): + assert all( + isinstance(ds_, (lc.Lightcone | sc.StructureCollection)) + for ds_ in datasets.values() + ) + datasets = cast( + "dict[str, sc.StructureCollection | lc.Lightcone]", datasets + ) return self.rebuild_lightcones(new_source, datasets) return self.__rebuild_datasets(self.__derived_from, new_source, datasets) diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index 8a49bf8b..a0a9b3b3 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -7,6 +7,7 @@ import numpy as np +from opencosmo import dataset as d from opencosmo import io from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import structure as sc @@ -14,7 +15,6 @@ if TYPE_CHECKING: import h5py - from opencosmo import dataset as d from opencosmo.io.iopen import FileTarget ALLOWED_LINKS = { # h5py.Files that can serve as a link holder and @@ -126,8 +126,7 @@ def build_lightcone_structure_collection( link_sources: dict[str, list[io.iopen.DatasetTarget]], link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]], ): - found_redshift_steps = set() - print(link_sources.keys(), link_targets.keys()) + found_redshift_steps: set[int] = set() for source_type, source_list in link_sources.items(): if not all(t["header"].file.is_lightcone for t in source_list): raise ValueError("All sources must be lightcone datasets!") @@ -165,9 +164,9 @@ def build_lightcone_structure_collection( {ds.header.file.step: ds for ds in datasets} ) galaxy_target_datasets = {} - for target_type, targets in link_targets[source_type].items(): + for target_type, targets in link_targets["galaxy_properties"].items(): galaxy_target_datasets[target_type] = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in targets} + {ds.header.file.step: ds for ds in targets} # type: ignore # already asserted this step exists ) collection = sc.StructureCollection(galaxy_lightcone, galaxy_target_datasets) if len(link_sources.get("halo_properties", [])) > 0: @@ -175,28 +174,24 @@ def build_lightcone_structure_collection( else: return collection - print(link_targets) - assert False - for source_type, source_list in link_sources.items(): - if source_type == "galaxy_properties": - raise NotImplementedError - - datasets = [ - io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) - for t in source_list - ] - output_sources[source_type] = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in datasets} - ) - for target_type, targets in link_targets[source_type].items(): - output_targets[source_type][target_type] = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in targets} - ) - return sc.StructureCollection( - output_sources["halo_properties"], output_targets["halo_properties"] + source_list = link_sources["halo_properties"] + source_datasets = [ + io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) + for t in source_list + ] + source_lightcone = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in source_datasets} ) - - return output_sources, output_targets + output_targets = {} + for target_type, targets in link_targets[source_type].items(): + if isinstance(targets, (d.Dataset, sc.StructureCollection)): + output_targets[target_type] = targets + continue + + output_targets[target_type] = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in targets} + ) + return sc.StructureCollection(source_lightcone, output_targets) def __build_structure_collection( @@ -217,7 +212,6 @@ def __build_structure_collection( source_dataset = remove_empty(source_dataset) collection = sc.StructureCollection( source_dataset, - source_dataset.header, link_targets["galaxy_properties"], ) if halo_properties_target is not None: diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index d2356593..eefd555d 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -188,7 +188,7 @@ def redshift(self) -> float | tuple[float, float] | None: raise NotImplementedError @property - def simulation(self) -> HaccSimulationParameters: + def simulation(self) -> HaccSimulationParameters | None: """ Get the parameters of the simulation this dataset is drawn from. @@ -664,7 +664,7 @@ def filter(self, *masks, on_galaxies: bool = False) -> StructureCollection: galaxy_properties = self["galaxy_properties"] assert isinstance(galaxy_properties, oc.Dataset) filtered = filter_source_by_dataset( - galaxy_properties, self.__source, self.__header, *masks + galaxy_properties, self.__source, self.__source.header, *masks ) new_handler = self.__handler.make_derived(self.__source) @@ -759,7 +759,7 @@ def select( arg = columns # type: ignore kwargs = {} - if dataset == self.__source.file.data_type: + if dataset == self.__source.header.file.data_type: new_source = self.__source.select(arg, **kwargs) continue From 4df804d47094dc60247044a60ca3a7aba30c655f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 14 May 2026 21:26:40 -0400 Subject: [PATCH 102/139] Add rebuilding for nested galaxy structure collections --- .../opencosmo/collection/lightcone/plugins.py | 2 +- .../opencosmo/collection/structure/handler.py | 70 ++++++++++++++++- .../collection/structure/structure.py | 2 +- test/test_lightcone.py | 7 -- test/test_structure_collection.py | 76 +++++++++++++++++++ 5 files changed, 144 insertions(+), 13 deletions(-) create mode 100644 test/test_structure_collection.py diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index bcfdb1c2..1e8478b0 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -48,7 +48,7 @@ def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: """Ensures a column called 'redshift' exists on every lightcone.""" lightcone: Lightcone = ctx.lightcone if ( - "properties" not in lightcone.dtype + "particles" in lightcone.dtype or "profiles" in lightcone.dtype ): # Particles or profiles, redshift handled at structure collection level return ctx if "redshift" in lightcone.columns: diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index b7be48cf..4152c479 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -233,21 +233,83 @@ def rebuild_lightcones( new_datasets: dict[str, lc.Lightcone | sc.StructureCollection] = defaultdict( dict ) - if "galaxies" in lightcones: - raise NotImplementedError() + + lc_datasets = { + k: v for k, v in lightcones.items() if isinstance(v, lc.Lightcone) + } + sc_datasets = { + k: v for k, v in lightcones.items() if isinstance(v, sc.StructureCollection) + } for step, step_source in new_source.items(): - step_datasets = {name: lc[step] for name, lc in lightcones.items()} + step_datasets = {name: ds[step] for name, ds in lc_datasets.items()} assert isinstance(self.__derived_from, lc.Lightcone) new_step_datasets = self.__rebuild_datasets( self.__derived_from[step], step_source, step_datasets ) for name, new_step_ds in new_step_datasets.items(): new_datasets[name][step] = new_step_ds - return { + + result: dict[str, lc.Lightcone | sc.StructureCollection] = { name: lc.Lightcone.from_datasets(datasets) for name, datasets in new_datasets.items() } + for name, galaxy_sc in sc_datasets.items(): + result[name] = self.__rebuild_sc_from_lightcone(new_source, name, galaxy_sc) + return result + + def __rebuild_sc_from_lightcone( + self, + new_source: oc.Lightcone, + name: str, + galaxy_sc: sc.StructureCollection, + ) -> sc.StructureCollection: + """ + Rebuild a StructureCollection linked dataset (e.g. galaxies) after the + parent halo lightcone has been filtered. + + After prep_datasets the galaxy SC rows are in halo order, grouped by + step: [step_0 galaxies][step_1 galaxies].... We compute cumulative + size boundaries per step to find which chunks in the post-prep SC + correspond to halos that survive the filter, then call take_rows on + the SC directly — avoiding the need to wrap per-step + StructureCollections inside a Lightcone. + """ + assert isinstance(self.__derived_from, lc.Lightcone) + col_names = self.columns[name] + size_col_name = next(c for c in col_names if "size" in c) + + all_starts: list[np.ndarray] = [] + all_sizes: list[np.ndarray] = [] + galaxy_offset = 0 + + for step, step_source in new_source.items(): + step_derived = self.__derived_from[step] + meta = step_derived.get_metadata(col_names) + size_col = meta[size_col_name].astype(np.int64) + + original_index = into_array(step_derived.index) + new_index = into_array(step_source.index) + _, idx_into_original, idx_into_new = np.intersect1d( + original_index, new_index, assume_unique=True, return_indices=True + ) + idx_into_original = idx_into_original[np.argsort(idx_into_new)] + + # chunk_boundaries[i] = start of halo i's galaxies within this + # step's portion of the post-prep SC. Add galaxy_offset to get + # the global position across all steps. + chunk_boundaries = np.zeros(len(size_col) + 1, dtype=np.int64) + np.cumsum(size_col, out=chunk_boundaries[1:]) + + valid = size_col[idx_into_original] > 0 + all_starts.append(chunk_boundaries[idx_into_original[valid]] + galaxy_offset) + all_sizes.append(size_col[idx_into_original[valid]]) + + galaxy_offset += int(size_col.sum()) + + starts = np.concatenate(all_starts) if all_starts else np.array([], dtype=np.int64) + sizes = np.concatenate(all_sizes) if all_sizes else np.array([], dtype=np.int64) + return galaxy_sc.take_rows((starts, sizes)) def rebuild_datasets( self, diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index eefd555d..d629a0d6 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -822,7 +822,7 @@ def drop(self, **columns_to_drop): new_datasets = {} for dataset_name, columns in columns_to_drop.items(): - if dataset_name == self.__source.file.data_type: + if dataset_name == self.__source.header.file.data_type: new_source = self.__source.drop(columns) continue diff --git a/test/test_lightcone.py b/test/test_lightcone.py index 83177aee..924b2da9 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -539,13 +539,6 @@ def test_lightcone_stacking_nostack( assert ds_new.z_range == ds.z_range -@pytest.mark.skip def test_lightcone_structure_collection_open(structure_600): c = oc.open(*structure_600) assert isinstance(c, oc.StructureCollection) - - -@pytest.mark.skip -def test_lightcone_structure_collection_open_multiple(structure_600, structure_601): - with pytest.raises(NotImplementedError): - _ = oc.open(*structure_600, *structure_601) diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py new file mode 100644 index 00000000..bab638ba --- /dev/null +++ b/test/test_structure_collection.py @@ -0,0 +1,76 @@ +import numpy as np +import pytest + +import opencosmo as oc + + +@pytest.fixture +def halos_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "haloproperties.hdf5" + particles = lightcone_path / "step_600" / "haloparticles.hdf5" + profiles = lightcone_path / "step_600" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_600" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def halos_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "haloproperties.hdf5" + particles = lightcone_path / "step_601" / "haloparticles.hdf5" + profiles = lightcone_path / "step_601" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_601" / "galaxyparticles.hdf5" + return [properties, particles] + + +def test_open_lightcone_structure(halos_600_path, halos_601_path): + ds = oc.open(*halos_600_path, *halos_601_path) + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = ( + halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + ) + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + + +def test_open_lightcone_structure_with_galaxies( + halos_600_path, galaxies_600_path, halos_601_path, galaxies_601_path +): + ds = oc.open( + *halos_600_path, *galaxies_600_path, *halos_601_path, *galaxies_601_path + ) + ds = ds.filter(oc.col("sod_halo_mass") > 1e14) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = ( + halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + ) + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + print(galaxy.keys()) + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + assert False From e3b7ff17fee26cd654d70b0e433e93d212a1b35d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 15 May 2026 16:02:28 -0500 Subject: [PATCH 103/139] Updates to fix galaxy linking --- .../collection/lightcone/lightcone.py | 6 - .../opencosmo/collection/structure/handler.py | 154 ++++-------------- python/opencosmo/collection/structure/io.py | 115 +++++++++++-- .../collection/structure/structure.py | 4 +- python/opencosmo/dataset/dataset.py | 1 + test/test_collection.py | 2 + test/test_structure_collection.py | 18 +- 7 files changed, 161 insertions(+), 139 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index ad780d76..e96316c1 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -1005,12 +1005,6 @@ def take_rows(self, rows: DataIndex): """ index_range = get_range(rows) - if isinstance(rows, np.ndarray): - rows = np.sort(rows) - index_range = (index_range[0], index_range[1] + 1) - else: - order = np.argsort(rows[0]) - rows = (rows[0][order], rows[1][order]) if index_range[0] < 0 or index_range[1] > len(self): raise ValueError( diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 4152c479..47207701 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -1,17 +1,16 @@ from __future__ import annotations -from collections import defaultdict from functools import partial, reduce -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional import numpy as np from opencosmo.collection.lightcone import lightcone as lc -from opencosmo.collection.structure import structure as sc from opencosmo.index import into_array if TYPE_CHECKING: import opencosmo as oc + from opencosmo.collection.structure import structure as sc """ A tale in 3 acts: @@ -93,6 +92,27 @@ def create_idx(data, idx_name, offsets): return np.atleast_1d(idx) +def build_lightcone_index(old_source: lc.Lightcone, new_source: lc.Lightcone): + index = np.zeros(len(new_source), dtype=np.int64) + offset = 0 + rs = 0 + for name, ds in old_source.items(): + if name not in new_source.keys(): + offset += len(ds) + continue + original_index = into_array(ds.index) + new_index = into_array(new_source[name].index) + _, index_into_original, index_into_new = np.intersect1d( + original_index, new_index, assume_unique=True, return_indices=True + ) + index_into_original = index_into_original[np.argsort(index_into_new)] + + index[rs : rs + len(index_into_original)] = index_into_original + offset + offset += len(ds) + rs += len(index_into_original) + return index + + def make_links(keys, rename_galaxies=False): starts = list(filter(lambda key: "start" in key, keys)) sizes = list(filter(lambda key: "size" in key, keys)) @@ -192,20 +212,10 @@ def prep_datasets( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) meta = source.get_metadata(all_columns) - offsets = None - if isinstance(source, lc.Lightcone): - offsets = {} - - for ds_type, lightcone in datasets.items(): - if isinstance(lightcone, sc.StructureCollection): - lightcone = lightcone["galaxy_properties"] - - assert isinstance(lightcone, lc.Lightcone) - offsets[ds_type] = [ - (len(source[key]), len(ds)) for key, ds in lightcone.items() - ] - - indices = self.parse(meta, offsets) + # Offsets are now baked into the metadata columns at construction time + # (see build_lightcone_structure_collection in io.py), so no per-step + # offset calculation is needed here. + indices = self.parse(meta, offsets=None) new_datasets = datasets for name, index in indices.items(): @@ -225,92 +235,6 @@ def make_derived(self, source: oc.Dataset): return LinkHandler(self.links, self.columns, derived_from) - def rebuild_lightcones( - self, - new_source: oc.Lightcone, - lightcones: dict[str, oc.Lightcone | sc.StructureCollection], - ): - new_datasets: dict[str, lc.Lightcone | sc.StructureCollection] = defaultdict( - dict - ) - - lc_datasets = { - k: v for k, v in lightcones.items() if isinstance(v, lc.Lightcone) - } - sc_datasets = { - k: v for k, v in lightcones.items() if isinstance(v, sc.StructureCollection) - } - - for step, step_source in new_source.items(): - step_datasets = {name: ds[step] for name, ds in lc_datasets.items()} - assert isinstance(self.__derived_from, lc.Lightcone) - new_step_datasets = self.__rebuild_datasets( - self.__derived_from[step], step_source, step_datasets - ) - for name, new_step_ds in new_step_datasets.items(): - new_datasets[name][step] = new_step_ds - - result: dict[str, lc.Lightcone | sc.StructureCollection] = { - name: lc.Lightcone.from_datasets(datasets) - for name, datasets in new_datasets.items() - } - for name, galaxy_sc in sc_datasets.items(): - result[name] = self.__rebuild_sc_from_lightcone(new_source, name, galaxy_sc) - return result - - def __rebuild_sc_from_lightcone( - self, - new_source: oc.Lightcone, - name: str, - galaxy_sc: sc.StructureCollection, - ) -> sc.StructureCollection: - """ - Rebuild a StructureCollection linked dataset (e.g. galaxies) after the - parent halo lightcone has been filtered. - - After prep_datasets the galaxy SC rows are in halo order, grouped by - step: [step_0 galaxies][step_1 galaxies].... We compute cumulative - size boundaries per step to find which chunks in the post-prep SC - correspond to halos that survive the filter, then call take_rows on - the SC directly — avoiding the need to wrap per-step - StructureCollections inside a Lightcone. - """ - assert isinstance(self.__derived_from, lc.Lightcone) - col_names = self.columns[name] - size_col_name = next(c for c in col_names if "size" in c) - - all_starts: list[np.ndarray] = [] - all_sizes: list[np.ndarray] = [] - galaxy_offset = 0 - - for step, step_source in new_source.items(): - step_derived = self.__derived_from[step] - meta = step_derived.get_metadata(col_names) - size_col = meta[size_col_name].astype(np.int64) - - original_index = into_array(step_derived.index) - new_index = into_array(step_source.index) - _, idx_into_original, idx_into_new = np.intersect1d( - original_index, new_index, assume_unique=True, return_indices=True - ) - idx_into_original = idx_into_original[np.argsort(idx_into_new)] - - # chunk_boundaries[i] = start of halo i's galaxies within this - # step's portion of the post-prep SC. Add galaxy_offset to get - # the global position across all steps. - chunk_boundaries = np.zeros(len(size_col) + 1, dtype=np.int64) - np.cumsum(size_col, out=chunk_boundaries[1:]) - - valid = size_col[idx_into_original] > 0 - all_starts.append(chunk_boundaries[idx_into_original[valid]] + galaxy_offset) - all_sizes.append(size_col[idx_into_original[valid]]) - - galaxy_offset += int(size_col.sum()) - - starts = np.concatenate(all_starts) if all_starts else np.array([], dtype=np.int64) - sizes = np.concatenate(all_sizes) if all_sizes else np.array([], dtype=np.int64) - return galaxy_sc.take_rows((starts, sizes)) - def rebuild_datasets( self, new_source: oc.Dataset | oc.Lightcone, @@ -326,28 +250,22 @@ def rebuild_datasets( """ if self.__derived_from is None: return datasets - elif isinstance(new_source, lc.Lightcone): - assert all( - isinstance(ds_, (lc.Lightcone | sc.StructureCollection)) - for ds_ in datasets.values() - ) - datasets = cast( - "dict[str, sc.StructureCollection | lc.Lightcone]", datasets - ) - return self.rebuild_lightcones(new_source, datasets) return self.__rebuild_datasets(self.__derived_from, new_source, datasets) def __rebuild_datasets(self, derived_from, new_source, datasets): - original_index = into_array(derived_from.index) - new_index = into_array(new_source.index) + if isinstance(derived_from, lc.Lightcone): + index_into_original = build_lightcone_index(derived_from, new_source) + else: + original_index = into_array(derived_from.index) + new_index = into_array(new_source.index) - _, index_into_original, index_into_new = np.intersect1d( - original_index, new_index, assume_unique=True, return_indices=True - ) + _, index_into_original, index_into_new = np.intersect1d( + original_index, new_index, assume_unique=True, return_indices=True + ) + index_into_original = index_into_original[np.argsort(index_into_new)] all_columns: list[str] = reduce( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) - index_into_original = index_into_original[np.argsort(index_into_new)] metadata = derived_from.get_metadata(all_columns) new_datasets = {} diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index a0a9b3b3..40d39359 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -7,10 +7,12 @@ import numpy as np +import opencosmo as oc from opencosmo import dataset as d from opencosmo import io from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import structure as sc +from opencosmo.collection.structure.handler import LINK_ALIASES if TYPE_CHECKING: import h5py @@ -122,6 +124,76 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): ) +def _apply_offset_corrections( + source_by_step: dict[int, d.Dataset], + targets_by_step: dict[str, dict[int, d.Dataset | sc.StructureCollection]], +) -> dict[int, d.Dataset]: + """ + Correct step-local _start and _idx metadata columns to be globally correct + before stacking per-step source datasets into a Lightcone. + + For _start columns: apply a lazy DerivedColumn offset (oc.col(name) + offset). + For _idx columns: apply the offset eagerly (only to non-negative values). + + targets_by_step keys may be either file-level group name prefixes (e.g. + "sodbighaloparticles_dm_particles") or alias values (e.g. "galaxy_properties"). + Column names always use file-level prefixes (e.g. "sodbighaloparticles_dm_particles_start"), + so we match via direct lookup then fall back to a LINK_ALIASES alias lookup. + """ + steps = sorted(source_by_step) + + type_step_offset: dict[str, dict[int, int]] = {} + for target_type, step_map in targets_by_step.items(): + cumulative = 0 + per_step: dict[int, int] = {} + for step in steps: + per_step[step] = cumulative + ds = step_map.get(step) + if ds is not None: + cumulative += len(ds) + type_step_offset[target_type] = per_step + + corrected: dict[int, d.Dataset] = {} + for step in steps: + step_ds = source_by_step[step] + meta_cols = set(step_ds.meta_columns) + updates: dict = {} + + for col in meta_cols: + if col.endswith("_start"): + prefix = col[:-6] + is_idx = False + elif col.endswith("_idx"): + prefix = col[:-4] + is_idx = True + else: + continue + + # Try direct match (target_type key == file-level prefix), then + # fall back to alias lookup for cases like "galaxyproperties" -> "galaxy_properties". + offset = None + if prefix in type_step_offset: + offset = type_step_offset[prefix][step] + elif prefix in LINK_ALIASES and LINK_ALIASES[prefix] in type_step_offset: + offset = type_step_offset[LINK_ALIASES[prefix]][step] + + if not offset: + continue + + if is_idx: + arr = step_ds.get_metadata([col])[col].copy() + arr[arr >= 0] += offset + updates[col] = arr + else: + updates[col] = oc.col(col) + offset + + if updates: + step_ds = step_ds.with_new_columns(allow_overwrite=True, **updates) + corrected[step] = step_ds + + return corrected + + def build_lightcone_structure_collection( link_sources: dict[str, list[io.iopen.DatasetTarget]], link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]], @@ -151,7 +223,7 @@ def build_lightcone_structure_collection( and "galaxy_properties" in link_targets ): # Galaxy properties and galaxy particles - datasets = [ + galaxy_datasets = [ io.iopen.open_single_dataset( t, "data_linked", @@ -160,13 +232,19 @@ def build_lightcone_structure_collection( ) for t in link_sources["galaxy_properties"] ] - galaxy_lightcone = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in datasets} + galaxy_source_by_step = {ds.header.file.step: ds for ds in galaxy_datasets} + galaxy_targets_by_step = { + target_type: {ds.header.file.step: ds for ds in targets} # type: ignore + for target_type, targets in link_targets["galaxy_properties"].items() + } + galaxy_source_by_step = _apply_offset_corrections( + galaxy_source_by_step, galaxy_targets_by_step ) + galaxy_lightcone = lc.Lightcone.from_datasets(galaxy_source_by_step) galaxy_target_datasets = {} for target_type, targets in link_targets["galaxy_properties"].items(): galaxy_target_datasets[target_type] = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in targets} # type: ignore # already asserted this step exists + {ds.header.file.step: ds for ds in targets} # type: ignore ) collection = sc.StructureCollection(galaxy_lightcone, galaxy_target_datasets) if len(link_sources.get("halo_properties", [])) > 0: @@ -174,20 +252,35 @@ def build_lightcone_structure_collection( else: return collection - source_list = link_sources["halo_properties"] - source_datasets = [ + halo_source_list = link_sources["halo_properties"] + halo_datasets = [ io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) - for t in source_list + for t in halo_source_list ] - source_lightcone = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in source_datasets} + halo_source_by_step = {ds.header.file.step: ds for ds in halo_datasets} + halo_targets_by_step: dict[str, dict[int, d.Dataset]] = {} + for target_type, targets in link_targets["halo_properties"].items(): + if isinstance(targets, sc.StructureCollection): + # For a nested SC (e.g. galaxies), target_type is its source dtype + # ("galaxy_properties"), so targets[target_type] returns the source + # Lightcone, giving us per-step sizes for offset accounting. + inner_lc = targets[target_type] + assert isinstance(inner_lc, lc.Lightcone) + halo_targets_by_step[target_type] = dict(inner_lc) + elif isinstance(targets, list): + halo_targets_by_step[target_type] = { + ds.header.file.step: ds for ds in targets + } + halo_source_by_step = _apply_offset_corrections( + halo_source_by_step, halo_targets_by_step ) + source_lightcone = lc.Lightcone.from_datasets(halo_source_by_step) + output_targets = {} - for target_type, targets in link_targets[source_type].items(): + for target_type, targets in link_targets["halo_properties"].items(): if isinstance(targets, (d.Dataset, sc.StructureCollection)): output_targets[target_type] = targets continue - output_targets[target_type] = lc.Lightcone.from_datasets( {ds.header.file.step: ds for ds in targets} ) diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index d629a0d6..a136854b 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -112,7 +112,9 @@ def __init__( self.__handler = LinkHandler.from_link_names( self.__source.meta_columns, "galaxies" in self.__datasets ) - datasets = self.__handler.prep_datasets(self.__source, self.__datasets) + self.__datasets = self.__handler.prep_datasets( + self.__source, self.__datasets + ) else: self.__handler = link_handler diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 741fb8eb..0e87d25e 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -823,6 +823,7 @@ def take_rows(self, rows: np.ndarray | DataIndex): dataset. """ + row_range = get_range(rows) if row_range[0] < 0 or row_range[1] > len(self): raise ValueError( diff --git a/test/test_collection.py b/test/test_collection.py index 3e03c7ec..1c88ce0c 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1323,3 +1323,5 @@ def test_modify_metadata_column(halo_paths): (galaxyproperties_start["galaxyproperties_start"] + 1000) == updated_galaxyproperties_start["galaxyproperties_start"] ) + assert "galaxyproperties_start" not in ds["halo_properties"].columns + assert "galaxyproperties_start" not in ds["halo_properties"].get_data("numpy") diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py index bab638ba..7a76b5a1 100644 --- a/test/test_structure_collection.py +++ b/test/test_structure_collection.py @@ -47,13 +47,25 @@ def test_open_lightcone_structure(halos_600_path, halos_601_path): assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) +def test_open_lightcone_galaxy_structure_collection( + galaxies_600_path, + galaxies_601_path, +): + ds = oc.open(*galaxies_600_path, *galaxies_601_path) + for galaxy in ds.take(100).galaxies(): + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + def test_open_lightcone_structure_with_galaxies( halos_600_path, galaxies_600_path, halos_601_path, galaxies_601_path ): ds = oc.open( *halos_600_path, *galaxies_600_path, *halos_601_path, *galaxies_601_path ) - ds = ds.filter(oc.col("sod_halo_mass") > 1e14) + ds = ds.filter(oc.col("sod_halo_mass") > 1e14).take(10) for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): gravity_particle_tags = ( @@ -70,7 +82,7 @@ def test_open_lightcone_structure_with_galaxies( == halo["halo_properties"]["fof_halo_tag"] ) - print(galaxy.keys()) + if "star_particles" not in galaxy: + continue tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) - assert False From f3ab876c74a106622838ed6217c197261cae4589 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 18 May 2026 16:18:02 -0500 Subject: [PATCH 104/139] Implement filtering with derived columns --- changes/+bfffa1ad.feature.rst | 1 + python/opencosmo/column/column.py | 115 ++++++++++++++++++---------- python/opencosmo/dataset/columns.py | 7 +- python/opencosmo/dataset/dataset.py | 9 +-- python/opencosmo/dataset/graph.py | 2 +- test/test_filters.py | 12 ++- 6 files changed, 96 insertions(+), 50 deletions(-) create mode 100644 changes/+bfffa1ad.feature.rst diff --git a/changes/+bfffa1ad.feature.rst b/changes/+bfffa1ad.feature.rst new file mode 100644 index 00000000..f0ffb6ef --- /dev/null +++ b/changes/+bfffa1ad.feature.rst @@ -0,0 +1 @@ +:py:meth:`Dataset.filter ` now accepts masks created from expressions built from column arithmetic. diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 580bc8c2..db26a391 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -18,7 +18,6 @@ import astropy.units as u # type: ignore import numpy as np -from astropy import table # type: ignore from opencosmo.column.evaluate import ( EvaluateStrategy, @@ -31,6 +30,8 @@ if TYPE_CHECKING: from uuid import UUID + from astropy import table + from opencosmo import Dataset from opencosmo.index import DataIndex @@ -94,7 +95,7 @@ def wrapper(self: Any, other: Any) -> Any: return wrapper -ColumnOrScalar = Union["Column", "DerivedColumn", int, float] +ColumnOrScalar = Union["ConstructedColumn", int, float, u.Quantity] def _log10( @@ -197,26 +198,30 @@ def __init__(self, name: str): self.name = name self.description = None + @property + def requires(self): + return {self.name} + def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.name, other, op.eq) + return ColumnMask(self, other, op.eq) def __ne__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.name, other, op.ne) + return ColumnMask(self, other, op.ne) def __gt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.name, other, op.gt) + return ColumnMask(self, other, op.gt) def __ge__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.name, other, op.ge) + return ColumnMask(self, other, op.ge) def __lt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.name, other, op.lt) + return ColumnMask(self, other, op.lt) def __le__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.name, other, op.le) + return ColumnMask(self, other, op.le) def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: - return ColumnMask(self.name, other, np.isin) + return ColumnMask(self, other, np.isin, force_binary=True) @_require_scalar_quantity def __rmul__(self, other: Any) -> DerivedColumn: @@ -430,6 +435,7 @@ def __init__( ): self.lhs = lhs self.rhs = rhs + self.name = output_name self.operation = operation self.description = description if description is not None else "None" @@ -451,6 +457,10 @@ def bind(self, name_to_uuid: dict[str, UUID]) -> DerivedColumn: producing it at the time this column was registered with a dataset. Returns a new bound DerivedColumn; does not mutate this instance. """ + required_names = self._traverse_names() + if missing := required_names.difference(name_to_uuid): + raise ValueError(f"Derived column depends on unknown columns {missing}") + dep_map = {name: name_to_uuid[name] for name in self._traverse_names()} return DerivedColumn( self.lhs, @@ -600,9 +610,30 @@ def exp10(self, expected_unit_container: u.LogUnit = u.DexUnit): def sqrt(self): return DerivedColumn(self, None, _sqrt) + def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore + return ColumnMask(self, other, op.eq) + + def __ne__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore + return ColumnMask(self, other, op.ne) + + def __gt__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.gt) + + def __ge__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.ge) + + def __lt__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.lt) + + def __le__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.le) + + def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: + return ColumnMask(self, other, np.isin, force_binary=True) + def evaluate(self, data: dict[str, np.ndarray], *args) -> np.ndarray: - lhs: np.typing.ArrayLike - rhs: Optional[np.typing.ArrayLike] + lhs: Any + rhs: Any match self.lhs: case DerivedColumn(): lhs = self.lhs.evaluate(data) @@ -817,36 +848,42 @@ class ColumnMask: def __init__( self, - name: str, - value: float | u.Quantity, + left: ColumnOrScalar, + right: ColumnOrScalar, operator: Callable[[table.Column, float | u.Quantity], np.ndarray], + force_binary: bool = False, ): - self.name = name - self.value = value + self.left = left + self.right = right self.operator = operator - @property - def requires(self): - return {self.name} - - def apply(self, column: u.Quantity | np.ndarray) -> np.ndarray: - """ - mask the dataset based on the mask. - """ - # Astropy's errors are good enough here - if isinstance(column, table.Table): - column = column[self.name] - - if isinstance(self.value, u.Quantity) and isinstance(column, u.Quantity): - if self.value.unit != column.unit: - raise ValueError( - f"Incompatible units in fiter: {self.value.unit} and {column.unit}" - ) - - elif isinstance(column, u.Quantity): - return self.operator(column.value, self.value) + def apply(self, ds: Dataset): + match self.left: + case Column(): + left = ds.select(self.left.name).get_data() + case DerivedColumn(): + left = ds.select(data=self.left).get_data() + case _: + left = self.left - return self.operator(column, self.value) # type: ignore + right_selected = False + match self.right: + case Column(): + right = ds.select(self.right.name).get_data() + right_selected = True + case DerivedColumn(): + right = ds.select(data=self.right).get_data() + right_selected = True + case _: + right = self.right + if ( + isinstance(left, u.Quantity) + and not isinstance(right, u.Quantity) + and not right_selected + ): + return self.operator(left.value, right) + result = self.operator(left, right) + return result def __and__(self, other: Self | CompoundColumnMask): return CompoundColumnMask(self, other, lambda left, right: left & right) @@ -879,7 +916,7 @@ def __and__(self, other: ColumnMask | Self): def __or__(self, other: ColumnMask | Self): return CompoundColumnMask(self, other, lambda left, right: left | right) - def apply(self, data): - left_mask = self.__left.apply(data) - right_mask = self.__right.apply(data) + def apply(self, ds: Dataset): + left_mask = self.__left.apply(ds) + right_mask = self.__right.apply(ds) return self.__op(left_mask, right_mask) diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py index 8d287680..e8ffab6b 100644 --- a/python/opencosmo/dataset/columns.py +++ b/python/opencosmo/dataset/columns.py @@ -72,8 +72,11 @@ def __categorize_columns( new_derived_columns.append(column) new_column_names.extend(column.produces) case Column(): - producer = RawColumn( - column.name, descriptions.get(colname, None), alias=colname + producer = DerivedColumn( + lhs=column, + rhs=None, + operation=lambda x, _: x, + output_name=colname, ) new_derived_columns.append(producer) new_column_names.extend(producer.produces) diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 741fb8eb..3e49a4fb 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import reduce from typing import ( TYPE_CHECKING, Callable, @@ -518,13 +517,9 @@ def filter(self, *masks: ColumnMask) -> Dataset: """ if not masks: return self - required_columns: set[str] = reduce( - lambda acc, r: acc | r.requires, masks, set() - ) - data = self.select(required_columns).get_data(unpack=False) - bool_mask = np.ones(len(data), dtype=bool) + bool_mask = np.ones(len(self), dtype=bool) for m in masks: - bool_mask &= m.apply(data) + bool_mask &= m.apply(self) new_state = st.take_rows(self.__state, np.where(bool_mask)[0]) return Dataset(self.__header, new_state, self.__tree) diff --git a/python/opencosmo/dataset/graph.py b/python/opencosmo/dataset/graph.py index 32ce073b..4c5e85a7 100644 --- a/python/opencosmo/dataset/graph.py +++ b/python/opencosmo/dataset/graph.py @@ -40,7 +40,7 @@ def validate_column_producers( f"Tried to derive columns from unknown columns: {node.produces}" ) - return get_derived_units(dependency_graph, unit_handler.base_units) + return get_derived_units(dependency_graph, unit_handler.current_units) def build_dependency_graph( diff --git a/test/test_filters.py b/test/test_filters.py index 8c64d8fa..142edd6c 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -5,6 +5,7 @@ import opencosmo as oc from opencosmo import col +from opencosmo.column import offset_3d @pytest.fixture @@ -31,6 +32,7 @@ def test_multi_filters_single_column(input_path, max_mass): ds = ds.filter(col("sod_halo_mass") > 0, col("sod_halo_mass") < max_mass) data = ds.get_data() + assert data["sod_halo_mass"].min() > 0 assert data["sod_halo_mass"].max() < max_mass @@ -138,7 +140,6 @@ def test_or_filter(input_path): ds = ds.filter(high_mass | low_mass) data = ds.select("fof_halo_mass").get_data("numpy") - assert len(data) > 0 assert np.all((data > 1e14) | (data < 1e12)) assert np.any(data > 1e14) assert np.any(data < 1e12) @@ -169,3 +170,12 @@ def test_filter_tree(input_path): assert np.all(data["sod_halo_cdelta"] < 5) assert np.any(data["fof_halo_mass"] > 1e14) assert np.any(data["fof_halo_mass"] < 1e12) + + +def test_filter_by_derived(input_path): + col = offset_3d("fof_halo_center", "fof_halo_com") / oc.col("sod_halo_radius") + ds = oc.open(input_path) + + ds = ds.filter(col > 0.1) + data = ds.select(xoff=col).get_data() + assert np.all(data > 0.1) From cbf5ea13209b3e29a031c80b26b8e0cb6c4d6f38 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 18 May 2026 16:21:42 -0500 Subject: [PATCH 105/139] Implement additional testing --- test/test_filters.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/test_filters.py b/test/test_filters.py index 142edd6c..6c246f20 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -179,3 +179,18 @@ def test_filter_by_derived(input_path): ds = ds.filter(col > 0.1) data = ds.select(xoff=col).get_data() assert np.all(data > 0.1) + + +def test_column_comparison(input_path): + ds = oc.open(input_path) + + ds = ds.filter(oc.col("fof_halo_center_x") < oc.col("fof_halo_center_y")) + data = ds.select("fof_halo_center_x", "fof_halo_center_y").get_data() + assert np.all(data["fof_halo_center_x"] < data["fof_halo_center_y"]) + + +def test_filter_bad_units(input_path): + ds = oc.open(input_path) + + with pytest.raises(u.UnitConversionError): + ds = ds.filter(oc.col("fof_halo_center_x") < 10 * u.kg) From 5169c613a6671f9043a476a4598d3eb6bb0f4945 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 18 May 2026 16:24:28 -0500 Subject: [PATCH 106/139] Changelog, additional test --- changes/+5661c945.bugfix.rst | 1 + test/test_select.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changes/+5661c945.bugfix.rst diff --git a/changes/+5661c945.bugfix.rst b/changes/+5661c945.bugfix.rst new file mode 100644 index 00000000..cd158d54 --- /dev/null +++ b/changes/+5661c945.bugfix.rst @@ -0,0 +1 @@ +Fix a bug that could cause renamed columns to be instantiated without the correct units diff --git a/test/test_select.py b/test/test_select.py index fd9fa58e..afb7f575 100644 --- a/test/test_select.py +++ b/test/test_select.py @@ -129,3 +129,10 @@ def test_select_with_derived(input_path): assert com_columns.issubset(data.columns) assert np.all(data["gal_px"] == data["gal_mass"] * data["gal_com_vx"]) + + +def test_select_single_check_units(input_path): + ds = oc.open(input_path) + ds = ds.select("fof_halo_center_x", center_x=oc.col("fof_halo_center_x")) + data = ds.get_data() + assert np.all(data["fof_halo_center_x"] == data["center_x"]) From 268f299fe08135be334e79960cd5c9a38367222f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 18 May 2026 18:49:50 -0500 Subject: [PATCH 107/139] Remove unneeded force_binary flag --- python/opencosmo/column/column.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index db26a391..b47868f0 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -221,7 +221,7 @@ def __le__(self, other: float | u.Quantity) -> ColumnMask: return ColumnMask(self, other, op.le) def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: - return ColumnMask(self, other, np.isin, force_binary=True) + return ColumnMask(self, other, np.isin) @_require_scalar_quantity def __rmul__(self, other: Any) -> DerivedColumn: @@ -629,7 +629,7 @@ def __le__(self, other: float | u.Quantity) -> ColumnMask: return ColumnMask(self, other, op.le) def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: - return ColumnMask(self, other, np.isin, force_binary=True) + return ColumnMask(self, other, np.isin) def evaluate(self, data: dict[str, np.ndarray], *args) -> np.ndarray: lhs: Any @@ -851,7 +851,6 @@ def __init__( left: ColumnOrScalar, right: ColumnOrScalar, operator: Callable[[table.Column, float | u.Quantity], np.ndarray], - force_binary: bool = False, ): self.left = left self.right = right From 4a860bac46efa9acac0c6bc39daf7c33ecd2bfbd Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 19 May 2026 14:09:54 -0500 Subject: [PATCH 108/139] Implement writing and re-reading for lightcone structure collections --- python/opencosmo/collection/lightcone/io.py | 5 +- .../collection/lightcone/lightcone.py | 38 +++++-- python/opencosmo/collection/protocols.py | 4 - .../opencosmo/collection/structure/handler.py | 19 ++-- python/opencosmo/collection/structure/io.py | 106 ++++++++++++++---- .../collection/structure/structure.py | 37 +++--- python/opencosmo/dataset/dataset.py | 17 ++- python/opencosmo/dataset/output.py | 5 +- python/opencosmo/dataset/state.py | 15 ++- python/opencosmo/io/iopen.py | 15 ++- python/opencosmo/io/schema.py | 8 ++ python/opencosmo/io/serial.py | 13 +++ test/test_structure_collection.py | 56 +++++++++ 13 files changed, 269 insertions(+), 69 deletions(-) diff --git a/python/opencosmo/collection/lightcone/io.py b/python/opencosmo/collection/lightcone/io.py index d5ee892f..8373af0f 100644 --- a/python/opencosmo/collection/lightcone/io.py +++ b/python/opencosmo/collection/lightcone/io.py @@ -58,7 +58,8 @@ def combine_adjacent_datasets_mpi( def combine_adjacent_datasets( ordered_datasets: dict[str, ocds.Dataset] | dict[str, dict[str, ocds.Dataset]], - min_dataset_size=100_000, + min_dataset_size: int, + no_stack: bool, ): is_single = isinstance(next(iter(ordered_datasets.values())), ocds.Dataset) datasets: dict[str, dict[str, ocds.Dataset]] @@ -80,7 +81,7 @@ def combine_adjacent_datasets( ) for key, step_datasets in datasets.items(): - if running_sum < min_dataset_size: + if not no_stack and running_sum < min_dataset_size: running_sum += sum(len(ds) for ds in step_datasets.values()) output_datasets[current_key].append(step_datasets) continue diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index e96316c1..a03338ff 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -265,6 +265,17 @@ def simulation(self) -> HaccSimulationParameters: """ return self.__header.simulation + @property + def sorted_by(self) -> Optional[str]: + """ + The column this dataset is sorted by. If not sorted, returns None. + + Returns + ------- + column: Optional[str] + """ + return self.__sort_key[0] if self.__sort_key is not None else None + @property def z_range(self): """ @@ -331,11 +342,11 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) if unpack: - data = { + output_data = { key: value[0] if len(value) == 1 else value for key, value in table.items() } - table = QTable(data) + table = QTable(output_data) if format != "astropy": return convert_data(dict(table), format) @@ -407,7 +418,7 @@ def open(cls, targets: list[FileTarget], **kwargs): @classmethod def from_datasets( cls, - datasets: dict[str, oc.Dataset], + datasets: Mapping[int, oc.Dataset], z_range: Optional[tuple[float, float]] = None, **open_kwargs, ): @@ -487,13 +498,15 @@ def __map( def __map_attribute(self, attribute): return {k: getattr(v, attribute) for k, v in self.items()} - def make_schema(self, name: str = "", _min_size=100_000) -> Schema: + def make_schema( + self, name: str = "", _min_size=100_000, no_stack: bool = False + ) -> Schema: datasets = lcio.order_by_redshift_range(self) for key in datasets: if isinstance(datasets[key], Lightcone): datasets[key] = dict(datasets[key]) output_datasets = lcio.combine_adjacent_datasets( - datasets, min_dataset_size=_min_size + datasets, min_dataset_size=_min_size, no_stack=no_stack ) children = {} @@ -1106,7 +1119,7 @@ def with_new_columns( new_datasets[ds_name] = new_dataset return Lightcone(new_datasets, self.z_range, self.__hidden, self.__sort_key) - def sort_by(self, column: str, invert: bool = False): + def sort_by(self, column: Optional[str], invert: bool = False): """ Sort this dataset by the values in a given column. By default sorting is in ascending order (least to greatest). Pass invert = True to sort in descending @@ -1124,9 +1137,9 @@ def sort_by(self, column: str, invert: bool = False): Parameters ---------- - column : str + column : Optional[str] The column in the halo_properties or galaxy_properties dataset to - order the collection by. + order the collection by. Pass None to remove sorting. invert : bool, default = False If False (the default), ordering will be from least to greatest. @@ -1140,9 +1153,14 @@ def sort_by(self, column: str, invert: bool = False): """ - if column not in self.columns: + if column is None: + sort_key = None + elif column not in self.columns: raise ValueError(f"Column {column} does not exist in this dataset!") - return Lightcone(dict(self), self.z_range, self.__hidden, (column, invert)) + else: + sort_key = (column, invert) + + return Lightcone(dict(self), self.z_range, self.__hidden, sort_key) def with_units( self, diff --git a/python/opencosmo/collection/protocols.py b/python/opencosmo/collection/protocols.py index 3d5b9382..2c4c0bbc 100644 --- a/python/opencosmo/collection/protocols.py +++ b/python/opencosmo/collection/protocols.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from opencosmo.column.column import ColumnMask from opencosmo.dataset import Dataset - from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema @@ -38,9 +37,6 @@ def make_schema(self) -> Schema: ... @property def dtype(self) -> str | dict[str, str]: ... - @property - def header(self) -> OpenCosmoHeader | dict[str, OpenCosmoHeader]: ... - def __getitem__(self, key: str) -> Union[Dataset, "Collection"]: ... def keys(self) -> Iterable[str]: ... def values(self) -> Iterable[Union[Dataset, "Collection"]]: ... diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 47207701..8f2bad54 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -289,25 +289,30 @@ def __rebuild_datasets(self, derived_from, new_source, datasets): ) return new_datasets - def resort(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): + def resort( + self, source: oc.Dataset | oc.Lightcone, datasets: dict[str, oc.Dataset] + ): """ Data is always written in its original order, whether or not it has been sorted. This is to preserve the spatial index. However, when linked datasets are rebuilt they are rebuilt in the sorted order. This method re-sorts them based on the index from the original data. """ + + is_sorted = source.sorted_by is not None + if not is_sorted: + return datasets + all_columns: list[str] = reduce( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) all_columns = list( filter(lambda name: "idx" in name or "size" in name, all_columns) ) - - sort_index = np.argsort(into_array(source.index)) - - if np.all(sort_index[1:] >= sort_index[:-1]): - # Already sorted. Carry on! - return datasets + if isinstance(source, lc.Lightcone): + raise NotImplementedError + else: + sort_index = np.argsort(source.index) meta = source.get_metadata(all_columns) output = {} diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index 40d39359..1e962bff 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -1,9 +1,9 @@ from __future__ import annotations from collections import defaultdict -from functools import reduce +from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Mapping, Optional, TypeGuard import numpy as np @@ -16,6 +16,7 @@ if TYPE_CHECKING: import h5py + from mpi4py import MPI from opencosmo.io.iopen import FileTarget @@ -39,6 +40,10 @@ def remove_empty(dataset): return dataset +def is_dataset(ds: Any) -> TypeGuard[d.Dataset]: + return isinstance(ds, d.Dataset) + + def validate_linked_groups(groups: dict[str, h5py.Group]): if "halo_properties" in groups: if "data_linked" not in groups["halo_properties"].keys(): @@ -59,9 +64,13 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]] = ( defaultdict(lambda: defaultdict(list)) ) - dataset_targets: list[io.iopen.DatasetTarget] = reduce( - lambda acc, t: acc + t["dataset_targets"], targets, [] - ) + + dataset_targets: list[io.iopen.DatasetTarget] = [] + for t in targets: + dataset_targets.extend(t["dataset_targets"]) + for datasets in t["dataset_groups"].values(): + dataset_targets.extend(datasets) + for target in dataset_targets: if target["header"].file.data_type == "halo_properties": link_sources["halo_properties"].append(target) @@ -71,7 +80,14 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): dataset = io.iopen.open_single_dataset( target, bypass_lightcone=True, bypass_mpi=True ) - name = target["dataset_group"].name.split("/")[-1] + name_source = target["dataset_group"] + if ( + "particles" in name_source.parent.name + or "profiles" in target["dataset_group"].parent.name + ): + name_source = target["dataset_group"].parent + name = name_source.name.split("/")[-1] + if not name: name = target["header"].file.data_type elif name.startswith("halo_properties"): @@ -125,8 +141,8 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): def _apply_offset_corrections( - source_by_step: dict[int, d.Dataset], - targets_by_step: dict[str, dict[int, d.Dataset | sc.StructureCollection]], + source_by_step: Mapping[int, d.Dataset], + targets_by_step: Mapping[str, Mapping[int, d.Dataset | sc.StructureCollection]], ) -> dict[int, d.Dataset]: """ Correct step-local _start and _idx metadata columns to be globally correct @@ -232,9 +248,14 @@ def build_lightcone_structure_collection( ) for t in link_sources["galaxy_properties"] ] - galaxy_source_by_step = {ds.header.file.step: ds for ds in galaxy_datasets} - galaxy_targets_by_step = { - target_type: {ds.header.file.step: ds for ds in targets} # type: ignore + galaxy_source_by_step: dict[int, d.Dataset] = {} + for ds in galaxy_datasets: + assert ds.header.file.step is not None + galaxy_source_by_step[ds.header.file.step] = ds + galaxy_targets_by_step: dict[ + str, Mapping[int, d.Dataset | sc.StructureCollection] + ] = { + target_type: {ds.header.file.step: ds for ds in targets} # type: ignore[misc] for target_type, targets in link_targets["galaxy_properties"].items() } galaxy_source_by_step = _apply_offset_corrections( @@ -248,7 +269,7 @@ def build_lightcone_structure_collection( ) collection = sc.StructureCollection(galaxy_lightcone, galaxy_target_datasets) if len(link_sources.get("halo_properties", [])) > 0: - link_targets["halo_properties"]["galaxy_properties"] = collection + link_targets["halo_properties"]["galaxy_properties"] = collection # type: ignore[assignment] else: return collection @@ -257,7 +278,10 @@ def build_lightcone_structure_collection( io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) for t in halo_source_list ] - halo_source_by_step = {ds.header.file.step: ds for ds in halo_datasets} + halo_source_by_step: dict[int, d.Dataset] = {} + for ds in halo_datasets: + assert ds.header.file.step is not None + halo_source_by_step[ds.header.file.step] = ds halo_targets_by_step: dict[str, dict[int, d.Dataset]] = {} for target_type, targets in link_targets["halo_properties"].items(): if isinstance(targets, sc.StructureCollection): @@ -268,9 +292,12 @@ def build_lightcone_structure_collection( assert isinstance(inner_lc, lc.Lightcone) halo_targets_by_step[target_type] = dict(inner_lc) elif isinstance(targets, list): - halo_targets_by_step[target_type] = { - ds.header.file.step: ds for ds in targets - } + step_map: dict[int, d.Dataset] = {} + for ds in targets: + assert isinstance(ds, d.Dataset) + assert ds.header.file.step is not None + step_map[ds.header.file.step] = ds + halo_targets_by_step[target_type] = step_map halo_source_by_step = _apply_offset_corrections( halo_source_by_step, halo_targets_by_step ) @@ -281,9 +308,13 @@ def build_lightcone_structure_collection( if isinstance(targets, (d.Dataset, sc.StructureCollection)): output_targets[target_type] = targets continue - output_targets[target_type] = lc.Lightcone.from_datasets( - {ds.header.file.step: ds for ds in targets} - ) + output_targets_of_type: dict[int, d.Dataset] = {} + for ds in targets: + assert isinstance(ds, d.Dataset) + assert ds.header.file.step is not None + output_targets_of_type[ds.header.file.step] = ds + + output_targets[target_type] = lc.Lightcone.from_datasets(output_targets_of_type) return sc.StructureCollection(source_lightcone, output_targets) @@ -334,3 +365,40 @@ def __build_structure_collection( source_dataset, link_targets["halo_properties"], ) + + +def do_idx_update(data: np.ndarray, comm: Optional[MPI.Comm] = None): + if comm is None: + return np.arange(len(data)) + lengths = comm.allgather(len(data)) + offsets = np.insert(np.cumsum(lengths), 0, 0) + offset = offsets[comm.Get_rank()] + result = np.arange(offset, offset + len(data)) + return result + + +def do_start_update(data: np.ndarray, size: np.ndarray, comm: Optional[MPI.Comm]): + psum = np.insert(np.cumsum(size), 0, 0)[:-1] + if comm is None: + return psum + lengths = comm.allgather(np.sum(size)) + offsets = np.insert(np.cumsum(lengths), 0, 0) + offset = offsets[comm.Get_rank()] + return psum + offset + + +def rebuild_data_linked(source_schema): + if source_schema.type == io.schema.FileEntry.LIGHTCONE: + for key, value in source_schema.children.items(): + source_schema.children[key] = rebuild_data_linked(value) + return source_schema + + for colname, column in source_schema.children["data_linked"].columns.items(): + if "idx" in colname: + column.set_transformation(do_idx_update) + elif "start" in colname: + size_colname = colname.replace("start", "size") + size_data = source_schema.children["data_linked"].columns[size_colname].data + updater = partial(do_start_update, size=size_data) + column.set_transformation(updater) + return source_schema diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index a136854b..06dccf47 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from functools import partial, reduce +from functools import reduce from inspect import signature from typing import ( TYPE_CHECKING, @@ -18,6 +18,7 @@ import numpy as np import opencosmo as oc +from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import evaluate from opencosmo.collection.structure import io as sio from opencosmo.index.unary import get_length @@ -31,6 +32,7 @@ from opencosmo.column.column import ConstructedColumn from opencosmo.dtypes import HaccSimulationParameters + from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema @@ -140,12 +142,16 @@ def __get_datasets(self): def __repr__(self): structure_type = self.__source.header.file.data_type.split("_")[0] + "s" + is_lightcone = isinstance(self.__source, lc.Lightcone) keys = list(self.keys()) if len(keys) == 2: dtype_str = " and ".join(keys) else: dtype_str = ", ".join(keys[:-1]) + ", and " + keys[-1] - return f"Collection of {structure_type} with {dtype_str}" + header = f"Collection of {structure_type} {'on a lightcone ' if is_lightcone else ' '}with {dtype_str}\n" + source_repr = self.__source.__repr__().split("\n", maxsplit=1)[1] + + return header + source_repr def __len__(self): return len(self.__source) @@ -154,7 +160,8 @@ def __len__(self): def open( cls, targets: list[FileTarget], ignore_empty=True, **kwargs ) -> StructureCollection: - return sio.build_structure_collection(targets, ignore_empty) + result = sio.build_structure_collection(targets, ignore_empty) + return result @property def dtype(self): @@ -168,6 +175,10 @@ def cosmology(self) -> astropy.cosmology.Cosmology: """ return self.__source.cosmology + @property + def header(self) -> OpenCosmoHeader: + return self.__source.header + @property def properties(self) -> list[str]: """ @@ -1388,25 +1399,17 @@ def make_schema(self, name: Optional[str] = None) -> Schema: children = {} source_name = self.__source.dtype datasets = self.__handler.resort(self.__source, self.__get_datasets()) + schema_kwargs: dict[str, Any] = ( + {"no_stack": True} if isinstance(self.__source, lc.Lightcone) else {} + ) - source_schema = self.__source.make_schema() - for colname, column in source_schema.children["data_linked"].columns.items(): - if "idx" in colname: - column.set_transformation(do_idx_update) - elif "start" in colname: - size_colname = colname.replace("start", "size") - size_data = ( - source_schema.children["data_linked"].columns[size_colname].data - ) - updater = partial(do_start_update, size=size_data) - column.set_transformation(updater) - - children[source_name] = source_schema + source_schema = self.__source.make_schema(**schema_kwargs) + children[source_name] = sio.rebuild_data_linked(source_schema) for name, dataset in datasets.items(): if name == "galaxies": name = "galaxy_properties" - ds_schema = dataset.make_schema() + ds_schema = dataset.make_schema(**schema_kwargs) if not isinstance(dataset, StructureCollection): children[name] = ds_schema continue diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 0e87d25e..e323de4b 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -218,6 +218,17 @@ def simulation(self) -> Optional[HaccSimulationParameters]: """ return getattr(self.__header, "simulation", None) + @property + def sorted_by(self) -> Optional[str]: + """ + The column this dataset is sorted by. If not sorted, returns None. + + Returns + ------- + column: Optional[str] + """ + return self.__state.sort_key[0] if self.__state.sort_key is not None else None + @property @deprecated( version="1.1.0", @@ -665,7 +676,7 @@ def drop(self, *columns: str | Iterable[str]) -> Dataset: self.__tree, ) - def sort_by(self, column: str, invert: bool = False) -> Dataset: + def sort_by(self, column: Optional[str], invert: bool = False) -> Dataset: """ Sort this dataset by the values in a given column. By default sorting is in ascending order (least to greatest). Pass invert = True to sort in descending @@ -683,9 +694,9 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: Parameters ---------- - column : str + column : Optional[str] The column in the halo_properties or galaxy_properties dataset to - order the collection by. + order the collection by. Pass :code:`None` to remove sorting. invert : bool, default = False If False (the default), ordering will be from least to greatest. diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py index 186961f3..e09accd3 100644 --- a/python/opencosmo/dataset/output.py +++ b/python/opencosmo/dataset/output.py @@ -87,7 +87,10 @@ def make_dataset_schema( columns_to_uuid, meta_columns ) - build_derived_writers(producers, derived_data, data_schema, cached_data_schema) + data_producers = [ + prod for prod in producers if not prod.produces.issubset(meta_columns) + ] + build_derived_writers(data_producers, derived_data, data_schema, cached_data_schema) attributes = {} if (load_conditions := raw_data_handler.load_conditions) is not None: diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index e316e85d..77ec88b3 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -358,7 +358,7 @@ def make_schema(state: DatasetState, name: Optional[str] = None) -> Schema: Get metadata columns. """ producers = list(state.producers.values()) - columns = set(state.column_map.keys()) + columns = set(state.column_map.keys()).difference(state.metadata_columns) derived_names = get_derived_column_names(producers, columns) if derived_names: selected = select(state, derived_names) @@ -433,10 +433,17 @@ def select(state: DatasetState, columns: set[str], drop: bool = False) -> Datase return dataclasses.replace(state, column_map=new_column_map) -def sort_by(state: DatasetState, column_name: str, invert: bool) -> DatasetState: - if column_name not in state.columns: +def sort_by( + state: DatasetState, column_name: Optional[str], invert: bool +) -> DatasetState: + if column_name is None: + sort_key = None + elif column_name not in state.columns: raise ValueError(f"This dataset has no column {column_name}") - return dataclasses.replace(state, sort_key=(column_name, invert)) + else: + sort_key = (column_name, invert) + + return dataclasses.replace(state, sort_key=sort_key) def get_sorted_index(state: DatasetState) -> np.ndarray | None: diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index a69bb91c..c4adfee6 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -11,6 +11,7 @@ import opencosmo as oc from opencosmo import collection as occ +from opencosmo.collection.structure import structure as sc from opencosmo.dataset import state as st from opencosmo.dataset.mpi import partition from opencosmo.header import OpenCosmoHeader, read_header @@ -143,7 +144,7 @@ def __open_single_file( == FileType.STRUCTURE_COLLECTION ): # Structure collection - return occ.StructureCollection.open([target]) + return occ.StructureCollection.open([target], **open_kwargs) elif target["dataset_groups"]: # Sometimes, lightcones have multiple datasets per slice if all( @@ -152,6 +153,16 @@ def __open_single_file( ): return occ.Lightcone.open([target]) + # Lightcone structure collection + elif ( + target["dataset_group_types"].get("halo_properties") == FileType.LIGHTCONE + or target["dataset_group_types"].get("galaxy_properties") + == FileType.LIGHTCONE + ): + result = sc.StructureCollection.open([target], **open_kwargs) + print(result) + return result + datasets = { name: __open_dataset_targets_for_sim_collection( targets, target["dataset_group_types"][name] @@ -552,7 +563,7 @@ def open_single_dataset( return __open_healpix_map(dataset, sim_region) elif header.file.is_lightcone and not bypass_lightcone: return occ.Lightcone.from_datasets( - {"data": dataset}, header.lightcone["z_range"], **open_kwargs + {0: dataset}, header.lightcone["z_range"], **open_kwargs ) return dataset diff --git a/python/opencosmo/io/schema.py b/python/opencosmo/io/schema.py index 416a1267..7e6fa4f6 100644 --- a/python/opencosmo/io/schema.py +++ b/python/opencosmo/io/schema.py @@ -28,6 +28,14 @@ class Schema(NamedTuple): attributes: dict[str, Any] +def dataset_schema_length(schema: Schema) -> Optional[int]: + if schema.type != FileEntry.DATASET: + return None + + column = next(iter(schema.children["data"].columns.values())) + return column.shape[0] + + def empty_schema(name: str, type_: FileEntry) -> Schema: return Schema(name, type_, {}, {}, {}) diff --git a/python/opencosmo/io/serial.py b/python/opencosmo/io/serial.py index 15dab840..0653eae9 100644 --- a/python/opencosmo/io/serial.py +++ b/python/opencosmo/io/serial.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +from opencosmo.io.schema import dataset_schema_length + if TYPE_CHECKING: import h5py @@ -10,17 +12,26 @@ def allocate(group: h5py.File | h5py.Group, schema: Schema): for column_name, column_writer in schema.columns.items(): + if column_writer.shape[0] == 0: + continue group.require_dataset(column_name, column_writer.shape, column_writer.dtype) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue + child_group = group.require_group(child_name) allocate(child_group, child_schema) def write_columns(group: h5py.File | h5py.Group, schema: Schema): for column_path, column_writer in schema.columns.items(): + if column_writer.shape[0] == 0: + continue group[column_path][:] = column_writer.data group[column_path].attrs.update(column_writer.attrs) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue write_columns(group[child_name], child_schema) @@ -30,4 +41,6 @@ def write_metadata(group: h5py.File | h5py.Group, schema: Schema): metadata_group.attrs.update(metadata) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue write_metadata(group[child_name], child_schema) diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py index 7a76b5a1..ba23e677 100644 --- a/test/test_structure_collection.py +++ b/test/test_structure_collection.py @@ -34,6 +34,27 @@ def galaxies_601_path(lightcone_path): return [properties, particles] +def verify_halo(halo): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + if "galaxy" not in halo: + return + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + def test_open_lightcone_structure(halos_600_path, halos_601_path): ds = oc.open(*halos_600_path, *halos_601_path) for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): @@ -86,3 +107,38 @@ def test_open_lightcone_structure_with_galaxies( continue tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +def test_write_lightcone_structure(halos_600_path, halos_601_path, tmp_path): + ds = ( + oc.open(*halos_600_path, *halos_601_path) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + oc.write(tmp_path / "halos.hdf5", ds) + ds_new = oc.open(tmp_path / "halos.hdf5") + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_start + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_end + verify_halo(halo) + + +def test_write_lightcone_structure_short(halos_600_path, halos_601_path, tmp_path): + ds = oc.open(*halos_600_path, *halos_601_path).take(20) + oc.write(tmp_path / "halos.hdf5", ds) From 54fe1c9c331ac039c247028c5b3c16b66f5c59bf Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 19 May 2026 16:06:42 -0500 Subject: [PATCH 109/139] Make updates for sorted lightcone structure collection --- .../collection/lightcone/lightcone.py | 17 ++-- .../opencosmo/collection/structure/handler.py | 77 +++++++++++-------- python/opencosmo/dataset/dataset.py | 4 +- python/opencosmo/dataset/state.py | 7 +- test/test_structure_collection.py | 19 ++++- 5 files changed, 83 insertions(+), 41 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index a03338ff..d74f7b57 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -18,6 +18,7 @@ import numpy as np from astropy.table import QTable, vstack # type: ignore +from deprecated import deprecated import opencosmo as oc from opencosmo.collection.lightcone import io as lcio @@ -355,18 +356,24 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): return table - def get_metadata(self, columns: list[str] = []): + def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False): data = [ds.get_metadata(columns) for ds in self.values()] - data_with_length = [d for d in data if len(d) > 0] - if len(data_with_length) == 0: - return data[0] output = {} for key in data[0].keys(): output[key] = np.concatenate([d[key] for d in data]) - return output + if ignore_sort or self.__sort_key is None: + return output + order = np.argsort(self.select(self.__sort_key[0]).get_data("numpy")) + if self.__sort_key[1]: + order = order[::-1] + return {name: arr[order] for name, arr in output.items()} @property + @deprecated( + version="1.1.0", + reason="Accessing data through the .data attribute is deprecated and will be removed in a future version. Use get_data()", + ) def data(self): """ Return the data in the dataset in astropy format. The value of this diff --git a/python/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py index 8f2bad54..8db23a0b 100644 --- a/python/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -141,6 +141,52 @@ def make_links(keys, rename_galaxies=False): return output, columns +def resort_datasets( + source: oc.Dataset | oc.Lightcone, + datasets: Mapping[str, oc.Dataset | oc.Lightcone | oc.StructureCollection], + columns: dict[str, list[str]], +): + all_columns: list[str] = reduce( + lambda acc, ds: acc + columns[ds], datasets.keys(), [] + ) + all_columns = list( + filter(lambda name: "idx" in name or "size" in name, all_columns) + ) + sort_column = next(filter(lambda c: "start" in c or "idx" in c, all_columns)) + unsorted_meta_column = source.get_metadata(sort_column, ignore_sort=True) + sorted_meta_column = source.get_metadata(sort_column) + + argsort_meta_column = np.argsort(sorted_meta_column[sort_column]) + + sort_index = argsort_meta_column[ + np.searchsorted( + sorted_meta_column[sort_column], + unsorted_meta_column[sort_column], + sorter=argsort_meta_column, + ) + ] + + meta = source.get_metadata(all_columns) + output = {} + for name, dataset in datasets.items(): + if len(columns[name]) == 1: + valid_rows = meta[columns[name][0]] >= 0 + new_dataset = dataset.take_rows(sort_index[valid_rows]) + else: + size_column = [name for name in columns[name] if "size" in name] + assert len(size_column) == 1 + size_column_data = meta[size_column[0]].astype(np.int64) + chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=np.int64) + _ = np.cumsum(size_column_data, out=chunk_boundaries[1:]) + starts = chunk_boundaries[sort_index] + sizes = size_column_data[sort_index] + valid = sizes > 0 + idx = (starts[valid], sizes[valid]) + new_dataset = dataset.take_rows(idx) + output[name] = new_dataset + return output + + class LinkHandler: """ This needs some explanation. We break the "don't mutate state" rule pretty hard here. @@ -303,36 +349,7 @@ def resort( if not is_sorted: return datasets - all_columns: list[str] = reduce( - lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] - ) - all_columns = list( - filter(lambda name: "idx" in name or "size" in name, all_columns) - ) - if isinstance(source, lc.Lightcone): - raise NotImplementedError - else: - sort_index = np.argsort(source.index) - - meta = source.get_metadata(all_columns) - output = {} - for name, dataset in datasets.items(): - if len(self.columns[name]) == 1: - valid_rows = meta[self.columns[name][0]] >= 0 - new_dataset = dataset.take_rows(sort_index[valid_rows]) - else: - size_column = [name for name in self.columns[name] if "size" in name] - assert len(size_column) == 1 - size_column_data = meta[size_column[0]].astype(np.int64) - chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=np.int64) - _ = np.cumsum(size_column_data, out=chunk_boundaries[1:]) - starts = chunk_boundaries[sort_index] - sizes = size_column_data[sort_index] - valid = sizes > 0 - idx = (starts[valid], sizes[valid]) - new_dataset = dataset.take_rows(idx) - output[name] = new_dataset - return output + return resort_datasets(source, datasets, self.columns) def rebuild_row_index( diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index e323de4b..90795d41 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -250,11 +250,11 @@ def data(self) -> QTable | u.Quantity: # Also the point is that there's MORE data than just the table return self.get_data("astropy") - def get_metadata(self, columns: str | list[str] = []): + def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False): if isinstance(columns, str): columns = [columns] - return st.get_metadata(self.__state, columns) + return st.get_metadata(self.__state, columns, ignore_sort) def get_data( self, format="astropy", unpack=True, metadata_columns=[], **kwargs diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py index 77ec88b3..b01ffa20 100644 --- a/python/opencosmo/dataset/state.py +++ b/python/opencosmo/dataset/state.py @@ -336,7 +336,9 @@ def iter_rows( raise -def get_metadata(state: DatasetState, columns: list = []) -> dict: +def get_metadata( + state: DatasetState, columns: list = [], ignore_sort: bool = False +) -> dict: names = list(columns) if columns else list(state.metadata_columns) data = instantiate_dataset( list(state.producers.values()), @@ -347,6 +349,9 @@ def get_metadata(state: DatasetState, columns: list = []) -> dict: {}, None, ) + if ignore_sort: + return data + sorted_index = get_sorted_index(state) if sorted_index is not None: data = {name: values[sorted_index] for name, values in data.items()} diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py index ba23e677..ea82192c 100644 --- a/test/test_structure_collection.py +++ b/test/test_structure_collection.py @@ -139,6 +139,19 @@ def test_write_lightcone_structure(halos_600_path, halos_601_path, tmp_path): verify_halo(halo) -def test_write_lightcone_structure_short(halos_600_path, halos_601_path, tmp_path): - ds = oc.open(*halos_600_path, *halos_601_path).take(20) - oc.write(tmp_path / "halos.hdf5", ds) +def test_data_link_sort_write_lightcone(halos_600_path, halos_601_path, tmp_path): + collection = oc.open(*halos_600_path, *halos_601_path) + collection = collection.filter(oc.col("sod_halo_mass") > 10**14).sort_by( + "fof_halo_mass" + ) + output = tmp_path / "halos.hdf5" + oc.write(output, collection) + new_collection = oc.open(output).take(10) + assert np.all( + collection["halo_properties"].select("sod_halo_mass").get_data("numpy") > 10**14 + ) + for halo in new_collection.objects(("halo_profiles",)): + assert np.all( + halo["halo_properties"]["fof_halo_tag"] + == halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy")[0] + ) From 30367556ec7536ad1c0aca19239f0704194220fe Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 19 May 2026 16:49:10 -0500 Subject: [PATCH 110/139] Writes with galaxy linking --- .../opencosmo/collection/lightcone/plugins.py | 12 +++++-- python/opencosmo/collection/structure/io.py | 9 ++++- .../collection/structure/structure.py | 2 +- python/opencosmo/io/iopen.py | 2 +- test/test_structure_collection.py | 34 +++++++++++++++++++ 5 files changed, 53 insertions(+), 6 deletions(-) diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py index 1e8478b0..904d1a6d 100644 --- a/python/opencosmo/collection/lightcone/plugins.py +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -61,7 +61,9 @@ def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: z_col = oc.col("zp") elif "chi" in lightcone.columns: lightcone = lightcone.evaluate( - redshift_from_chi, cosmology=lightcone.cosmology, vectorize=True + redshift_from_chi, + cosmology=lightcone.cosmology, + vectorize=True, ) return dataclasses.replace(ctx, lightcone=lightcone) @@ -97,6 +99,10 @@ def radec_from_thetaphi(theta, phi): return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} -def redshift_from_chi(chi, cosmology): - redshift = chi.to(cu.redshift, cu.redshift_distance(cosmology, kind="comoving")) +def redshift_from_chi(chi: u.Quantity, cosmology): + distance = chi.to(u.Mpc, cu.with_H0(cosmology.H0)) + + redshift = distance.to( + cu.redshift, cu.redshift_distance(cosmology, kind="comoving") + ) return {"redshift": redshift} diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index 1e962bff..7c91fefb 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -97,7 +97,14 @@ def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): dataset = io.iopen.open_single_dataset( target, bypass_lightcone=True, bypass_mpi=True ) - name = target["dataset_group"].name.split("/")[-1] + name_source = target["dataset_group"] + if ( + "particles" in name_source.parent.name + or "profiles" in target["dataset_group"].parent.name + ): + name_source = target["dataset_group"].parent + name = name_source.name.split("/")[-1] + if not name: name = target["header"].file.data_type elif name.startswith("galaxy_properties"): diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 06dccf47..f1b9d53d 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1395,7 +1395,7 @@ def galaxies(self, *args, **kwargs): else: raise AttributeError("This collection does not contain galaxies!") - def make_schema(self, name: Optional[str] = None) -> Schema: + def make_schema(self, name: Optional[str] = None, **kwargs) -> Schema: children = {} source_name = self.__source.dtype datasets = self.__handler.resort(self.__source, self.__get_datasets()) diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index c4adfee6..525db140 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -89,6 +89,7 @@ def open_files(paths: list[Path], open_kwargs: dict[str, Any]): if len(valid_targets) > 1: collection_type = __determine_multi_file_collection_type(valid_targets) return collection_type.open(valid_targets, **open_kwargs) + return __open_single_file(valid_targets[0], open_kwargs) @@ -160,7 +161,6 @@ def __open_single_file( == FileType.LIGHTCONE ): result = sc.StructureCollection.open([target], **open_kwargs) - print(result) return result datasets = { diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py index ea82192c..ff475acf 100644 --- a/test/test_structure_collection.py +++ b/test/test_structure_collection.py @@ -139,6 +139,40 @@ def test_write_lightcone_structure(halos_600_path, halos_601_path, tmp_path): verify_halo(halo) +def test_write_lightcone_structure_with_galaxies( + halos_600_path, halos_601_path, galaxies_600_path, galaxies_601_path, tmp_path +): + ds = ( + oc.open( + *halos_600_path, *halos_601_path, *galaxies_600_path, *galaxies_601_path + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + oc.write(tmp_path / "halos.hdf5", ds) + ds_new = oc.open(tmp_path / "halos.hdf5") + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_start + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_end + verify_halo(halo) + + def test_data_link_sort_write_lightcone(halos_600_path, halos_601_path, tmp_path): collection = oc.open(*halos_600_path, *halos_601_path) collection = collection.filter(oc.col("sod_halo_mass") > 10**14).sort_by( From c4f1e4a6713185aff3df4d9d8698f40553fc1bda Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 13:06:45 -0500 Subject: [PATCH 111/139] Implement fixes for parallel output --- python/opencosmo/collection/lightcone/io.py | 7 +- .../collection/lightcone/lightcone.py | 4 +- .../opencosmo/collection/lightcone/stack.py | 14 +- python/opencosmo/collection/structure/io.py | 5 +- python/opencosmo/io/mpi.py | 1 + python/opencosmo/io/verify.py | 17 +- .../parallel/test_structure_collection_mpi.py | 213 ++++++++++++++++++ 7 files changed, 253 insertions(+), 8 deletions(-) create mode 100644 test/parallel/test_structure_collection_mpi.py diff --git a/python/opencosmo/collection/lightcone/io.py b/python/opencosmo/collection/lightcone/io.py index 8373af0f..a2531c7d 100644 --- a/python/opencosmo/collection/lightcone/io.py +++ b/python/opencosmo/collection/lightcone/io.py @@ -21,7 +21,8 @@ def order_by_redshift_range(datasets: dict[str, ocds.Dataset]): def combine_adjacent_datasets_mpi( ordered_datasets: dict[str, dict[str, ocds.Dataset]], - min_dataset_size, + min_dataset_size: int, + no_stack: bool, ): MIN_DATASET_SIZE = 100_000 comm = get_comm_world() @@ -42,7 +43,7 @@ def combine_adjacent_datasets_mpi( rs += comm.allreduce(length) output_datasets[current_key].append(ordered_datasets[step]) - if rs > MIN_DATASET_SIZE: + if rs > MIN_DATASET_SIZE or no_stack: rs = 0 output = OrderedDict() @@ -71,7 +72,7 @@ def combine_adjacent_datasets( datasets = ordered_datasets # type: ignore if get_comm_world() is not None: - return combine_adjacent_datasets_mpi(datasets, min_dataset_size) + return combine_adjacent_datasets_mpi(datasets, min_dataset_size, no_stack) running_sum = 0 diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index d74f7b57..28acf492 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -530,7 +530,9 @@ def make_schema( min(header_zrange[1], my_zrange[1]), ) - child_schemas = stack_lightcone_datasets_in_schema(datasets, step, zrange) + child_schemas = stack_lightcone_datasets_in_schema( + datasets, step, zrange, no_stack + ) child_schemas = { f"{step}_{name}": schema for name, schema in child_schemas.items() } diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index 904e3c17..e25f9861 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -149,6 +149,7 @@ def stack_lightcone_datasets_in_schema( datasets: dict[str, list[ds.Dataset]], name: Optional[str], redshift_range: Optional[tuple[float, float]], + no_stack: bool = True, ): n_datasets = sum(len(lst) for lst in datasets.values()) if n_datasets == 1 and get_comm_world() is None: @@ -162,9 +163,12 @@ def stack_lightcone_datasets_in_schema( schema_children = {} ds_groups = get_all_keys(datasets, get_comm_world()) for ds_group in ds_groups: + schema_name = ds_group if len(datasets) > 1 else name ds_list = datasets.get(ds_group, []) ds_list = list(filter(lambda ds: len(ds) > 0, ds_list)) if len(ds_list) == 0: + if no_stack: + continue get_stacked_lightcone_order([], -1) sync_headers(ds_list, None) continue @@ -174,10 +178,13 @@ def stack_lightcone_datasets_in_schema( max_level = int(index_names[-1][-1]) assert all(isinstance(dataset, ds.Dataset) for dataset in ds_list) + if no_stack: + assert len(schemas) == 1 + schema_children[schema_name] = schemas[0] + continue new_data_group = stack_data_groups( [schema.children["data"] for schema in schemas] ) - order = get_stacked_lightcone_order(ds_list, max_level) updater = partial(update_order, order=order) slice_sizes = [len(dataset) for dataset in ds_list] @@ -202,8 +209,11 @@ def stack_lightcone_datasets_in_schema( "index": new_index_group, "header": header_schema, } + if all("data_linked" in c.children for c in schemas): + children["data_linked"] = stack_data_groups( + [schema.children["data_linked"] for schema in schemas] + ) - schema_name = ds_group if len(datasets) > 1 else name assert schema_name is not None schema_children[schema_name] = make_schema( schema_name, diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py index 7c91fefb..939e79cd 100644 --- a/python/opencosmo/collection/structure/io.py +++ b/python/opencosmo/collection/structure/io.py @@ -395,7 +395,10 @@ def do_start_update(data: np.ndarray, size: np.ndarray, comm: Optional[MPI.Comm] def rebuild_data_linked(source_schema): - if source_schema.type == io.schema.FileEntry.LIGHTCONE: + if ( + source_schema.type == io.schema.FileEntry.LIGHTCONE + and "data" not in source_schema.children + ): for key, value in source_schema.children.items(): source_schema.children[key] = rebuild_data_linked(value) return source_schema diff --git a/python/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py index 6d2daeb8..9790e19f 100644 --- a/python/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -72,6 +72,7 @@ def write_parallel(file: Path, file_schema: Schema): results = comm.allgather(CombineState.VALID) except ValueError: results = comm.allgather(CombineState.INVALID) + raise except ZeroLengthError: results = comm.allgather(CombineState.ZERO_LENGTH) if any(rs == CombineState.INVALID for rs in results): diff --git a/python/opencosmo/io/verify.py b/python/opencosmo/io/verify.py index efbfc0f1..c34bd96b 100644 --- a/python/opencosmo/io/verify.py +++ b/python/opencosmo/io/verify.py @@ -146,7 +146,20 @@ def verify_structure_collection_data(schema: Schema): raise ValueError("No valid link holder found in schema!") for child_name, child_schema in schema.children.items(): - if child_name == link_holder: + if child_schema.type == FileEntry.LIGHTCONE and child_name == link_holder: + for grandchild_name, grandchild_schema in child_schema.children.items(): + has_link = any( + map( + lambda cn: "data_linked" in cn, + grandchild_schema.children.keys(), + ) + ) + if not has_link: + raise ValueError( + f'Source dataset {child_name}/{grandchild_name} does not have expected "data_linked" group' + ) + + elif child_name == link_holder: has_link = any( map(lambda cn: "data_linked" in cn, child_schema.children.keys()) ) @@ -160,5 +173,7 @@ def verify_structure_collection_data(schema: Schema): verify_dataset_data(child_schema) case FileEntry.STRUCTURE_COLLECTION: verify_structure_collection_data(child_schema) + case FileEntry.LIGHTCONE: + verify_lightcone_collection_schema(child_schema) case _: raise ValueError("Got an unknown child for structure collection!") diff --git a/test/parallel/test_structure_collection_mpi.py b/test/parallel/test_structure_collection_mpi.py new file mode 100644 index 00000000..fd4ea956 --- /dev/null +++ b/test/parallel/test_structure_collection_mpi.py @@ -0,0 +1,213 @@ +import os +import shutil + +import numpy as np +import pytest +from mpi4py import MPI +from opencosmo.mpi import get_comm_world + +import opencosmo as oc + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +@pytest.fixture +def halos_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "haloproperties.hdf5" + particles = lightcone_path / "step_600" / "haloparticles.hdf5" + profiles = lightcone_path / "step_600" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_600" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def halos_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "haloproperties.hdf5" + particles = lightcone_path / "step_601" / "haloparticles.hdf5" + profiles = lightcone_path / "step_601" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_601" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def per_test_dir( + tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest +): + """ + Creates a unique directory for each test and deletes it after the test finishes. + + Uses tmp_path_factory so you can control base temp location via pytest's + tempdir handling, and also so it can be used from broader-scoped fixtures + if needed. + """ + # request.node.nodeid is unique across parameterizations; sanitize for filesystem + nodeid = ( + request.node.nodeid.replace("/", "_") + .replace("::", "__") + .replace("[", "_") + .replace("]", "_") + ) + + path = tmp_path_factory.mktemp(nodeid) + comm = MPI.COMM_WORLD + path_to_return = comm.bcast(path) + + try: + yield path_to_return + finally: + # Close out storage pressure immediately after each test + if IN_GITHUB_ACTIONS: + shutil.rmtree(path, ignore_errors=True) + + +def verify_halo(halo): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + if "galaxy" not in halo: + return + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +@pytest.mark.parallel(nprocs=4) +def test_open_lightcone_structure_with_galaxies( + halos_600_path, galaxies_600_path, halos_601_path, galaxies_601_path +): + ds = oc.open( + *halos_600_path, *galaxies_600_path, *halos_601_path, *galaxies_601_path + ) + ds = ds.filter(oc.col("sod_halo_mass") > 1e14).take(10) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + verify_halo(halo) + + +@pytest.mark.parallel(nprocs=4) +def test_write_lightcone_structure(halos_600_path, halos_601_path, per_test_dir): + comm = get_comm_world() + ds = ( + oc.open( + *halos_600_path, + *halos_601_path, + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + all_halos_read = set( + np.concatenate( + comm.allgather( + ds["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + oc.write(per_test_dir / "halos.hdf5", ds) + ds_new = oc.open(per_test_dir / "halos.hdf5") + + all_halos = set( + np.concatenate( + comm.allgather( + ds_new["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + verify_halo(halo) + + assert halo_tags_start.issubset(all_halos) + assert halo_tags_end.issubset(all_halos) + assert all_halos_read == all_halos + + +@pytest.mark.parallel(nprocs=4) +def test_write_lightcone_structure_with_galaxies( + halos_600_path, halos_601_path, galaxies_600_path, galaxies_601_path, per_test_dir +): + comm = get_comm_world() + ds = ( + oc.open( + *halos_600_path, *halos_601_path, *galaxies_600_path, *galaxies_601_path + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + all_halos_read = set( + np.concatenate( + comm.allgather( + ds["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + oc.write(per_test_dir / "halos.hdf5", ds) + ds_new = oc.open(per_test_dir / "halos.hdf5") + + all_halos = set( + np.concatenate( + comm.allgather( + ds_new["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + verify_halo(halo) + + assert halo_tags_start.issubset(all_halos) + assert halo_tags_end.issubset(all_halos) + assert all_halos_read == all_halos From f00531c04ae9bc57f4b2450913bc6df7cccdf784 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 14:23:48 -0500 Subject: [PATCH 112/139] Implement lightcone convinience methods on structure collection --- .../collection/structure/structure.py | 138 +++++++++++++++++- 1 file changed, 135 insertions(+), 3 deletions(-) diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index f1b9d53d..d475cabd 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -190,15 +190,49 @@ def properties(self) -> list[str]: @property def redshift(self) -> float | tuple[float, float] | None: """ - For snapshots, return the redshift or redshift range - this dataset was drawn from. + For snapshots, return the redshift this dataset was drawn from. Returns ------- redshift: float | tuple[float, float] """ - raise NotImplementedError + if isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "This is a lightcone structure collection. Use .z_range to get the redshift range." + ) + return self.__source.header.file.redshift + + @property + def z_range(self) -> tuple[float, float]: + """ + The redshift range covered by this lightcone structure collection. + + Returns + ------- + z_range: tuple[float, float] + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + if not isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "This is not a lightcone structure collection. Use .redshift to get the redshift." + ) + return self.__source.z_range + + @property + def sorted_by(self) -> Optional[str]: + """ + The column this collection is currently sorted by, or ``None`` if unsorted. + + Returns + ------- + column : Optional[str] + """ + return self.__source.sorted_by @property def simulation(self) -> HaccSimulationParameters | None: @@ -298,6 +332,104 @@ def bound( self.__derived_columns, ) + def with_redshift_range(self, z_low: float, z_high: float) -> StructureCollection: + """ + Restrict this lightcone structure collection to a specific redshift range. + This is more efficient than filtering on the redshift column directly because + it prunes entire redshift steps before row-level filtering, and it updates + the z_range metadata on the returned collection. + + Parameters + ---------- + z_low : float + The lower bound of the redshift range (inclusive). + z_high : float + The upper bound of the redshift range (inclusive). + + Returns + ------- + result : StructureCollection + A new StructureCollection restricted to the given redshift range. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If the requested range does not overlap the available redshift range. + """ + if not isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "with_redshift_range is only available on lightcone structure collections." + ) + new_source = self.__source.with_redshift_range(z_low, z_high) + new_handler = self.__handler.make_derived(self.__source) + return StructureCollection( + new_source, + self.__datasets, + self.__hide_source, + new_handler, + self.__derived_columns, + ) + + def cone_search(self, center, radius) -> StructureCollection: + """ + Search for structures within an angular distance of a point on the sky. + Equivalent to ``collection.bound(oc.make_cone(center, radius))``. + + Parameters + ---------- + center : tuple | astropy.coordinates.SkyCoord + Center of the search cone. If a tuple with no units, assumed to be + (RA, Dec) in degrees. + radius : float | astropy.units.Quantity + Angular radius of the search cone. If no units, assumed to be degrees. + + Returns + ------- + result : StructureCollection + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + if not isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "cone_search is only available on lightcone structure collections." + ) + region = oc.make_cone(center, radius) + return self.bound(region) + + def box_search(self, p1, p2) -> StructureCollection: + """ + Search for structures within a rectangular region of the sky (defined by + RA/Dec corners). Equivalent to ``collection.bound(oc.make_skybox(p1, p2))``. + + Parameters + ---------- + p1 : tuple | astropy.coordinates.SkyCoord + One corner of the box. If a tuple with no units, assumed to be + (RA, Dec) in degrees. + p2 : tuple | astropy.coordinates.SkyCoord + The opposite corner of the box. + + Returns + ------- + result : StructureCollection + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + if not isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "box_search is only available on lightcone structure collections." + ) + region = oc.make_skybox(p1, p2) + return self.bound(region) + def evaluate( self, func: Callable, From f57924559b0f534cca73755b4f1d32aedc084587 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 14:26:51 -0500 Subject: [PATCH 113/139] Add changelog --- changes/+342c1e1e.feature.rst | 1 + changes/+82359f21.feature.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changes/+342c1e1e.feature.rst create mode 100644 changes/+82359f21.feature.rst diff --git a/changes/+342c1e1e.feature.rst b/changes/+342c1e1e.feature.rst new file mode 100644 index 00000000..04d44b02 --- /dev/null +++ b/changes/+342c1e1e.feature.rst @@ -0,0 +1 @@ +The :py:class:`StructureCollection ` now has a more complete string representation diff --git a/changes/+82359f21.feature.rst b/changes/+82359f21.feature.rst new file mode 100644 index 00000000..89e3d06c --- /dev/null +++ b/changes/+82359f21.feature.rst @@ -0,0 +1 @@ +The :py:class:`StructureCollection ` object now supports multi-step lightcones. From 15f90f268eea0d79db913bb68cbe10a935d6728b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 14:48:37 -0500 Subject: [PATCH 114/139] Fixed a bug that could cause failures when one rank had no data to write for a multi-dimension column --- python/opencosmo/io/mpi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py index 9790e19f..f4122700 100644 --- a/python/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -463,7 +463,8 @@ def __write_column( if writer is not None: data = writer.get_data(new_comm) else: - data = np.empty((0,), dtype=ds.dtype) + shape = (0,) + ds.shape[1:] + data = np.empty(shape, dtype=ds.dtype) ds.write_direct(data, dest_sel=np.s_[offset : offset + len(data)]) From 48682fcb4340e83db285c6a1e9f56d3d494d8c56 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 15:52:20 -0500 Subject: [PATCH 115/139] Fix a bug that may be causing some MPI tests to fail periodically --- python/opencosmo/collection/lightcone/lightcone.py | 6 +++--- python/opencosmo/collection/lightcone/stack.py | 2 +- test/parallel/test_lc_mpi.py | 5 ++++- test/parallel/test_structure_collection_mpi.py | 1 + 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 28acf492..bbe3f7b1 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -17,7 +17,7 @@ from warnings import warn import numpy as np -from astropy.table import QTable, vstack # type: ignore +from astropy.table import vstack # type: ignore from deprecated import deprecated import opencosmo as oc @@ -342,12 +342,12 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) - if unpack: + if len(table) == 1 and unpack: output_data = { key: value[0] if len(value) == 1 else value for key, value in table.items() } - table = QTable(output_data) + return convert_data(output_data, format) if format != "astropy": return convert_data(dict(table), format) diff --git a/python/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py index e25f9861..313cc715 100644 --- a/python/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -149,7 +149,7 @@ def stack_lightcone_datasets_in_schema( datasets: dict[str, list[ds.Dataset]], name: Optional[str], redshift_range: Optional[tuple[float, float]], - no_stack: bool = True, + no_stack: bool = False, ): n_datasets = sum(len(lst) for lst in datasets.values()) if n_datasets == 1 and get_comm_world() is None: diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index e7c7a5fa..b7a00ce9 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -243,6 +243,7 @@ def test_box_search_write(haloproperties_600_path, per_test_dir): # Each rank works with a pixel it owns so the search is guaranteed to find data. pixel = np.random.choice(ds.region.pixels) + print(pixel) ra_center, dec_center = pix2ang(ds.region.nside, pixel, lonlat=True, nest=True) # Write with a wider box, refine with a narrower one after re-open. @@ -258,7 +259,9 @@ def test_box_search_write(haloproperties_600_path, per_test_dir): ds = ds.box_search(p1_inner, p2_inner) new_ds = new_ds.box_search(p1_inner, p2_inner) - assert set(ds.get_data()["fof_halo_tag"]) == set(new_ds.get_data()["fof_halo_tag"]) + assert set(ds.get_data(unpack=False)["fof_halo_tag"]) == set( + new_ds.get_data(unpack=False)["fof_halo_tag"] + ) @pytest.mark.parallel(nprocs=4) diff --git a/test/parallel/test_structure_collection_mpi.py b/test/parallel/test_structure_collection_mpi.py index fd4ea956..e3e840f9 100644 --- a/test/parallel/test_structure_collection_mpi.py +++ b/test/parallel/test_structure_collection_mpi.py @@ -134,6 +134,7 @@ def test_write_lightcone_structure(halos_600_path, halos_601_path, per_test_dir) ) ) + print("ASDFASDFASFDFASD") oc.write(per_test_dir / "halos.hdf5", ds) ds_new = oc.open(per_test_dir / "halos.hdf5") From c40b36477a209ebaaebaee98da286061d82bbf0d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 15:59:52 -0500 Subject: [PATCH 116/139] Remove debugginc code --- python/opencosmo/collection/lightcone/lightcone.py | 2 +- test/parallel/test_lc_mpi.py | 3 --- test/parallel/test_structure_collection_mpi.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index bbe3f7b1..4d6bf272 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -519,7 +519,7 @@ def make_schema( for step, datasets in output_datasets.items(): if len(datasets) == 0: - stack_lightcone_datasets_in_schema(datasets, None, None) + stack_lightcone_datasets_in_schema(datasets, None, None, no_stack) continue all_datasets = list(chain(*tuple(lst for lst in datasets.values()))) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index b7a00ce9..cb54b45f 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -243,7 +243,6 @@ def test_box_search_write(haloproperties_600_path, per_test_dir): # Each rank works with a pixel it owns so the search is guaranteed to find data. pixel = np.random.choice(ds.region.pixels) - print(pixel) ra_center, dec_center = pix2ang(ds.region.nside, pixel, lonlat=True, nest=True) # Write with a wider box, refine with a narrower one after re-open. @@ -893,8 +892,6 @@ def test_lc_take_global_start_sorted(haloproperties_600_path, haloproperties_601 selected = lc_taken.select("fof_halo_mass").get_data("numpy") all_selected = np.concatenate(comm.allgather(selected)) - print(all_selected) - print(threshold) parallel_assert(len(all_selected) == n) parallel_assert( diff --git a/test/parallel/test_structure_collection_mpi.py b/test/parallel/test_structure_collection_mpi.py index e3e840f9..fd4ea956 100644 --- a/test/parallel/test_structure_collection_mpi.py +++ b/test/parallel/test_structure_collection_mpi.py @@ -134,7 +134,6 @@ def test_write_lightcone_structure(halos_600_path, halos_601_path, per_test_dir) ) ) - print("ASDFASDFASFDFASD") oc.write(per_test_dir / "halos.hdf5", ds) ds_new = oc.open(per_test_dir / "halos.hdf5") From c1ad6ffcd0a216230d628702a6c950cc36798d90 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 16:10:39 -0500 Subject: [PATCH 117/139] Fix a test that was incorrectly failing --- test/parallel/test_lc_mpi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index cb54b45f..5dfdca9b 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -445,8 +445,9 @@ def test_write_diffsky_some_missing_no_stack( ds.pop(475) assert len(ds.keys()) == 1 - columns_to_check = comm.bcast(np.random.choice(ds.columns, 10, replace=False)) - columns_to_check = np.insert(columns_to_check, 0, "gal_id") + # columns_to_check = comm.bcast(np.random.choice(ds.columns, 10, replace=False)) + # columns_to_check = np.insert(columns_to_check, 0, "gal_id") + columns_to_check = list(ds.columns) original_data = ds.select(columns_to_check).get_data("numpy") @@ -463,9 +464,8 @@ def test_write_diffsky_some_missing_no_stack( columns_to_check.sort() for column_name in columns_to_check: - if column_name == "gal_id": + if column_name in ["gal_id", "top_host_idx"]: continue - column_name = str(column_name) column_data_original = np.concat(comm.allgather(original_data.pop(column_name))) column_data_written = np.concat(comm.allgather(written_data.pop(column_name))) parallel_assert( From 310de7192cd041f7b044c90c31e15b842001af6c Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 16:18:24 -0500 Subject: [PATCH 118/139] One more test fix --- test/parallel/test_lc_mpi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 5dfdca9b..643f126e 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -141,8 +141,8 @@ def test_healpix_write(haloproperties_600_path, per_test_dir): new_ds = new_ds.bound(region2) ds = ds.bound(region2) - rank_tags = ds.select("fof_halo_tag").get_data() - new_rank_tags = new_ds.select("fof_halo_tag").get_data() + rank_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + new_rank_tags = new_ds.select("fof_halo_tag").get_data("numpy", unpack=False) all_tags = np.concatenate(comm.allgather(rank_tags)) all_new_tags = np.concatenate(comm.allgather(new_rank_tags)) From 05d3121d9df22142a9653b0a460b25161c3f32aa Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 20 May 2026 16:38:35 -0500 Subject: [PATCH 119/139] More robust setup for lightcone test --- test/parallel/test_lc_mpi.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index 643f126e..f85f0bb2 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -239,9 +239,9 @@ def test_box_search_chain_failure(haloproperties_600_path): @pytest.mark.parallel(nprocs=4) def test_box_search_write(haloproperties_600_path, per_test_dir): """Written box-search result supports a narrower refinement search on re-open.""" + comm = get_comm_world() ds = oc.open(haloproperties_600_path) - # Each rank works with a pixel it owns so the search is guaranteed to find data. pixel = np.random.choice(ds.region.pixels) ra_center, dec_center = pix2ang(ds.region.nside, pixel, lonlat=True, nest=True) @@ -258,9 +258,12 @@ def test_box_search_write(haloproperties_600_path, per_test_dir): ds = ds.box_search(p1_inner, p2_inner) new_ds = new_ds.box_search(p1_inner, p2_inner) - assert set(ds.get_data(unpack=False)["fof_halo_tag"]) == set( - new_ds.get_data(unpack=False)["fof_halo_tag"] - ) + original_halo_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + written_halo_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + + all_original_tags = np.concat(comm.allgather(original_halo_tags)) + all_written_tags = np.concat(comm.allgather(written_halo_tags)) + parallel_assert(np.all(all_original_tags == all_written_tags)) @pytest.mark.parallel(nprocs=4) From af8bf6e97b02c745d65a0feb4c84762dc10bbf02 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 11:26:30 -0500 Subject: [PATCH 120/139] Implement get_pixels and pixel_search on lightcone --- pyproject.toml | 1 + .../collection/lightcone/lightcone.py | 33 +++++++++ .../opencosmo/collection/lightcone/utils.py | 74 ++++++++++++++++--- python/opencosmo/dataset/dataset.py | 4 + python/opencosmo/io/iopen.py | 4 +- python/opencosmo/spatial/tree.py | 35 ++++++++- test/test_diffsky.py | 32 ++++++++ test/test_lightcone.py | 32 ++++++++ uv.lock | 14 ++++ 9 files changed, 216 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5540c1ba..cfa53d68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "numpy>=2.0,<2.5", "click (>=8.2.1,<9.0.0)", "rustworkx>=0.17.1,<1.0", + "returns>=0.27.0", ] [project.urls] diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 4d6bf272..ae780ae2 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -16,6 +16,7 @@ ) from warnings import warn +import healpy as hp import numpy as np from astropy.table import vstack # type: ignore from deprecated import deprecated @@ -46,6 +47,7 @@ if TYPE_CHECKING: import astropy.units as u # type: ignore + import numpy.typing as npt from astropy.coordinates import SkyCoord from astropy.cosmology import Cosmology @@ -289,6 +291,15 @@ def z_range(self): return self.__header.lightcone["z_range"] + def get_pixels(self, nside: int = 128): + """The healpix pixels this lightcone covers""" + + level = np.log2(nside) + if not level.is_integer() or level < 0: + raise ValueError("nside must be a positive power of two!") + + return lcutils.get_pixels(self, int(level)) + def get_data(self, format="astropy", unpack: bool = True, **kwargs): """ Get the data in this dataset as an astropy table/column or as @@ -629,6 +640,28 @@ def box_search(self, p1: tuple | SkyCoord, p2: tuple | SkyCoord): region = oc.make_skybox(p1, p2) return self.bound(region) + def pixel_search(self, pixels: npt.NDArray[np.int_], nside: int = 64): + level = np.log2(nside) + if not level.is_integer() or level < 0: + raise ValueError("nside must be a positive power of two!") + level = int(level) + pixels = np.atleast_1d(pixels) + pixels = np.unique(pixels) + if ( + not np.isdtype(pixels.dtype, "integral") + or pixels[0] < 0 + or pixels[-1] > hp.nside2npix(nside) + ): + raise ValueError("Pixels must be a 1d array of positive integers") + output = {} + for name, ds in self.items(): + if isinstance(ds, Lightcone): + output[name] = ds.pixel_search(pixels) + continue + rows = ds.tree.project_on_index(level, ds.index, pixels) + output[name] = ds.take_rows(rows) + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) + def evaluate( self, func: Callable, diff --git a/python/opencosmo/collection/lightcone/utils.py b/python/opencosmo/collection/lightcone/utils.py index c5799fa8..d6a9e1d1 100644 --- a/python/opencosmo/collection/lightcone/utils.py +++ b/python/opencosmo/collection/lightcone/utils.py @@ -1,18 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Optional, Sequence +import healpy as hp import numpy as np +from returns.maybe import Maybe, Nothing -from opencosmo.collection.lightcone import lightcone as oclc +from opencosmo.collection.lightcone import lightcone as lc +from opencosmo.dataset import dataset as ds if TYPE_CHECKING: from astropy.table import Table - from opencosmo import Dataset - -def get_redshift_range(datasets: Sequence[Dataset | oclc.Lightcone]): +def get_redshift_range(datasets: Sequence[ds.Dataset | lc.Lightcone]): redshift_ranges = list(map(get_single_redshift_range, datasets)) min_z = min(rr[0] for rr in redshift_ranges) max_z = max(rr[1] for rr in redshift_ranges) @@ -20,8 +21,8 @@ def get_redshift_range(datasets: Sequence[Dataset | oclc.Lightcone]): return (min_z, max_z) -def get_single_redshift_range(dataset: Dataset | oclc.Lightcone): - if isinstance(dataset, oclc.Lightcone): +def get_single_redshift_range(dataset: ds.Dataset | lc.Lightcone): + if isinstance(dataset, lc.Lightcone): return dataset.z_range redshift_range = dataset.header.lightcone["z_range"] if redshift_range is not None: @@ -34,7 +35,7 @@ def get_single_redshift_range(dataset: Dataset | oclc.Lightcone): return (min_redshift, max_redshift) -def is_in_range(dataset: Dataset, z_low: float, z_high: float): +def is_in_range(dataset: ds.Dataset, z_low: float, z_high: float): z_range = dataset.header.lightcone["z_range"] if z_range is None: z_range = get_single_redshift_range(dataset) @@ -54,7 +55,7 @@ def sort_table(table: Table, column: str, invert: bool): def take_from_sorted( - lightcone: "oclc.Lightcone", sort_by: str, invert: bool, n: int, at: str | int + lightcone: lc.Lightcone, sort_by: str, invert: bool, n: int, at: str | int ): column = np.concatenate( [ds.select(sort_by).get_data("numpy") for ds in lightcone.values()] @@ -75,3 +76,58 @@ def take_from_sorted( sorted_indices = np.sort(sort_index) return sorted_indices + + +def determine_max_level(lightcone: lc.Lightcone, requested_level: int) -> Maybe[int]: + """ + Find the common level that can be used by all the trees and is at least equal to the + requested level. + """ + + max_level = Nothing + for ds_ in lightcone.values(): + if isinstance(ds_, lc.Lightcone): + ds_level = determine_max_level(ds_, requested_level) + else: + assert isinstance(ds_, ds.Dataset) + ds_level = Maybe.from_optional(ds_.tree).map(lambda t: t.max_level) + max_level = ds_level.lash(lambda _: ds_level) + max_level = max_level.bind( + lambda ml: ds_level.map(lambda dl: ml if dl >= ml else dl) + ) + return max_level + + +def raise_missing_spatial_index(): + raise ValueError("Lightcone does not have a spatial index!") + + +def get_pixels( + lightcone: lc.Lightcone, level: int, is_occupied: Optional[np.ndarray] = None +): + # We know nside is a power of two at this point + available_level = ( + determine_max_level(lightcone, level).lash(raise_missing_spatial_index).unwrap() + ) + if level > available_level: + raise ValueError( + f"The maximum available nside for this lightcone is {2**available_level}, but {2**level} was requested" + ) + + if is_occupied is None: + is_occupied = np.zeros(hp.nside2npix(2**level), dtype=bool) + + for ds_ in lightcone.values(): + if isinstance(ds_, lc.Lightcone): + is_occupied = get_pixels(ds_, level, is_occupied) + + assert isinstance(ds_, ds.Dataset) + tree = ds_.tree + if tree is None: + raise ValueError( + "One or more datasets in this lightcone does not have a spatial index!" + ) + read_level = tree.max_level if tree.max_level < level else level + ds_pixels = tree.get_occupied_partitions(read_level, ds_.index) + is_occupied[ds_pixels] = True + return np.where(is_occupied)[0] diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 8e2be64a..52b0a6f9 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -228,6 +228,10 @@ def sorted_by(self) -> Optional[str]: """ return self.__state.sort_key[0] if self.__state.sort_key is not None else None + @property + def tree(self) -> Optional[Tree]: + return self.__tree + @property @deprecated( version="1.1.0", diff --git a/python/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py index 525db140..d3c9f67f 100644 --- a/python/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -511,7 +511,7 @@ def open_single_dataset( if header.file.region is not None: sim_region = from_model(header.file.region) elif header.file.is_lightcone and tree is not None: - pixels = tree.get_full_index(tree.max_level) + pixels = tree.get_partitions_with_data(tree.max_level) sim_region = HealpixRegion(pixels, nside=2**tree.max_level) elif header.file.data_type == "healpix_map": assert header.healpix_map["full_sky"] @@ -602,7 +602,7 @@ def __expand_lightcone_region(region, tree): pixels = pixels[:, None] * npix_ratio + np.arange(npix_ratio) pixels = pixels.flatten() - full_pixels = tree.get_full_index(tree.max_level) + full_pixels = tree.get_partitions_with_data(tree.max_level) full_pixels = np.intersect1d(pixels, full_pixels) return HealpixRegion(full_pixels, 2**tree.max_level) diff --git a/python/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py index 3c1fb05a..bbcbfaed 100644 --- a/python/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -13,7 +13,14 @@ MPI = None # type: ignore -from opencosmo.index import from_size, from_start_size_group, get_data, n_in_range +from opencosmo.index import ( + from_size, + from_start_size_group, + get_data, + into_array, + n_in_range, + project, +) from opencosmo.io.schema import FileEntry, make_schema from opencosmo.io.writer import ( ColumnCombineStrategy, @@ -177,7 +184,7 @@ def __init__( def max_level(self): return self.__max_level - def get_full_index(self, level: int): + def get_partitions_with_data(self, level: int): if level > self.max_level: raise ValueError( "Requested level is greater than the max level of this tree!" @@ -186,6 +193,30 @@ def get_full_index(self, level: int): return np.where(sizes > 0)[0] + def get_occupied_partitions(self, level: int, index: DataIndex): + if level > self.max_level: + raise ValueError( + "Level must be less than or equal to the max level of this tree" + ) + starts = self.__columns[f"level_{level}/start"][:] + return np.searchsorted(starts, into_array(index), side="right") - 1 + + def project_on_index( + self, level: int, index: DataIndex, partitions: Optional[DataIndex] + ): + if level > self.max_level: + raise ValueError( + "Level must be less than or equal to the max level of this tree" + ) + starts = self.__columns[f"level_{level}/start"] + sizes = self.__columns[f"level_{level}/size"] + if partitions is None: + partitions = from_size(len(starts)) + + return project( + index, (get_data(starts, partitions), get_data(sizes, partitions)) + ) + def partition( self, n_partitions: int, counts: h5py.Group, min_level: Optional[int] = None ) -> Sequence[TreePartition]: diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 42b95905..acfb3553 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -3,6 +3,7 @@ import astropy.units as u import h5py +import healpy as hp import numpy as np import pytest from opencosmo.spatial.region import HealpixRegion @@ -302,6 +303,37 @@ def test_region(core_path_475, core_path_487): assert len(ds.region.pixels) == 1610 +def test_lc_collection_pixel_search(haloproperties_600_path, haloproperties_601_path): + ds = oc.open(haloproperties_601_path, haloproperties_600_path) + pixels = ds.get_pixels(64) + + all_coordinates = ds.select("theta", "phi").get_data("numpy") + all_pixels = hp.ang2pix( + 64, all_coordinates["theta"], all_coordinates["phi"], nest=True + ) + + assert np.all(np.unique(all_pixels) == pixels) + + pixels_to_search = np.sort(np.random.choice(pixels, 20, replace=False)) + ds_bound = ds.pixel_search(pixels_to_search) + + bound_coordinates = ds_bound.select("ra", "dec").get_data("numpy") + bound_pixels = np.unique( + hp.ang2pix( + 64, + bound_coordinates["ra"], + bound_coordinates["dec"], + lonlat=True, + nest=True, + ) + ) + assert np.all(pixels_to_search == bound_pixels) + expected_index = np.isin(all_pixels, pixels_to_search) + found_halo_tags = ds_bound.select("fof_halo_tag").get_data() + expected_halo_tags = ds.select("fof_halo_tag").get_data()[expected_index] + assert np.all(found_halo_tags == expected_halo_tags) + + def test_open_bad_data(core_path_475, core_path_487, invalid_data_path): with pytest.raises(ValueError, match=str(invalid_data_path)): oc.open(core_path_475, core_path_487, invalid_data_path) diff --git a/test/test_lightcone.py b/test/test_lightcone.py index 924b2da9..97c5fc33 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -1,4 +1,5 @@ import astropy.units as u +import healpy as hp import numpy as np import pytest from astropy.cosmology import units as cu @@ -70,6 +71,37 @@ def test_lc_collection_restrict_z(haloproperties_600_path, haloproperties_601_pa assert np.sum(masked_redshifts) == len(redshifts) +def test_lc_collection_pixel_search(haloproperties_600_path, haloproperties_601_path): + ds = oc.open(haloproperties_601_path, haloproperties_600_path) + pixels = ds.get_pixels(64) + + all_coordinates = ds.select("theta", "phi").get_data("numpy") + all_pixels = hp.ang2pix( + 64, all_coordinates["theta"], all_coordinates["phi"], nest=True + ) + + assert np.all(np.unique(all_pixels) == pixels) + + pixels_to_search = np.sort(np.random.choice(pixels, 20, replace=False)) + ds_bound = ds.pixel_search(pixels_to_search) + + bound_coordinates = ds_bound.select("ra", "dec").get_data("numpy") + bound_pixels = np.unique( + hp.ang2pix( + 64, + bound_coordinates["ra"], + bound_coordinates["dec"], + lonlat=True, + nest=True, + ) + ) + assert np.all(pixels_to_search == bound_pixels) + expected_index = np.isin(all_pixels, pixels_to_search) + found_halo_tags = ds_bound.select("fof_halo_tag").get_data() + expected_halo_tags = ds.select("fof_halo_tag").get_data()[expected_index] + assert np.all(found_halo_tags == expected_halo_tags) + + def test_lc_collection_write( haloproperties_600_path, haloproperties_601_path, tmp_path ): diff --git a/uv.lock b/uv.lock index 0a16bfce..9a508ca9 100644 --- a/uv.lock +++ b/uv.lock @@ -1037,6 +1037,7 @@ dependencies = [ { name = "healsparse" }, { name = "numpy" }, { name = "pydantic" }, + { name = "returns" }, { name = "rustworkx" }, ] @@ -1088,6 +1089,7 @@ requires-dist = [ { name = "numpy", specifier = ">=2.0,<2.5" }, { name = "pyarrow", marker = "extra == 'io'", specifier = ">=21.0.0" }, { name = "pydantic", specifier = ">=2.10.6,<3.0.0" }, + { name = "returns", specifier = ">=0.27.0" }, { name = "rustworkx", specifier = ">=0.17.1,<1.0" }, ] provides-extras = ["io"] @@ -1708,6 +1710,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" }, ] +[[package]] +name = "returns" +version = "0.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/d5/a3208e3193848ecfec2adb689f474cc6c66c4b7e1711c31528c5b3cfbc93/returns-0.27.0.tar.gz", hash = "sha256:f70a452dd81e6d024c97523683aba85076b15e00874e723afb23bcf3aa4ecea2", size = 105261, upload-time = "2026-04-14T07:11:21.206Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/9c/1db56b3a26b56abde214833556c83af7e9817999bc05163c9e1bf200952d/returns-0.27.0-py3-none-any.whl", hash = "sha256:a84f243b9a17e9b96c16b709f5cca819550c38a2fc4db725ee2539354956af12", size = 160133, upload-time = "2026-04-14T07:11:23.187Z" }, +] + [[package]] name = "roman-numerals" version = "4.1.0" From 7ed44ea354380df3a19c536ead8d9507d3413350 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 11:42:50 -0500 Subject: [PATCH 121/139] Add docstrings and implement on structure collection --- .../collection/lightcone/lightcone.py | 60 +++++++++- .../collection/structure/structure.py | 111 ++++++++++++++++-- test/test_collection.py | 14 +++ test/test_structure_collection.py | 11 ++ 4 files changed, 181 insertions(+), 15 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index ae780ae2..c48677dc 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -291,8 +291,33 @@ def z_range(self): return self.__header.lightcone["z_range"] - def get_pixels(self, nside: int = 128): - """The healpix pixels this lightcone covers""" + def get_pixels(self, nside: int = 64): + """ + Return the HEALPix pixels occupied by this lightcone at a given resolution. + + Pixel indices are returned in nested ordering. The ``nside`` parameter + controls angular resolution: larger values produce finer pixels. The + requested resolution may not exceed the resolution of the spatial index + stored in the file. + + Parameters + ---------- + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two. + + Returns + ------- + pixels : numpy.ndarray[int] + HEALPix pixel indices (nested ordering) occupied by this lightcone + at the given resolution. + + Raises + ------ + ValueError + If ``nside`` is not a positive power of two, if ``nside`` exceeds + the maximum resolution of the spatial index, or if the lightcone + does not have a spatial index. + """ level = np.log2(nside) if not level.is_integer() or level < 0: @@ -641,6 +666,37 @@ def box_search(self, p1: tuple | SkyCoord, p2: tuple | SkyCoord): return self.bound(region) def pixel_search(self, pixels: npt.NDArray[np.int_], nside: int = 64): + """ + Return the subset of this lightcone that falls within a set of HEALPix pixels. + + Pixels must be specified in nested ordering and must be valid indices at + the given ``nside``. Duplicate pixel indices are ignored. Use + :py:meth:`get_pixels ` to discover + which pixels this lightcone covers. + + Parameters + ---------- + pixels : array_like[int] + HEALPix pixel indices to query, in nested ordering. Must be a 1-D + array of non-negative integers. Values must be less than + ``healpy.nside2npix(nside)``. + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two + and must not exceed the resolution of the spatial index stored in + the file. + + Returns + ------- + lightcone : opencosmo.Lightcone + A new lightcone containing only the objects that fall within the + specified pixels. + + Raises + ------ + ValueError + If ``nside`` is not a positive power of two, or if ``pixels`` + contains values that are out of range for the given ``nside``. + """ level = np.log2(nside) if not level.is_integer() or level < 0: raise ValueError("nside must be a positive power of two!") diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index d475cabd..56b69ca5 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from functools import reduce +from functools import reduce, wraps from inspect import signature from typing import ( TYPE_CHECKING, @@ -78,6 +78,18 @@ def do_start_update(data: np.ndarray, size: np.ndarray, comm: Optional[MPI.Comm] return psum + offset +def _lightcone_only(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not isinstance(self._StructureCollection__source, lc.Lightcone): + raise AttributeError( + f"{func.__name__} is only available on lightcone structure collections." + ) + return func(self, *args, **kwargs) + + return wrapper + + class StructureCollection: """ A collection of datasets that contain both high-level properties @@ -332,6 +344,7 @@ def bound( self.__derived_columns, ) + @_lightcone_only def with_redshift_range(self, z_low: float, z_high: float) -> StructureCollection: """ Restrict this lightcone structure collection to a specific redshift range. @@ -358,10 +371,7 @@ def with_redshift_range(self, z_low: float, z_high: float) -> StructureCollectio ValueError If the requested range does not overlap the available redshift range. """ - if not isinstance(self.__source, lc.Lightcone): - raise AttributeError( - "with_redshift_range is only available on lightcone structure collections." - ) + assert isinstance(self.__source, lc.Lightcone) new_source = self.__source.with_redshift_range(z_low, z_high) new_handler = self.__handler.make_derived(self.__source) return StructureCollection( @@ -372,6 +382,7 @@ def with_redshift_range(self, z_low: float, z_high: float) -> StructureCollectio self.__derived_columns, ) + @_lightcone_only def cone_search(self, center, radius) -> StructureCollection: """ Search for structures within an angular distance of a point on the sky. @@ -394,13 +405,10 @@ def cone_search(self, center, radius) -> StructureCollection: AttributeError If this is not a lightcone structure collection. """ - if not isinstance(self.__source, lc.Lightcone): - raise AttributeError( - "cone_search is only available on lightcone structure collections." - ) region = oc.make_cone(center, radius) return self.bound(region) + @_lightcone_only def box_search(self, p1, p2) -> StructureCollection: """ Search for structures within a rectangular region of the sky (defined by @@ -423,13 +431,90 @@ def box_search(self, p1, p2) -> StructureCollection: AttributeError If this is not a lightcone structure collection. """ - if not isinstance(self.__source, lc.Lightcone): - raise AttributeError( - "box_search is only available on lightcone structure collections." - ) region = oc.make_skybox(p1, p2) return self.bound(region) + @_lightcone_only + def get_pixels(self, nside: int = 64) -> np.ndarray: + """ + Return the HEALPix pixels occupied by this lightcone structure collection + at a given resolution. + + Pixel indices are returned in nested ordering. The ``nside`` parameter + controls angular resolution: larger values produce finer pixels. The + requested resolution may not exceed the resolution of the spatial index + stored in the file. + + Parameters + ---------- + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two. + + Returns + ------- + pixels : numpy.ndarray[int] + HEALPix pixel indices (nested ordering) occupied by this collection + at the given resolution. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If ``nside`` is not a positive power of two, if ``nside`` exceeds + the maximum resolution of the spatial index, or if the lightcone + does not have a spatial index. + """ + assert isinstance(self.__source, lc.Lightcone) + return self.__source.get_pixels(nside) + + @_lightcone_only + def pixel_search(self, pixels: np.ndarray, nside: int = 64) -> StructureCollection: + """ + Return the subset of this lightcone structure collection that falls within + a set of HEALPix pixels. + + Pixels must be specified in nested ordering and must be valid indices at + the given ``nside``. Duplicate pixel indices are ignored. Use + :py:meth:`get_pixels ` to + discover which pixels this collection covers. + + Parameters + ---------- + pixels : array_like[int] + HEALPix pixel indices to query, in nested ordering. Must be a 1-D + array of non-negative integers. Values must be less than + ``healpy.nside2npix(nside)``. + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two + and must not exceed the resolution of the spatial index stored in + the file. + + Returns + ------- + result : StructureCollection + A new collection containing only the structures that fall within + the specified pixels. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If ``nside`` is not a positive power of two, or if ``pixels`` + contains values that are out of range for the given ``nside``. + """ + assert isinstance(self.__source, lc.Lightcone) + new_source = self.__source.pixel_search(pixels, nside) + new_handler = self.__handler.make_derived(self.__source) + return StructureCollection( + new_source, + self.__datasets, + self.__hide_source, + new_handler, + self.__derived_columns, + ) + def evaluate( self, func: Callable, diff --git a/test/test_collection.py b/test/test_collection.py index 1c88ce0c..a8d52b67 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -108,6 +108,20 @@ def test_open_structures(halo_paths, galaxy_paths): assert isinstance(c3, oc.StructureCollection) +def test_call_lightcone_fails(halo_paths, galaxy_paths): + ds = oc.open(*halo_paths, *galaxy_paths) + with pytest.raises(AttributeError): + ds.get_pixels() + with pytest.raises(AttributeError): + ds.cone_search() + with pytest.raises(AttributeError): + ds.box_search() + with pytest.raises(AttributeError): + ds.pixel_search() + with pytest.raises(AttributeError): + ds.with_redshift_range() + + def test_multi_filter(multi_path): collection = oc.open(multi_path) collection = collection.filter(oc.col("sod_halo_mass") > 0) diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py index ff475acf..a08866d8 100644 --- a/test/test_structure_collection.py +++ b/test/test_structure_collection.py @@ -189,3 +189,14 @@ def test_data_link_sort_write_lightcone(halos_600_path, halos_601_path, tmp_path halo["halo_properties"]["fof_halo_tag"] == halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy")[0] ) + + +def test_redshift_bound(halos_600_path, halos_601_path, tmp_path): + collection = oc.open(*halos_600_path, *halos_601_path) + collection = collection.with_redshift_range(0.038, 0.039) + + collection = collection.filter(oc.col("sod_halo_mass") > 10**14) + for halo in collection.halos(): + redshift = halo["halo_properties"]["redshift"] + assert redshift > 0.038 and redshift < 0.039 + verify_halo(halo) From b30d35771d484b7bf47e80caaaa0ee4977caec76 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 11:45:30 -0500 Subject: [PATCH 122/139] Add changelog --- changes/+6f102bf6.feature.rst | 1 + changes/+e13be920.feature.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changes/+6f102bf6.feature.rst create mode 100644 changes/+e13be920.feature.rst diff --git a/changes/+6f102bf6.feature.rst b/changes/+6f102bf6.feature.rst new file mode 100644 index 00000000..85ffc520 --- /dev/null +++ b/changes/+6f102bf6.feature.rst @@ -0,0 +1 @@ +Add the :py:meth:`Lightcone.get_pixels() ` method to query a lightcone based on healpix pixels. From dac42c0bdef042b55ba193ccf77543e6ab37dd19 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 11:51:45 -0500 Subject: [PATCH 123/139] Fix type errors --- python/opencosmo/collection/lightcone/utils.py | 4 ++-- python/opencosmo/column/column.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/opencosmo/collection/lightcone/utils.py b/python/opencosmo/collection/lightcone/utils.py index d6a9e1d1..42ca3ee9 100644 --- a/python/opencosmo/collection/lightcone/utils.py +++ b/python/opencosmo/collection/lightcone/utils.py @@ -84,7 +84,7 @@ def determine_max_level(lightcone: lc.Lightcone, requested_level: int) -> Maybe[ requested level. """ - max_level = Nothing + max_level: Maybe[int] = Nothing for ds_ in lightcone.values(): if isinstance(ds_, lc.Lightcone): ds_level = determine_max_level(ds_, requested_level) @@ -98,7 +98,7 @@ def determine_max_level(lightcone: lc.Lightcone, requested_level: int) -> Maybe[ return max_level -def raise_missing_spatial_index(): +def raise_missing_spatial_index(_): raise ValueError("Lightcone does not have a spatial index!") diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 875cf04d..8ff8c2db 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -696,6 +696,7 @@ def arccos(self) -> DerivedColumn: def arctan2(self, other: ColumnOrScalar) -> DerivedColumn: return DerivedColumn(self, other, _arctan2) + def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore return ColumnMask(self, other, op.eq) From 667cb53bf4075d89a14ec40f79957a893a9617e8 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 12:44:51 -0500 Subject: [PATCH 124/139] Some fixes and new tests --- pyproject.toml | 1 - .../collection/lightcone/lightcone.py | 10 +++--- .../opencosmo/collection/lightcone/utils.py | 36 +++++++++---------- test/test_collection.py | 8 ++--- test/test_diffsky.py | 14 ++++---- uv.lock | 14 -------- 6 files changed, 32 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfa53d68..5540c1ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ dependencies = [ "numpy>=2.0,<2.5", "click (>=8.2.1,<9.0.0)", "rustworkx>=0.17.1,<1.0", - "returns>=0.27.0", ] [project.urls] diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index c48677dc..f82410e7 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -703,16 +703,14 @@ def pixel_search(self, pixels: npt.NDArray[np.int_], nside: int = 64): level = int(level) pixels = np.atleast_1d(pixels) pixels = np.unique(pixels) - if ( - not np.isdtype(pixels.dtype, "integral") - or pixels[0] < 0 - or pixels[-1] > hp.nside2npix(nside) - ): + if not np.isdtype(pixels.dtype, "integral") or len(pixels) == 0: + raise ValueError("Pixels must be a 1d array of positive integers") + if pixels[0] < 0 or pixels[-1] >= hp.nside2npix(nside): raise ValueError("Pixels must be a 1d array of positive integers") output = {} for name, ds in self.items(): if isinstance(ds, Lightcone): - output[name] = ds.pixel_search(pixels) + output[name] = ds.pixel_search(pixels, nside) continue rows = ds.tree.project_on_index(level, ds.index, pixels) output[name] = ds.take_rows(rows) diff --git a/python/opencosmo/collection/lightcone/utils.py b/python/opencosmo/collection/lightcone/utils.py index 42ca3ee9..107c100e 100644 --- a/python/opencosmo/collection/lightcone/utils.py +++ b/python/opencosmo/collection/lightcone/utils.py @@ -4,7 +4,6 @@ import healpy as hp import numpy as np -from returns.maybe import Maybe, Nothing from opencosmo.collection.lightcone import lightcone as lc from opencosmo.dataset import dataset as ds @@ -78,37 +77,32 @@ def take_from_sorted( return sorted_indices -def determine_max_level(lightcone: lc.Lightcone, requested_level: int) -> Maybe[int]: +def determine_max_level(lightcone: lc.Lightcone) -> Optional[int]: """ - Find the common level that can be used by all the trees and is at least equal to the - requested level. + Return the minimum tree max_level across all datasets in the lightcone, or + None if any dataset has no spatial index. """ - - max_level: Maybe[int] = Nothing + max_level: Optional[int] = None for ds_ in lightcone.values(): if isinstance(ds_, lc.Lightcone): - ds_level = determine_max_level(ds_, requested_level) + ds_level = determine_max_level(ds_) else: assert isinstance(ds_, ds.Dataset) - ds_level = Maybe.from_optional(ds_.tree).map(lambda t: t.max_level) - max_level = ds_level.lash(lambda _: ds_level) - max_level = max_level.bind( - lambda ml: ds_level.map(lambda dl: ml if dl >= ml else dl) - ) + ds_level = ds_.tree.max_level if ds_.tree is not None else None + if ds_level is None: + return None + if max_level is None or ds_level < max_level: + max_level = ds_level return max_level -def raise_missing_spatial_index(_): - raise ValueError("Lightcone does not have a spatial index!") - - def get_pixels( lightcone: lc.Lightcone, level: int, is_occupied: Optional[np.ndarray] = None ): # We know nside is a power of two at this point - available_level = ( - determine_max_level(lightcone, level).lash(raise_missing_spatial_index).unwrap() - ) + available_level = determine_max_level(lightcone) + if available_level is None: + raise ValueError("Lightcone does not have a spatial index!") if level > available_level: raise ValueError( f"The maximum available nside for this lightcone is {2**available_level}, but {2**level} was requested" @@ -119,7 +113,9 @@ def get_pixels( for ds_ in lightcone.values(): if isinstance(ds_, lc.Lightcone): - is_occupied = get_pixels(ds_, level, is_occupied) + lightcone_pixels = get_pixels(ds_, level, is_occupied) + is_occupied[lightcone_pixels] = True + continue assert isinstance(ds_, ds.Dataset) tree = ds_.tree diff --git a/test/test_collection.py b/test/test_collection.py index a8d52b67..8cbebec9 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -113,13 +113,13 @@ def test_call_lightcone_fails(halo_paths, galaxy_paths): with pytest.raises(AttributeError): ds.get_pixels() with pytest.raises(AttributeError): - ds.cone_search() + ds.cone_search(None, None) with pytest.raises(AttributeError): - ds.box_search() + ds.box_search(None, None) with pytest.raises(AttributeError): - ds.pixel_search() + ds.pixel_search(None) with pytest.raises(AttributeError): - ds.with_redshift_range() + ds.with_redshift_range(0.0, 1.0) def test_multi_filter(multi_path): diff --git a/test/test_diffsky.py b/test/test_diffsky.py index acfb3553..3c491951 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -303,19 +303,21 @@ def test_region(core_path_475, core_path_487): assert len(ds.region.pixels) == 1610 -def test_lc_collection_pixel_search(haloproperties_600_path, haloproperties_601_path): - ds = oc.open(haloproperties_601_path, haloproperties_600_path) +def test_lc_collection_pixel_search_with_synthetics(core_path_475, core_path_487): + ds = oc.open(core_path_487, core_path_475, synth_cores=True) pixels = ds.get_pixels(64) - all_coordinates = ds.select("theta", "phi").get_data("numpy") + all_coordinates = ds.select("ra", "dec").get_data("numpy") all_pixels = hp.ang2pix( - 64, all_coordinates["theta"], all_coordinates["phi"], nest=True + 64, all_coordinates["ra"], all_coordinates["dec"], nest=True, lonlat=True ) assert np.all(np.unique(all_pixels) == pixels) pixels_to_search = np.sort(np.random.choice(pixels, 20, replace=False)) ds_bound = ds.pixel_search(pixels_to_search) + for ds_ in ds_bound.values(): + assert isinstance(ds_, oc.Lightcone) bound_coordinates = ds_bound.select("ra", "dec").get_data("numpy") bound_pixels = np.unique( @@ -329,8 +331,8 @@ def test_lc_collection_pixel_search(haloproperties_600_path, haloproperties_601_ ) assert np.all(pixels_to_search == bound_pixels) expected_index = np.isin(all_pixels, pixels_to_search) - found_halo_tags = ds_bound.select("fof_halo_tag").get_data() - expected_halo_tags = ds.select("fof_halo_tag").get_data()[expected_index] + found_halo_tags = ds_bound.select("gal_id").get_data() + expected_halo_tags = ds.select("gal_id").get_data()[expected_index] assert np.all(found_halo_tags == expected_halo_tags) diff --git a/uv.lock b/uv.lock index 9a508ca9..0a16bfce 100644 --- a/uv.lock +++ b/uv.lock @@ -1037,7 +1037,6 @@ dependencies = [ { name = "healsparse" }, { name = "numpy" }, { name = "pydantic" }, - { name = "returns" }, { name = "rustworkx" }, ] @@ -1089,7 +1088,6 @@ requires-dist = [ { name = "numpy", specifier = ">=2.0,<2.5" }, { name = "pyarrow", marker = "extra == 'io'", specifier = ">=21.0.0" }, { name = "pydantic", specifier = ">=2.10.6,<3.0.0" }, - { name = "returns", specifier = ">=0.27.0" }, { name = "rustworkx", specifier = ">=0.17.1,<1.0" }, ] provides-extras = ["io"] @@ -1710,18 +1708,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" }, ] -[[package]] -name = "returns" -version = "0.27.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a9/d5/a3208e3193848ecfec2adb689f474cc6c66c4b7e1711c31528c5b3cfbc93/returns-0.27.0.tar.gz", hash = "sha256:f70a452dd81e6d024c97523683aba85076b15e00874e723afb23bcf3aa4ecea2", size = 105261, upload-time = "2026-04-14T07:11:21.206Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/9c/1db56b3a26b56abde214833556c83af7e9817999bc05163c9e1bf200952d/returns-0.27.0-py3-none-any.whl", hash = "sha256:a84f243b9a17e9b96c16b709f5cca819550c38a2fc4db725ee2539354956af12", size = 160133, upload-time = "2026-04-14T07:11:23.187Z" }, -] - [[package]] name = "roman-numerals" version = "4.1.0" From 3eb8ab60931d7cddd826e0812e233d4f830276ef Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 21 May 2026 14:11:43 -0500 Subject: [PATCH 125/139] Implement projection in rust --- python/opencosmo/_lib/index.pyi | 5 ++++ python/opencosmo/index/get.py | 2 +- python/opencosmo/index/project.py | 10 +++++-- src/index.rs | 46 +++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi index 0662bf7c..43f3c11b 100644 --- a/python/opencosmo/_lib/index.pyi +++ b/python/opencosmo/_lib/index.pyi @@ -22,3 +22,8 @@ def rebuild_chunked_by_ranges( chunk_starts: IndexArray, chunk_sizes: IndexArray, ) -> IndexArray: ... +def project_chunked_on_simple( + simple: IndexArray, + chunk_starts: IndexArray, + chunk_sizes: IndexArray, +) -> IndexArray: ... diff --git a/python/opencosmo/index/get.py b/python/opencosmo/index/get.py index 98a66905..2f5c2a18 100644 --- a/python/opencosmo/index/get.py +++ b/python/opencosmo/index/get.py @@ -14,7 +14,7 @@ def get_data(data: h5py.Dataset | np.ndarray, index: DataIndex): if get_length(index) == 0: - return np.array([]) + return np.array([], data.dtype) match index: case np.ndarray(): return get_data_simple(data, index) diff --git a/python/opencosmo/index/project.py b/python/opencosmo/index/project.py index 5957be77..57865a08 100644 --- a/python/opencosmo/index/project.py +++ b/python/opencosmo/index/project.py @@ -4,6 +4,8 @@ import numpy as np +from opencosmo._lib import index as idxlib + from . import into_array if TYPE_CHECKING: @@ -29,8 +31,12 @@ def __project_simple_on_simple(source: SimpleIndex, other: SimpleIndex) -> Simpl return np.where(isin)[0] -def __project_chunked_on_simple(source: SimpleIndex, other: ChunkedIndex) -> DataIndex: - return project(source, into_array(other)) +def __project_chunked_on_simple( + source: SimpleIndex, other: ChunkedIndex +) -> SimpleIndex: + if len(other[0]) == 0: + return np.array([], dtype=np.int64) + return idxlib.project_chunked_on_simple(source, *other) def __project_simple_on_chunked(source: ChunkedIndex, other: SimpleIndex) -> DataIndex: diff --git a/src/index.rs b/src/index.rs index aa7c118e..961fcb06 100644 --- a/src/index.rs +++ b/src/index.rs @@ -406,6 +406,52 @@ pub(crate) mod index { PyList::new(py, output.drain(0..).map(|a| a.into_pyarray(py))) } + #[pyfunction(name = "project_chunked_on_simple")] + fn project_chunked_on_simple_py<'py>( + py: Python<'py>, + simple: &Bound<'_, PyAny>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let simple_arr = unpack_index_array(simple)?; + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let result = project_chunked_on_simple( + simple_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + Ok(result.into_pyarray(py)) + } + + fn project_chunked_on_simple( + simple: ArrayView1<'_, i64>, + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> Array1 { + let mut output: Vec = Vec::new(); + let n_chunks = start.len(); + if simple.is_empty() || n_chunks == 0 { + return Array1::from_vec(output); + } + + let mut chunk_idx = 0usize; + for (i, &val) in simple.iter().enumerate() { + // Advance past chunks whose end is at or before val. + while chunk_idx < n_chunks && val >= start[chunk_idx] + size[chunk_idx] { + chunk_idx += 1; + } + if chunk_idx >= n_chunks { + break; + } + // val is before the current chunk; skip. + if val < start[chunk_idx] { + continue; + } + output.push(i as i64); + } + Array1::from_vec(output) + } + fn rebuild_simple_by_ranges( index: ArrayView1<'_, i64>, range_starts: ArrayView1<'_, i64>, From 4c511ff12c4ac1bfb08da2d59753a16db015ae19 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 07:45:46 -0500 Subject: [PATCH 126/139] Updated healpix containment algorithm to improve spatial query speed --- changes/+78480439.improvement.rst | 1 + python/opencosmo/dataset/dataset.py | 6 ++++-- python/opencosmo/spatial/healpix.py | 18 +++++++++++++++++- python/opencosmo/spatial/relations.py | 3 ++- test/spatial/test_2d.py | 2 +- 5 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 changes/+78480439.improvement.rst diff --git a/changes/+78480439.improvement.rst b/changes/+78480439.improvement.rst new file mode 100644 index 00000000..15811ff0 --- /dev/null +++ b/changes/+78480439.improvement.rst @@ -0,0 +1 @@ +Improved spatial querying on lightcones, which should result in significant speedup for larger regions. diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 52b0a6f9..ac3b62c6 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -406,8 +406,10 @@ def bound(self, region: Region, select_by: Optional[str] = None): else: new_intersects_index = np.array([], dtype=np.int64) - new_index = np.concatenate( - [into_array(contained_index), into_array(new_intersects_index)] + new_index = np.sort( + np.concatenate( + [into_array(contained_index), into_array(new_intersects_index)] + ) ) new_state = st.with_region(st.take_rows(self.__state, new_index), check_region) diff --git a/python/opencosmo/spatial/healpix.py b/python/opencosmo/spatial/healpix.py index 19a44adb..442b6db4 100644 --- a/python/opencosmo/spatial/healpix.py +++ b/python/opencosmo/spatial/healpix.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING +import healpy as hp import numpy as np +from astropy.coordinates import SkyCoord from opencosmo.index import into_array from opencosmo.spatial.region import HealpixRegion @@ -47,4 +49,18 @@ def query( raise ValueError("Didn't recieve a 2D region!") nside = 2**level intersects = region.get_healpix_intersections(nside) - return {level: (np.array([]), intersects)} + boundaries = ( + hp.boundaries(nside, intersects, nest=True) + .transpose( + 0, + 2, + 1, + ) + .reshape(-1, 3) + ) + coords = SkyCoord(*hp.vec2ang(boundaries, lonlat=True), unit="deg") + coord_is_contained = region.contains(coords) + pixel_is_contained = np.all(coord_is_contained.reshape(-1, 4), axis=1) + return { + level: (intersects[pixel_is_contained], intersects[~pixel_is_contained]) + } diff --git a/python/opencosmo/spatial/relations.py b/python/opencosmo/spatial/relations.py index 0aae3ff7..76686d15 100644 --- a/python/opencosmo/spatial/relations.py +++ b/python/opencosmo/spatial/relations.py @@ -160,7 +160,8 @@ def __healpix_contains_other(region: HealpixRegion, other) -> bool: intersections = other.get_healpix_intersections(region.nside) except AttributeError: raise ValueError(f"Expected a 2D Sky Region but received {type(other)}") - return len(np.intersect1d(intersections, region.pixels)) == len(intersections) + print(intersections) + return bool(np.all(np.isin(intersections, region.pixels))) # --------------------------------------------------------------------------- diff --git a/test/spatial/test_2d.py b/test/spatial/test_2d.py index 363017e7..5613e4b6 100644 --- a/test/spatial/test_2d.py +++ b/test/spatial/test_2d.py @@ -1,9 +1,9 @@ import astropy.units as u import numpy as np from astropy.coordinates import SkyCoord +from opencosmo.spatial.relations import contains_2d, intersects_2d import opencosmo as oc -from opencosmo.spatial.relations import contains_2d, intersects_2d # --------------------------------------------------------------------------- # Cone search (bound with ConeRegion) From 334652d13e52e448dc4e02b39d82296aecf18a31 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 07:55:10 -0500 Subject: [PATCH 127/139] Remove debugging code --- python/opencosmo/spatial/relations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/opencosmo/spatial/relations.py b/python/opencosmo/spatial/relations.py index 76686d15..fd19481f 100644 --- a/python/opencosmo/spatial/relations.py +++ b/python/opencosmo/spatial/relations.py @@ -160,7 +160,6 @@ def __healpix_contains_other(region: HealpixRegion, other) -> bool: intersections = other.get_healpix_intersections(region.nside) except AttributeError: raise ValueError(f"Expected a 2D Sky Region but received {type(other)}") - print(intersections) return bool(np.all(np.isin(intersections, region.pixels))) From fff56d8d60042e25b62604e5d2cfaef930411e34 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 14:27:38 -0500 Subject: [PATCH 128/139] Update release pipeline for pre-releases --- .github/workflows/release.yaml | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index fa455c9c..789a6fdc 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,6 +12,8 @@ jobs: runs-on: ubuntu-latest permissions: contents: write + outputs: + prerelease: ${{ steps.check-version.outputs.prerelease }} steps: - uses: actions/checkout@v4 with: @@ -19,6 +21,15 @@ jobs: token: ${{ secrets.RELEASE_PAT }} ref: 'release' + - name: Check if pre-release + id: check-version + run: | + if echo "${{ inputs.version }}" | grep -qE '(a|b|rc|alpha|beta|dev)[0-9]*$'; then + echo "prerelease=true" >> "$GITHUB_OUTPUT" + else + echo "prerelease=false" >> "$GITHUB_OUTPUT" + fi + - name: Install uv uses: astral-sh/setup-uv@v6 with: @@ -31,9 +42,11 @@ jobs: run: uv run towncrier build --draft --version ${{ inputs.version }} > release_notes.rst - name: Build changelog + if: steps.check-version.outputs.prerelease == 'false' run: uv run towncrier build --yes --version ${{ inputs.version }} - name: Commit, tag, and push release branch + if: steps.check-version.outputs.prerelease == 'false' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" @@ -44,6 +57,14 @@ jobs: git push origin release --force git push origin ${{ inputs.version }} + - name: Tag pre-release + if: steps.check-version.outputs.prerelease == 'true' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git tag ${{ inputs.version }} + git push origin ${{ inputs.version }} + - name: Upload release notes uses: actions/upload-artifact@v4 with: @@ -105,7 +126,7 @@ jobs: run: uv publish --token ${{ secrets.PYPI_TOKEN }} github-release: - needs: publish + needs: [publish, prepare] runs-on: ubuntu-latest permissions: contents: write @@ -128,4 +149,5 @@ jobs: tag: ${{ inputs.version }} bodyFile: release_notes.rst artifacts: dist/* - makeLatest: true + prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }} + makeLatest: ${{ needs.prepare.outputs.prerelease == 'false' }} From 09141d2b85d9ae58b3666c265378c31523f85da0 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 14:31:03 -0500 Subject: [PATCH 129/139] Update bumpmyversion --- .bumpversion.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index b5c93c25..88a74202 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "1.2.5" +current_version = "1.2.6" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" serialize = ["{major}.{minor}.{patch}"] search = "{current_version}" From b9b5f9b828a4c53857e2d8a7326c727b93b3f68b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 14:34:09 -0500 Subject: [PATCH 130/139] Update version locations --- .bumpversion.toml | 5 ++++- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 88a74202..8d2d2872 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -22,10 +22,13 @@ post_commit_hooks = [] allow_shell_hooks = true [[tool.bumpversion.files]] -filename = "src/opencosmo/__init__.py" +filename = "python/opencosmo/__init__.py" [[tool.bumpversion.files]] filename = "pyproject.toml" [[tool.bumpversion.files]] filename = "docs/source/conf.py" + +[[tool.bumpversion.files]] +filename = "Cargo.toml" diff --git a/Cargo.toml b/Cargo.toml index 9cb40fba..64325640 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "opencosmo" -version = "1.2.4" +version = "1.2.6" edition = "2021" [lib] From eee2b72580e63a9e5b1243776cd9698eed6a81cc Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 14:41:12 -0500 Subject: [PATCH 131/139] Update bumpversion config --- .bumpversion.toml | 14 ++++++++++++-- .github/workflows/release.yaml | 2 ++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 8d2d2872..68658318 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,7 +1,10 @@ [tool.bumpversion] current_version = "1.2.6" -parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" -serialize = ["{major}.{minor}.{patch}"] +parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(?Pa|b|rc)?(?P\\d+)?" +serialize = [ + "{major}.{minor}.{patch}{pre_l}{pre_n}", + "{major}.{minor}.{patch}", +] search = "{current_version}" replace = "{new_version}" regex = false @@ -21,6 +24,13 @@ pre_commit_hooks = [] post_commit_hooks = [] allow_shell_hooks = true +[tool.bumpversion.parts.pre_l] +optional_value = "final" +values = ["a", "b", "rc", "final"] + +[tool.bumpversion.parts.pre_n] +first_value = "1" + [[tool.bumpversion.files]] filename = "python/opencosmo/__init__.py" diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 789a6fdc..b08ecd43 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -62,6 +62,8 @@ jobs: run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" + git add --all + git commit -m "Release ${{ inputs.version }}" git tag ${{ inputs.version }} git push origin ${{ inputs.version }} From baa75ae73f7f73e54294f1e28bf2e9955f6cc72b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 14:46:34 -0500 Subject: [PATCH 132/139] Allow release to run on main for pre-release versions --- .github/workflows/release.yaml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index b08ecd43..5b18e27e 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -15,21 +15,23 @@ jobs: outputs: prerelease: ${{ steps.check-version.outputs.prerelease }} steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - token: ${{ secrets.RELEASE_PAT }} - ref: 'release' - - name: Check if pre-release id: check-version run: | if echo "${{ inputs.version }}" | grep -qE '(a|b|rc|alpha|beta|dev)[0-9]*$'; then echo "prerelease=true" >> "$GITHUB_OUTPUT" + echo "ref=main" >> "$GITHUB_OUTPUT" else echo "prerelease=false" >> "$GITHUB_OUTPUT" + echo "ref=release" >> "$GITHUB_OUTPUT" fi + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.RELEASE_PAT }} + ref: ${{ steps.check-version.outputs.ref }} + - name: Install uv uses: astral-sh/setup-uv@v6 with: From 452a9795ba8660c8a58a9b1a6c0a69e3b581fc8b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 22 May 2026 15:27:53 -0500 Subject: [PATCH 133/139] Correctly serialize hyphenated pre-release --- .bumpversion.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 68658318..0fa58ebf 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,8 +1,8 @@ [tool.bumpversion] current_version = "1.2.6" -parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(?Pa|b|rc)?(?P\\d+)?" +parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)[-.]?(?Pa|b|rc)?(?P\\d+)?" serialize = [ - "{major}.{minor}.{patch}{pre_l}{pre_n}", + "{major}.{minor}.{patch}-{pre_l}{pre_n}", "{major}.{minor}.{patch}", ] search = "{current_version}" From 6eeb91035823962b55d2d639d3145673198e07e9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Sun, 24 May 2026 17:43:28 -0500 Subject: [PATCH 134/139] Add jax output, allow evaluate to use other formats --- Cargo.lock | 2 +- pyproject.toml | 1 + python/opencosmo/column/column.py | 29 ++++- python/opencosmo/column/evaluate.py | 102 +++++---------- python/opencosmo/dataset/dataset.py | 12 +- python/opencosmo/dataset/evaluate.py | 30 ++--- python/opencosmo/dataset/formats.py | 165 +++++++++++++++++++++++- python/opencosmo/dataset/instantiate.py | 11 +- test/test_formats.py | 14 ++ uv.lock | 105 ++++++++++++++- 10 files changed, 357 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b5b9a4c..c8611404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,7 +96,7 @@ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "opencosmo" -version = "1.2.4" +version = "1.2.6" dependencies = [ "numpy", "pyo3", diff --git a/pyproject.toml b/pyproject.toml index 5540c1ba..f12d3c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dev = [ "pip>=25.1.1", "pytest>=8.3.4,<9.0.0", "pytest-timeout>=2.4.0", + "jax>=0.10.1", ] docs = [ "sphinx>=9.0.0", diff --git a/python/opencosmo/column/column.py b/python/opencosmo/column/column.py index 8ff8c2db..d8dd419a 100644 --- a/python/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -865,11 +865,6 @@ def get_units(self, units: dict[str, np.ndarray]): def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): data = {name: data[name] for name in self.__requires} chunk_sizes = index[1] if isinstance(index, tuple) else None - if self.__format != "astropy": - data = { - name: val.value if isinstance(val, u.Quantity) else val - for name, val in data.items() - } if self.batch_size > 0: length = len(next(iter(data.values()))) @@ -886,13 +881,33 @@ def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): case EvaluateStrategy.VECTORIZE: return evaluate_vectorized(data, self.__func, self.__kwargs, index) case EvaluateStrategy.ROW_WISE: - return evaluate_rows(data, self.__func, self.__kwargs) + return evaluate_rows(data, self.__func, self.__kwargs, self.__format) case EvaluateStrategy.CHUNKED: if chunk_sizes is None: raise ValueError( "Cannot evaluate in CHUNKED strategy with a non-chunked index" ) - return evaluate_chunks(data, self.__func, self.__kwargs, chunk_sizes) + return evaluate_chunks( + data, self.__func, self.__kwargs, chunk_sizes, self.__format + ) + + def evaluate_for_storage( + self, data: dict[str, np.ndarray], index: DataIndex | None + ) -> dict[str, np.ndarray]: + """ + Evaluate and return numpy-formatted output suitable for the column + cache. Input arrives in the numpy/astropy form used internally, so + it is first converted to the user's requested format before the + function runs, and the output is converted back to numpy. + """ + from opencosmo.dataset.formats import to_format_dict, to_numpy_dict + + required = {name: data[name] for name in self.__requires} + converted = to_format_dict(required, self.__format) + output = self.evaluate(converted, index) + if not isinstance(output, dict): + output = {next(iter(self.__produces)): output} + return to_numpy_dict(output) def evaluate_one(self, dataset: Dataset): match self.__strategy: diff --git a/python/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py index 20decfc1..d4822c84 100644 --- a/python/opencosmo/column/evaluate.py +++ b/python/opencosmo/column/evaluate.py @@ -3,11 +3,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable -import astropy.units as u import numpy as np -from opencosmo.evaluate import insert_data - if TYPE_CHECKING: from opencosmo import Dataset @@ -18,77 +15,48 @@ class EvaluateStrategy(Enum): CHUNKED = "chunked" -def evaluate_rows(data: dict[str, np.ndarray], func: Callable, kwargs: dict[str, Any]): +def evaluate_rows( + data: dict[str, Any], + func: Callable, + kwargs: dict[str, Any], + format: str, +): + from opencosmo.dataset.formats import stack_rows + data_length = len(next(iter(data.values()))) - storage = {} + per_column: dict[str, list] = {} for i in range(data_length): iterable_inputs = {name: values[i] for name, values in data.items()} output = func(**iterable_inputs, **kwargs) if not isinstance(output, dict): output = {func.__name__: output} - if i == 0: - storage = __make_row_based_output_from_first_values(output, data_length) - continue - insert_data(storage, i, output) - return storage - - -def __make_row_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - try: - shape = (data_length,) + value.shape - except AttributeError: - shape = (data_length,) - try: - dtype = value.dtype - except AttributeError: - dtype = type(value) - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0] = value - storage[name] = column_storage - - return storage + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: stack_rows(values, format) for name, values in per_column.items()} def evaluate_chunks( - data: dict[str, np.ndarray], + data: dict[str, Any], func: Callable, kwargs: dict[str, Any], chunk_sizes: np.ndarray, + format: str, ): - data_length = len(next(iter(data.values()))) + from opencosmo.dataset.formats import concat_chunks chunk_splits = np.cumsum(chunk_sizes) - storage = {} - input_data = {name: np.split(arr, chunk_splits) for name, arr in data.items()} - for i in range(len(chunk_splits)): - chunk_input_data = {name: split[i] for name, split in input_data.items()} + starts = np.concatenate([[0], chunk_splits[:-1]]) + per_column: dict[str, list] = {} + for start, end in zip(starts, chunk_splits): + chunk_input_data = { + name: arr[int(start) : int(end)] for name, arr in data.items() + } output = func(**chunk_input_data, **kwargs) if not isinstance(output, dict): output = {func.__name__: output} - if i == 0: - storage = __make_chunked_based_output_from_first_values(output, data_length) - continue - for name, values in output.items(): - storage[name][chunk_splits[i - 1] : chunk_splits[i]] = values - return storage - - -def __make_chunked_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - shape = (data_length,) + value.shape[1:] - dtype = value.dtype - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0 : len(value)] = value - storage[name] = column_storage - - return storage + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: concat_chunks(chunks, format) for name, chunks in per_column.items()} def evaluate_vectorized(data, func, kwargs, index): @@ -105,29 +73,25 @@ def do_first_evaluation( kwargs: dict[str, Any], dataset: Dataset, ): + from opencosmo.dataset.formats import fetch_as_dict + eval_strategy = EvaluateStrategy(strategy) + columns = list(dataset.columns) match eval_strategy: case EvaluateStrategy.VECTORIZE: - values = dataset.take(1).get_data(format, unpack=False) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} - + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) return func(**values, **kwargs), eval_strategy case EvaluateStrategy.ROW_WISE: - values = dataset.take(1).get_data(format, unpack=True) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) + values = {name: container[0] for name, container in values.items()} return func(**values, **kwargs), eval_strategy case EvaluateStrategy.CHUNKED: index = dataset.index assert isinstance(index, tuple) first_chunk_size = index[1][0] - first_chunk = dataset.take(first_chunk_size, at="start").get_data(format) - first_chunk = dict(first_chunk) + first_chunk = fetch_as_dict( + dataset.take(first_chunk_size, at="start"), columns, format + ) return func(**first_chunk, **kwargs), eval_strategy diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index ac3b62c6..7066b2ef 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -270,7 +270,7 @@ def get_data( on the data. The method supports output into several different formats, including - "astropy", "numpy", "pandas", "polars", and "pyarrow". Although astropy + "astropy", "numpy", "pandas", "polars", "jax", and "arrow". Although astropy and numpy are core dependencies of OpenCosmo, the remaining formats require you to have the relevant libraries installed in your python environment. This method will check that it can import the necessary @@ -286,7 +286,7 @@ def get_data( ---------- output: str, default="astropy" The format to output the data in. - Currently supported are "astropy", "numpy", "pandas", "polars", "arrow" + Currently supported are "astropy", "numpy", "pandas", "polars", "arrow", "jax" Returns ------- @@ -478,8 +478,11 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): as the function. Otherwise the data will be returned directly. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. allow_overwrite: bool, default = False batch_size: int, default = -1 @@ -496,6 +499,7 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): result : Dataset | dict[str, np.ndarray | astropy.units.Quantity] The new dataset with the evaluated column(s) or the results as numpy arrays or astropy quantities """ + verify_format(format) evaluated_column = build_evaluated_column( self, func, vectorize, insert, format, batch_size, evaluate_kwargs ) diff --git a/python/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py index 38968dfd..e2f84210 100644 --- a/python/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -10,6 +10,7 @@ from opencosmo.column.column import EvaluatedColumn from opencosmo.column.evaluate import EvaluateStrategy, do_first_evaluation +from opencosmo.dataset.formats import concat_chunks, fetch_as_dict from opencosmo.evaluate import ( insert_data, make_output_from_first_values, @@ -27,10 +28,6 @@ def build_evaluated_column( dataset, func, vectorize, insert, format, batch_size, evaluate_kwargs ): - if format not in ["astropy", "numpy"]: - raise ValueError( - f"Evaluate only supports numpy and astropy format, got: {format}" - ) kwarg_columns = set(evaluate_kwargs.keys()).intersection(dataset.columns) if kwarg_columns: raise ValueError( @@ -69,12 +66,7 @@ def visit_dataset( ) -> dict[str, np.ndarray]: if column.batch_size > 0: return visit_dataset_batched(column, dataset) - requires_names = column.requires_names - data = dataset.select(requires_names).get_data(format=column.format) - try: - data = dict(data) - except (TypeError, ValueError): - data = {next(iter(requires_names)): data} + data = fetch_as_dict(dataset, column.requires_names, column.format) output = column.evaluate(data, dataset.index) if not isinstance(output, dict): assert len(column.produces) == 1 @@ -89,24 +81,22 @@ def visit_dataset_batched(column: EvaluatedColumn, dataset: Dataset): output = defaultdict(list) - requires_names = column.requires_names for start, end in np.lib.stride_tricks.sliding_window_view(ranges, 2): - batch_data = ( - dataset.select(requires_names) - .take_range(start, end) - .get_data(format=column.format, unpack=False) + batch_data = fetch_as_dict( + dataset.take_range(start, end), + column.requires_names, + column.format, + unpack=False, ) - try: - batch_data = dict(batch_data) - except TypeError: - batch_data = {next(iter(requires_names)): batch_data} batch_output = column.evaluate(batch_data, None) if batch_output is not None and not isinstance(batch_output, dict): batch_output = {column.produces.pop(): batch_output} for name, column_batch in batch_output.items(): output[name].append(column_batch) - full_output = {name: np.concat(out) for name, out in output.items()} + full_output = { + name: concat_chunks(out, column.format) for name, out in output.items() + } return full_output diff --git a/python/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py index 321916d2..2ac5f711 100644 --- a/python/opencosmo/dataset/formats.py +++ b/python/opencosmo/dataset/formats.py @@ -1,11 +1,15 @@ from __future__ import annotations from importlib import import_module +from typing import TYPE_CHECKING, Any, Iterable import astropy.units as u import numpy as np from astropy.table import Column, QTable +if TYPE_CHECKING: + from opencosmo import Dataset + def verify_format(output_format: str): match output_format: @@ -19,6 +23,8 @@ def verify_format(output_format: str): import_name = "pyarrow" case "polars": import_name = "polars" + case "jax": + import_name = "jax" case _: raise ValueError(f"Unknown data output format {output_format}") @@ -39,13 +45,153 @@ def convert_data(data: dict[str, np.ndarray], output_format: str): case "astropy": return __convert_to_astropy(data) case "numpy": - return __convert_to_numpy(data) + return convert_to_numpy(data) case "pandas": return __convert_to_pandas(data) case "polars": return __convert_to_polars(data) case "arrow": return __convert_to_arrow(data) + case "jax": + return __convert_to_jax(data) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def fetch_as_dict( + dataset: Dataset, + requires_names: Iterable[str], + output_format: str, + unpack: bool = True, +) -> dict[str, Any]: + """ + Fetch the requested columns and return them as a {name: container} dict in + the user's requested format. Routes through astropy so that Quantities (with + units) survive into the conversion step; other formats then receive plain + values via to_format_dict. + """ + requires_names = list(requires_names) + raw = dataset.select(requires_names).get_data(format="astropy", unpack=unpack) + if isinstance(raw, QTable): + raw = {name: raw[name] for name in raw.colnames} + elif not isinstance(raw, dict): + raw = {requires_names[0]: raw} + if output_format == "astropy": + return raw + return to_format_dict(raw, output_format) + + +def to_format_dict(data: dict[str, np.ndarray], output_format: str) -> dict: + """ + Convert each column of a numpy/astropy dict to the requested format, + preserving the dict shape. Unlike convert_data, this never wraps the + result in a higher-level container (DataFrame, QTable, ...). Used to + feed user-supplied evaluate functions when the upstream data is + numpy-shaped (e.g. when reading from the column cache). + """ + if output_format == "astropy": + return data + + def strip(value): + return value.value if isinstance(value, (u.Quantity, Column)) else value + + match output_format: + case "numpy": + return {k: strip(v) for k, v in data.items()} + case "jax": + import jax.numpy as jnp + + return {k: jnp.asarray(strip(v)) for k, v in data.items()} + case "pandas": + import pandas as pd + + return {k: pd.Series(strip(v)) for k, v in data.items()} + case "polars": + import polars as pl + + return {k: pl.Series(values=strip(v)) for k, v in data.items()} + case "arrow": + import pyarrow as pa # type: ignore + + return {k: pa.array(strip(v)) for k, v in data.items()} + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def to_numpy_dict(data: dict) -> dict[str, np.ndarray]: + """ + Convert each value in a dict-of-format-arrays back to a numpy array + suitable for the column cache. Astropy Quantities are preserved so + that downstream unit handling continues to work; other formats are + converted to plain numpy with no unit information. + """ + result: dict[str, np.ndarray] = {} + for name, value in data.items(): + if isinstance(value, (u.Quantity, np.ndarray)): + result[name] = value + else: + result[name] = np.asarray(value) + return result + + +def stack_rows(values: list, output_format: str): + """ + Stack a list of per-row values into a 1-D container in the target format. + Used by row-wise evaluation strategies to assemble output without + preallocation, which would break for formats with immutable arrays + (e.g. jax). + """ + match output_format: + case "astropy": + if values and isinstance(values[0], u.Quantity): + return u.Quantity(values) + return np.array(values) + case "numpy": + return np.array(values) + case "jax": + import jax.numpy as jnp + + return jnp.array(values) + case "pandas": + import pandas as pd + + return pd.Series(values) + case "polars": + import polars as pl + + return pl.Series(values=values) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.array(values) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def concat_chunks(chunks: list, output_format: str): + """ + Concatenate a list of per-chunk arrays into a single container in the + target format. + """ + match output_format: + case "astropy" | "numpy": + return np.concatenate(chunks) + case "jax": + import jax.numpy as jnp + + return jnp.concatenate(chunks) + case "pandas": + import pandas as pd + + return pd.concat(chunks, ignore_index=True) + case "polars": + import polars as pl + + return pl.concat(chunks) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.concat_arrays(chunks) case _: raise ValueError(f"Unknown data output format {output_format}") @@ -62,7 +208,7 @@ def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: return QTable(data, copy=False) -def __convert_to_numpy( +def convert_to_numpy( data: dict[str, np.ndarray], ) -> dict[str, np.ndarray] | np.ndarray: converted_data = dict( @@ -82,7 +228,7 @@ def __convert_to_numpy( def __convert_to_pandas(data: dict[str, np.ndarray]): import pandas as pd - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): # only one column return pd.Series(numpy_data, name=next(iter(data.keys()))) @@ -92,7 +238,7 @@ def __convert_to_pandas(data: dict[str, np.ndarray]): def __convert_to_arrow(data: dict[str, np.ndarray]): import pyarrow as pa # type: ignore - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): return pa.array(numpy_data) @@ -106,8 +252,17 @@ def __convert_to_arrow(data: dict[str, np.ndarray]): def __convert_to_polars(data: dict[str, np.ndarray]): import polars as pl - numpy_data = __convert_to_numpy(data) + numpy_data = convert_to_numpy(data) if isinstance(numpy_data, np.ndarray): return pl.Series(name=next(iter(data.keys())), values=numpy_data) return pl.from_dict(data) # type: ignore + + +def __convert_to_jax(data: dict[str, np.ndarray]): + import jax.numpy as jnp + + output_data = convert_to_numpy(data) + if isinstance(output_data, np.ndarray): + return jnp.asarray(output_data) + return {key: jnp.asarray(value) for key, value in output_data.items()} diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py index 60507c80..79d4ce75 100644 --- a/python/opencosmo/dataset/instantiate.py +++ b/python/opencosmo/dataset/instantiate.py @@ -4,7 +4,7 @@ import rustworkx as rx -from opencosmo.column.column import RawColumn +from opencosmo.column.column import EvaluatedColumn, RawColumn from opencosmo.dataset.graph import build_dependency_graph if TYPE_CHECKING: @@ -101,9 +101,12 @@ def build_derived_columns( name: all_data[dep_uuid][name] for name, dep_uuid in producer.dep_map.items() } - output = producer.evaluate(input_data, index) - if not isinstance(output, dict): - output = {next(iter(producer.produces)): output} + if isinstance(producer, EvaluatedColumn): + output = producer.evaluate_for_storage(input_data, index) + else: + output = producer.evaluate(input_data, index) + if not isinstance(output, dict): + output = {next(iter(producer.produces)): output} new_derived[producer.uuid] = output if not producer.no_cache: to_cache[producer.uuid] = output diff --git a/test/test_formats.py b/test/test_formats.py index 533fd383..d5b46d63 100644 --- a/test/test_formats.py +++ b/test/test_formats.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np import pandas as pd import polars as pl @@ -28,6 +29,12 @@ def test_return_pyarrow(input_path): assert all(isinstance(v, pa.Array) for v in data.values()) +def test_return_jax(input_path): + data = oc.open(input_path).get_data("jax") + assert isinstance(data, dict) + assert all(isinstance(v, jnp.ndarray) for v in data.values()) + + def test_return_pandas_single(input_path): dataset = oc.open(input_path) column = np.random.choice(dataset.columns) @@ -48,3 +55,10 @@ def test_return_pyarrow_single(input_path): column = np.random.choice(dataset.columns) data = dataset.select(column).get_data("arrow") assert isinstance(data, pa.Array) + + +def test_return_jax_single(input_path): + dataset = oc.open(input_path) + column = np.random.choice(dataset.columns) + data = dataset.select(column).get_data("jax") + assert isinstance(data, jnp.ndarray) diff --git a/uv.lock b/uv.lock index 0a16bfce..53519f59 100644 --- a/uv.lock +++ b/uv.lock @@ -6,10 +6,14 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", - "python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version < '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "(python_full_version < '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version < '3.13' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", ] [[package]] @@ -557,6 +561,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "jax" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/49/b082387119c4a6bc7596296bbdc6bce034628cdd2845ebb27304cbca3624/jax-0.10.1.tar.gz", hash = "sha256:11672410faf8752429eb9a131de203dc488a2a3a012d509baa2b39878008810d", size = 2718178, upload-time = "2026-05-20T14:54:09.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl", hash = "sha256:47f3192c76e9e3358de1b106a8af5e943fccb10510903f25d96ea53652729134", size = 3150973, upload-time = "2026-05-20T14:51:30.066Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/b4/fdb6e989b142d8a8d2093f342cbc5323fe0d4a7217fd899c8ddf9e108a5a/jaxlib-0.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7a4399c7429c87ee6f7ab1e5712c5548ecfabde974ee9f4a12957fea4e35efb8", size = 60816137, upload-time = "2026-05-20T14:52:53.028Z" }, + { url = "https://files.pythonhosted.org/packages/a7/29/0b4eaaca005708751ff301903a2ca760dcc34175dadb7536498c57a7de85/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:12126603ba472300c62480f86d20972d563143041039d9dea349ad510aa6123c", size = 80294034, upload-time = "2026-05-20T14:52:56.941Z" }, + { url = "https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:f3cdf5b7f48470ab5455ab79aab746419694ccb6b52651cc2ce5fb27def03588", size = 85828774, upload-time = "2026-05-20T14:53:01.749Z" }, + { url = "https://files.pythonhosted.org/packages/ec/8f/993ea419eca6f34fe12613e22a03b93f40e5b1e8e0df18d4060e1313a1fc/jaxlib-0.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:0acf3f8e7dca9074c0327f0f61502845792ca9f82fab23b841b00daa78e85488", size = 64830187, upload-time = "2026-05-20T14:53:05.932Z" }, + { url = "https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4167213aa00f14bb0d8fbd90f9ded75e976f71ce8baf8c3c44e04c8fb80ea0c1", size = 60815855, upload-time = "2026-05-20T14:53:11.718Z" }, + { url = "https://files.pythonhosted.org/packages/5b/76/9a1971bc9edb8728a7ba86d2693127ee46add9811230b8452321415fd4e9/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:e53223d8f6861d33c02dd02343fa464700401ff5363784d37c33407218e328b8", size = 80293947, upload-time = "2026-05-20T14:53:15.408Z" }, + { url = "https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:bb073a1224e659e01e8d32d47c000edb52ec2aa8ba97ec22b2228b3a46e5c167", size = 85829861, upload-time = "2026-05-20T14:53:19.773Z" }, + { url = "https://files.pythonhosted.org/packages/a9/df/48659e2ee57705c63a51525f810fe3e0c87af4ca9f89d4738281a872d58e/jaxlib-0.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:6449f1d4a22324f5f02c843360475783f9fd1d353fe711806cbf4e927d1360ae", size = 64828863, upload-time = "2026-05-20T14:53:24.562Z" }, + { url = "https://files.pythonhosted.org/packages/c7/34/3f7c95ee1b2555d611f836988a49b522c04b8d186e0528f91d45118089bf/jaxlib-0.10.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:72a7db1242d6b7773340e430c244e73a8d13f82ce83933c0529a8b4b1520dcf3", size = 60934063, upload-time = "2026-05-20T14:53:28.549Z" }, + { url = "https://files.pythonhosted.org/packages/3a/84/855a2395d299e00e1d9ffd0aecd516b52b81780f3e4ef537527be6d8c1fc/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:649c26ca92e9bbffb3c35226b37893916f20e6b89fb3911ce79b39e5cfb27b46", size = 80402500, upload-time = "2026-05-20T14:53:32.444Z" }, + { url = "https://files.pythonhosted.org/packages/be/35/153e91a9c770a981d525d845b3f4cdb71a1e119681594a33908e9536bdff/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:cfe75d8a17e0d33a7bed27f32d7a5344a66e8d4af7073973f396e14ff4a9c503", size = 85941210, upload-time = "2026-05-20T14:53:36.982Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a1/c4d4c0530313c50dd1ba07fff480cfd0c5f18c5ec49742f4a52a6edfd95f/jaxlib-0.10.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9a657554dbe56f3691d377cb99d00b29df14e5638e579e57965f9f69c32c9315", size = 60826189, upload-time = "2026-05-20T14:53:40.73Z" }, + { url = "https://files.pythonhosted.org/packages/98/71/529b2439b88491e0806ee9fa6191ecf90f4447dfe092e80bd19577c85260/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:93cd9989404f86a21b50c7ff4e850a31f55509664312080978bde97a664d92a9", size = 80300654, upload-time = "2026-05-20T14:53:44.751Z" }, + { url = "https://files.pythonhosted.org/packages/be/08/0bc4132fb2fe224ee9d83fd60e3650fd4d893b4e5148707a98df8f0333a4/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:26b94e9640b01968cc14b8353dd6b6540d723f30579c78b1f46a477fc4aa196d", size = 85841241, upload-time = "2026-05-20T14:53:49.13Z" }, + { url = "https://files.pythonhosted.org/packages/5c/de/423d748ce3367bd5ea20d8cc34a7ceb6420da4d41c20b247f54194700d04/jaxlib-0.10.1-cp314-cp314-win_amd64.whl", hash = "sha256:375820799bbf7d515dd4e4d40f3334566b73d3fe64d340afbd6aa897d5d7c486", size = 67301520, upload-time = "2026-05-20T14:53:52.989Z" }, + { url = "https://files.pythonhosted.org/packages/2d/bf/3f7ce089d62f7ac85ea678925471f7ec88038899e67ab02079c1e7d8ad4e/jaxlib-0.10.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:bc52f16bdd61b299efeed7ea8e743e91a118059a463da2ab97196fc05b69ddb7", size = 60935393, upload-time = "2026-05-20T14:53:56.886Z" }, + { url = "https://files.pythonhosted.org/packages/65/6a/38cf1d4ff8c8f74ec8d567e0aa6e3d2082ab6fc580545f6bb51368da20a9/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:55b0a473fbd57d31dc3935c4cd5c0c38af5f7c1f41300df27923cf46676972ca", size = 80405741, upload-time = "2026-05-20T14:54:01.94Z" }, + { url = "https://files.pythonhosted.org/packages/16/9e/d3cff171aaf13a09aab26a44d1a27dcbf0d6311e4d855f5d99685965ace3/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:4b0cb8ef960a3723037db63f28ffa20083d90fff6d30085e99d7c63cfa08e4c0", size = 85942183, upload-time = "2026-05-20T14:54:05.91Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -832,6 +882,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/49/d651878698a0b67f23aa28e17f45a6d6dd3d3f933fa29087fa4ce5947b5a/matplotlib-3.10.8-cp314-cp314t-win_arm64.whl", hash = "sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f", size = 8192560, upload-time = "2025-12-10T22:56:38.008Z" }, ] +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" }, + { url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, + { url = "https://files.pythonhosted.org/packages/f5/f0/0cfadd537c5470378b1b32bd859cf2824972174b51b873c9d95cfd7475a5/ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7", size = 212222, upload-time = "2025-11-17T22:31:53.742Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/9acc86985bfad8f2c2d30291b27cd2bb4c74cea08695bd540906ed744249/ml_dtypes-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460", size = 160793, upload-time = "2025-11-17T22:31:55.358Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" }, + { url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" }, + { url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" }, + { url = "https://files.pythonhosted.org/packages/e1/8b/200088c6859d8221454825959df35b5244fa9bdf263fd0249ac5fb75e281/ml_dtypes-0.5.4-cp313-cp313-win_amd64.whl", hash = "sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328", size = 212224, upload-time = "2025-11-17T22:32:01.349Z" }, + { url = "https://files.pythonhosted.org/packages/8f/75/dfc3775cb36367816e678f69a7843f6f03bd4e2bcd79941e01ea960a068e/ml_dtypes-0.5.4-cp313-cp313-win_arm64.whl", hash = "sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175", size = 160798, upload-time = "2025-11-17T22:32:02.864Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" }, + { url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" }, + { url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" }, + { url = "https://files.pythonhosted.org/packages/8c/27/12607423d0a9c6bbbcc780ad19f1f6baa2b68b18ce4bddcdc122c4c68dc9/ml_dtypes-0.5.4-cp313-cp313t-win_amd64.whl", hash = "sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6", size = 225612, upload-time = "2025-11-17T22:32:08.615Z" }, + { url = "https://files.pythonhosted.org/packages/e5/80/5a5929e92c72936d5b19872c5fb8fc09327c1da67b3b68c6a13139e77e20/ml_dtypes-0.5.4-cp313-cp313t-win_arm64.whl", hash = "sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1", size = 164145, upload-time = "2025-11-17T22:32:09.782Z" }, + { url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" }, + { url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" }, + { url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" }, + { url = "https://files.pythonhosted.org/packages/e9/93/2bfed22d2498c468f6bcd0d9f56b033eaa19f33320389314c19ef6766413/ml_dtypes-0.5.4-cp314-cp314-win_amd64.whl", hash = "sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56", size = 221032, upload-time = "2025-11-17T22:32:15.763Z" }, + { url = "https://files.pythonhosted.org/packages/76/a3/9c912fe6ea747bb10fe2f8f54d027eb265db05dfb0c6335e3e063e74e6e8/ml_dtypes-0.5.4-cp314-cp314-win_arm64.whl", hash = "sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049", size = 163353, upload-time = "2025-11-17T22:32:16.932Z" }, + { url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" }, + { url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" }, + { url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" }, + { url = "https://files.pythonhosted.org/packages/84/44/f4d18446eacb20ea11e82f133ea8f86e2bf2891785b67d9da8d0ab0ef525/ml_dtypes-0.5.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1", size = 236612, upload-time = "2025-11-17T22:32:22.579Z" }, + { url = "https://files.pythonhosted.org/packages/ad/3f/3d42e9a78fe5edf792a83c074b13b9b770092a4fbf3462872f4303135f09/ml_dtypes-0.5.4-cp314-cp314t-win_arm64.whl", hash = "sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d", size = 168825, upload-time = "2025-11-17T22:32:23.766Z" }, +] + [[package]] name = "more-itertools" version = "11.0.1" @@ -1047,6 +1133,7 @@ io = [ [package.dev-dependencies] dev = [ + { name = "jax" }, { name = "mypy" }, { name = "pip" }, { name = "pre-commit" }, @@ -1094,6 +1181,7 @@ provides-extras = ["io"] [package.metadata.requires-dev] dev = [ + { name = "jax", specifier = ">=0.10.1" }, { name = "mypy", specifier = ">=1.15,<2.0.0" }, { name = "pip", specifier = ">=25.1.1" }, { name = "pre-commit", specifier = ">=4.2.0,<5.0.0" }, @@ -1119,6 +1207,15 @@ test = [ ] test-mpi = [{ name = "mpi-pytest", specifier = ">=2025.4.0,<2026.0.0" }] +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" }, +] + [[package]] name = "packaging" version = "26.0" From 5fceb99fcd286566c5d91d433e558f1a200047e1 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Sun, 24 May 2026 17:46:05 -0500 Subject: [PATCH 135/139] Add updated evaluate formats tests --- test/test_evaluate_formats.py | 185 ++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 test/test_evaluate_formats.py diff --git a/test/test_evaluate_formats.py b/test/test_evaluate_formats.py new file mode 100644 index 00000000..4ba425e0 --- /dev/null +++ b/test/test_evaluate_formats.py @@ -0,0 +1,185 @@ +import jax.numpy as jnp +import numpy as np +import pandas as pd +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import opencosmo as oc + + +@pytest.fixture +def input_path(snapshot_path): + return snapshot_path / "haloproperties.hdf5" + + +FORMATS = ["jax", "pandas", "polars", "arrow"] +SCALARS = { + "jax": (jnp.ndarray, np.floating, float), + "pandas": (pd.Series, np.floating, float), + "polars": (pl.Series, float, int), + "arrow": (pa.Scalar, float, int), +} + + +def _vectorized_func(format): + """Multiply two columns using each format's native multiplication path.""" + + if format == "arrow": + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return pc.multiply(fof_halo_mass, fof_halo_com_vx) + + else: + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + return fof_px + + +def _row_func(format): + """Multiply two scalars; works for any format because each row is a scalar.""" + + def fof_px(fof_halo_mass, fof_halo_com_vx): + if isinstance(fof_halo_mass, pa.Scalar): + return fof_halo_mass.as_py() * fof_halo_com_vx.as_py() + return float(fof_halo_mass) * float(fof_halo_com_vx) + + return fof_px + + +def _expected(input_path): + data = ( + oc.open(input_path) + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + return data["fof_halo_mass"] * data["fof_halo_com_vx"] + + +def _to_numpy(value): + if isinstance(value, jnp.ndarray): + return np.asarray(value) + if isinstance(value, (pd.Series, pl.Series)): + return value.to_numpy() + if isinstance(value, pa.Array): + return value.to_numpy(zero_copy_only=False) + return np.asarray(value) + + +# --------------------------------------------------------------------------- +# insert = False (return result directly) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_noinsert(input_path, format): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_noinsert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + result = ds.evaluate( + _row_func(format), vectorize=False, insert=False, format=format + ) + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_noinsert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + result = ds.evaluate( + _vectorized_func(format), + insert=False, + batch_size=batch_size, + format=format, + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +# --------------------------------------------------------------------------- +# insert = True (converted to numpy, stored in cache) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_insert(input_path, format): + ds = oc.open(input_path) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_insert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + ds = ds.evaluate(_row_func(format), vectorize=False, insert=True, format=format) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_insert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + ds = ds.evaluate( + _vectorized_func(format), + insert=True, + batch_size=batch_size, + format=format, + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Output-type assertions on the not-insert path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "format,expected_type", + [ + ("jax", jnp.ndarray), + ("pandas", pd.Series), + ("polars", pl.Series), + ("arrow", pa.Array), + ], +) +def test_evaluate_noinsert_returns_native_container(input_path, format, expected_type): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert isinstance(result["fof_px"], expected_type) From 91d1a7850a6d6a5526ad92be216020c3cbde81de Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:21:58 -0500 Subject: [PATCH 136/139] Add support for other formats to StructureCollection.evaluate --- .../collection/lightcone/lightcone.py | 34 ++-- .../collection/structure/evaluate.py | 127 ++++----------- .../collection/structure/structure.py | 22 ++- python/opencosmo/dataset/dataset.py | 17 +- python/opencosmo/dataset/formats.py | 50 +++--- test/test_evaluate_formats.py | 152 ++++++++++++++++++ 6 files changed, 270 insertions(+), 132 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index f82410e7..98766124 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -27,7 +27,7 @@ from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn from opencosmo.dataset.evaluate import build_evaluated_column -from opencosmo.dataset.formats import convert_data, verify_format +from opencosmo.dataset.formats import concat_chunks, convert_data, verify_format from opencosmo.dataset.take import ( get_end_take_index, get_random_take_index, @@ -325,7 +325,13 @@ def get_pixels(self, nside: int = 64): return lcutils.get_pixels(self, int(level)) - def get_data(self, format="astropy", unpack: bool = True, **kwargs): + def get_data( + self, + format="astropy", + unpack: bool = True, + wrap_single: bool = False, + **kwargs, + ): """ Get the data in this dataset as an astropy table/column or as numpy array(s). Note that a dataset does not load data from disk into @@ -340,7 +346,9 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): units will be attached. If the dataset only contains a single column, it will be returned as an - astropy.table.Column or a single numpy array. + astropy.table.Column or a single numpy array. Pass :code:`wrap_single=True` + to always return the format's multi-column container (QTable, DataFrame, + dict, ...) regardless of column count. Parameters ---------- @@ -348,6 +356,10 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): The format to output the data in. Currently supported are "astropy", "numpy", "pandas", "polars", and "arrow" + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. + Returns ------- data: Table | Column | dict[str, ndarray] | ndarray @@ -383,11 +395,11 @@ def get_data(self, format="astropy", unpack: bool = True, **kwargs): key: value[0] if len(value) == 1 else value for key, value in table.items() } - return convert_data(output_data, format) + return convert_data(output_data, format, wrap_single=wrap_single) if format != "astropy": - return convert_data(dict(table), format) - elif len(table.columns) == 1: + return convert_data(dict(table), format, wrap_single=wrap_single) + elif len(table.columns) == 1 and not wrap_single: return next(iter(dict(table).values())) return table @@ -758,9 +770,11 @@ def evaluate( The function to evaluate on the rows in the dataset. format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. vectorize: bool, default = False Whether to provide the values as full columns (True) or one row at a time (False) @@ -834,7 +848,7 @@ def evaluate( keys = next(iter(result.values())).keys() output = {} for key in keys: - output[key] = np.concatenate([r[key] for r in result.values()]) + output[key] = concat_chunks([r[key] for r in result.values()], format) return output def filter(self, *masks: ColumnMask, **kwargs) -> Self: diff --git a/python/opencosmo/collection/structure/evaluate.py b/python/opencosmo/collection/structure/evaluate.py index 7bed468a..efc80cb8 100644 --- a/python/opencosmo/collection/structure/evaluate.py +++ b/python/opencosmo/collection/structure/evaluate.py @@ -1,16 +1,12 @@ from __future__ import annotations from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional -import numpy as np from astropy.units import Quantity # type: ignore from opencosmo import dataset as ds -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) +from opencosmo.dataset.formats import concat_chunks, stack_rows if TYPE_CHECKING: from opencosmo import StructureCollection @@ -82,17 +78,22 @@ def evaluate_into_properties( kwargs: dict[str, Any], insert: bool, ): - storage = __make_output(function, collection, format, kwargs, {}, insert) - for i, structure in enumerate(collection.objects()): - if i == 0: - continue + per_column: dict[str, list] = {} + for structure in collection.objects(): input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) - if storage is not None: - insert_data(storage, i, output) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) + if not isinstance(output, dict): + output = {function.__name__: output} + for name, value in output.items(): + per_column.setdefault(name, []).append(value) - return storage + if not per_column: + return None + return {name: stack_rows(values, format) for name, values in per_column.items()} def evaluate_into_dataset( @@ -103,25 +104,28 @@ def evaluate_into_dataset( dataset: str, insert: bool, ): - storage = __make_chunked_output(function, collection, dataset, format, kwargs, {}) - + per_column: dict[str, list] = {} for i, structure in enumerate(collection.objects()): - if i == 0: - continue input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) if not isinstance(output, dict): output = {function.__name__: output} - - if storage is not None: - for name, output_arr in output.items(): - storage[name].append(output_arr) - - if storage is None: - return - output_data = {name: np.concatenate(data) for name, data in storage.items()} - return output_data + if i == 0: + expected_length = len(input_structure[dataset]) + if any(len(v) != expected_length for v in output.values()): + raise ValueError( + "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" + ) + for name, output_arr in output.items(): + per_column.setdefault(name, []).append(output_arr) + + if not per_column: + return None + return {name: concat_chunks(data, format) for name, data in per_column.items()} def __make_input(structure: dict, format: str = "astropy"): @@ -130,78 +134,15 @@ def __make_input(structure: dict, format: str = "astropy"): if isinstance(element, dict): values[name] = __make_input(element, format) elif isinstance(element, ds.Dataset): - data = element.get_data(format) + data = element.get_data(format, wrap_single=True) values[name] = data - elif isinstance(element, Quantity) and format == "numpy": + elif isinstance(element, Quantity) and format != "astropy": values[name] = element.value else: values[name] = element return values -def __make_output( - function: Callable, - collection: StructureCollection, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = next(collection.take(1, at="start").objects()) - first_input = __make_input(first_structure, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - n_rows = len(collection) - return make_output_from_first_values(first_values, n_rows) - - -def __make_chunked_output( - function: Callable, - collection: StructureCollection, - dataset: str, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = collection.take(1, at="start") - expected_length = len(first_structure[dataset]) - first_structure_data = next(iter(first_structure.objects())) - - first_input = __make_input(first_structure_data, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - if any(len(fv) != expected_length for fv in first_values.values()): - raise ValueError( - "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" - ) - return {name: [fv] for name, fv in first_values.items()} - - def __prepare_collection( spec: dict[str, Optional[list[str]]], collection: StructureCollection ) -> StructureCollection: diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index 56b69ca5..d5100968 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -21,6 +21,7 @@ from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import evaluate from opencosmo.collection.structure import io as sio +from opencosmo.dataset.formats import verify_format from opencosmo.index.unary import get_length from opencosmo.io.schema import FileEntry, make_schema @@ -597,8 +598,11 @@ def computation(halo_properties, dm_particles): collection contains galaxies. If False, simply return the data. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. **evaluate_kwargs: any, Any additional arguments that are required for your function to run. These will be passed directly @@ -623,8 +627,7 @@ def computation(halo_properties, dm_particles): **evaluate_kwargs, ) - if format not in ["astropy", "numpy"]: - raise ValueError(f"Invalid format requested for data: {format}") + verify_format(format) if dataset is not None and dataset.startswith("galaxies"): # Nested structure collection, special case @@ -701,10 +704,12 @@ def computation(halo_properties, dm_particles): ) if not insert or output is None: return output + from opencosmo.dataset.formats import to_numpy_dict + return self.with_new_columns( - **output, dataset=dataset if dataset is not None else self.__source.dtype, allow_overwrite=allow_overwrite, + **to_numpy_dict(output), # type: ignore ) def evaluate_on_dataset( @@ -751,8 +756,11 @@ def evaluate_on_dataset( Whether to provide the values as full columns (True) or one row at a time (False). Ignored if :code:`batch_size` is set. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. Unit information is preserved only when the function + returns astropy Quantities; outputs in other formats are stored without unit metadata. insert: bool, default = True If true, the data will be inserted as a column in this dataset. The new column will have the same name diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index 7066b2ef..cc591d53 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -260,7 +260,12 @@ def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False) return st.get_metadata(self.__state, columns, ignore_sort) def get_data( - self, format="astropy", unpack=True, metadata_columns=[], **kwargs + self, + format="astropy", + unpack=True, + metadata_columns=[], + wrap_single=False, + **kwargs, ) -> OpenCosmoData: """ Get the data in this dataset as an astropy table/column or as @@ -280,7 +285,9 @@ def get_data( If the dataset only contains a single column, it will not be put in a table or dictionary. "astropy", "numpy" and "arrow" will return a single array - in this case, while "polars" and "pandas" will return a Series object. + in this case, while "polars" and "pandas" will return a Series object. Pass + :code:`wrap_single=True` to always return the format's multi-column container + (QTable, DataFrame, dict, ...) regardless of column count. Parameters ---------- @@ -288,6 +295,10 @@ def get_data( The format to output the data in. Currently supported are "astropy", "numpy", "pandas", "polars", "arrow", "jax" + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. + Returns ------- data: Any @@ -321,7 +332,7 @@ def get_data( for key, value in data.items() } - return convert_data(data, format) + return convert_data(data, format, wrap_single=wrap_single) def bound(self, region: Region, select_by: Optional[str] = None): """ diff --git a/python/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py index 2ac5f711..ff70f72b 100644 --- a/python/opencosmo/dataset/formats.py +++ b/python/opencosmo/dataset/formats.py @@ -40,20 +40,29 @@ def __verify_import(import_name: str, format_name: str): ) -def convert_data(data: dict[str, np.ndarray], output_format: str): +def convert_data( + data: dict[str, np.ndarray], output_format: str, wrap_single: bool = False +): + """ + If `wrap_single` is True, the result is always the format's natural + multi-column container (QTable, DataFrame, dict[name, array], etc.) even + when there is only one column, instead of collapsing to a bare array / + Series. Used by callers (e.g. evaluate) that want a uniform + `container[colname]` access pattern regardless of column count. + """ match output_format: case "astropy": - return __convert_to_astropy(data) + return __convert_to_astropy(data, wrap_single) case "numpy": - return convert_to_numpy(data) + return convert_to_numpy(data, wrap_single) case "pandas": - return __convert_to_pandas(data) + return __convert_to_pandas(data, wrap_single) case "polars": - return __convert_to_polars(data) + return __convert_to_polars(data, wrap_single) case "arrow": - return __convert_to_arrow(data) + return __convert_to_arrow(data, wrap_single) case "jax": - return __convert_to_jax(data) + return __convert_to_jax(data, wrap_single) case _: raise ValueError(f"Unknown data output format {output_format}") @@ -196,8 +205,10 @@ def concat_chunks(chunks: list, output_format: str): raise ValueError(f"Unknown data output format {output_format}") -def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: - if len(data) == 1: +def __convert_to_astropy( + data: dict[str, np.ndarray], wrap_single: bool = False +) -> QTable: + if len(data) == 1 and not wrap_single: return next(iter(data.values())) if any( (isinstance(d, u.Quantity) and d.isscalar) or not isinstance(d, np.ndarray) @@ -210,6 +221,7 @@ def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: def convert_to_numpy( data: dict[str, np.ndarray], + wrap_single: bool = False, ) -> dict[str, np.ndarray] | np.ndarray: converted_data = dict( map( @@ -220,25 +232,25 @@ def convert_to_numpy( data.items(), ) ) - if len(converted_data) == 1: + if len(converted_data) == 1 and not wrap_single: return next(iter(converted_data.values())) return converted_data -def __convert_to_pandas(data: dict[str, np.ndarray]): +def __convert_to_pandas(data: dict[str, np.ndarray], wrap_single: bool = False): import pandas as pd - numpy_data = convert_to_numpy(data) - if isinstance(numpy_data, np.ndarray): # only one column + numpy_data = convert_to_numpy(data, wrap_single) + if isinstance(numpy_data, np.ndarray): # only one column, wrap_single=False return pd.Series(numpy_data, name=next(iter(data.keys()))) return pd.DataFrame(numpy_data, copy=True) -def __convert_to_arrow(data: dict[str, np.ndarray]): +def __convert_to_arrow(data: dict[str, np.ndarray], wrap_single: bool = False): import pyarrow as pa # type: ignore - numpy_data = convert_to_numpy(data) + numpy_data = convert_to_numpy(data, wrap_single) if isinstance(numpy_data, np.ndarray): return pa.array(numpy_data) @@ -249,20 +261,20 @@ def __convert_to_arrow(data: dict[str, np.ndarray]): return dict(converted_data) -def __convert_to_polars(data: dict[str, np.ndarray]): +def __convert_to_polars(data: dict[str, np.ndarray], wrap_single: bool = False): import polars as pl - numpy_data = convert_to_numpy(data) + numpy_data = convert_to_numpy(data, wrap_single) if isinstance(numpy_data, np.ndarray): return pl.Series(name=next(iter(data.keys())), values=numpy_data) return pl.from_dict(data) # type: ignore -def __convert_to_jax(data: dict[str, np.ndarray]): +def __convert_to_jax(data: dict[str, np.ndarray], wrap_single: bool = False): import jax.numpy as jnp - output_data = convert_to_numpy(data) + output_data = convert_to_numpy(data, wrap_single) if isinstance(output_data, np.ndarray): return jnp.asarray(output_data) return {key: jnp.asarray(value) for key, value in output_data.items()} diff --git a/test/test_evaluate_formats.py b/test/test_evaluate_formats.py index 4ba425e0..6b1363c9 100644 --- a/test/test_evaluate_formats.py +++ b/test/test_evaluate_formats.py @@ -183,3 +183,155 @@ def test_evaluate_noinsert_returns_native_container(input_path, format, expected _vectorized_func(format), vectorize=True, insert=False, format=format ) assert isinstance(result["fof_px"], expected_type) + + +# --------------------------------------------------------------------------- +# StructureCollection paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def halo_paths(snapshot_path): + files = ["haloproperties.hdf5", "haloparticles.hdf5"] + return [snapshot_path / f for f in files] + + +def _mean_x(format): + """Per-structure function: mean of dm_particles 'x' coord. Returns a scalar + in the user's format. fof_halo_center_x is a scalar Quantity that has had + its unit stripped for non-astropy formats, so it's plain float.""" + + def offset(halo_properties, dm_particles): + x = dm_particles["x"] + if format == "arrow": + mean_x = pc.mean(x).as_py() + elif format == "polars": + mean_x = x.mean() + elif format == "pandas": + mean_x = float(x.mean()) + elif format == "jax": + mean_x = float(jnp.mean(x)) + else: + mean_x = float(np.mean(x)) + return mean_x - float(halo_properties["fof_halo_center_x"]) + + return offset + + +def _arange_like(format): + """Per-structure function with dataset=`dm_particles`: must return an array + in the user's format with the same length as the input dataset.""" + + def particle_id(x, y, z): + n = len(x) + if format == "jax": + return jnp.arange(n) + if format == "pandas": + return pd.Series(np.arange(n)) + if format == "polars": + return pl.Series(values=np.arange(n)) + if format == "arrow": + return pa.array(np.arange(n)) + return np.arange(n) + + return particle_id + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + collection = collection.evaluate( + _mean_x(format), **spec, format=format, insert=True + ) + data = collection["halo_properties"].select("offset").get_data("numpy") + assert len(data) == 50 + assert np.any(data != 0) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties_noinsert(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + result = collection.evaluate(_mean_x(format), **spec, format=format, insert=False) + assert "offset" in result + assert len(result["offset"]) == 50 + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_dataset(halo_paths, format): + collection = oc.open(*halo_paths).take(20) + collection = collection.evaluate( + _arange_like(format), + dataset="dm_particles", + format=format, + insert=True, + ) + for halo in collection.halos(["dm_particles"]): + pid = halo["dm_particles"].select("particle_id").get_data("numpy") + assert np.all(pid == np.arange(len(pid))) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_on_dataset(halo_paths, format): + """Routes through Dataset.evaluate via the collection wrapper.""" + collection = oc.open(*halo_paths).take(50) + selected = ( + collection["halo_properties"] + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + collection = collection.evaluate_on_dataset( + _vectorized_func(format), + dataset="halo_properties", + vectorize=True, + format=format, + insert=True, + ) + data = collection["halo_properties"].select("fof_px").get_data("numpy") + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Lightcone paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def lc_paths(lightcone_path): + return [ + lightcone_path / "step_600" / "haloproperties.hdf5", + lightcone_path / "step_601" / "haloproperties.hdf5", + ] + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_insert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + for name in ds.keys(): + data = ds[name].select("fof_px").get_data("numpy") + original = ( + ds[name].select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + ) + expected = original["fof_halo_mass"] * original["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_noinsert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert "fof_px" in result + assert len(result["fof_px"]) == len(ds) From df709ec9dfee70b63840c1b4596965e448f565e5 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:24:31 -0500 Subject: [PATCH 137/139] Update docstrings to better reflect behavior --- python/opencosmo/collection/lightcone/lightcone.py | 3 +-- python/opencosmo/collection/simulation/simulation.py | 8 +++++--- python/opencosmo/collection/structure/structure.py | 6 ++---- python/opencosmo/dataset/dataset.py | 3 +-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py index 98766124..94be05ba 100644 --- a/python/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -773,8 +773,7 @@ def evaluate( The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. vectorize: bool, default = False Whether to provide the values as full columns (True) or one row at a time (False) diff --git a/python/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py index e3284c16..30d04ef5 100644 --- a/python/opencosmo/collection/simulation/simulation.py +++ b/python/opencosmo/collection/simulation/simulation.py @@ -384,9 +384,11 @@ def evaluate( datasets: str | list[str], optional The datasets to evaluate on. If not provided, will be evaluated on all datasets format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. + vectorize: bool, default = False Whether to vectorize the computation. See :py:meth:`StructureCollection.evaluate ` and/or :py:meth:`Dataset.evaluate ` for more details. diff --git a/python/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py index d5100968..26c4430c 100644 --- a/python/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -601,8 +601,7 @@ def computation(halo_properties, dm_particles): The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. **evaluate_kwargs: any, Any additional arguments that are required for your function to run. These will be passed directly @@ -759,8 +758,7 @@ def evaluate_on_dataset( The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. insert: bool, default = True If true, the data will be inserted as a column in this dataset. The new column will have the same name diff --git a/python/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py index cc591d53..4dcba8e0 100644 --- a/python/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -492,8 +492,7 @@ def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): The format in which to provide column data to your function. Supports the same formats as :py:meth:`get_data ` ("astropy", "numpy", "pandas", "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted - back to numpy before being stored. Unit information is preserved only when the function - returns astropy Quantities; outputs in other formats are stored without unit metadata. + back to numpy before being stored. allow_overwrite: bool, default = False batch_size: int, default = -1 From a8edab8f78b0633288594a0f6e414b2bcc85d7aa Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:27:18 -0500 Subject: [PATCH 138/139] Add changelog --- changes/+95437205.feature.rst | 1 + changes/+c414feda.feature.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changes/+95437205.feature.rst create mode 100644 changes/+c414feda.feature.rst diff --git a/changes/+95437205.feature.rst b/changes/+95437205.feature.rst new file mode 100644 index 00000000..f59c1e7f --- /dev/null +++ b/changes/+95437205.feature.rst @@ -0,0 +1 @@ +:py:meth:`get_data ` now supports :code:`jax` as an output format. diff --git a/changes/+c414feda.feature.rst b/changes/+c414feda.feature.rst new file mode 100644 index 00000000..d0e6e288 --- /dev/null +++ b/changes/+c414feda.feature.rst @@ -0,0 +1 @@ +All :code:`evaluate` methods (e.g. :py:meth:`Dataset.evaluate `) now support passing data to the function in any format supported by :py:meth:`get_data `. From 7697925b36bd89a034e2452b5009024b3c1708ee Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 27 May 2026 09:39:31 -0500 Subject: [PATCH 139/139] Delete a bunch of dead code --- python/opencosmo/dataset/evaluate.py | 85 +--------------------------- python/opencosmo/evaluate.py | 48 ---------------- 2 files changed, 1 insertion(+), 132 deletions(-) delete mode 100644 python/opencosmo/evaluate.py diff --git a/python/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py index e2f84210..2f14e24c 100644 --- a/python/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -2,8 +2,7 @@ from collections import defaultdict from inspect import Parameter, signature -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np from astropy.units import Quantity @@ -11,10 +10,6 @@ from opencosmo.column.column import EvaluatedColumn from opencosmo.column.evaluate import EvaluateStrategy, do_first_evaluation from opencosmo.dataset.formats import concat_chunks, fetch_as_dict -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) if TYPE_CHECKING: from opencosmo import Dataset @@ -165,84 +160,6 @@ def verify_for_lazy_evaluation( return column -def __visit_rows_in_dataset( - function: Callable, - dataset: Dataset, - format: str, - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, -): - first_row_values = dict(dataset.take(1, at="start").get_data()) - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - storage = __make_output(function, first_row_values | first_row_kwargs, len(dataset)) - for i, row in enumerate(dataset.rows(include_units=format == "astropy")): - if i == 0: - continue - iter_kwargs = {name: arr[i] for name, arr in iterable_kwargs.items()} - output = function(**row, **kwargs, **iter_kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __visit_rows_in_data( - function: Callable, - data: dict[str, np.ndarray], - format="astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, np.ndarray] = {}, -): - data = {key: d for key, d in data.items() if key in signature(function).parameters} - first_row_data = {name: arr[0] for name, arr in data.items()} - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - n_rows = len(next(iter(data.values()))) - storage = __make_output(function, first_row_data | first_row_kwargs, n_rows) - if format == "numpy": - data = { - key: arr.value if isinstance(arr, Quantity) else arr - for key, arr in data.items() - } - - for i in range(1, n_rows): - row = { - name: arr[i] for name, arr in chain(data.items(), iterable_kwargs.items()) - } - output = function(**row, **kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __make_output( - function: Callable, - first_input_values: dict[str, Any], - n_rows: int, -) -> dict | None: - first_values = function(**first_input_values) - if first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - - return make_output_from_first_values(first_values, n_rows) - - -def __visit_vectorize( - function: Callable, - data: dict[str, Iterable] | Iterable, - evaluator_kwargs: dict[str, Any] = {}, -): - pars = signature(function).parameters - - if not isinstance(data, dict) or (len(data) > 1 and len(pars) == 1): - return function(data, **evaluator_kwargs) - - input_data = {pname: data[pname] for pname in pars if pname in data} - - return function(**input_data, **evaluator_kwargs) - - def __verify( function: Callable, data_columns: Iterable[str], kwarg_names: Iterable[str] ): diff --git a/python/opencosmo/evaluate.py b/python/opencosmo/evaluate.py deleted file mode 100644 index 74db256d..00000000 --- a/python/opencosmo/evaluate.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Any - -import astropy.units as u -import numpy as np - -""" -General helper routines for evaluating expressions on datasets and collections -""" - - -def insert_data( - storage: dict[str, np.ndarray], index: int, values_to_insert: dict[str, Any] -): - if isinstance(values_to_insert, dict): - for name, value in values_to_insert.items(): - storage[name][index] = value - return storage - - name = next(iter(storage.keys())) - storage[name][index] = values_to_insert - - -def make_output_from_first_values(first_values: dict, n_rows: int): - storage = {} - new_first_values = {} - for name, value in first_values.items(): - shape: tuple[int, ...] = (n_rows,) - dtype = type(value) - if not isinstance(value, np.ndarray): - new_first_values[name] = value - elif isinstance(value, u.Quantity) and value.isscalar: - dtype = value.value.dtype - new_first_values[name] = value - elif isinstance(value, np.ndarray) and len(value) == 1: - dtype = value.dtype - new_first_values[name] = value[0] - else: - dtype = value.dtype - shape = shape + value.shape - new_first_values[name] = value - - storage[name] = np.zeros(shape, dtype=dtype) - for name, value in new_first_values.items(): - if isinstance(value, u.Quantity): - storage[name] = storage[name] * value.unit - - storage[name][0] = value - return storage