diff --git a/.bumpversion.toml b/.bumpversion.toml index b5c93c25..0fa58ebf 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,7 +1,10 @@ [tool.bumpversion] -current_version = "1.2.5" -parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" -serialize = ["{major}.{minor}.{patch}"] +current_version = "1.2.6" +parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)[-.]?(?Pa|b|rc)?(?P\\d+)?" +serialize = [ + "{major}.{minor}.{patch}-{pre_l}{pre_n}", + "{major}.{minor}.{patch}", +] search = "{current_version}" replace = "{new_version}" regex = false @@ -21,11 +24,21 @@ pre_commit_hooks = [] post_commit_hooks = [] allow_shell_hooks = true +[tool.bumpversion.parts.pre_l] +optional_value = "final" +values = ["a", "b", "rc", "final"] + +[tool.bumpversion.parts.pre_n] +first_value = "1" + [[tool.bumpversion.files]] -filename = "src/opencosmo/__init__.py" +filename = "python/opencosmo/__init__.py" [[tool.bumpversion.files]] filename = "pyproject.toml" [[tool.bumpversion.files]] filename = "docs/source/conf.py" + +[[tool.bumpversion.files]] +filename = "Cargo.toml" diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3aff91fe..1d31beb3 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,5 +1,5 @@ name: build dev -on: +on: workflow_call: inputs: push-docker: @@ -14,34 +14,37 @@ on: required: false jobs: - build-wheel: - runs-on: ubuntu-latest + build-wheels: + name: Build wheels (${{ matrix.target }}) + strategy: + matrix: + include: + - { os: ubuntu-latest, target: x86_64 } + - { os: ubuntu-latest, target: aarch64 } + - { os: macos-latest, target: x86_64-apple-darwin } + - { os: macos-latest, target: aarch64-apple-darwin } + runs-on: ${{ matrix.os }} steps: - name: Checkout repository uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + if: matrix.os != 'ubuntu-latest' + uses: actions/setup-python@v5 with: - python-version: '3.11' - - name: Install uv - uses: astral-sh/setup-uv@v6 + python-version: | + 3.12 + 3.13 + 3.14 + - name: Build wheels + uses: PyO3/maturin-action@v1 with: - version: "0.8.2" - - name: Build package - run: uv build - - build-container: - runs-on: ubuntu-latest - steps: - - id: login - if: ${{ inputs.push-docker }} - uses: docker/login-action@v3 + target: ${{ matrix.target }} + manylinux: auto + args: --release --out dist --interpreter 3.12 3.13 3.14 + - name: Upload wheels + uses: actions/upload-artifact@v4 with: - username: ${{ secrets.docker-username }} - password: ${{ secrets.docker-access-key }} - - uses: docker/setup-buildx-action@v3 - - uses: docker/bake-action@v6 - with: - push: ${{ inputs.push-docker }} - targets: dev + name: wheels-${{ matrix.target }} + path: dist/ + diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 55fbfc80..b8c79b1d 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -34,6 +34,6 @@ jobs: run: uv run ruff check . - name: Run mypy type checker - run: uv run mypy src/opencosmo + run: uv run mypy python/opencosmo diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index f0d67eec..5b18e27e 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -8,59 +8,150 @@ on: type: string jobs: - release: + prepare: runs-on: ubuntu-latest permissions: contents: write + outputs: + prerelease: ${{ steps.check-version.outputs.prerelease }} steps: + - name: Check if pre-release + id: check-version + run: | + if echo "${{ inputs.version }}" | grep -qE '(a|b|rc|alpha|beta|dev)[0-9]*$'; then + echo "prerelease=true" >> "$GITHUB_OUTPUT" + echo "ref=main" >> "$GITHUB_OUTPUT" + else + echo "prerelease=false" >> "$GITHUB_OUTPUT" + echo "ref=release" >> "$GITHUB_OUTPUT" + fi + - uses: actions/checkout@v4 with: fetch-depth: 0 token: ${{ secrets.RELEASE_PAT }} - ref: 'release' + ref: ${{ steps.check-version.outputs.ref }} - name: Install uv uses: astral-sh/setup-uv@v6 with: version: "0.8.2" - - name: Set version - run: echo "VERSION=${{ inputs.version }}" >> $GITHUB_ENV - - name: Bump version files - # Uses --new-version so you don't need to specify major/minor/patch. - # commit=false and tag=false are already set in .bumpversion.toml so - # bump-my-version only updates the files. - run: uvx bump-my-version bump --new-version ${{ env.VERSION }} + run: uvx bump-my-version bump --new-version ${{ inputs.version }} - name: Draft changelog (captured for GitHub release body) - run: uv run towncrier build --draft --version ${{ env.VERSION }} > release_notes.rst + run: uv run towncrier build --draft --version ${{ inputs.version }} > release_notes.rst - name: Build changelog - # Writes to docs/source/changelog.rst and removes news fragments in changes/ - run: uv run towncrier build --yes --version ${{ env.VERSION }} + if: steps.check-version.outputs.prerelease == 'false' + run: uv run towncrier build --yes --version ${{ inputs.version }} - name: Commit, tag, and push release branch + if: steps.check-version.outputs.prerelease == 'false' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" git checkout -B release git add --all - git commit -m "Release ${{ env.VERSION }}" - git tag ${{ env.VERSION }} + git commit -m "Release ${{ inputs.version }}" + git tag ${{ inputs.version }} git push origin release --force - git push origin ${{ env.VERSION }} + git push origin ${{ inputs.version }} - - name: Build package - run: uv build + - name: Tag pre-release + if: steps.check-version.outputs.prerelease == 'true' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add --all + git commit -m "Release ${{ inputs.version }}" + git tag ${{ inputs.version }} + git push origin ${{ inputs.version }} + + - name: Upload release notes + uses: actions/upload-artifact@v4 + with: + name: release-notes + path: release_notes.rst + + build-wheels: + needs: prepare + name: Build wheels (${{ matrix.target }}) + strategy: + matrix: + include: + - { os: ubuntu-latest, target: x86_64 } + - { os: ubuntu-latest, target: aarch64 } + - { os: macos-latest, target: x86_64-apple-darwin } + - { os: macos-latest, target: aarch64-apple-darwin } + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.version }} + - name: Set up Python + if: matrix.os != 'ubuntu-latest' + uses: actions/setup-python@v5 + with: + python-version: | + 3.12 + 3.13 + 3.14 + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + manylinux: auto + args: --release --out dist --interpreter 3.12 3.13 3.14 + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.target }} + path: dist/ + + publish: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist/ + merge-multiple: true + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + version: "0.8.2" - name: Publish to PyPI run: uv publish --token ${{ secrets.PYPI_TOKEN }} + github-release: + needs: [publish, prepare] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist/ + merge-multiple: true + + - name: Download release notes + uses: actions/download-artifact@v4 + with: + name: release-notes + - name: Create GitHub release uses: ncipollo/release-action@v1 with: - tag: ${{ env.VERSION }} + tag: ${{ inputs.version }} bodyFile: release_notes.rst artifacts: dist/* - makeLatest: true + prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }} + makeLatest: ${{ needs.prepare.outputs.prerelease == 'false' }} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 25311478..d318e67a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,12 +4,6 @@ jobs: get-test-data: runs-on: ubuntu-latest steps: - - name: Setup AWS CLI - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.TEST_DATA_ACCESS_KEY }} - aws-secret-access-key: ${{ secrets.TEST_DATA_SECRET_KEY }} - aws-region: us-west-2 - name: check if cache exists id: check-cache uses: actions/cache@v4 @@ -19,6 +13,13 @@ jobs: lookup-only: true restore-keys: | test-data + - name: Setup AWS CLI + if: steps.check-cache.outputs.cache-hit != 'true' + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.TEST_DATA_ACCESS_KEY }} + aws-secret-access-key: ${{ secrets.TEST_DATA_SECRET_KEY }} + aws-region: us-west-2 - name: Download test data if: steps.check-cache.outputs.cache-hit != 'true' run: aws s3 cp s3://${{ secrets.TEST_DATA_BUCKET }}/test_data.tar.gz test_data.tar.gz diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 66f3ea6a..ebdaa2a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,15 +14,20 @@ If you find that the documentation is not quite up to par, feel free to create a In addition to raising issues on the repository or adding documentation, we welcome PRs that fix bugs, improve performance, or add new features. The remainder of this document describes how to go about setting up a development environment and prepping your PR for merging. -### Step 0: Install UV +### Step 0: Install UV and Build Dependencies We use [uv](https://docs.astral.sh/uv/) to manage dependencies and provide a consistent execution environment for packaging and testing. If you do not already have uv installed on your machine, [follow the documentation](https://docs.astral.sh/uv/getting-started/installation/) for your system. +Becuase `opencosmo` includes some Rust extension modules, it uses [maturin](https://github.com/pyo3/maturin) as its build system and requires a Rust compiler. `uv` should install `maturin`, which is capable of bootstrapping itself. This should all happen automatically in step 2 below, but if for some reason this does not work you may need [install a rust compiler manually](https://rust-lang.org/tools/install/) + + ### Step 1: Create a Fork of the OpenCosmo Repo and Clone You should preform your development in your own personal fork of the repository. You can create one [at this link](https://github.com/ArgonneCPAC/OpenCosmo/fork). Once you have a fork, you can clone it to the machine you will be doing your development on. -### Step 2: Create a Virtual Environment with UV +If you're only planning to make a single contribution, it's fine to commit to the main branch of your fork. If you intend to make many contributions, we recommend creating a branch for each feature. + +### Step 2: Create a Virtual Environment with UV and Install Dependencies In addition to managing dependencies, uv manages a virtual environment with the appropriate versions of all packages. From the root of the repository, run the command: @@ -32,6 +37,8 @@ uv sync This will create a virtual environment in the repository and install all necessary dependencies, as well as extra dependencies required for development work. +As part of the initial `uv sync`, `maturin` will build the rust extension modules and install them in the correct place. This may take some time, but unless you are modifying the rust code it will only have to happen once. + ### Step 2.1: Install Parallel HDF5 (Optional) If you plan to work on features that involve parallel I/O, you will need to install parallel HDF5 to run parallel tests. To start, install the additional packages necessary for developing and testing parallel features @@ -49,7 +56,9 @@ Note that running with uv is required to ensure that the parallel version of HDF ### Step 3: Add a Commit & Create a PR -I know what you're thinking: I haven't actually implemented my changes yet. Why am I already submitting a PR? Two reasons. First, so we know what people are working on so we're not doing duplicate work. Second, so that your PR starts running through the CI pipeline. +I know what you're thinking: I haven't actually implemented my changes yet. Why am I already submitting a PR? Two reasons. First, so we know what people are working on so we're not doing duplicate work. Second, so that your PR starts running through the CI pipeline. This will make it easier down the line! + +For more details of how to do all this, see the [GitHub documentation](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) ### Step 4: Implement Your Changes @@ -59,6 +68,8 @@ This is where the magic happens. Implement your features! If you have questions, Our CI pipeline performs linting with [ruff](https://astral.sh/ruff) and static type checking with [mypy](https://www.mypy-lang.org/). Both should have been installed automatically when you ran `uv sync`. Your PR must pass both to be merged. If you prefer, you can use [pre-commit](https://pre-commit.com/) to automatically run linting and type checking before you commit. Altenatively, you can call them manually on your last commit. +If you're unfamiliar with type hints, go ahead and implement your features without them and we can worry about the details later. + Many libraries do not have full typing support. If this is the case, you can add a `# type: ignore` directive when you import them. If the type stubs exist as a seperate library, you should instead add them as a development dependency in the pyprojet.toml with ```bash @@ -79,7 +90,7 @@ uv run pytest --ignore=test/parallel If you are working on features designed to be used in an MPI context, you can run the parallel tests with: ```bash -uv run mpiexec -n 4 pytest -m parallel test/parallel -x +uv run mpiexec -n 4 pytest -m parallel test/parallel ``` All tests must pass for your PR to be merged. diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..c8611404 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,229 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "libc" +version = "0.2.184" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "numpy" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "opencosmo" +version = "1.2.6" +dependencies = [ + "numpy", + "pyo3", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12" +dependencies = [ + "libc", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", +] + +[[package]] +name = "pyo3-build-config" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..64325640 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "opencosmo" +version = "1.2.6" +edition = "2021" + +[lib] +name = "_lib" +# "cdylib" is necessary to produce a shared library for Python to import from. +crate-type = ["cdylib"] + +[dependencies] +numpy = "0.28" +pyo3 = "0.28.2" + diff --git a/changes/+018f30cb.bugfix.rst b/changes/+018f30cb.bugfix.rst new file mode 100644 index 00000000..b236b155 --- /dev/null +++ b/changes/+018f30cb.bugfix.rst @@ -0,0 +1,2 @@ +HealpixMap now correctly unpacks data when there is only a single map, instead of returning a dictionary. Mirrors +behavior in Dataset etc. diff --git a/changes/+07fe4e9.feature.rst b/changes/+07fe4e9.feature.rst new file mode 100644 index 00000000..0e8debe3 --- /dev/null +++ b/changes/+07fe4e9.feature.rst @@ -0,0 +1 @@ +Added animate_halos function, which calls either "visualize_halo" or "halo_projection_array" in a loop to create an animated visualization of a given halo or set of halos. diff --git a/changes/+19969705.misc.rst b/changes/+19969705.misc.rst new file mode 100644 index 00000000..db75a8f2 --- /dev/null +++ b/changes/+19969705.misc.rst @@ -0,0 +1 @@ +The DatasetState object has been broken up to allow for more flexibility in instantiating datasets, laying the groundwork for future optimizations. diff --git a/changes/+21d9d2f3.improvement.rst b/changes/+21d9d2f3.improvement.rst new file mode 100644 index 00000000..c245a55f --- /dev/null +++ b/changes/+21d9d2f3.improvement.rst @@ -0,0 +1,3 @@ +The OpenCosmo internals now include a basic plugin architecture, which allows for specific data types to introduce +modified behavior at particular points. Currently being used to dynamically re-build the `top_host_idx` column in +diffsky data. diff --git a/changes/+31323023.improvement.rst b/changes/+31323023.improvement.rst new file mode 100644 index 00000000..0e7104d7 --- /dev/null +++ b/changes/+31323023.improvement.rst @@ -0,0 +1 @@ +You can now request HealpixMaps in Healpix format even when the map does not cover the full sky. Maps requested in this format will be returned as masked numpy arrays. diff --git a/changes/+342c1e1e.feature.rst b/changes/+342c1e1e.feature.rst new file mode 100644 index 00000000..04d44b02 --- /dev/null +++ b/changes/+342c1e1e.feature.rst @@ -0,0 +1 @@ +The :py:class:`StructureCollection ` now has a more complete string representation diff --git a/changes/+5661c945.bugfix.rst b/changes/+5661c945.bugfix.rst new file mode 100644 index 00000000..cd158d54 --- /dev/null +++ b/changes/+5661c945.bugfix.rst @@ -0,0 +1 @@ +Fix a bug that could cause renamed columns to be instantiated without the correct units diff --git a/changes/+66861282.feature.rst b/changes/+66861282.feature.rst new file mode 100644 index 00000000..b0a1eac9 --- /dev/null +++ b/changes/+66861282.feature.rst @@ -0,0 +1,2 @@ +Diffsky catalog can be forced to keep groups during filtering etc. by setting the `get_top_host = True` flag when +opening. diff --git a/changes/+6f102bf6.feature.rst b/changes/+6f102bf6.feature.rst new file mode 100644 index 00000000..85ffc520 --- /dev/null +++ b/changes/+6f102bf6.feature.rst @@ -0,0 +1 @@ +Add the :py:meth:`Lightcone.get_pixels() ` object now supports multi-step lightcones. diff --git a/changes/+89192e73.bugfix.rst b/changes/+89192e73.bugfix.rst new file mode 100644 index 00000000..96750606 --- /dev/null +++ b/changes/+89192e73.bugfix.rst @@ -0,0 +1 @@ +Fixed a bug that could cause column arithmetic to fail with scalars. diff --git a/changes/+91ec5e88.improvement.rst b/changes/+91ec5e88.improvement.rst new file mode 100644 index 00000000..c8f87405 --- /dev/null +++ b/changes/+91ec5e88.improvement.rst @@ -0,0 +1 @@ +Conversion to healsparse in :py:meth:`HealpixMap.get_data ` now supports :code:`jax` as an output format. diff --git a/changes/+bfffa1ad.feature.rst b/changes/+bfffa1ad.feature.rst new file mode 100644 index 00000000..f0ffb6ef --- /dev/null +++ b/changes/+bfffa1ad.feature.rst @@ -0,0 +1 @@ +:py:meth:`Dataset.filter ` now accepts masks created from expressions built from column arithmetic. diff --git a/changes/+c414feda.feature.rst b/changes/+c414feda.feature.rst new file mode 100644 index 00000000..d0e6e288 --- /dev/null +++ b/changes/+c414feda.feature.rst @@ -0,0 +1 @@ +All :code:`evaluate` methods (e.g. :py:meth:`Dataset.evaluate `) now support passing data to the function in any format supported by :py:meth:`get_data `. diff --git a/changes/+caf6ec12.improvement.rst b/changes/+caf6ec12.improvement.rst new file mode 100644 index 00000000..ef170bf4 --- /dev/null +++ b/changes/+caf6ec12.improvement.rst @@ -0,0 +1 @@ +Dataset instantiation and backend process has been reworked to allow for dynamic column updating. diff --git a/changes/+e13be920.feature.rst b/changes/+e13be920.feature.rst new file mode 100644 index 00000000..489d8367 --- /dev/null +++ b/changes/+e13be920.feature.rst @@ -0,0 +1 @@ +Add the :py:meth:`Lightcone.query_pixels ` method to query a lightcone based on healpix pixels. diff --git a/changes/+e334af41.feature.rst b/changes/+e334af41.feature.rst new file mode 100644 index 00000000..3c60cba0 --- /dev/null +++ b/changes/+e334af41.feature.rst @@ -0,0 +1 @@ +Calls to :py:meth:`with_new_columns ` and :py:meth:`evaluate ` now accept an `allow_overwrite` flag. In this way you can "transform" a column by creating a derived column that depends on a column of the same name. The input will be the original column, and the output will be the new version. diff --git a/changes/240.bugfix.rst b/changes/240.bugfix.rst new file mode 100644 index 00000000..958c8149 --- /dev/null +++ b/changes/240.bugfix.rst @@ -0,0 +1 @@ +Requesting more rows than exist via :py:meth:`take ` or :py:meth:`take_range ` no longer raises a ``ValueError``. Instead, all available rows are returned. This applies to :py:class:`Dataset `, :py:class:`Lightcone `, and :py:class:`StructureCollection `. diff --git a/changes/240.feature.rst b/changes/240.feature.rst new file mode 100644 index 00000000..0ef4ab54 --- /dev/null +++ b/changes/240.feature.rst @@ -0,0 +1 @@ +:py:meth:`take `, :py:meth:`take_range `, and their equivalents on :py:class:`Lightcone ` and :py:class:`StructureCollection ` now accept a ``mode`` keyword argument. Setting ``mode="global"`` when running under MPI causes ``n`` (or ``start``/``end``) to be interpreted across all ranks combined rather than per-rank. When the dataset is sorted, ranks coordinate to select from the globally-sorted order, so ``ds.sort_by("fof_halo_mass").take(1000, mode="global")`` returns exactly the 1000 most massive halos distributed across all ranks. diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst index d1c96588..c11f6c9f 100644 --- a/docs/source/analysis.rst +++ b/docs/source/analysis.rst @@ -105,7 +105,7 @@ The two primary functions for this purpose are: - :func:`opencosmo.analysis.visualize_halo` — a simple 2x2 panel plot for one halo - :func:`opencosmo.analysis.halo_projection_array` — a customizable grid of halos and fields -These use yt under the hood, and are useful for visually inspecting halos with minimal input required. +These use yt under the hood, and are useful for visually inspecting halos with minimal input required. Animated versions of the visualizations outputted by either of these functions can be made using :func:`opencosmo.analysis.animate_halos`. Quick Projections diff --git a/docs/source/analysis_ref.rst b/docs/source/analysis_ref.rst index 934d9b17..761381a8 100644 --- a/docs/source/analysis_ref.rst +++ b/docs/source/analysis_ref.rst @@ -9,6 +9,8 @@ Analysis .. autofunction:: opencosmo.analysis.halo_projection_array +.. autofunction:: opencosmo.analysis.animate_halos + .. autofunction:: opencosmo.analysis.ParticleProjectionPlot .. autofunction:: opencosmo.analysis.ProjectionPlot diff --git a/docs/source/parameters_ref.rst b/docs/source/parameters_ref.rst index c4074a25..9f78f0f5 100644 --- a/docs/source/parameters_ref.rst +++ b/docs/source/parameters_ref.rst @@ -17,7 +17,7 @@ Cosmology Most OpenCosmo files will contain cosmology parameters, which describe the cosmology the simulation was run under. In general you will not interact with this parameter block directly. Instead, requiresting it will return an astropy.cosmology.Cosmology object. Dataset and collections will generally make this object available directly with the :py:attr:`.cosmology ` attribute. -.. autoclass:: opencosmo.parameters.cosmology.CosmologyParameters +.. autoclass:: opencosmo.dtypes.cosmology.CosmologyParameters :members: :undoc-members: :exclude-members: model_config, ACCESS_PATH, ACCESS_TRANSFORMATION @@ -28,14 +28,14 @@ Simulation Parameters Data that was originally produced by HACC will contain the parameters that were used to initialize the simulation. Datasets and collections will generally make these paramters available with the :py:attr:`.simulation ` attribute. -.. autoclass:: opencosmo.parameters.hacc.HaccSimulationParameters +.. autoclass:: opencosmo.dtypes.hacc.HaccSimulationParameters :members: :undoc-members: :exclude-members: model_config,empty_string_to_none,cosmology_parameters,ACCESS_PATH :member-order: bysource -.. autoclass:: opencosmo.parameters.hacc.HaccHydroSimulationParameters +.. autoclass:: opencosmo.dtypes.hacc.HaccHydroSimulationParameters :members: :undoc-members: :exclude-members: model_config diff --git a/pyproject.toml b/pyproject.toml index cc5966b1..f12d3c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,13 @@ dependencies = [ "deprecated>=1.2.58,<2.0.0", "numpy>=2.0,<2.5", "click (>=8.2.1,<9.0.0)", - "numba>=0.64.0,<1.0", "rustworkx>=0.17.1,<1.0", ] +[project.urls] +Repository = "https://github.com/ArgonneCPAC/OpenCosmo" +Documentation = "https://opencosmo.readthedocs.io/en/stable/" + [dependency-groups] dev = [ "ruff>=0.9.4,<1.0.0", @@ -36,6 +39,7 @@ dev = [ "pip>=25.1.1", "pytest>=8.3.4,<9.0.0", "pytest-timeout>=2.4.0", + "jax>=0.10.1", ] docs = [ "sphinx>=9.0.0", @@ -56,6 +60,7 @@ test = [ test-mpi = [ "mpi-pytest>=2025.4.0,<2026.0.0" ] + [project.scripts] opencosmo = "opencosmo.analysis.cli:cli" @@ -65,34 +70,27 @@ io = [ ] [build-system] -requires = ["uv_build>=0.8.2,<0.9.0", "pip"] -build-backend = "uv_build" +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" -[tool.uv.build-backend] -source-include = ["LICENSE.md"] +[tool.maturin] +module-name = "opencosmo._lib" +include = ["LICENSE.md"] +python-source = "python" +[tool.uv] +cache-keys = [{file = "pyproject.toml"}, {file = "Cargo.toml"}, {file = "**/*.rs"}] [[tool.mypy.overrides]] module = ["h5py", "astropy.*", "healpy", "healsparse", "numba"] ignore_missing_imports = true -[tool.towncrier] -directory = "changes" -package = "opencosmo" -name = "opencosmo" -filename = "docs/source/changelog.rst" [tool.pytest.ini_options] testpaths = ["test"] log_cli = true log_cli_level = "INFO" -[tool.pyrefly] -project-includes = [ - "**/*.py*", -] -replace-imports-with-any = ["astropy.units", "astropy.constants"] - [lint.pycodestyle] max-doc-length = 90 @@ -101,6 +99,11 @@ future-annotations = true preview = false extend-select = ["TC"] +[tool.towncrier] +directory = "changes" +package = "opencosmo" +name = "opencosmo" +filename = "docs/source/changelog.rst" [tool.towncrier.fragment.improvement] name = "Improvements" diff --git a/src/opencosmo/__init__.py b/python/opencosmo/__init__.py similarity index 100% rename from src/opencosmo/__init__.py rename to python/opencosmo/__init__.py diff --git a/python/opencosmo/_lib/__init__.pyi b/python/opencosmo/_lib/__init__.pyi new file mode 100644 index 00000000..c417140d --- /dev/null +++ b/python/opencosmo/_lib/__init__.pyi @@ -0,0 +1 @@ +from . import index as index diff --git a/python/opencosmo/_lib/index.pyi b/python/opencosmo/_lib/index.pyi new file mode 100644 index 00000000..43f3c11b --- /dev/null +++ b/python/opencosmo/_lib/index.pyi @@ -0,0 +1,29 @@ +from opencosmo.index import IndexArray + +def take_chunked_from_chunked( + start: IndexArray, size: IndexArray, take_start: IndexArray, take_size: IndexArray +) -> tuple[IndexArray, IndexArray]: ... +def get_simple_range(index: IndexArray) -> tuple[int, int]: ... +def get_chunked_range(start: IndexArray, size: IndexArray) -> tuple[int, int]: ... +def n_in_range_chunked( + start: IndexArray, size: IndexArray, range_start: IndexArray, range_size: IndexArray +) -> IndexArray: ... +def chunked_into_array(start: IndexArray, size: IndexArray) -> IndexArray: ... +def take_chunked_from_simple( + simple: IndexArray, start: IndexArray, size: IndexArray +) -> IndexArray: ... +def reindex_column(index: IndexArray, index_column: IndexArray) -> IndexArray: ... +def rebuild_simple_by_ranges( + index: IndexArray, starts: IndexArray, sizes: IndexArray +) -> IndexArray: ... +def rebuild_chunked_by_ranges( + starts: IndexArray, + sizes: IndexArray, + chunk_starts: IndexArray, + chunk_sizes: IndexArray, +) -> IndexArray: ... +def project_chunked_on_simple( + simple: IndexArray, + chunk_starts: IndexArray, + chunk_sizes: IndexArray, +) -> IndexArray: ... diff --git a/src/opencosmo/analysis/__init__.py b/python/opencosmo/analysis/__init__.py similarity index 98% rename from src/opencosmo/analysis/__init__.py rename to python/opencosmo/analysis/__init__.py index e1c7f928..66de110f 100644 --- a/src/opencosmo/analysis/__init__.py +++ b/python/opencosmo/analysis/__init__.py @@ -13,6 +13,7 @@ "PhasePlot", "visualize_halo", "halo_projection_array", + "animate_halos", ] diff --git a/src/opencosmo/analysis/cli.py b/python/opencosmo/analysis/cli.py similarity index 100% rename from src/opencosmo/analysis/cli.py rename to python/opencosmo/analysis/cli.py diff --git a/src/opencosmo/analysis/diffsky.py b/python/opencosmo/analysis/diffsky.py similarity index 100% rename from src/opencosmo/analysis/diffsky.py rename to python/opencosmo/analysis/diffsky.py diff --git a/src/opencosmo/analysis/install/__init__.py b/python/opencosmo/analysis/install/__init__.py similarity index 100% rename from src/opencosmo/analysis/install/__init__.py rename to python/opencosmo/analysis/install/__init__.py diff --git a/src/opencosmo/analysis/install/install.py b/python/opencosmo/analysis/install/install.py similarity index 100% rename from src/opencosmo/analysis/install/install.py rename to python/opencosmo/analysis/install/install.py diff --git a/src/opencosmo/analysis/install/source.py b/python/opencosmo/analysis/install/source.py similarity index 100% rename from src/opencosmo/analysis/install/source.py rename to python/opencosmo/analysis/install/source.py diff --git a/src/opencosmo/analysis/install/specs/__init__.py b/python/opencosmo/analysis/install/specs/__init__.py similarity index 100% rename from src/opencosmo/analysis/install/specs/__init__.py rename to python/opencosmo/analysis/install/specs/__init__.py diff --git a/src/opencosmo/analysis/install/specs/diffsky.json b/python/opencosmo/analysis/install/specs/diffsky.json similarity index 100% rename from src/opencosmo/analysis/install/specs/diffsky.json rename to python/opencosmo/analysis/install/specs/diffsky.json diff --git a/src/opencosmo/analysis/install/specs/vizualize.json b/python/opencosmo/analysis/install/specs/vizualize.json similarity index 100% rename from src/opencosmo/analysis/install/specs/vizualize.json rename to python/opencosmo/analysis/install/specs/vizualize.json diff --git a/src/opencosmo/analysis/install/versions.py b/python/opencosmo/analysis/install/versions.py similarity index 100% rename from src/opencosmo/analysis/install/versions.py rename to python/opencosmo/analysis/install/versions.py diff --git a/src/opencosmo/analysis/mpi.py b/python/opencosmo/analysis/mpi.py similarity index 100% rename from src/opencosmo/analysis/mpi.py rename to python/opencosmo/analysis/mpi.py diff --git a/src/opencosmo/analysis/yt_utils.py b/python/opencosmo/analysis/yt_utils.py similarity index 97% rename from src/opencosmo/analysis/yt_utils.py rename to python/opencosmo/analysis/yt_utils.py index 9d724eed..41c201ed 100644 --- a/src/opencosmo/analysis/yt_utils.py +++ b/python/opencosmo/analysis/yt_utils.py @@ -92,9 +92,11 @@ def astropy_to_yt(array): return unyt_array(array.data, "dimensionless") if "littleh" in str(array.unit): - raise RuntimeError("cannot convert factors of littleh to yt convention, " - "try converting the opencosmo dataset to comoving units " - "(e.g. set `ds = ds.with_units(\"comoving\"))`") + raise RuntimeError( + "cannot convert factors of littleh to yt convention, " + "try converting the opencosmo dataset to comoving units " + '(e.g. set `ds = ds.with_units("comoving"))`' + ) return unyt_array.from_astropy(array) diff --git a/src/opencosmo/analysis/yt_viz.py b/python/opencosmo/analysis/yt_viz.py similarity index 55% rename from src/opencosmo/analysis/yt_viz.py rename to python/opencosmo/analysis/yt_viz.py index e96ae001..053fa4eb 100644 --- a/src/opencosmo/analysis/yt_viz.py +++ b/python/opencosmo/analysis/yt_viz.py @@ -4,18 +4,22 @@ import numpy as np import yt # type: ignore +import matplotlib.pyplot as plt from matplotlib.colors import LogNorm # type: ignore +from matplotlib.animation import FuncAnimation from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar # type: ignore from unyt import unyt_quantity # type: ignore from yt.visualization.base_plot_types import get_multi_plot # type: ignore +from yt.visualization.particle_plots import OffAxisParticleProjectionPlot import opencosmo as oc from opencosmo.analysis import create_yt_dataset if TYPE_CHECKING: - from matplotlib.figure import Figure from matplotlib.colors import Normalize + from matplotlib.figure import Figure from yt.visualization.plot_window import NormalPlot + from yt.data_objects.static_output import Dataset as YT_Dataset # ruff: noqa: E501 @@ -136,10 +140,14 @@ def PhasePlot(*args, **kwargs) -> yt.PhasePlot: def visualize_halo( halo_id: int, data: oc.StructureCollection, + yt_ds: Optional[YT_Dataset] = None, projection_axis: Optional[str] = "z", + north_vector: Optional[list[float]] = None, length_scale: Optional[str] = "top left", text_color: Optional[str] = "lightgray", width: Optional[float] = None, + manual_axis_alignment: Optional[bool] = False, + ) -> Figure: """ Creates a figure showing particle projections of dark matter, stars, gas, and/or gas temperature @@ -251,14 +259,27 @@ def visualize_halo( ) halo_ids: list[int] | tuple[list[int], list[int]] + + yt_dataset_provided = yt_ds is not None + if not yt_dataset_provided: + yt_ds_arr = None + if len(params["fields"]) == 4: # if 4 fields, make a 2x2 figure halo_ids = ([halo_id, halo_id], [halo_id, halo_id]) + + if yt_dataset_provided: + yt_ds_arr = ([yt_ds, yt_ds],[yt_ds, yt_ds]) + params = {key: (value[:2], value[2:]) for key, value in params.items()} else: # otherwise, do 1xN halo_ids = np.shape(params["fields"])[0] * [halo_id] + + if yt_dataset_provided: + yt_ds_arr = np.shape(params["fields"])[0] * [yt_ds] + params = {key: [value] for key, value in params.items()} return halo_projection_array( @@ -269,22 +290,27 @@ def visualize_halo( width=width, projection_axis=projection_axis, text_color=text_color, + north_vector=north_vector, + yt_ds=yt_ds_arr, ) def halo_projection_array( halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, data: oc.StructureCollection, + yt_ds: Optional[ YT_Dataset | list[YT_Dataset|None] | tuple[list[YT_Dataset|None], list[YT_Dataset|None]] | np.ndarray ] = None, field: Optional[Tuple[str, str]] = ("dm", "particle_mass"), weight_field: Optional[Tuple[str, str]] = None, projection_axis: Optional[str] = "z", + north_vector: Optional[list[float]] = None, cmap: Optional[str] = "gray", - cmap_norm: Optional[Normalize] = None, # type: ignore + cmap_norm: Optional[Normalize] = None, # type: ignore zlim: Optional[Tuple[float, float]] = None, params: Optional[Dict[str, Any]] = None, length_scale: Optional[str] = None, text_color: Optional[str] = "lightgray", width: Optional[float] = None, + manual_axis_alignment: Optional[bool] = False, ) -> Figure: """ Creates a multipanel figure of projections for different fields and/or halos. @@ -306,20 +332,30 @@ def halo_projection_array( halo_ids : int or 2D array of int Unique ID of the halo(s) to be visualized. The shape of `halo_ids` sets the layout of the figure (e.g., if `halo_ids` is a 2x3 array, the outputted figure will be a 2x3 - array of projections). To leave a panel in the outputted figure blank, set the corresponding + array of projections). To leave a panel in the outputted figure blank, set the corresponding entry into the `halo_ids` array to `None`. If `int`, a single panel is output while preserving formatting. data : opencosmo.StructureCollection OpenCosmo StructureCollection dataset containing both halo properties and particle data (e.g., output of ``opencosmo.open([haloproperties, sodbighaloparticles])``). + yt_ds : yt dataset or 2D array of yt datasets, optional + Pre-loaded yt dataset (e.g., output of ``opencosmo.analysis.create_yt_dataset()``). + If ``None``, ``halo_projection_array`` will internally search for the halo ID to create the yt dataset. field : tuple of str, optional Field to plot for all panels. Follows yt naming conventions (e.g., ``("dm", "particle_mass")``, ``("gas", "temperature")``). Overridden if ``params["fields"]`` is provided. weight_field : tuple of str, optional Field to weight by during projection. Follows yt naming conventions. Overridden if ``params["weight_fields"]`` is provided. - projection_axis : str, optional - Data is projected along this axis (``"x"``, ``"y"``, or ``"z"``). + projection_axis : str, int, or 3-element sequence of floats, optional + Data is projected along this axis (``"x"``, ``"y"``, or ``"z"``), or, alternatively, + (0, 1, or 2). ``projection_axis`` is forwarded to the ``normal`` parameter of `ParticleProjectionPlot`. + An arbitrary projection axis may be provided as a 3-element sequence of floats. Overridden if ``params["projection_axes"]`` is provided + north_vector : str, int, or 3-element sequence of floats, optional + Sets the north vector of the projection (i.e. which axis corresponds to "up" in the final image). + Setting ``north_vector`` requires setting a ``projection_axis`` that is perpendicular to ``north_vector``. + If ``north_vector`` is not set, yt will choose a north vector internally. + ``north_vector`` is forwarded to the ``north_vector`` parameter of ParticleProjectionPlot. cmap : str Matplotlib colormap to use for all panels. Overridden if ``params["cmaps"]`` is provided. See https://matplotlib.org/stable/gallery/color/colormap_reference.html for named colormaps. @@ -330,6 +366,10 @@ def halo_projection_array( Colorbar limits for `field`. Overridden if ``params["zlims"]`` is provided. length_scale : str or None, optional Optionally add a horizontal bar denoting length scale in Mpc. + manual_axis_alignment : bool, optional + Generate images by directly calling yt.OffAxisParticleProjectionPlot, + which can give more flexibility for managing image orientation. + If False, ``halo_projection_array`` will use yt.ParticleProjectionPlot. Options: - ``"top left"``: add to top left panel @@ -354,7 +394,7 @@ def halo_projection_array( - ``"zlims"``: 2D array of colorbar limits (log-scaled) - ``"labels"``: 2D array of panel labels (or None) - ``"cmaps"``: 2D array of Matplotlib colormaps for each panel - - ``"cmap_norms"``: 2D array of colormap normalization method (e.g. matplotlib.colors.LogNorm()) + - ``"cmap_norms"``: 2D array of colormap normalization method (e.g. matplotlib.colors.LogNorm()) - ``"widths"``: 2D array of widths in units of R200 text_color : str, optional Set the color of all text annotations. Default is "gray" @@ -373,10 +413,17 @@ def halo_projection_array( # easily translatable to yt's unit conventions data = data.with_units("comoving") - halo_ids = np.atleast_2d(halo_ids) + halo_ids_2d = np.atleast_2d(halo_ids) # determine shape of figure - fig_shape = np.shape(halo_ids) + fig_shape = np.shape(halo_ids_2d) + + yt_datasets_provided = yt_ds is not None + + if yt_datasets_provided: + yt_ds_2d = np.atleast_2d(yt_ds) # type: ignore + else: + yt_ds_2d = np.full(fig_shape, None) # Default plotting parameters if weight_field is None: @@ -391,9 +438,32 @@ def halo_projection_array( zlim_ = np.full(fig_shape, None) else: zlim_ = np.reshape( - [zlim for _ in range(np.prod(fig_shape))], (fig_shape[0], fig_shape[1], 2) + [zlim for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 2) + ) + + if isinstance(projection_axis, (str, int)): + projection_axis_ = np.full(fig_shape, projection_axis) + + elif isinstance(projection_axis, (list, tuple, np.ndarray)): + projection_axis_ = np.reshape( + [projection_axis for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 3), ) + else: + raise RuntimeError(f"`projection_axis` has unsopported type ({type(projection_axis)}).") + + + if north_vector is None: + north_vector_ = np.full(fig_shape, None) + else: + north_vector_ = np.reshape( + [north_vector for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 3) + ) + + default_params = { "fields": ( np.reshape( @@ -403,7 +473,8 @@ def halo_projection_array( ), "weight_fields": (weight_field_), "zlims": (zlim_), - "projection_axes": (np.full(fig_shape, projection_axis)), + "projection_axes": (projection_axis_), + "north_vectors": (north_vector_), "labels": (np.full(fig_shape, None)), "cmaps": (np.full(fig_shape, cmap)), "cmap_norms": (np.full(fig_shape, None)), @@ -416,6 +487,7 @@ def halo_projection_array( fields = params.get("fields", default_params["fields"]) weight_fields = params.get("weight_fields", default_params["weight_fields"]) projection_axes = params.get("projection_axes", default_params["projection_axes"]) + north_vectors = params.get("north_vectors", default_params["north_vectors"]) zlims = params.get("zlims", default_params["zlims"]) labels = params.get("labels", default_params["labels"]) cmaps = params.get("cmaps", default_params["cmaps"]) @@ -430,35 +502,44 @@ def halo_projection_array( fig, axes, cbars = get_multi_plot(fig_shape[1], fig_shape[0], cbar_padding=0) # are we plotting a single halo multiple times? - halo_ids = np.array(halo_ids) halo_id_previous = np.inf for i in range(nrow): for j in range(ncol): - halo_id = halo_ids[i][j] + halo_id = halo_ids_2d[i][j] ax = axes[i][j] if halo_id is None: ax.set_facecolor("black") continue - # retrieve halo particle info if new halo - if (i == 0 and j == 0) or halo_id != halo_id_previous: - # retrieve properties of halo - if len(data) > 1: - data_id = data.filter(oc.col("unique_tag") == halo_id) - else: - if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore - raise RuntimeError(f"Halo ID {halo_id} not in dataset!") - data_id = data - halo_data = next(iter(data_id.objects())) + if yt_datasets_provided: + ds = yt_ds_2d[i][j] + if ds is None: + raise ValueError("provided yt dataset cannot be None") - # load particles into yt - ds = create_yt_dataset(halo_data) + # sodbighaloparticles holds particle data out to 2*R200 + Rh = ds.domain_width[0] / 4 + + else: + # retrieve halo particle info if new halo + if (i == 0 and j == 0) or halo_id != halo_id_previous: + # retrieve properties of halo + if len(data) > 1: + data_id = data.filter(oc.col("unique_tag") == halo_id) + else: + if halo_id != data["halo_properties"].select("unique_tag").get_data(): # type: ignore + raise RuntimeError(f"Halo ID {halo_id} not in dataset!") + data_id = data + + halo_data = next(iter(data_id.objects())) + + # load particles into yt + ds = create_yt_dataset(halo_data) - halo_properties = halo_data["halo_properties"] + halo_properties = halo_data["halo_properties"] - Rh = unyt_quantity.from_astropy(halo_properties["sod_halo_radius"]) + Rh = unyt_quantity.from_astropy(halo_properties["sod_halo_radius"]) field, weight_field, zlim, width = ( tuple(fields[i][j]), @@ -473,10 +554,36 @@ def halo_projection_array( zlim = tuple(zlim) # type: ignore label = labels[i][j] + projection_axis = projection_axes[i][j] + + north_vector = _sanitize_input_vector(north_vector) + + + # we need to determine of the projection is going to be axis-aligned. + # yt internally checks this within ParticleProjectionPlot and forwards to + # OffAxisParticleProjectionPlot if it is not axis-aligned. We are manually + # calling OffAxisParticleProjectionPlot for more control over the normal/north + # vectors (ParticleProjectionPlot ignores these inputs if axis-aligned). + if manual_axis_alignment: + projection_axis = _sanitize_input_vector(projection_axis) + + proj = OffAxisParticleProjectionPlot( + ds, + projection_axis, + field, + weight_field=weight_field, + north_vector=north_vectors[i][j], + ) # type: ignore + + else: + proj = ParticleProjectionPlot( + ds, + projection_axis, + field, + weight_field=weight_field, + north_vector=north_vectors[i][j], + ) # type: ignore - proj = ParticleProjectionPlot( - ds, projection_axes[i][j], field, weight_field=weight_field - ) proj.set_background_color(field, color="black") @@ -567,3 +674,328 @@ def halo_projection_array( halo_id_previous = halo_id return fig + +def _fig_to_rgb(fig): + """ + Render a Matplotlib Figure to an (H, W, 3) uint8 RGB array in memory. + """ + fig.canvas.draw() + buf = np.asarray(fig.canvas.buffer_rgba()) # (H, W, 4) + return buf[..., :3] # drop alpha channel + +def _sanitize_input_vector(v): + if isinstance(v, str): + match v: + case "x" | 0: + return (1, 0, 0) + case "y" | 1: + return (0, 1, 0) + case "z" | 2: + return (0, 0, 1) + else: + return v + +def _normalize(v, eps=0): + v = np.asarray(v, dtype=float) + + # normalize + v /= np.linalg.norm(v) + + if eps > 0: + # pad zeros with some non-zero value + v[v==0] = eps + + # normalize again + v /= np.linalg.norm(v) + return v + +def _rodrigues_rotate(v, axis, angle): + """ + Rotate vector v around 'axis' by 'angle' radians (right-hand rule). + """ + v = np.asarray(v, dtype=float) + k = _normalize(axis) + c = np.cos(angle) + s = np.sin(angle) + vrot = v * c + np.cross(k, v) * s + k * np.dot(k, v) * (1 - c) + + return vrot + +def _enforce_orthogonality(v1, v2): + # enforce orthogonality of v1, relative to v2 + return v1 - np.dot(v1, v2) * v2 + +def _get_rotation_vectors(rotations, frames, normal0=(0, 0, 1), north0=(0, 1, 0)): + + normal0 = _sanitize_input_vector(normal0) + north0 = _sanitize_input_vector(north0) + + normals = [_normalize(normal0, eps=1e-3)] + norths = [_normalize(_enforce_orthogonality(north0, normals[-1]))] + + factors = [] + axes = [] + + # get a list of rotations + for rotation in rotations: + if "*" in rotation: + if rotation.count("*") > 1: + raise RuntimeError(f"rotation \"{rotation}\" not recognized") + factor, axis = rotation.split("*") + factor = float(factor) + else: + factor = 1 + axis = rotation + + factors.append(factor) + axes.append(axis) + + # loop through rotations again and actually apply them + for i, rotation in enumerate(rotations): + + # determine number of frames for this rotation (round up) + frames_i = int(np.ceil( frames * np.absolute(factors[i])/sum(np.absolute(factors)) )) + + # angular distance traveled in theta and phi + delta_angle_i = factors[i] * 2*np.pi / frames_i + + axis = _sanitize_input_vector(axes[i]) + + for _ in range(frames_i): + n = _normalize(_rodrigues_rotate(normals[-1], axis, delta_angle_i), eps=1e-3) + u = _normalize(_rodrigues_rotate(norths[-1], axis, delta_angle_i)) + + # enforce orthogonality of normal and north vectors + u = _normalize( _enforce_orthogonality(u, n) ) + + normals.append(n) + norths.append(u) + + return normals, norths + + +def animate_halos( + halo_ids: int | list[int] | tuple[list[int], list[int]] | np.ndarray, + data: oc.StructureCollection, + func: str = "visualize_halo", + rotations: str | int | list[str] = "y", + frames: int = 30, + dpi: int = 100, + normal0: str | int | list[int] | tuple[int] = "z", + north0: str | int | list[int] | tuple[int] = "y", + **kwargs, +): + """ + Creates an animation of one or more halo projections while rotating the + viewing direction. + + The animation is constructed by repeatedly calling either ``visualize_halo`` or + ``halo_projection_array`` for a sequence of projection orientations and stacking the individual + frames into an animation. The viewing orientation evolves according + to ``rotations``, beginning from the initial projection axis ``normal0`` and + initial "up" direction ``north0``. + + By default, this function animates a single halo using ``visualize_halo`` while + rotating about the y-axis. It can also animate a customizable multipanel projection layout by + setting ``func="halo_projection_array"`` and passing a 2D arrangement of halo IDs. + + Example usage for visualizing the most massive halo in a dataset: + + .. code-block:: python + + import opencosmo as oc + from opencosmo.analysis import animate_halos + + # fetch data and ID for most massive halo + ds = oc.open("haloproperties.hdf5", "haloparticles.hdf5").sort_by("sod_halo_mass").take(1, at="end") + halo_id = ds["halo_properties"].select("unique_tag").get_data() + + # create a 30-frame animation that rotates the object once about the y-axis, then once about the x-axis. + anim = animate_halos(halo_id, ds, rotations=["y", "x"], frames=30) + anim.save("animation.gif", fps=10) + + + Parameters + ---------- + halo_ids : int or array of int + Unique ID of the halo(s) to be animated. `halo_ids` is forwarded to the parameter of the same name + in either `visualize_halo` or `halo_projection_array`, depending on the value of ``func``. + + When ``func="visualize_halo"``, only a single halo ID is allowed. When + ``func="halo_projection_array"``, ``halo_ids`` can be an int, list, or 2D array. + + data : opencosmo.StructureCollection + OpenCosmo StructureCollection containing the halo properties and particle data + needed to create yt datasets for the requested halos. For example, this may be + the output of ``opencosmo.open(["haloproperties.hdf5", "haloparticles.hdf5"])``. + + func : str, optional + Name of the plotting function used to generate each animation frame. + + - ``"visualize_halo"``: animate a single halo using ``visualize_halo``. + - ``"halo_projection_array"``: animate a panel array using + ``halo_projection_array``. + + rotations : str or sequence of str, optional + Specification for how the camera rotates during the animation. + + For example, ``rotations = "x"`` rotates the object once about the x-axis, while + ``rotations = ["x", "y"]`` rotates the object once about the x-axis, then once about the y-axis. + Partial rotations can be defined by prepending the string with a float + (e.g. ``rotations=["0.5*x", "0.25*y"]`` does half a rotation about the x-axis, then a quarter rotation about the y-axis). + Prepending the string with a negative value reverses the rotation direction. + + frames : int, optional + Total number of frames in the animation. + + dpi : int, optional + Resolution of the persistent display figure used to assemble the animation. + This also controls the effective pixel size of the output animation. + + normal0 : str, int, or 3-tuple of float, optional + Projection axis for the initial frame. + + north0 : str, int, or 3-tuple of float, optional + Initial north vector, i.e. the initial "up" direction in the image plane. + ``north0`` muat be perpendicular to ``normal0``. + + **kwargs + Additional keyword arguments passed directly to the selected plotting function + (either ``visualize_halo`` or ``halo_projection_array``). This can be used to + customize the field being projected, color normalization, labels, plot width, + colormap, and so on. + + Note that ``projection_axis``, ``north_vector``, ``yt_ds``, and + ``manual_axis_alignment`` are set internally by ``animate_halo`` and will be + overridden regardless of values passed through ``kwargs``. + + Returns + ------- + matplotlib.animation.FuncAnimation + Matplotlib animation object. + + """ + + halo_ids = np.atleast_2d(halo_ids) + + fig_shape = np.shape(halo_ids) + yt_ds_arr = np.full(fig_shape, None) + + nrow, ncol = fig_shape + halo_id_previous = np.inf + for i in range(nrow): + for j in range(ncol): + halo_id = halo_ids[i][j] + + if (i == 0 and j == 0) or halo_id != halo_id_previous: + # retrieve properties of halo and load into yt + # this part is skipped if the halo has just been found/loaded in the + # previous iteration + # Can make this slightly faster by copying directly yt_ds_arr in cases where + # the halo was loaded into yt more than 1 iteration ago + + if len(data) > 1: + data_id = data.filter(oc.col("unique_tag") == halo_id) + else: + if data["halo_properties"].data["unique_tag"] != halo_id: # type: ignore + raise RuntimeError(f"Halo ID {halo_id} not in dataset!") + data_id = data + + halo_data = next(iter(data_id.objects())) + + # load particles into yt + ds = create_yt_dataset(halo_data) + + yt_ds_arr[i][j] = ds + + else: + yt_ds_arr[i][j] = ds + + halo_id_previous = halo_id + + normals, norths = _get_rotation_vectors(rotations, + frames=frames, + normal0=normal0, + north0=north0 + ) + + call_visualize_halo = False + call_halo_projection_array = False + if func == "visualize_halo": + call_visualize_halo = True + if np.prod(np.shape(halo_ids)) > 1: + raise ValueError("`visualize_halo` requires a single int for `halo_id`, not an array of values") + + elif func == "halo_projection_array": + call_halo_projection_array = True + else: + raise RuntimeError(f"\`func\` {func} not recognized") + + if call_visualize_halo: + fig0 = visualize_halo( + halo_ids[0][0], + data, + projection_axis=normals[0], + north_vector=norths[0], + yt_ds=yt_ds_arr[0][0], + **kwargs, + ) + + elif call_halo_projection_array: + fig0 = halo_projection_array( + halo_ids, + data, + projection_axis=normals[0], + north_vector=norths[0], + yt_ds = yt_ds_arr, + **kwargs, + ) + + + frame0 = _fig_to_rgb(fig0) + plt.close(fig0) + + H, W = frame0.shape[:2] + + # ---- animation "display" figure (single persistent figure) ---- + fig = plt.figure(figsize=(W / dpi, H / dpi), dpi=dpi) + ax = fig.add_axes((0.0, 0.0, 1.0, 1.0)) + ax.set_axis_off() + ax.set_aspect("auto") + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + im = ax.imshow(frame0, interpolation="nearest") + + def update(i): + normal = normals[i] + north = norths[i] + + if call_visualize_halo: + f = visualize_halo( + halo_ids[0][0], + data, + projection_axis=normal, + north_vector=north, + yt_ds=yt_ds_arr[0][0], + manual_axis_alignment=True, + **kwargs, + ) + elif call_halo_projection_array: + f = halo_projection_array( + halo_ids, + data, + projection_axis=normal, + north_vector=north, + yt_ds=yt_ds_arr, + manual_axis_alignment=True, + **kwargs, + ) + frame = _fig_to_rgb(f) + plt.close(f) # close each per-frame figure + + im.set_data(frame) + + return (im,) + + anim = FuncAnimation(fig, update, frames=frames, interval=50, blit=True) + + return anim diff --git a/src/opencosmo/collection/__init__.py b/python/opencosmo/collection/__init__.py similarity index 100% rename from src/opencosmo/collection/__init__.py rename to python/opencosmo/collection/__init__.py diff --git a/src/opencosmo/collection/lightcone/__init__.py b/python/opencosmo/collection/lightcone/__init__.py similarity index 53% rename from src/opencosmo/collection/lightcone/__init__.py rename to python/opencosmo/collection/lightcone/__init__.py index 586a29eb..62f96d4a 100644 --- a/src/opencosmo/collection/lightcone/__init__.py +++ b/python/opencosmo/collection/lightcone/__init__.py @@ -1,4 +1,9 @@ +from importlib import import_module + from .healpix_map import HealpixMap from .lightcone import Lightcone +# register plugins +import_module(f"{__name__}.plugins") + __all__ = ["Lightcone", "HealpixMap"] diff --git a/src/opencosmo/collection/lightcone/healpix_map.py b/python/opencosmo/collection/lightcone/healpix_map.py similarity index 94% rename from src/opencosmo/collection/lightcone/healpix_map.py rename to python/opencosmo/collection/lightcone/healpix_map.py index 80097804..90157547 100644 --- a/src/opencosmo/collection/lightcone/healpix_map.py +++ b/python/opencosmo/collection/lightcone/healpix_map.py @@ -26,13 +26,45 @@ from opencosmo.column.column import ColumnMask, ConstructedColumn from opencosmo.dataset import Dataset from opencosmo.dataset.build import GroupedColumnData + from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters.hacc import HaccSimulationParameters from opencosmo.spatial import Region +def make_healsparse_maps( + table, + nside: int, + nside_lr: int, +) -> dict[str, hsp.HealSparseMap]: + sentinel = np.float32(hp.UNSEEN) + pixels = table["pixel"].value + + # Build coverage map once, shared across all columns. + cov_map = hsp.HealSparseCoverage.make_empty(nside_lr, nside) + cov_pix = cov_map.cov_pixels(pixels) + unique_cov_pix = np.unique(cov_pix) + cov_map.initialize_pixels(unique_cov_pix) + sparse_indices = pixels + cov_map[cov_pix] + sparse_map_size = (len(unique_cov_pix) + 1) * cov_map.nfine_per_cov + + result = {} + for name, col in table.items(): + if name != "pixel": + sparse_map = np.full(sparse_map_size, sentinel, dtype=np.float32) + sparse_map[sparse_indices] = col.value.astype(np.float32) + result[name] = hsp.HealSparseMap( + cov_map=cov_map, + sparse_map=sparse_map, + nside_sparse=nside, + sentinel=sentinel, + ) + if len(result) == 1: + return next(iter(result.values())) + return result + + def take_from_sorted( healpix_map: "HealpixMap", sort_by: str, invert: bool, n: int, at: str | int ): @@ -276,7 +308,7 @@ def simulation(self) -> HaccSimulationParameters: Returns ------- - parameters: opencosmo.parameters.hacc.HaccSimulationParameters + parameters: opencosmo.dtypes.hacc.HaccSimulationParameters """ return self.__header.simulation @@ -302,7 +334,8 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg You can get the data in two formats, "healsparse" (the default) and "healpix". "healsparse" format will return the data as a healsparse sparse map. - "healpix" will return the data as a dictionary of numpy arrays. For map data, + "healpix" will return the data as a dictionary of numpy arrays. If the map does not + cover the full sky, this wll be a masked numpy array. For map data, due to format requirements, no units will be attached to the data itself, although these will match the units from the data attributes. @@ -340,31 +373,38 @@ def get_data(self, format="healsparse", nside_out: Optional[int] = None, **kwarg table["pixel"] = pixels table.sort("pixel", reverse=False) - if format == "healpix": - if self.__len__() != hp.nside2npix(self.nside): - raise ValueError( - "healpix format chosen but length of dataset doesn't match nside value. Use healsparse" - ) - if len(table.colnames) == 1: table = next(table.itercols()) if format == "healpix": + npix = hp.nside2npix(self.nside) if isinstance(table, (u.Quantity, Column)): - return table.value + vals = np.zeros(npix, dtype=np.float32) + vals[pixels] = table.value + storage = {"vals": vals} + else: table.remove_columns(self.__hidden) - return {name: col.value for name, col in table.items()} + storage = { + name: np.zeros(npix, dtype=np.float32) + for name in table.columns + if name != "pixel" + } + for name, arr in storage.items(): + arr[pixels] = table[name].value + + if len(pixels) != hp.nside2npix(self.nside): + mask = np.zeros(hp.nside2npix(self.nside), dtype=bool) + mask[pixels] = True + storage = { + name: np.ma.masked_array(arr, mask) for name, arr in storage.items() + } + if len(storage) == 1: + return next(iter(storage.values())) + return storage + elif format == "healsparse": - dict_maps = {} - for name, col in table.items(): - if name != "pixel": - hsp_out = hsp.HealSparseMap.make_empty( - self.nside_lr, self.nside, dtype=np.float32 - ) - hsp_out[table["pixel"].value] = (col.value).astype(np.float32) - dict_maps[name] = hsp_out - return dict_maps + return make_healsparse_maps(table, self.nside, self.nside_lr) @property def data(self): diff --git a/python/opencosmo/collection/lightcone/io.py b/python/opencosmo/collection/lightcone/io.py new file mode 100644 index 00000000..a2531c7d --- /dev/null +++ b/python/opencosmo/collection/lightcone/io.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from collections import OrderedDict, defaultdict + +from opencosmo.collection.lightcone import utils as lcutils +from opencosmo.dataset import dataset as ocds +from opencosmo.io.mpi import get_all_keys +from opencosmo.mpi import get_comm_world, get_mpi + + +def order_by_redshift_range(datasets: dict[str, ocds.Dataset]): + redshift_ranges = { + key: lcutils.get_single_redshift_range(ds) for key, ds in datasets.items() + } + sorted_ranges = sorted(redshift_ranges.items(), key=lambda item: item[1][0]) + output = OrderedDict() + for name, _ in sorted_ranges: + output[name] = datasets[name] + return output + + +def combine_adjacent_datasets_mpi( + ordered_datasets: dict[str, dict[str, ocds.Dataset]], + min_dataset_size: int, + no_stack: bool, +): + MIN_DATASET_SIZE = 100_000 + comm = get_comm_world() + MPI = get_mpi() + all_dataset_steps = get_all_keys(ordered_datasets, comm) + assert comm is not None and MPI is not None + rs = 0 + output_datasets: dict[str, list[dict[str, ocds.Dataset]]] = OrderedDict() + for step in all_dataset_steps: + if rs == 0: + current_key = step + output_datasets[current_key] = [] + + if step not in ordered_datasets: + rs += comm.allreduce(0, MPI.SUM) + else: + length = sum(len(ds) for ds in ordered_datasets[step].values()) + rs += comm.allreduce(length) + output_datasets[current_key].append(ordered_datasets[step]) + + if rs > MIN_DATASET_SIZE or no_stack: + rs = 0 + + output = OrderedDict() + for step, datasets in output_datasets.items(): + step_output = defaultdict(list) + for ds_group in datasets: + for ds_type, ds in ds_group.items(): + step_output[ds_type].append(ds) + output[step] = step_output + + return output + + +def combine_adjacent_datasets( + ordered_datasets: dict[str, ocds.Dataset] | dict[str, dict[str, ocds.Dataset]], + min_dataset_size: int, + no_stack: bool, +): + is_single = isinstance(next(iter(ordered_datasets.values())), ocds.Dataset) + datasets: dict[str, dict[str, ocds.Dataset]] + if is_single: + assert all(isinstance(ds, ocds.Dataset) for ds in ordered_datasets.values()) + datasets = {key: {"data": ds} for key, ds in ordered_datasets.items()} # type: ignore + else: + assert all(isinstance(ds, dict) for ds in ordered_datasets.values()) + datasets = ordered_datasets # type: ignore + + if get_comm_world() is not None: + return combine_adjacent_datasets_mpi(datasets, min_dataset_size, no_stack) + + running_sum = 0 + + current_key = next(iter(ordered_datasets.keys())) + output_datasets: dict[str, list[dict[str, ocds.Dataset]]] = OrderedDict( + {current_key: []} + ) + + for key, step_datasets in datasets.items(): + if not no_stack and running_sum < min_dataset_size: + running_sum += sum(len(ds) for ds in step_datasets.values()) + output_datasets[current_key].append(step_datasets) + continue + current_key = key + output_datasets[current_key] = [step_datasets] + running_sum = sum(len(ds) for ds in step_datasets.values()) + + # We have list of dicts, go to dict of lists + output = OrderedDict() + for step, step_datasets_ in output_datasets.items(): + step_output = defaultdict(list) + for ds_group in step_datasets_: + for ds_type, ds in ds_group.items(): + step_output[ds_type].append(ds) + output[step] = step_output + + return output diff --git a/src/opencosmo/collection/lightcone/lightcone.py b/python/opencosmo/collection/lightcone/lightcone.py similarity index 72% rename from src/opencosmo/collection/lightcone/lightcone.py rename to python/opencosmo/collection/lightcone/lightcone.py index e721ee65..94be05ba 100644 --- a/src/opencosmo/collection/lightcone/lightcone.py +++ b/python/opencosmo/collection/lightcone/lightcone.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections import OrderedDict, defaultdict +from collections import defaultdict from functools import cached_property, reduce from itertools import chain from typing import ( @@ -9,221 +9,58 @@ Callable, Generator, Iterable, + Literal, Mapping, Optional, Self, - Sequence, ) from warnings import warn +import healpy as hp import numpy as np from astropy.table import vstack # type: ignore +from deprecated import deprecated import opencosmo as oc -from opencosmo.collection.lightcone.coordinates import make_radec_columns +from opencosmo.collection.lightcone import io as lcio +from opencosmo.collection.lightcone import utils as lcutils from opencosmo.collection.lightcone.stack import stack_lightcone_datasets_in_schema from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn -from opencosmo.dataset import Dataset from opencosmo.dataset.evaluate import build_evaluated_column -from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.io.iopen import open_single_dataset -from opencosmo.io.mpi import get_all_keys +from opencosmo.dataset.formats import concat_chunks, convert_data, verify_format +from opencosmo.dataset.take import ( + get_end_take_index, + get_random_take_index, + get_range_take_index, + get_rows_take_index, +) +from opencosmo.index import get_range, into_array, rebuild_by_ranges +from opencosmo.io import iopen from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.mpi import get_comm_world, get_mpi +from opencosmo.plugins.contexts import ( + HookPoint, + LightconeInstantiateCtx, + LightconeOpenCtx, + PostSortCtx, +) +from opencosmo.plugins.hook import fold if TYPE_CHECKING: import astropy.units as u # type: ignore + import numpy.typing as npt from astropy.coordinates import SkyCoord from astropy.cosmology import Cosmology - from astropy.table import Table from opencosmo.column.column import ColumnMask, ConstructedColumn + from opencosmo.dataset import Dataset + from opencosmo.dtypes.hacc import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader + from opencosmo.index import DataIndex from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters.hacc import HaccSimulationParameters from opencosmo.spatial import Region -def get_redshift_range(datasets: Sequence[Dataset | Lightcone]): - redshift_ranges = list(map(get_single_redshift_range, datasets)) - min_z = min(rr[0] for rr in redshift_ranges) - max_z = max(rr[1] for rr in redshift_ranges) - - return (min_z, max_z) - - -def get_single_redshift_range(dataset: Dataset | Lightcone): - if isinstance(dataset, Lightcone): - return dataset.z_range - redshift_range = dataset.header.lightcone["z_range"] - if redshift_range is not None: - return redshift_range - step_zs = dataset.header.simulation["step_zs"] - step = dataset.header.file.step - assert step is not None - min_redshift = step_zs[step] - max_redshift = step_zs[step - 1] - return (min_redshift, max_redshift) - - -def is_in_range(dataset: Dataset, z_low: float, z_high: float): - z_range = dataset.header.lightcone["z_range"] - if z_range is None: - z_range = get_single_redshift_range(dataset) - if z_high < z_range[0] or z_low > z_range[1]: - return False - return True - - -def sort_table(table: Table, column: str, invert: bool): - column_data = table[column] - if invert: - column_data = -column_data - indices = np.argsort(column_data) - for name in table.columns: - table[name] = table[name][indices] - return table - - -def take_from_sorted( - lightcone: "Lightcone", sort_by: str, invert: bool, n: int, at: str | int -): - column = np.concatenate( - [ds.select(sort_by).get_data("numpy") for ds in lightcone.values()] - ) - if invert: - column = -column - sort_index = np.argsort(column) - if at == "start": - sort_index = sort_index[:n] - elif at == "end": - sort_index = sort_index[-n:] - elif isinstance(at, int): - if at + n > len(sort_index) or at < 0: - raise ValueError( - "Requested a range that is outside the size of this dataset!" - ) - sort_index = sort_index[at : at + n] - - sorted_indices = np.sort(sort_index) - return sorted_indices - - -def order_by_redshift_range(datasets: dict[str, Dataset]): - redshift_ranges = { - key: get_single_redshift_range(ds) for key, ds in datasets.items() - } - sorted_ranges = sorted(redshift_ranges.items(), key=lambda item: item[1][0]) - output = OrderedDict() - for name, _ in sorted_ranges: - output[name] = datasets[name] - return output - - -def combine_adjacent_datasets_mpi( - ordered_datasets: dict[str, dict[str, Dataset]], - min_dataset_size, -): - MIN_DATASET_SIZE = 100_000 - comm = get_comm_world() - MPI = get_mpi() - all_dataset_steps = get_all_keys(ordered_datasets, comm) - assert comm is not None and MPI is not None - rs = 0 - output_datasets: dict[str, list[dict[str, Dataset]]] = OrderedDict() - for step in all_dataset_steps: - if rs == 0: - current_key = step - output_datasets[current_key] = [] - - if step not in ordered_datasets: - rs += comm.allreduce(0, MPI.SUM) - else: - length = sum(len(ds) for ds in ordered_datasets[step].values()) - rs += comm.allreduce(length) - output_datasets[current_key].append(ordered_datasets[step]) - - if rs > MIN_DATASET_SIZE: - rs = 0 - - output = OrderedDict() - for step, datasets in output_datasets.items(): - step_output = defaultdict(list) - for ds_group in datasets: - for ds_type, ds in ds_group.items(): - step_output[ds_type].append(ds) - output[step] = step_output - - return output - - -def combine_adjacent_datasets( - ordered_datasets: dict[str, Dataset] | dict[str, dict[str, Dataset]], - min_dataset_size=100_000, -): - is_single = isinstance(next(iter(ordered_datasets.values())), Dataset) - datasets: dict[str, dict[str, Dataset]] - if is_single: - assert all(isinstance(ds, Dataset) for ds in ordered_datasets.values()) - datasets = {key: {"data": ds} for key, ds in ordered_datasets.items()} # type: ignore - else: - assert all(isinstance(ds, dict) for ds in ordered_datasets.values()) - datasets = ordered_datasets # type: ignore - - if get_comm_world() is not None: - return combine_adjacent_datasets_mpi(datasets, min_dataset_size) - - running_sum = 0 - - current_key = next(iter(ordered_datasets.keys())) - output_datasets: dict[str, list[dict[str, Dataset]]] = OrderedDict( - {current_key: []} - ) - - for key, step_datasets in datasets.items(): - if running_sum < min_dataset_size: - running_sum += sum(len(ds) for ds in step_datasets.values()) - output_datasets[current_key].append(step_datasets) - continue - current_key = key - output_datasets[current_key] = [step_datasets] - running_sum = sum(len(ds) for ds in step_datasets.values()) - - # We have list of dicts, go to dict of lists - output = OrderedDict() - for step, step_datasets_ in output_datasets.items(): - step_output = defaultdict(list) - for ds_group in step_datasets_: - for ds_type, ds in ds_group.items(): - step_output[ds_type].append(ds) - output[step] = step_output - - return output - - -def with_redshift_column(dataset: Dataset): - """ - Ensures a column exists called "redshift" which contains the redshift of the objects - in the lightcone. - """ - if "redshift" in dataset.columns: - return dataset - - elif "fof_halo_center_a" in dataset.columns: - z_col = 1 / oc.col("fof_halo_center_a") - 1 - return dataset.with_new_columns(redshift=z_col) - elif "redshift_true" in dataset.columns: - z_col = oc.col("redshift_true") - return dataset.with_new_columns(redshift=z_col) - elif "zp" in dataset.columns: - z_col = oc.col("zp") - return dataset.with_new_columns(redshift=z_col) - raise ValueError( - "Unable to find a redshift or scale factor column for this lightcone dataset" - ) - - class Lightcone(dict): """ A lightcone contains two or more datasets that are part of a lightcone. Typically @@ -244,17 +81,13 @@ def __init__( datasets: Mapping[Any, Dataset | Lightcone], z_range: Optional[tuple[float, float]] = None, hidden: Optional[set[str]] = None, - ordered_by: Optional[tuple[str, bool]] = None, + sort_key: Optional[tuple[str, bool]] = None, ): - datasets = { - k: with_redshift_column(ds) if isinstance(ds, Dataset) else ds - for k, ds in datasets.items() - } self.update(datasets) z_range = ( z_range if z_range is not None - else get_redshift_range(list(datasets.values())) + else lcutils.get_redshift_range(list(datasets.values())) ) columns: set[str] = reduce( @@ -269,7 +102,7 @@ def __init__( hidden = set() self.__hidden = hidden - self.__ordered_by = ordered_by + self.__sort_key = sort_key def __repr__(self): """ @@ -284,7 +117,7 @@ def __repr__(self): repr_ds = self.take(10, at="start") table_head = "First 10 rows:\n" - table_repr = repr_ds.data.__repr__() + table_repr = repr_ds.get_data().__repr__() # remove the first line table_repr = table_repr[table_repr.find("\n") + 1 :] z_range = self.z_range @@ -337,6 +170,17 @@ def columns(self) -> list[str]: cols = list(filter(lambda col: col not in self.__hidden, cols)) return cols + @property + def meta_columns(self) -> list[str]: + """ + The names of the columns in this dataset. + + Returns + ------- + columns: list[str] + """ + return next(iter(self.values())).meta_columns + @cached_property def descriptions(self) -> dict[str, Optional[str]]: """ @@ -420,10 +264,21 @@ def simulation(self) -> HaccSimulationParameters: Returns ------- - parameters: opencosmo.parameters.hacc.HaccSimulationParameters + parameters: opencosmo.dtypes.hacc.HaccSimulationParameters """ return self.__header.simulation + @property + def sorted_by(self) -> Optional[str]: + """ + The column this dataset is sorted by. If not sorted, returns None. + + Returns + ------- + column: Optional[str] + """ + return self.__sort_key[0] if self.__sort_key is not None else None + @property def z_range(self): """ @@ -436,7 +291,47 @@ def z_range(self): return self.__header.lightcone["z_range"] - def get_data(self, format="astropy", unpack: bool = False, **kwargs): + def get_pixels(self, nside: int = 64): + """ + Return the HEALPix pixels occupied by this lightcone at a given resolution. + + Pixel indices are returned in nested ordering. The ``nside`` parameter + controls angular resolution: larger values produce finer pixels. The + requested resolution may not exceed the resolution of the spatial index + stored in the file. + + Parameters + ---------- + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two. + + Returns + ------- + pixels : numpy.ndarray[int] + HEALPix pixel indices (nested ordering) occupied by this lightcone + at the given resolution. + + Raises + ------ + ValueError + If ``nside`` is not a positive power of two, if ``nside`` exceeds + the maximum resolution of the spatial index, or if the lightcone + does not have a spatial index. + """ + + level = np.log2(nside) + if not level.is_integer() or level < 0: + raise ValueError("nside must be a positive power of two!") + + return lcutils.get_pixels(self, int(level)) + + def get_data( + self, + format="astropy", + unpack: bool = True, + wrap_single: bool = False, + **kwargs, + ): """ Get the data in this dataset as an astropy table/column or as numpy array(s). Note that a dataset does not load data from disk into @@ -451,7 +346,9 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): units will be attached. If the dataset only contains a single column, it will be returned as an - astropy.table.Column or a single numpy array. + astropy.table.Column or a single numpy array. Pass :code:`wrap_single=True` + to always return the format's multi-column container (QTable, DataFrame, + dict, ...) regardless of column count. Parameters ---------- @@ -459,6 +356,10 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): The format to output the data in. Currently supported are "astropy", "numpy", "pandas", "polars", and "arrow" + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. + Returns ------- data: Table | Column | dict[str, ndarray] | ndarray @@ -470,27 +371,57 @@ def get_data(self, format="astropy", unpack: bool = False, **kwargs): ) format = kwargs["output"] verify_format(format) - - data = [ds.get_data(unpack=unpack) for ds in self.values()] + lightcone = fold( + HookPoint.LightconeInstantiate, LightconeInstantiateCtx(self) + ).lightcone + data = [ds.get_data(unpack=False) for ds in lightcone.values()] data_with_length = [d for d in data if len(d) > 0] if len(data_with_length) == 0: return data[0] table = vstack(data_with_length, join_type="exact") - if self.__ordered_by is not None: - table.sort(self.__ordered_by[0], reverse=self.__ordered_by[1]) + if self.__sort_key is not None and not kwargs.get("ignore_sort", False): + order = table.argsort(self.__sort_key[0], reverse=self.__sort_key[1]) + table = table[order] + table = fold( + HookPoint.PostSort, PostSortCtx(self, table, np.argsort(order)) + ).data to_remove = self.__hidden.intersection(table.colnames) table.remove_columns(to_remove) + if len(table) == 1 and unpack: + output_data = { + key: value[0] if len(value) == 1 else value + for key, value in table.items() + } + return convert_data(output_data, format, wrap_single=wrap_single) + if format != "astropy": - return convert_data(dict(table), format) - elif len(table.columns) == 1: + return convert_data(dict(table), format, wrap_single=wrap_single) + elif len(table.columns) == 1 and not wrap_single: return next(iter(dict(table).values())) return table + def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False): + data = [ds.get_metadata(columns) for ds in self.values()] + + output = {} + for key in data[0].keys(): + output[key] = np.concatenate([d[key] for d in data]) + if ignore_sort or self.__sort_key is None: + return output + order = np.argsort(self.select(self.__sort_key[0]).get_data("numpy")) + if self.__sort_key[1]: + order = order[::-1] + return {name: arr[order] for name, arr in output.items()} + @property + @deprecated( + version="1.1.0", + reason="Accessing data through the .data attribute is deprecated and will be removed in a future version. Use get_data()", + ) def data(self): """ Return the data in the dataset in astropy format. The value of this @@ -516,7 +447,9 @@ def open(cls, targets: list[FileTarget], **kwargs): for i, ds_target in enumerate(dataset_targets): group_name = ds_target["dataset_group"].name.split("/")[-1] group_name = group_name.lstrip(f"{ds_target['header'].file.step}_") - ds = open_single_dataset(ds_target, bypass_lightcone=True) + ds = iopen.open_single_dataset( + ds_target, bypass_lightcone=True, open_kwargs=kwargs + ) step = ds_target["header"].file.step if step is None: step = i @@ -535,14 +468,19 @@ def open(cls, targets: list[FileTarget], **kwargs): raise ValueError() result = cls(output) - return make_radec_columns(result) + return fold(HookPoint.LightconeOpen, LightconeOpenCtx(result, kwargs)).lightcone @classmethod def from_datasets( - cls, datasets: dict[str, oc.Dataset], z_range: tuple[float, float] + cls, + datasets: Mapping[int, oc.Dataset], + z_range: Optional[tuple[float, float]] = None, + **open_kwargs, ): result = cls(datasets, z_range) - return make_radec_columns(result) + return fold( + HookPoint.LightconeOpen, LightconeOpenCtx(result, open_kwargs) + ).lightcone def with_redshift_range(self, z_low: float, z_high: float): """ @@ -567,16 +505,14 @@ def with_redshift_range(self, z_low: float, z_high: float): raise ValueError("Low and high values of the redshift range are the same!") new_datasets = {} for key, dataset in self.items(): - if not is_in_range(dataset, z_low, z_high): + if not lcutils.is_in_range(dataset, z_low, z_high): continue new_dataset = dataset.filter( oc.col("redshift") > z_low, oc.col("redshift") < z_high ) if len(new_dataset) > 0: new_datasets[key] = new_dataset - return Lightcone( - new_datasets, (z_low, z_high), self.__hidden, self.__ordered_by - ) + return Lightcone(new_datasets, (z_low, z_high), self.__hidden, self.__sort_key) def __map( self, @@ -611,36 +547,40 @@ def __map( if not output: output = zero_length_output if construct: - return Lightcone(output, self.z_range, hidden, self.__ordered_by) + return Lightcone(output, self.z_range, hidden, self.__sort_key) return output def __map_attribute(self, attribute): return {k: getattr(v, attribute) for k, v in self.items()} - def make_schema(self, name: str = "", _min_size=100_000) -> Schema: - datasets = order_by_redshift_range(self) + def make_schema( + self, name: str = "", _min_size=100_000, no_stack: bool = False + ) -> Schema: + datasets = lcio.order_by_redshift_range(self) for key in datasets: if isinstance(datasets[key], Lightcone): datasets[key] = dict(datasets[key]) - output_datasets = combine_adjacent_datasets( - datasets, min_dataset_size=_min_size + output_datasets = lcio.combine_adjacent_datasets( + datasets, min_dataset_size=_min_size, no_stack=no_stack ) children = {} for step, datasets in output_datasets.items(): if len(datasets) == 0: - stack_lightcone_datasets_in_schema(datasets, None, None) + stack_lightcone_datasets_in_schema(datasets, None, None, no_stack) continue all_datasets = list(chain(*tuple(lst for lst in datasets.values()))) - header_zrange = get_redshift_range(all_datasets) + header_zrange = lcutils.get_redshift_range(all_datasets) my_zrange = self.z_range zrange = ( max(header_zrange[0], my_zrange[0]), min(header_zrange[1], my_zrange[1]), ) - child_schemas = stack_lightcone_datasets_in_schema(datasets, step, zrange) + child_schemas = stack_lightcone_datasets_in_schema( + datasets, step, zrange, no_stack + ) child_schemas = { f"{step}_{name}": schema for name, schema in child_schemas.items() } @@ -737,6 +677,57 @@ def box_search(self, p1: tuple | SkyCoord, p2: tuple | SkyCoord): region = oc.make_skybox(p1, p2) return self.bound(region) + def pixel_search(self, pixels: npt.NDArray[np.int_], nside: int = 64): + """ + Return the subset of this lightcone that falls within a set of HEALPix pixels. + + Pixels must be specified in nested ordering and must be valid indices at + the given ``nside``. Duplicate pixel indices are ignored. Use + :py:meth:`get_pixels ` to discover + which pixels this lightcone covers. + + Parameters + ---------- + pixels : array_like[int] + HEALPix pixel indices to query, in nested ordering. Must be a 1-D + array of non-negative integers. Values must be less than + ``healpy.nside2npix(nside)``. + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two + and must not exceed the resolution of the spatial index stored in + the file. + + Returns + ------- + lightcone : opencosmo.Lightcone + A new lightcone containing only the objects that fall within the + specified pixels. + + Raises + ------ + ValueError + If ``nside`` is not a positive power of two, or if ``pixels`` + contains values that are out of range for the given ``nside``. + """ + level = np.log2(nside) + if not level.is_integer() or level < 0: + raise ValueError("nside must be a positive power of two!") + level = int(level) + pixels = np.atleast_1d(pixels) + pixels = np.unique(pixels) + if not np.isdtype(pixels.dtype, "integral") or len(pixels) == 0: + raise ValueError("Pixels must be a 1d array of positive integers") + if pixels[0] < 0 or pixels[-1] >= hp.nside2npix(nside): + raise ValueError("Pixels must be a 1d array of positive integers") + output = {} + for name, ds in self.items(): + if isinstance(ds, Lightcone): + output[name] = ds.pixel_search(pixels, nside) + continue + rows = ds.tree.project_on_index(level, ds.index, pixels) + output[name] = ds.take_rows(rows) + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) + def evaluate( self, func: Callable, @@ -744,6 +735,7 @@ def evaluate( insert=True, format: str = "astropy", batch_size: int = -1, + allow_overwrite: bool = False, **evaluate_kwargs, ): """ @@ -778,9 +770,10 @@ def evaluate( The function to evaluate on the rows in the dataset. format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. vectorize: bool, default = False Whether to provide the values as full columns (True) or one row at a time (False) @@ -832,6 +825,7 @@ def evaluate( "with_new_columns", mapped_arguments=mapped_evaluated_columns, construct=True, + allow_overwrite=allow_overwrite, ) result = self.__map( @@ -842,6 +836,7 @@ def evaluate( insert=insert, mapped_arguments=mapped_kwargs, batch_size=batch_size, + allow_overwrite=allow_overwrite, construct=insert, **evaluate_kwargs, ) @@ -852,7 +847,7 @@ def evaluate( keys = next(iter(result.values())).keys() output = {} for key in keys: - output[key] = np.concatenate([r[key] for r in result.values()]) + output[key] = concat_chunks([r[key] for r in result.values()], format) return output def filter(self, *masks: ColumnMask, **kwargs) -> Self: @@ -879,7 +874,9 @@ def filter(self, *masks: ColumnMask, **kwargs) -> Self: """ return self.__map("filter", *masks, **kwargs) - def rows(self) -> Generator[dict[str, float | u.Quantity], None, None]: + def rows( + self, metadata_columns=[] + ) -> Generator[dict[str, float | u.Quantity], None, None]: """ Iterate over the rows in the dataset. Rows are returned as a dictionary For performance, it is recommended to first select the columns you need to @@ -890,7 +887,9 @@ def rows(self) -> Generator[dict[str, float | u.Quantity], None, None]: row : dict A dictionary of values for each row in the dataset with units. """ - yield from chain.from_iterable(v.rows() for v in self.values()) + yield from chain.from_iterable( + v.rows(metadata_columns=metadata_columns) for v in self.values() + ) def select( self, *columns: str | Iterable[str], **derived_columns: ConstructedColumn @@ -950,13 +949,13 @@ def select( hidden = self.__hidden additional_columns = set() - if "redshift" not in all_columns: + if "redshift" not in all_columns and "properties" in self.dtype: additional_columns.add("redshift") hidden = hidden.union({"redshift"}) - if self.__ordered_by is not None and self.__ordered_by[0] not in all_columns: - additional_columns.add(self.__ordered_by[0]) - hidden = hidden.union({self.__ordered_by[0]}) + if self.__sort_key is not None and self.__sort_key[0] not in all_columns: + additional_columns.add(self.__sort_key[0]) + hidden = hidden.union({self.__sort_key[0]}) return self.__map( "select", @@ -1001,7 +1000,9 @@ def drop(self, *columns: str | Iterable[str]) -> Self: kept_columns = current_columns - dropped_columns return self.select(kept_columns) - def take(self, n: int, at: str = "random") -> "Lightcone": + def take( + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" + ) -> "Lightcone": """ Create a new dataset from some number of rows from this dataset. @@ -1015,11 +1016,21 @@ def take(self, n: int, at: str = "random") -> "Lightcone": at : str Where to take the rows from. One of "start", "end", or "random". The default is "random". + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- - dataset : Dataset - The new dataset with only the selected rows. + lightcone : Lightcone + The new lightcone with only the selected rows. Raises ------ @@ -1028,28 +1039,27 @@ def take(self, n: int, at: str = "random") -> "Lightcone": or if 'at' is invalid. """ - if n > len(self): - raise ValueError( - "Number of rows to take must be less than number of rows in dataset" - ) if at == "random": - indices = np.random.choice(len(self), n, replace=False) - indices = np.sort(indices) - return self.__take_rows(indices) - - elif self.__ordered_by is not None: - index = take_from_sorted(self, *self.__ordered_by, n=n, at=at) - return self.__take_rows(index) + index = get_random_take_index(n, len(self), mode) elif at == "start": - return self.take_range(0, n) + index = get_range_take_index(self, self.__sort_key, 0, n, mode) + if self.__sort_key is not None: + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) elif at == "end": - return self.take_range(len(self) - n, len(self)) + index = get_end_take_index(n, self, self.__sort_key, mode) + if self.__sort_key is not None: + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) else: raise ValueError( f'"at" should be one of ("start", "end", "random", got {at}' ) + return self.__take_rows(index) - def take_range(self, start: int, end: int): + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ): """ Create a new lightcone from a row range in this lightcone. We use standard indexing conventions, so the rows included will be start -> end - 1. Because @@ -1060,9 +1070,19 @@ def take_range(self, start: int, end: int): Parameters ---------- start : int - The beginning of the range + The beginning of the range. end : int - The end of the range + The end of the range (exclusive). + mode : str, "local" or "global", default = "local" + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. + + * ``"local"`` (default): the range is applied independently on + each rank. + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the lightcone is sorted, ranks will + coordinate to take from the globally-sorted lightcone. Returns ------- @@ -1076,31 +1096,16 @@ def take_range(self, start: int, end: int): or if end is greater than start. """ - if start < 0 or end > len(self): - raise ValueError("Got row indices that are out of range!") - - if self.__ordered_by is not None: - indices = take_from_sorted(self, *self.__ordered_by, end - start, at=start) - return self.__take_rows(indices) - - ends = np.cumsum(np.fromiter((len(ds) for ds in self.values()), dtype=int)) - starts = np.insert(ends, 0, 0)[:-1] - clipped_starts = np.clip(starts, a_min=start, a_max=None) - clipped_ends = np.clip(ends, a_min=None, a_max=end) + if start < 0: + raise ValueError("Tried to take negative rows!") - output = {} - for i, (name, dataset) in enumerate(self.items()): - if starts[i] == clipped_starts[i] and ends[i] == clipped_ends[i]: - output[name] = dataset - elif clipped_starts[i] >= clipped_ends[i]: - continue - else: - output[name] = dataset.take_range( - clipped_starts[i] - starts[i], clipped_ends[i] - starts[i] - ) - return Lightcone(output, self.z_range, self.__hidden, self.__ordered_by) + index = get_range_take_index(self, self.__sort_key, start, end - start, mode) + if self.__sort_key is not None: + sort_index = self.__make_sort_index() + index = np.sort(sort_index[into_array(index)]) + return self.__take_rows(index) - def take_rows(self, rows: np.ndarray): + def take_rows(self, rows: DataIndex): """ Take the rows of a lightcone specified by the :code:`rows` argument. :code:`rows` should be an array of integers. @@ -1121,47 +1126,48 @@ def take_rows(self, rows: np.ndarray): lightcone. """ - rows = np.sort(rows) - if rows[-1] >= len(self) or rows[0] < 0: + index_range = get_range(rows) + + if index_range[0] < 0 or index_range[1] > len(self): raise ValueError( "Rows must be between 0 and the length of this dataset - 1" ) - if self.__ordered_by is not None: - sort_index = self.__make_sort_index() - rows = sort_index[rows] - rows.sort() - + rows = get_rows_take_index(self, rows, self.__sort_key) return self.__take_rows(rows) def __make_sort_index(self): - if self.__ordered_by is None: + if self.__sort_key is None: return None data = np.concatenate( - [ds.select(self.__ordered_by[0]).get_data("numpy") for ds in self.values()] + [ds.select(self.__sort_key[0]).get_data("numpy") for ds in self.values()] ) - if self.__ordered_by[1]: + if self.__sort_key[1]: data = -data return np.argsort(data) - def __take_rows(self, rows: np.ndarray): + def __take_rows(self, rows: DataIndex): """ - Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorte. + Takes rows from this lightcone while ignoring sort. "rows" is assumed to be sorted. For internal use only. """ - ds_ends = np.cumsum(np.fromiter((len(ds) for ds in self.values()), dtype=int)) - partitions = np.searchsorted(rows, ds_ends) - splits = np.split(rows, partitions) - rs = 0 + sizes = np.fromiter((len(ds) for ds in self.values()), dtype=np.int64) + starts = np.zeros_like(sizes) + starts[1:] = np.cumsum(sizes)[:-1] + projected = rebuild_by_ranges(rows, (starts, sizes)) output = {} - for split, (name, dataset) in zip(splits, self.items()): - if len(split) > 0: - output[name] = dataset.take_rows(split - rs) - rs += len(dataset) - return Lightcone(output, self.z_range, self.__hidden, self.__ordered_by) + for (name, ds), index in zip(self.items(), projected): + output[name] = ds.take_rows(index) + if all(len(ds) == 0 for ds in output.values()): + output = {"data": next(iter(output.values()))} + else: + output = {name: ds for name, ds in output.items() if len(ds) != 0} + + return Lightcone(output, self.z_range, self.__hidden, self.__sort_key) def with_new_columns( self, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **columns: ConstructedColumn | np.ndarray | u.Quantity, ): """ @@ -1204,7 +1210,7 @@ def with_new_columns( else: raw[name] = column - if self.__ordered_by is not None: + if self.__sort_key is not None: sort_index = self.__make_sort_index() sort_index = np.argsort(sort_index) raw = {name: raw_data[sort_index] for name, raw_data in raw.items()} @@ -1216,11 +1222,13 @@ def with_new_columns( for i, (ds_name, ds) in enumerate(self.items()): raw_columns = {name: arrs[i] for name, arrs in raw_split.items()} columns_input = raw_columns | derived - new_dataset = ds.with_new_columns(descriptions, **columns_input) + new_dataset = ds.with_new_columns( + descriptions, allow_overwrite=allow_overwrite, **columns_input + ) new_datasets[ds_name] = new_dataset - return Lightcone(new_datasets, self.z_range, self.__hidden, self.__ordered_by) + return Lightcone(new_datasets, self.z_range, self.__hidden, self.__sort_key) - def sort_by(self, column: str, invert: bool = False): + def sort_by(self, column: Optional[str], invert: bool = False): """ Sort this dataset by the values in a given column. By default sorting is in ascending order (least to greatest). Pass invert = True to sort in descending @@ -1238,9 +1246,9 @@ def sort_by(self, column: str, invert: bool = False): Parameters ---------- - column : str + column : Optional[str] The column in the halo_properties or galaxy_properties dataset to - order the collection by. + order the collection by. Pass None to remove sorting. invert : bool, default = False If False (the default), ordering will be from least to greatest. @@ -1254,9 +1262,14 @@ def sort_by(self, column: str, invert: bool = False): """ - if column not in self.columns: + if column is None: + sort_key = None + elif column not in self.columns: raise ValueError(f"Column {column} does not exist in this dataset!") - return Lightcone(dict(self), self.z_range, self.__hidden, (column, invert)) + else: + sort_key = (column, invert) + + return Lightcone(dict(self), self.z_range, self.__hidden, sort_key) def with_units( self, diff --git a/python/opencosmo/collection/lightcone/plugins.py b/python/opencosmo/collection/lightcone/plugins.py new file mode 100644 index 00000000..904d1a6d --- /dev/null +++ b/python/opencosmo/collection/lightcone/plugins.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import dataclasses +import warnings +from typing import TYPE_CHECKING + +import astropy.cosmology.units as cu +import astropy.units as u +import numpy as np + +import opencosmo as oc +from opencosmo.column import norm_cols +from opencosmo.plugins.contexts import HookPoint +from opencosmo.plugins.hook import hook + +if TYPE_CHECKING: + from opencosmo import Lightcone + from opencosmo.plugins.contexts import LightconeOpenCtx + + +@hook( + HookPoint.LightconeOpen, + when=lambda ctx: ctx.lightcone.dtype in ["halo_properties", "galaxy_properties"], +) +def _ensure_coordinates(ctx: LightconeOpenCtx) -> LightconeOpenCtx: + known_columns = set(ctx.lightcone.columns) + if known_columns.issuperset(("phi", "theta")) or known_columns.issuperset( + ("ra", "dec") + ): + return ctx + + prefix = "fof_halo_center" + if ctx.lightcone.dtype == "galaxy_properties": + prefix = "gal_center" + + coord_columns = {coord: f"{prefix}_{coord}" for coord in ["x", "y", "z"]} + if not set(ctx.lightcone.columns).issuperset(coord_columns.values()): + raise ValueError("Unable to find coordinate columns for this lightcone dataset") + chi = norm_cols(*list(coord_columns.values())) + phi = oc.col(coord_columns["y"]).arctan2(oc.col(coord_columns["x"])) + theta = (oc.col(coord_columns["z"]) / oc.col("chi")).arccos() + new_lightcone = ctx.lightcone.with_new_columns(chi=chi, phi=phi, theta=theta) + return dataclasses.replace(ctx, lightcone=new_lightcone) + + +@hook(HookPoint.LightconeOpen) +def _ensure_redshift_column(ctx: LightconeOpenCtx) -> LightconeOpenCtx: + """Ensures a column called 'redshift' exists on every lightcone.""" + lightcone: Lightcone = ctx.lightcone + if ( + "particles" in lightcone.dtype or "profiles" in lightcone.dtype + ): # Particles or profiles, redshift handled at structure collection level + return ctx + if "redshift" in lightcone.columns: + return ctx + elif "fof_halo_center_a" in lightcone.columns: + z_col = 1 / oc.col("fof_halo_center_a") - 1 + elif "redshift_true" in lightcone.columns: + z_col = oc.col("redshift_true") + elif "zp" in lightcone.columns: + z_col = oc.col("zp") + elif "chi" in lightcone.columns: + lightcone = lightcone.evaluate( + redshift_from_chi, + cosmology=lightcone.cosmology, + vectorize=True, + ) + return dataclasses.replace(ctx, lightcone=lightcone) + + else: + raise ValueError( + "Unable to find a redshift or scale factor column for this lightcone dataset" + ) + return dataclasses.replace( + ctx, lightcone=lightcone.with_new_columns(redshift=z_col) + ) + + +@hook(HookPoint.LightconeOpen) +def _make_radec_columns(ctx: LightconeOpenCtx): + lightcone: Lightcone = ctx.lightcone + if "ra" in lightcone.columns and "dec" in lightcone.columns: + pass + elif "theta" in lightcone.columns and "phi" in lightcone.columns: + lightcone = lightcone.evaluate( + radec_from_thetaphi, vectorize=True, insert=True, format="numpy" + ) + elif "properties" in lightcone.dtype: + warnings.warn( + "Could not find coordinates in this catalog. Spatial queries will not be available" + ) + + return dataclasses.replace(ctx, lightcone=lightcone) + + +def radec_from_thetaphi(theta, phi): + theta_deg = theta * 180 / np.pi + phi_deg = phi * 180 / np.pi + return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} + + +def redshift_from_chi(chi: u.Quantity, cosmology): + distance = chi.to(u.Mpc, cu.with_H0(cosmology.H0)) + + redshift = distance.to( + cu.redshift, cu.redshift_distance(cosmology, kind="comoving") + ) + return {"redshift": redshift} diff --git a/src/opencosmo/collection/lightcone/stack.py b/python/opencosmo/collection/lightcone/stack.py similarity index 70% rename from src/opencosmo/collection/lightcone/stack.py rename to python/opencosmo/collection/lightcone/stack.py index 22ac5be1..313cc715 100644 --- a/src/opencosmo/collection/lightcone/stack.py +++ b/python/opencosmo/collection/lightcone/stack.py @@ -24,6 +24,69 @@ def update_order(data: np.ndarray, comm: Optional[MPI.Comm], order: np.ndarray): return data[order] +def _global_inverse_order_mpi(order: np.ndarray, comm) -> np.ndarray: + """ + Build the global inverse permutation across all MPI ranks. + + Each rank's `order` (possibly referencing rows on other ranks) is combined + into a single global permutation of length N. The inverse maps global input + position -> global output position in the written file. + """ + sizes = comm.allgather(len(order)) + ends = np.cumsum(sizes) + starts = np.insert(ends, 0, 0) + N = int(starts[-1]) + + # global_order[i] is the global input position that lands at global output i + global_order = order + starts[comm.Get_rank()] + all_global_orders = np.concatenate(comm.allgather(global_order)) + + global_inv = np.empty(N, dtype=np.int64) + global_inv[all_global_orders] = np.arange(N) + return global_inv + + +def update_top_host_idx( + data: np.ndarray, + comm: Optional[MPI.Comm], + order: np.ndarray, + slice_sizes: list[int], +): + result = data.copy() + + # Add per-slice global offsets so local row references become global + offset = 0 + if comm is not None: + offsets = np.cumsum(comm.allgather(len(data))) + rank = comm.Get_rank() + if rank > 0: + offset = offsets[rank - 1] + + pos = 0 + for size in slice_sizes: + segment = result[pos : pos + size] + segment[segment >= 0] += offset + offset += size + pos += size + + # Reorder row positions (same as all other columns) + if comm is not None: + result = update_global_order_mpi(result, comm, order) + # After a cross-rank shuffle, result[valid] may contain global indices + # (0..N-1). np.argsort(order) only covers this rank's local portion, so + # we need the full global inverse permutation instead. + inverse_order = _global_inverse_order_mpi(order, comm) + else: + result = result[order] + inverse_order = np.argsort(order) + + # Remap stored row references through the inverse permutation + output = np.full_like(result, -1) + valid = result >= 0 + output[valid] = inverse_order[result[valid]] + return output + + def update_global_order_mpi(data, comm, order): needs_global_reordering = comm.allgather(np.any((order < 0) | (order > len(order)))) if not np.any(needs_global_reordering): @@ -86,6 +149,7 @@ def stack_lightcone_datasets_in_schema( datasets: dict[str, list[ds.Dataset]], name: Optional[str], redshift_range: Optional[tuple[float, float]], + no_stack: bool = False, ): n_datasets = sum(len(lst) for lst in datasets.values()) if n_datasets == 1 and get_comm_world() is None: @@ -99,9 +163,12 @@ def stack_lightcone_datasets_in_schema( schema_children = {} ds_groups = get_all_keys(datasets, get_comm_world()) for ds_group in ds_groups: + schema_name = ds_group if len(datasets) > 1 else name ds_list = datasets.get(ds_group, []) ds_list = list(filter(lambda ds: len(ds) > 0, ds_list)) if len(ds_list) == 0: + if no_stack: + continue get_stacked_lightcone_order([], -1) sync_headers(ds_list, None) continue @@ -111,15 +178,25 @@ def stack_lightcone_datasets_in_schema( max_level = int(index_names[-1][-1]) assert all(isinstance(dataset, ds.Dataset) for dataset in ds_list) + if no_stack: + assert len(schemas) == 1 + schema_children[schema_name] = schemas[0] + continue new_data_group = stack_data_groups( [schema.children["data"] for schema in schemas] ) - order = get_stacked_lightcone_order(ds_list, max_level) updater = partial(update_order, order=order) + slice_sizes = [len(dataset) for dataset in ds_list] + top_host_idx_updater = partial( + update_top_host_idx, order=order, slice_sizes=slice_sizes + ) - for column in new_data_group.columns.values(): - column.set_transformation(updater) + for col_name, column in new_data_group.columns.items(): + if col_name == "top_host_idx": + column.set_transformation(top_host_idx_updater) + else: + column.set_transformation(updater) new_index_group = stack_index_groups( [schema.children["index"] for schema in schemas] @@ -132,8 +209,11 @@ def stack_lightcone_datasets_in_schema( "index": new_index_group, "header": header_schema, } + if all("data_linked" in c.children for c in schemas): + children["data_linked"] = stack_data_groups( + [schema.children["data_linked"] for schema in schemas] + ) - schema_name = ds_group if len(datasets) > 1 else name assert schema_name is not None schema_children[schema_name] = make_schema( schema_name, diff --git a/python/opencosmo/collection/lightcone/utils.py b/python/opencosmo/collection/lightcone/utils.py new file mode 100644 index 00000000..107c100e --- /dev/null +++ b/python/opencosmo/collection/lightcone/utils.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Sequence + +import healpy as hp +import numpy as np + +from opencosmo.collection.lightcone import lightcone as lc +from opencosmo.dataset import dataset as ds + +if TYPE_CHECKING: + from astropy.table import Table + + +def get_redshift_range(datasets: Sequence[ds.Dataset | lc.Lightcone]): + redshift_ranges = list(map(get_single_redshift_range, datasets)) + min_z = min(rr[0] for rr in redshift_ranges) + max_z = max(rr[1] for rr in redshift_ranges) + + return (min_z, max_z) + + +def get_single_redshift_range(dataset: ds.Dataset | lc.Lightcone): + if isinstance(dataset, lc.Lightcone): + return dataset.z_range + redshift_range = dataset.header.lightcone["z_range"] + if redshift_range is not None: + return redshift_range + step_zs = dataset.header.simulation["step_zs"] + step = dataset.header.file.step + assert step is not None + min_redshift = step_zs[step] + max_redshift = step_zs[step - 1] + return (min_redshift, max_redshift) + + +def is_in_range(dataset: ds.Dataset, z_low: float, z_high: float): + z_range = dataset.header.lightcone["z_range"] + if z_range is None: + z_range = get_single_redshift_range(dataset) + if z_high < z_range[0] or z_low > z_range[1]: + return False + return True + + +def sort_table(table: Table, column: str, invert: bool): + column_data = table[column] + if invert: + column_data = -column_data + indices = np.argsort(column_data) + for name in table.columns: + table[name] = table[name][indices] + return table + + +def take_from_sorted( + lightcone: lc.Lightcone, sort_by: str, invert: bool, n: int, at: str | int +): + column = np.concatenate( + [ds.select(sort_by).get_data("numpy") for ds in lightcone.values()] + ) + if invert: + column = -column + sort_index = np.argsort(column) + if at == "start": + sort_index = sort_index[:n] + elif at == "end": + sort_index = sort_index[-n:] + elif isinstance(at, int): + if at + n > len(sort_index) or at < 0: + raise ValueError( + "Requested a range that is outside the size of this dataset!" + ) + sort_index = sort_index[at : at + n] + + sorted_indices = np.sort(sort_index) + return sorted_indices + + +def determine_max_level(lightcone: lc.Lightcone) -> Optional[int]: + """ + Return the minimum tree max_level across all datasets in the lightcone, or + None if any dataset has no spatial index. + """ + max_level: Optional[int] = None + for ds_ in lightcone.values(): + if isinstance(ds_, lc.Lightcone): + ds_level = determine_max_level(ds_) + else: + assert isinstance(ds_, ds.Dataset) + ds_level = ds_.tree.max_level if ds_.tree is not None else None + if ds_level is None: + return None + if max_level is None or ds_level < max_level: + max_level = ds_level + return max_level + + +def get_pixels( + lightcone: lc.Lightcone, level: int, is_occupied: Optional[np.ndarray] = None +): + # We know nside is a power of two at this point + available_level = determine_max_level(lightcone) + if available_level is None: + raise ValueError("Lightcone does not have a spatial index!") + if level > available_level: + raise ValueError( + f"The maximum available nside for this lightcone is {2**available_level}, but {2**level} was requested" + ) + + if is_occupied is None: + is_occupied = np.zeros(hp.nside2npix(2**level), dtype=bool) + + for ds_ in lightcone.values(): + if isinstance(ds_, lc.Lightcone): + lightcone_pixels = get_pixels(ds_, level, is_occupied) + is_occupied[lightcone_pixels] = True + continue + + assert isinstance(ds_, ds.Dataset) + tree = ds_.tree + if tree is None: + raise ValueError( + "One or more datasets in this lightcone does not have a spatial index!" + ) + read_level = tree.max_level if tree.max_level < level else level + ds_pixels = tree.get_occupied_partitions(read_level, ds_.index) + is_occupied[ds_pixels] = True + return np.where(is_occupied)[0] diff --git a/src/opencosmo/collection/protocols.py b/python/opencosmo/collection/protocols.py similarity index 93% rename from src/opencosmo/collection/protocols.py rename to python/opencosmo/collection/protocols.py index 3d5b9382..2c4c0bbc 100644 --- a/src/opencosmo/collection/protocols.py +++ b/python/opencosmo/collection/protocols.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from opencosmo.column.column import ColumnMask from opencosmo.dataset import Dataset - from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema @@ -38,9 +37,6 @@ def make_schema(self) -> Schema: ... @property def dtype(self) -> str | dict[str, str]: ... - @property - def header(self) -> OpenCosmoHeader | dict[str, OpenCosmoHeader]: ... - def __getitem__(self, key: str) -> Union[Dataset, "Collection"]: ... def keys(self) -> Iterable[str]: ... def values(self) -> Iterable[Union[Dataset, "Collection"]]: ... diff --git a/src/opencosmo/collection/simulation/__init__.py b/python/opencosmo/collection/simulation/__init__.py similarity index 100% rename from src/opencosmo/collection/simulation/__init__.py rename to python/opencosmo/collection/simulation/__init__.py diff --git a/src/opencosmo/collection/simulation/simulation.py b/python/opencosmo/collection/simulation/simulation.py similarity index 95% rename from src/opencosmo/collection/simulation/simulation.py rename to python/opencosmo/collection/simulation/simulation.py index 19ba8745..30d04ef5 100644 --- a/src/opencosmo/collection/simulation/simulation.py +++ b/python/opencosmo/collection/simulation/simulation.py @@ -14,10 +14,10 @@ from opencosmo.collection.protocols import Collection from opencosmo.column.column import ColumnMask, ConstructedColumn + from opencosmo.dtypes import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region @@ -154,7 +154,7 @@ def simulation(self) -> dict[str, HaccSimulationParameters]: Returns -------- - simulation_parameters: dict[str, opencosmo.parameters.HaccSimulationParameters] + simulation_parameters: dict[str, opencosmo.dtypes.HaccSimulationParameters] """ return self.__map_attribute("simulation") @@ -301,6 +301,7 @@ def with_new_columns( *args, datasets: Optional[str | Iterable[str]] = None, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | np.ndarray, ): """ @@ -324,6 +325,9 @@ def with_new_columns( :py:attr:`SimulationCollection(datasets).descriptions `. If a dictionary, should have keys matching the column names. + allow_overwrite: bool, default = False + + ** columns : opencosmo.DerivedColumn | np.ndarray | units.Quantity The new columns """ @@ -336,7 +340,10 @@ def with_new_columns( output = {name: ds for name, ds in self.items()} for ds_name in datasets: output[ds_name] = output[ds_name].with_new_columns( - *args, descriptions=descriptions, **new_columns + *args, + descriptions=descriptions, + allow_overwrite=allow_overwrite, + **new_columns, ) return SimulationCollection(output) @@ -345,6 +352,7 @@ def with_new_columns( *args, descriptions=descriptions, datasets=datasets, + allow_overwrite=allow_overwrite, **new_columns, ) @@ -355,6 +363,7 @@ def evaluate( format: str = "astropy", vectorize: bool = False, insert: bool = False, + allow_overwrite: bool = False, **evaluate_kwargs, ): """ @@ -375,9 +384,11 @@ def evaluate( datasets: str | list[str], optional The datasets to evaluate on. If not provided, will be evaluated on all datasets format: str, default = "astropy" - The format of the data that is provided to your function. If "astropy", will be a dictionary of - astropy quantities. If "numpy", will be a dictionary of numpy arrays. Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. + vectorize: bool, default = False Whether to vectorize the computation. See :py:meth:`StructureCollection.evaluate ` and/or :py:meth:`Dataset.evaluate ` for more details. @@ -403,6 +414,7 @@ def evaluate( vectorize=vectorize, insert=insert, format=format, + allow_overwrite=allow_overwrite, construct=insert, **evaluate_kwargs, ) diff --git a/src/opencosmo/collection/structure/__init__.py b/python/opencosmo/collection/structure/__init__.py similarity index 100% rename from src/opencosmo/collection/structure/__init__.py rename to python/opencosmo/collection/structure/__init__.py diff --git a/src/opencosmo/collection/structure/evaluate.py b/python/opencosmo/collection/structure/evaluate.py similarity index 54% rename from src/opencosmo/collection/structure/evaluate.py rename to python/opencosmo/collection/structure/evaluate.py index 7bed468a..efc80cb8 100644 --- a/src/opencosmo/collection/structure/evaluate.py +++ b/python/opencosmo/collection/structure/evaluate.py @@ -1,16 +1,12 @@ from __future__ import annotations from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional -import numpy as np from astropy.units import Quantity # type: ignore from opencosmo import dataset as ds -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) +from opencosmo.dataset.formats import concat_chunks, stack_rows if TYPE_CHECKING: from opencosmo import StructureCollection @@ -82,17 +78,22 @@ def evaluate_into_properties( kwargs: dict[str, Any], insert: bool, ): - storage = __make_output(function, collection, format, kwargs, {}, insert) - for i, structure in enumerate(collection.objects()): - if i == 0: - continue + per_column: dict[str, list] = {} + for structure in collection.objects(): input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) - if storage is not None: - insert_data(storage, i, output) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) + if not isinstance(output, dict): + output = {function.__name__: output} + for name, value in output.items(): + per_column.setdefault(name, []).append(value) - return storage + if not per_column: + return None + return {name: stack_rows(values, format) for name, values in per_column.items()} def evaluate_into_dataset( @@ -103,25 +104,28 @@ def evaluate_into_dataset( dataset: str, insert: bool, ): - storage = __make_chunked_output(function, collection, dataset, format, kwargs, {}) - + per_column: dict[str, list] = {} for i, structure in enumerate(collection.objects()): - if i == 0: - continue input_structure = __make_input(structure, format) - output = function(**input_structure, **kwargs) + if output is None and insert: + raise ValueError( + "You asked to insert these values, but your function returns None!" + ) if not isinstance(output, dict): output = {function.__name__: output} - - if storage is not None: - for name, output_arr in output.items(): - storage[name].append(output_arr) - - if storage is None: - return - output_data = {name: np.concatenate(data) for name, data in storage.items()} - return output_data + if i == 0: + expected_length = len(input_structure[dataset]) + if any(len(v) != expected_length for v in output.values()): + raise ValueError( + "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" + ) + for name, output_arr in output.items(): + per_column.setdefault(name, []).append(output_arr) + + if not per_column: + return None + return {name: concat_chunks(data, format) for name, data in per_column.items()} def __make_input(structure: dict, format: str = "astropy"): @@ -130,78 +134,15 @@ def __make_input(structure: dict, format: str = "astropy"): if isinstance(element, dict): values[name] = __make_input(element, format) elif isinstance(element, ds.Dataset): - data = element.get_data(format) + data = element.get_data(format, wrap_single=True) values[name] = data - elif isinstance(element, Quantity) and format == "numpy": + elif isinstance(element, Quantity) and format != "astropy": values[name] = element.value else: values[name] = element return values -def __make_output( - function: Callable, - collection: StructureCollection, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = next(collection.take(1, at="start").objects()) - first_input = __make_input(first_structure, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - n_rows = len(collection) - return make_output_from_first_values(first_values, n_rows) - - -def __make_chunked_output( - function: Callable, - collection: StructureCollection, - dataset: str, - format: str = "astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, - insert: bool = True, -) -> dict | None: - first_structure = collection.take(1, at="start") - expected_length = len(first_structure[dataset]) - first_structure_data = next(iter(first_structure.objects())) - - first_input = __make_input(first_structure_data, format) - first_values = function( - **first_input, - **kwargs, - **{name: arr[0] for name, arr in iterable_kwargs.items()}, - ) - if first_values is None and insert: - raise ValueError( - "You asked to insert these values, but your function returns None!" - ) - elif first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - if any(len(fv) != expected_length for fv in first_values.values()): - raise ValueError( - "If you pass a `dataset` argument, your function should output an array with the same length as that dataset" - ) - return {name: [fv] for name, fv in first_values.items()} - - def __prepare_collection( spec: dict[str, Optional[list[str]]], collection: StructureCollection ) -> StructureCollection: diff --git a/src/opencosmo/collection/structure/handler.py b/python/opencosmo/collection/structure/handler.py similarity index 62% rename from src/opencosmo/collection/structure/handler.py rename to python/opencosmo/collection/structure/handler.py index 15935894..8db23a0b 100644 --- a/src/opencosmo/collection/structure/handler.py +++ b/python/opencosmo/collection/structure/handler.py @@ -1,14 +1,16 @@ from __future__ import annotations from functools import partial, reduce -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional import numpy as np +from opencosmo.collection.lightcone import lightcone as lc from opencosmo.index import into_array if TYPE_CHECKING: import opencosmo as oc + from opencosmo.collection.structure import structure as sc """ A tale in 3 acts: @@ -43,25 +45,45 @@ } -def create_start_size(data, start_name, size_name): +def create_start_size(data, start_name, size_name, offsets): start = data.pop(start_name, None) size = data.pop(size_name, None) - if start is None: + if start is None or size is None: return None + + start = np.atleast_1d(start).astype(np.int64) + size = np.atleast_1d(size).astype(np.int64) valid = size > 0 - if isinstance(start, np.ndarray): - return (start[valid], size[valid]) - if size == 0: + if not np.any(valid): return None - return (np.atleast_1d(start), np.atleast_1d(size)) + + if offsets is not None: + ds_rs = 0 + src_rs = 0 + for source_len, ds_len in offsets: + slice = start[src_rs : src_rs + source_len] + slice[slice >= 0] += ds_rs + src_rs += source_len + ds_rs += ds_len + + return (start[valid], size[valid]) -def create_idx(data, idx_name): +def create_idx(data, idx_name, offsets): idx = data.pop(idx_name, None) if idx is None: return None + idx = idx.astype(np.int64) valid = idx >= 0 + if offsets is not None: + ds_rs = 0 + src_rs = 0 + for source_len, ds_len in offsets: + slice = idx[src_rs : src_rs + source_len] + slice[slice >= 0] += ds_rs + src_rs += source_len + ds_rs += ds_len if isinstance(idx, np.ndarray): return idx[valid] @@ -70,6 +92,27 @@ def create_idx(data, idx_name): return np.atleast_1d(idx) +def build_lightcone_index(old_source: lc.Lightcone, new_source: lc.Lightcone): + index = np.zeros(len(new_source), dtype=np.int64) + offset = 0 + rs = 0 + for name, ds in old_source.items(): + if name not in new_source.keys(): + offset += len(ds) + continue + original_index = into_array(ds.index) + new_index = into_array(new_source[name].index) + _, index_into_original, index_into_new = np.intersect1d( + original_index, new_index, assume_unique=True, return_indices=True + ) + index_into_original = index_into_original[np.argsort(index_into_new)] + + index[rs : rs + len(index_into_original)] = index_into_original + offset + offset += len(ds) + rs += len(index_into_original) + return index + + def make_links(keys, rename_galaxies=False): starts = list(filter(lambda key: "start" in key, keys)) sizes = list(filter(lambda key: "size" in key, keys)) @@ -98,6 +141,52 @@ def make_links(keys, rename_galaxies=False): return output, columns +def resort_datasets( + source: oc.Dataset | oc.Lightcone, + datasets: Mapping[str, oc.Dataset | oc.Lightcone | oc.StructureCollection], + columns: dict[str, list[str]], +): + all_columns: list[str] = reduce( + lambda acc, ds: acc + columns[ds], datasets.keys(), [] + ) + all_columns = list( + filter(lambda name: "idx" in name or "size" in name, all_columns) + ) + sort_column = next(filter(lambda c: "start" in c or "idx" in c, all_columns)) + unsorted_meta_column = source.get_metadata(sort_column, ignore_sort=True) + sorted_meta_column = source.get_metadata(sort_column) + + argsort_meta_column = np.argsort(sorted_meta_column[sort_column]) + + sort_index = argsort_meta_column[ + np.searchsorted( + sorted_meta_column[sort_column], + unsorted_meta_column[sort_column], + sorter=argsort_meta_column, + ) + ] + + meta = source.get_metadata(all_columns) + output = {} + for name, dataset in datasets.items(): + if len(columns[name]) == 1: + valid_rows = meta[columns[name][0]] >= 0 + new_dataset = dataset.take_rows(sort_index[valid_rows]) + else: + size_column = [name for name in columns[name] if "size" in name] + assert len(size_column) == 1 + size_column_data = meta[size_column[0]].astype(np.int64) + chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=np.int64) + _ = np.cumsum(size_column_data, out=chunk_boundaries[1:]) + starts = chunk_boundaries[sort_index] + sizes = size_column_data[sort_index] + valid = sizes > 0 + idx = (starts[valid], sizes[valid]) + new_dataset = dataset.take_rows(idx) + output[name] = new_dataset + return output + + class LinkHandler: """ This needs some explanation. We break the "don't mutate state" rule pretty hard here. @@ -131,7 +220,7 @@ def __init__( self, links, columns, - derived_from: Optional[oc.Dataset], + derived_from: Optional[oc.Dataset | oc.Lightcone], ): self.__derived_from = derived_from self.links = links @@ -142,26 +231,39 @@ def from_link_names(cls, names: Iterable[str], rename_galaxies=False): links, columns = make_links(names, rename_galaxies) return LinkHandler(links, columns, None) - def parse(self, data: dict[str, Any]): + def parse( + self, + data: dict[str, Any], + offsets: Optional[dict[str, list[tuple[int, int]]]] = None, + ): output = {} for name, handler in self.links.items(): - result = handler(data) + result = handler( + data, offsets=offsets.get(name) if offsets is not None else None + ) if result is not None: output[name] = result return output - def prep_datasets(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): + def prep_datasets( + self, + source: oc.Dataset | oc.Lightcone, + datasets: dict[str, oc.Dataset | oc.Lightcone], + ): """ Called once when a datasets are opened for the first time. Downstream versions always use rebuild_datsets """ - all_columns: list[str] = reduce( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) meta = source.get_metadata(all_columns) - indices = self.parse(meta) + # Offsets are now baked into the metadata columns at construction time + # (see build_lightcone_structure_collection in io.py), so no per-step + # offset calculation is needed here. + indices = self.parse(meta, offsets=None) new_datasets = datasets + for name, index in indices.items(): new_datasets[name] = new_datasets[name].take_rows(index) return new_datasets @@ -181,8 +283,8 @@ def make_derived(self, source: oc.Dataset): def rebuild_datasets( self, - new_source: oc.Dataset, - datasets: dict[str, oc.Dataset], + new_source: oc.Dataset | oc.Lightcone, + datasets: Mapping[str, oc.Dataset | oc.Lightcone | sc.StructureCollection], ): """ We have a few guarantees here: @@ -194,17 +296,23 @@ def rebuild_datasets( """ if self.__derived_from is None: return datasets - original_index = into_array(self.__derived_from.index) - new_index = into_array(new_source.index) - - _, index_into_original, index_into_new = np.intersect1d( - original_index, new_index, assume_unique=True, return_indices=True - ) + return self.__rebuild_datasets(self.__derived_from, new_source, datasets) + + def __rebuild_datasets(self, derived_from, new_source, datasets): + if isinstance(derived_from, lc.Lightcone): + index_into_original = build_lightcone_index(derived_from, new_source) + else: + original_index = into_array(derived_from.index) + new_index = into_array(new_source.index) + + _, index_into_original, index_into_new = np.intersect1d( + original_index, new_index, assume_unique=True, return_indices=True + ) + index_into_original = index_into_original[np.argsort(index_into_new)] all_columns: list[str] = reduce( lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] ) - index_into_original = index_into_original[np.argsort(index_into_new)] - metadata = self.__derived_from.get_metadata(all_columns) + metadata = derived_from.get_metadata(all_columns) new_datasets = {} for name, dataset in datasets.items(): @@ -218,50 +326,30 @@ def rebuild_datasets( assert len(size_column) == 1 size_column_name = size_column[0] size_column_data = metadata[size_column_name] + new_datasets[name] = rebuild_chunk_index( - new_source, dataset, size_column_data, index_into_original + new_source, + dataset, + size_column_data.astype(np.int64), + index_into_original.astype(np.int64), ) return new_datasets - def resort(self, source: oc.Dataset, datasets: dict[str, oc.Dataset]): + def resort( + self, source: oc.Dataset | oc.Lightcone, datasets: dict[str, oc.Dataset] + ): """ Data is always written in its original order, whether or not it has been sorted. This is to preserve the spatial index. However, when linked datasets are rebuilt they are rebuilt in the sorted order. This method re-sorts them based on the index from the original data. """ - all_columns: list[str] = reduce( - lambda acc, ds: acc + self.columns[ds], datasets.keys(), [] - ) - all_columns = list( - filter(lambda name: "idx" in name or "size" in name, all_columns) - ) - - sort_index = np.argsort(into_array(source.index)) - if np.all(sort_index[1:] >= sort_index[:-1]): - # Already sorted. Carry on! + is_sorted = source.sorted_by is not None + if not is_sorted: return datasets - meta = source.get_metadata(all_columns) - output = {} - for name, dataset in datasets.items(): - if len(self.columns[name]) == 1: - valid_rows = meta[self.columns[name][0]] >= 0 - new_dataset = dataset.take_rows(sort_index[valid_rows]) - else: - size_column = [name for name in self.columns[name] if "size" in name] - assert len(size_column) == 1 - size_column_data = meta[size_column[0]] - chunk_boundaries = np.zeros(len(size_column_data) + 1, dtype=int) - _ = np.cumsum(size_column_data, out=chunk_boundaries[1:]) - starts = chunk_boundaries[sort_index] - sizes = size_column_data[sort_index] - valid = sizes > 0 - idx = (starts[valid], sizes[valid]) - new_dataset = dataset.take_rows(idx) - output[name] = new_dataset - return output + return resort_datasets(source, datasets, self.columns) def rebuild_row_index( @@ -271,7 +359,7 @@ def rebuild_row_index( index_into_original: np.ndarray, ): valid_rows = original_metadata_column >= 0 - index = np.full(len(original_metadata_column), -1, dtype=int) + index = np.full(len(original_metadata_column), -1, dtype=np.int64) index[valid_rows] = np.arange(0, sum(valid_rows)) index_to_take = index[index_into_original] index_to_take = index_to_take[index_to_take >= 0] @@ -285,7 +373,7 @@ def rebuild_chunk_index( original_size_column: np.ndarray, index_into_original: np.ndarray, ): - chunk_boundaries = np.zeros(len(original_size_column) + 1, dtype=int) + chunk_boundaries = np.zeros(len(original_size_column) + 1, dtype=np.int64) _ = np.cumsum(original_size_column, out=chunk_boundaries[1:]) valid_rows = original_size_column[index_into_original] > 0 diff --git a/python/opencosmo/collection/structure/io.py b/python/opencosmo/collection/structure/io.py new file mode 100644 index 00000000..939e79cd --- /dev/null +++ b/python/opencosmo/collection/structure/io.py @@ -0,0 +1,414 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import partial +from itertools import chain +from typing import TYPE_CHECKING, Any, Mapping, Optional, TypeGuard + +import numpy as np + +import opencosmo as oc +from opencosmo import dataset as d +from opencosmo import io +from opencosmo.collection.lightcone import lightcone as lc +from opencosmo.collection.structure import structure as sc +from opencosmo.collection.structure.handler import LINK_ALIASES + +if TYPE_CHECKING: + import h5py + from mpi4py import MPI + + from opencosmo.io.iopen import FileTarget + +ALLOWED_LINKS = { # h5py.Files that can serve as a link holder and + "halo_properties": ["halo_particles", "halo_profiles", "galaxy_properties"], + "galaxy_properties": ["galaxy_particles"], +} + + +def remove_empty(dataset): + metadata = dataset.get_metadata() + mask = np.ones(len(dataset), dtype=bool) + for name, col in metadata.items(): + if "size" in name: + mask &= col != 0 + elif "idx" in name: + mask &= col != -1 + + if not mask.all(): + dataset = dataset.take_rows(np.where(mask)[0]) + return dataset + + +def is_dataset(ds: Any) -> TypeGuard[d.Dataset]: + return isinstance(ds, d.Dataset) + + +def validate_linked_groups(groups: dict[str, h5py.Group]): + if "halo_properties" in groups: + if "data_linked" not in groups["halo_properties"].keys(): + raise ValueError( + "File appears to be a structure collection, but does not have links!" + ) + elif "galaxy_properties" in groups: + if "data_linked" not in groups["galaxy_properties"].keys(): + raise ValueError( + "File appears to be a structure collection, but does not have links!" + ) + if len(groups) == 1: + raise ValueError("Structure collections must have more than one dataset") + + +def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): + link_sources: dict[str, list[io.iopen.DatasetTarget]] = defaultdict(list) + link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]] = ( + defaultdict(lambda: defaultdict(list)) + ) + + dataset_targets: list[io.iopen.DatasetTarget] = [] + for t in targets: + dataset_targets.extend(t["dataset_targets"]) + for datasets in t["dataset_groups"].values(): + dataset_targets.extend(datasets) + + for target in dataset_targets: + if target["header"].file.data_type == "halo_properties": + link_sources["halo_properties"].append(target) + elif target["header"].file.data_type == "galaxy_properties": + link_sources["galaxy_properties"].append(target) + elif str(target["header"].file.data_type).startswith("halo"): + dataset = io.iopen.open_single_dataset( + target, bypass_lightcone=True, bypass_mpi=True + ) + name_source = target["dataset_group"] + if ( + "particles" in name_source.parent.name + or "profiles" in target["dataset_group"].parent.name + ): + name_source = target["dataset_group"].parent + name = name_source.name.split("/")[-1] + + if not name: + name = target["header"].file.data_type + elif name.startswith("halo_properties"): + name = name[16:] + link_targets["halo_properties"][name].append(dataset) + elif str(target["header"].file.data_type).startswith("galaxy"): + dataset = io.iopen.open_single_dataset( + target, bypass_lightcone=True, bypass_mpi=True + ) + name_source = target["dataset_group"] + if ( + "particles" in name_source.parent.name + or "profiles" in target["dataset_group"].parent.name + ): + name_source = target["dataset_group"].parent + name = name_source.name.split("/")[-1] + + if not name: + name = target["header"].file.data_type + elif name.startswith("galaxy_properties"): + name = name[18:] + link_targets["galaxy_properties"][name].append(dataset) + else: + raise ValueError( + f"Unknown data type for structure collection {target['header'].data_type}" + ) + + if ( + len(link_sources["halo_properties"]) > 1 + or len(link_sources["galaxy_properties"]) > 1 + ): + # Potentially a lightcone structure collection + return build_lightcone_structure_collection(link_sources, link_targets) + + halo_properties_target = None + galaxy_properties_target = None + if link_sources["halo_properties"]: + halo_properties_target = link_sources["halo_properties"][0] + if link_sources["galaxy_properties"]: + galaxy_properties_target = link_sources["galaxy_properties"][0] + + input_link_targets: dict[str, dict[str, d.Dataset | sc.StructureCollection]] = ( + defaultdict(dict) + ) + for source_type, source_targets in link_targets.items(): + if any(len(ts) > 1 for ts in source_targets.values()): + raise ValueError("Found more than one linked file of a given type!") + input_link_targets[source_type] = { + key: t[0] for key, t in source_targets.items() + } + + return __build_structure_collection( + halo_properties_target, + galaxy_properties_target, + input_link_targets, + ignore_empty, + ) + + +def _apply_offset_corrections( + source_by_step: Mapping[int, d.Dataset], + targets_by_step: Mapping[str, Mapping[int, d.Dataset | sc.StructureCollection]], +) -> dict[int, d.Dataset]: + """ + Correct step-local _start and _idx metadata columns to be globally correct + before stacking per-step source datasets into a Lightcone. + + For _start columns: apply a lazy DerivedColumn offset (oc.col(name) + offset). + For _idx columns: apply the offset eagerly (only to non-negative values). + + targets_by_step keys may be either file-level group name prefixes (e.g. + "sodbighaloparticles_dm_particles") or alias values (e.g. "galaxy_properties"). + Column names always use file-level prefixes (e.g. "sodbighaloparticles_dm_particles_start"), + so we match via direct lookup then fall back to a LINK_ALIASES alias lookup. + """ + steps = sorted(source_by_step) + + type_step_offset: dict[str, dict[int, int]] = {} + for target_type, step_map in targets_by_step.items(): + cumulative = 0 + per_step: dict[int, int] = {} + for step in steps: + per_step[step] = cumulative + ds = step_map.get(step) + if ds is not None: + cumulative += len(ds) + type_step_offset[target_type] = per_step + + corrected: dict[int, d.Dataset] = {} + for step in steps: + step_ds = source_by_step[step] + meta_cols = set(step_ds.meta_columns) + updates: dict = {} + + for col in meta_cols: + if col.endswith("_start"): + prefix = col[:-6] + is_idx = False + elif col.endswith("_idx"): + prefix = col[:-4] + is_idx = True + else: + continue + + # Try direct match (target_type key == file-level prefix), then + # fall back to alias lookup for cases like "galaxyproperties" -> "galaxy_properties". + offset = None + if prefix in type_step_offset: + offset = type_step_offset[prefix][step] + elif prefix in LINK_ALIASES and LINK_ALIASES[prefix] in type_step_offset: + offset = type_step_offset[LINK_ALIASES[prefix]][step] + + if not offset: + continue + + if is_idx: + arr = step_ds.get_metadata([col])[col].copy() + arr[arr >= 0] += offset + updates[col] = arr + else: + updates[col] = oc.col(col) + offset + + if updates: + step_ds = step_ds.with_new_columns(allow_overwrite=True, **updates) + corrected[step] = step_ds + + return corrected + + +def build_lightcone_structure_collection( + link_sources: dict[str, list[io.iopen.DatasetTarget]], + link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]], +): + found_redshift_steps: set[int] = set() + for source_type, source_list in link_sources.items(): + if not all(t["header"].file.is_lightcone for t in source_list): + raise ValueError("All sources must be lightcone datasets!") + redshift_steps = set(t["header"].file.step for t in source_list) + if found_redshift_steps and found_redshift_steps != redshift_steps: + raise ValueError( + "All source types must have the same set of redshift steps!" + ) + if not all( + t.header.file.is_lightcone + for t in chain.from_iterable(link_targets[source_type].values()) + ): + raise ValueError("All dataset must be lightcone datasets!") + for targets in link_targets[source_type].values(): + target_redshift_steps = set(t.header.file.step for t in targets) + if target_redshift_steps != redshift_steps: + raise ValueError( + "All datasets must have the same set of redshift steps!" + ) + if ( + len(link_sources.get("galaxy_properties", [])) > 0 + and "galaxy_properties" in link_targets + ): + # Galaxy properties and galaxy particles + galaxy_datasets = [ + io.iopen.open_single_dataset( + t, + "data_linked", + bypass_lightcone=True, + bypass_mpi=len(link_sources.get("halo_properties", [])) > 0, + ) + for t in link_sources["galaxy_properties"] + ] + galaxy_source_by_step: dict[int, d.Dataset] = {} + for ds in galaxy_datasets: + assert ds.header.file.step is not None + galaxy_source_by_step[ds.header.file.step] = ds + galaxy_targets_by_step: dict[ + str, Mapping[int, d.Dataset | sc.StructureCollection] + ] = { + target_type: {ds.header.file.step: ds for ds in targets} # type: ignore[misc] + for target_type, targets in link_targets["galaxy_properties"].items() + } + galaxy_source_by_step = _apply_offset_corrections( + galaxy_source_by_step, galaxy_targets_by_step + ) + galaxy_lightcone = lc.Lightcone.from_datasets(galaxy_source_by_step) + galaxy_target_datasets = {} + for target_type, targets in link_targets["galaxy_properties"].items(): + galaxy_target_datasets[target_type] = lc.Lightcone.from_datasets( + {ds.header.file.step: ds for ds in targets} # type: ignore + ) + collection = sc.StructureCollection(galaxy_lightcone, galaxy_target_datasets) + if len(link_sources.get("halo_properties", [])) > 0: + link_targets["halo_properties"]["galaxy_properties"] = collection # type: ignore[assignment] + else: + return collection + + halo_source_list = link_sources["halo_properties"] + halo_datasets = [ + io.iopen.open_single_dataset(t, "data_linked", bypass_lightcone=True) + for t in halo_source_list + ] + halo_source_by_step: dict[int, d.Dataset] = {} + for ds in halo_datasets: + assert ds.header.file.step is not None + halo_source_by_step[ds.header.file.step] = ds + halo_targets_by_step: dict[str, dict[int, d.Dataset]] = {} + for target_type, targets in link_targets["halo_properties"].items(): + if isinstance(targets, sc.StructureCollection): + # For a nested SC (e.g. galaxies), target_type is its source dtype + # ("galaxy_properties"), so targets[target_type] returns the source + # Lightcone, giving us per-step sizes for offset accounting. + inner_lc = targets[target_type] + assert isinstance(inner_lc, lc.Lightcone) + halo_targets_by_step[target_type] = dict(inner_lc) + elif isinstance(targets, list): + step_map: dict[int, d.Dataset] = {} + for ds in targets: + assert isinstance(ds, d.Dataset) + assert ds.header.file.step is not None + step_map[ds.header.file.step] = ds + halo_targets_by_step[target_type] = step_map + halo_source_by_step = _apply_offset_corrections( + halo_source_by_step, halo_targets_by_step + ) + source_lightcone = lc.Lightcone.from_datasets(halo_source_by_step) + + output_targets = {} + for target_type, targets in link_targets["halo_properties"].items(): + if isinstance(targets, (d.Dataset, sc.StructureCollection)): + output_targets[target_type] = targets + continue + output_targets_of_type: dict[int, d.Dataset] = {} + for ds in targets: + assert isinstance(ds, d.Dataset) + assert ds.header.file.step is not None + output_targets_of_type[ds.header.file.step] = ds + + output_targets[target_type] = lc.Lightcone.from_datasets(output_targets_of_type) + return sc.StructureCollection(source_lightcone, output_targets) + + +def __build_structure_collection( + halo_properties_target: Optional[io.iopen.DatasetTarget], + galaxy_properties_target: Optional[io.iopen.DatasetTarget], + link_targets: dict[str, dict[str, d.Dataset | sc.StructureCollection]], + ignore_empty: bool, +): + if galaxy_properties_target is not None and "galaxy_properties" in link_targets: + # Galaxy properties and galaxy particles + source_dataset = io.iopen.open_single_dataset( + galaxy_properties_target, + metadata_group="data_linked", + bypass_lightcone=True, + bypass_mpi=halo_properties_target is not None, + ) + if ignore_empty and halo_properties_target is None: + source_dataset = remove_empty(source_dataset) + collection = sc.StructureCollection( + source_dataset, + link_targets["galaxy_properties"], + ) + if halo_properties_target is not None: + link_targets["halo_properties"]["galaxy_properties"] = collection + else: + return collection + + if ( + halo_properties_target is not None + and galaxy_properties_target is not None + and "galaxy_properties" not in link_targets + ): + # Halo properties and galaxy properties, but no galaxy particles + galaxy_properties = io.iopen.open_single_dataset( + galaxy_properties_target, bypass_lightcone=True, bypass_mpi=True + ) + link_targets["halo_properties"]["galaxy_properties"] = galaxy_properties + + if halo_properties_target is not None and link_targets["halo_properties"]: + source_dataset = io.iopen.open_single_dataset( + halo_properties_target, metadata_group="data_linked", bypass_lightcone=True + ) + if ignore_empty: + source_dataset = remove_empty(source_dataset) + + return sc.StructureCollection( + source_dataset, + link_targets["halo_properties"], + ) + + +def do_idx_update(data: np.ndarray, comm: Optional[MPI.Comm] = None): + if comm is None: + return np.arange(len(data)) + lengths = comm.allgather(len(data)) + offsets = np.insert(np.cumsum(lengths), 0, 0) + offset = offsets[comm.Get_rank()] + result = np.arange(offset, offset + len(data)) + return result + + +def do_start_update(data: np.ndarray, size: np.ndarray, comm: Optional[MPI.Comm]): + psum = np.insert(np.cumsum(size), 0, 0)[:-1] + if comm is None: + return psum + lengths = comm.allgather(np.sum(size)) + offsets = np.insert(np.cumsum(lengths), 0, 0) + offset = offsets[comm.Get_rank()] + return psum + offset + + +def rebuild_data_linked(source_schema): + if ( + source_schema.type == io.schema.FileEntry.LIGHTCONE + and "data" not in source_schema.children + ): + for key, value in source_schema.children.items(): + source_schema.children[key] = rebuild_data_linked(value) + return source_schema + + for colname, column in source_schema.children["data_linked"].columns.items(): + if "idx" in colname: + column.set_transformation(do_idx_update) + elif "start" in colname: + size_colname = colname.replace("start", "size") + size_data = source_schema.children["data_linked"].columns[size_colname].data + updater = partial(do_start_update, size=size_data) + column.set_transformation(updater) + return source_schema diff --git a/src/opencosmo/collection/structure/structure.py b/python/opencosmo/collection/structure/structure.py similarity index 77% rename from src/opencosmo/collection/structure/structure.py rename to python/opencosmo/collection/structure/structure.py index 05e9463f..26c4430c 100644 --- a/src/opencosmo/collection/structure/structure.py +++ b/python/opencosmo/collection/structure/structure.py @@ -1,16 +1,27 @@ from __future__ import annotations from collections import defaultdict -from functools import partial, reduce +from functools import reduce, wraps from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Literal, + Mapping, + Optional, +) from warnings import warn import numpy as np import opencosmo as oc +from opencosmo.collection.lightcone import lightcone as lc from opencosmo.collection.structure import evaluate from opencosmo.collection.structure import io as sio +from opencosmo.dataset.formats import verify_format from opencosmo.index.unary import get_length from opencosmo.io.schema import FileEntry, make_schema @@ -21,11 +32,12 @@ import astropy.units as u from opencosmo.column.column import ConstructedColumn + from opencosmo.dtypes import HaccSimulationParameters + from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.iopen import FileTarget from opencosmo.io.schema import Schema from opencosmo.mpi import MPI - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region @@ -67,6 +79,18 @@ def do_start_update(data: np.ndarray, size: np.ndarray, comm: Optional[MPI.Comm] return psum + offset +def _lightcone_only(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not isinstance(self._StructureCollection__source, lc.Lightcone): + raise AttributeError( + f"{func.__name__} is only available on lightcone structure collections." + ) + return func(self, *args, **kwargs) + + return wrapper + + class StructureCollection: """ A collection of datasets that contain both high-level properties @@ -83,7 +107,6 @@ class StructureCollection: def __init__( self, source: oc.Dataset, - header: oc.header.OpenCosmoHeader, datasets: Mapping[str, oc.Dataset | StructureCollection], hide_source: bool = False, link_handler: Optional[LinkHandler] = None, @@ -95,9 +118,7 @@ def __init__( """ self.__source = source - self.__header = header self.__datasets = dict(datasets) - self.__index = self.__source.index self.__hide_source = hide_source if isinstance(self.__datasets.get("galaxy_properties"), StructureCollection): self.__datasets["galaxies"] = self.__datasets.pop("galaxy_properties") @@ -106,7 +127,9 @@ def __init__( self.__handler = LinkHandler.from_link_names( self.__source.meta_columns, "galaxies" in self.__datasets ) - datasets = self.__handler.prep_datasets(self.__source, self.__datasets) + self.__datasets = self.__handler.prep_datasets( + self.__source, self.__datasets + ) else: self.__handler = link_handler @@ -131,13 +154,17 @@ def __get_datasets(self): return self.__datasets def __repr__(self): - structure_type = self.__header.file.data_type.split("_")[0] + "s" + structure_type = self.__source.header.file.data_type.split("_")[0] + "s" + is_lightcone = isinstance(self.__source, lc.Lightcone) keys = list(self.keys()) if len(keys) == 2: dtype_str = " and ".join(keys) else: dtype_str = ", ".join(keys[:-1]) + ", and " + keys[-1] - return f"Collection of {structure_type} with {dtype_str}" + header = f"Collection of {structure_type} {'on a lightcone ' if is_lightcone else ' '}with {dtype_str}\n" + source_repr = self.__source.__repr__().split("\n", maxsplit=1)[1] + + return header + source_repr def __len__(self): return len(self.__source) @@ -146,15 +173,12 @@ def __len__(self): def open( cls, targets: list[FileTarget], ignore_empty=True, **kwargs ) -> StructureCollection: - return sio.build_structure_collection(targets, ignore_empty) - - @property - def header(self): - return self.__header + result = sio.build_structure_collection(targets, ignore_empty) + return result @property def dtype(self): - structure_type = self.__header.file.data_type.split("_")[0] + structure_type = self.__source.header.file.dt return structure_type @property @@ -164,6 +188,10 @@ def cosmology(self) -> astropy.cosmology.Cosmology: """ return self.__source.cosmology + @property + def header(self) -> OpenCosmoHeader: + return self.__source.header + @property def properties(self) -> list[str]: """ @@ -175,27 +203,61 @@ def properties(self) -> list[str]: @property def redshift(self) -> float | tuple[float, float] | None: """ - For snapshots, return the redshift or redshift range - this dataset was drawn from. + For snapshots, return the redshift this dataset was drawn from. Returns ------- redshift: float | tuple[float, float] """ - return self.__header.file.redshift + if isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "This is a lightcone structure collection. Use .z_range to get the redshift range." + ) + return self.__source.header.file.redshift + + @property + def z_range(self) -> tuple[float, float]: + """ + The redshift range covered by this lightcone structure collection. + + Returns + ------- + z_range: tuple[float, float] + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + if not isinstance(self.__source, lc.Lightcone): + raise AttributeError( + "This is not a lightcone structure collection. Use .redshift to get the redshift." + ) + return self.__source.z_range + + @property + def sorted_by(self) -> Optional[str]: + """ + The column this collection is currently sorted by, or ``None`` if unsorted. + + Returns + ------- + column : Optional[str] + """ + return self.__source.sorted_by @property - def simulation(self) -> HaccSimulationParameters: + def simulation(self) -> HaccSimulationParameters | None: """ Get the parameters of the simulation this dataset is drawn from. Returns ------- - parameters: opencosmo.parameters.HaccSimulationParameters + parameters: opencosmo.dtypes.HaccSimulationParameters """ - return self.__header.simulation + return self.__source.simulation def keys(self) -> list[str]: """ @@ -224,7 +286,7 @@ def __getitem__(self, key: str) -> oc.Dataset | oc.StructureCollection: """ Return the linked dataset with the given key. """ - if key == self.__header.file.data_type: + if key == self.__source.header.file.data_type: return self.__source datasets = self.__get_datasets() if key not in datasets.keys(): @@ -277,7 +339,177 @@ def bound( new_handler = self.__handler.make_derived(self.__source) return StructureCollection( bounded, - self.__header, + self.__datasets, + self.__hide_source, + new_handler, + self.__derived_columns, + ) + + @_lightcone_only + def with_redshift_range(self, z_low: float, z_high: float) -> StructureCollection: + """ + Restrict this lightcone structure collection to a specific redshift range. + This is more efficient than filtering on the redshift column directly because + it prunes entire redshift steps before row-level filtering, and it updates + the z_range metadata on the returned collection. + + Parameters + ---------- + z_low : float + The lower bound of the redshift range (inclusive). + z_high : float + The upper bound of the redshift range (inclusive). + + Returns + ------- + result : StructureCollection + A new StructureCollection restricted to the given redshift range. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If the requested range does not overlap the available redshift range. + """ + assert isinstance(self.__source, lc.Lightcone) + new_source = self.__source.with_redshift_range(z_low, z_high) + new_handler = self.__handler.make_derived(self.__source) + return StructureCollection( + new_source, + self.__datasets, + self.__hide_source, + new_handler, + self.__derived_columns, + ) + + @_lightcone_only + def cone_search(self, center, radius) -> StructureCollection: + """ + Search for structures within an angular distance of a point on the sky. + Equivalent to ``collection.bound(oc.make_cone(center, radius))``. + + Parameters + ---------- + center : tuple | astropy.coordinates.SkyCoord + Center of the search cone. If a tuple with no units, assumed to be + (RA, Dec) in degrees. + radius : float | astropy.units.Quantity + Angular radius of the search cone. If no units, assumed to be degrees. + + Returns + ------- + result : StructureCollection + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + region = oc.make_cone(center, radius) + return self.bound(region) + + @_lightcone_only + def box_search(self, p1, p2) -> StructureCollection: + """ + Search for structures within a rectangular region of the sky (defined by + RA/Dec corners). Equivalent to ``collection.bound(oc.make_skybox(p1, p2))``. + + Parameters + ---------- + p1 : tuple | astropy.coordinates.SkyCoord + One corner of the box. If a tuple with no units, assumed to be + (RA, Dec) in degrees. + p2 : tuple | astropy.coordinates.SkyCoord + The opposite corner of the box. + + Returns + ------- + result : StructureCollection + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + """ + region = oc.make_skybox(p1, p2) + return self.bound(region) + + @_lightcone_only + def get_pixels(self, nside: int = 64) -> np.ndarray: + """ + Return the HEALPix pixels occupied by this lightcone structure collection + at a given resolution. + + Pixel indices are returned in nested ordering. The ``nside`` parameter + controls angular resolution: larger values produce finer pixels. The + requested resolution may not exceed the resolution of the spatial index + stored in the file. + + Parameters + ---------- + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two. + + Returns + ------- + pixels : numpy.ndarray[int] + HEALPix pixel indices (nested ordering) occupied by this collection + at the given resolution. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If ``nside`` is not a positive power of two, if ``nside`` exceeds + the maximum resolution of the spatial index, or if the lightcone + does not have a spatial index. + """ + assert isinstance(self.__source, lc.Lightcone) + return self.__source.get_pixels(nside) + + @_lightcone_only + def pixel_search(self, pixels: np.ndarray, nside: int = 64) -> StructureCollection: + """ + Return the subset of this lightcone structure collection that falls within + a set of HEALPix pixels. + + Pixels must be specified in nested ordering and must be valid indices at + the given ``nside``. Duplicate pixel indices are ignored. Use + :py:meth:`get_pixels ` to + discover which pixels this collection covers. + + Parameters + ---------- + pixels : array_like[int] + HEALPix pixel indices to query, in nested ordering. Must be a 1-D + array of non-negative integers. Values must be less than + ``healpy.nside2npix(nside)``. + nside : int, default = 64 + The HEALPix resolution parameter. Must be a positive power of two + and must not exceed the resolution of the spatial index stored in + the file. + + Returns + ------- + result : StructureCollection + A new collection containing only the structures that fall within + the specified pixels. + + Raises + ------ + AttributeError + If this is not a lightcone structure collection. + ValueError + If ``nside`` is not a positive power of two, or if ``pixels`` + contains values that are out of range for the given ``nside``. + """ + assert isinstance(self.__source, lc.Lightcone) + new_source = self.__source.pixel_search(pixels, nside) + new_handler = self.__handler.make_derived(self.__source) + return StructureCollection( + new_source, self.__datasets, self.__hide_source, new_handler, @@ -290,6 +522,7 @@ def evaluate( dataset: Optional[str] = None, format: str = "astropy", insert: bool = True, + allow_overwrite: bool = False, **evaluate_kwargs: Any, ): """ @@ -365,8 +598,10 @@ def computation(halo_properties, dm_particles): collection contains galaxies. If False, simply return the data. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. **evaluate_kwargs: any, Any additional arguments that are required for your function to run. These will be passed directly @@ -383,11 +618,15 @@ def computation(halo_properties, dm_particles): # If the user sets insert=False, everything is eager. if dataset is not None and dataset == self.__source.dtype: return self.evaluate_on_dataset( - func, dataset=dataset, format=format, insert=insert, **evaluate_kwargs + func, + dataset=dataset, + format=format, + insert=insert, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) - if format not in ["astropy", "numpy"]: - raise ValueError(f"Invalid format requested for data: {format}") + verify_format(format) if dataset is not None and dataset.startswith("galaxies"): # Nested structure collection, special case @@ -397,13 +636,17 @@ def computation(halo_properties, dm_particles): sub_dataset = dataset_path[1] result = self[dataset_path[0]].evaluate( - func, sub_dataset, format, insert, **evaluate_kwargs + func, + sub_dataset, + format, + insert, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) if not insert: return result return StructureCollection( self.__source, - self.__header, self.__get_datasets() | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -427,6 +670,7 @@ def computation(halo_properties, dm_particles): func, insert=insert, format=format, + allow_overwrite=allow_overwrite, strategy="chunked", **evaluate_kwargs, ) @@ -440,7 +684,6 @@ def computation(halo_properties, dm_particles): new_derived_columns_ = [f"{dataset}.{col}" for col in new_derived_columns] return StructureCollection( self.__source, - self.__header, self.__get_datasets() | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -460,9 +703,12 @@ def computation(halo_properties, dm_particles): ) if not insert or output is None: return output + from opencosmo.dataset.formats import to_numpy_dict + return self.with_new_columns( - **output, dataset=dataset if dataset is not None else self.__source.dtype, + allow_overwrite=allow_overwrite, + **to_numpy_dict(output), # type: ignore ) def evaluate_on_dataset( @@ -473,6 +719,7 @@ def evaluate_on_dataset( format: str = "astropy", insert: bool = True, batch_size: int = -1, + allow_overwrite: bool = False, **evaluate_kwargs: Any, ): """ @@ -508,8 +755,10 @@ def evaluate_on_dataset( Whether to provide the values as full columns (True) or one row at a time (False). Ignored if :code:`batch_size` is set. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. insert: bool, default = True If true, the data will be inserted as a column in this dataset. The new column will have the same name @@ -530,14 +779,18 @@ def evaluate_on_dataset( ds: oc.Dataset | StructureCollection if dataset is None or dataset == self.__source.dtype: result = self.__source.evaluate( - func, vectorize, insert, format, **evaluate_kwargs + func, + vectorize, + insert, + format, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, ) if not insert: return result assert isinstance(result, oc.Dataset) return StructureCollection( result, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -554,7 +807,14 @@ def evaluate_on_dataset( ) if len(ds_path) == 1 and isinstance(ds, oc.Dataset): - result = ds.evaluate(func, vectorize, insert, format, **evaluate_kwargs) + result = ds.evaluate( + func, + vectorize, + insert, + format, + allow_overwrite=allow_overwrite, + **evaluate_kwargs, + ) if not insert: return result assert isinstance(result, oc.Dataset) @@ -563,18 +823,24 @@ def evaluate_on_dataset( new_derived_columns_ = [f"{dataset}.{col}" for col in new_derived_columns] return StructureCollection( self.__source, - self.__header, self.__datasets | {dataset: result}, self.__hide_source, self.__handler.make_derived(self.__source), self.__derived_columns.union(new_derived_columns_), ) elif len(ds_path) == 1 and isinstance(ds, oc.StructureCollection): - result = ds.evaluate(func, None, format, insert) + result = ds.evaluate( + func, None, format, insert, allow_overwrite=allow_overwrite + ) elif len(ds_path) > 1 and isinstance(ds, oc.StructureCollection): result = ds.evaluate_on_dataset( - func, ".".join(dataset[1:]), vectorize, format, insert + func, + ".".join(ds_path[1:]), + vectorize, + format, + insert, + allow_overwrite=allow_overwrite, ) if not insert: @@ -583,7 +849,6 @@ def evaluate_on_dataset( assert isinstance(result, (oc.Dataset, StructureCollection)) return StructureCollection( self.__source, - self.__header, self.__datasets | {ds_path[0]: result}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -635,13 +900,12 @@ def filter(self, *masks, on_galaxies: bool = False) -> StructureCollection: galaxy_properties = self["galaxy_properties"] assert isinstance(galaxy_properties, oc.Dataset) filtered = filter_source_by_dataset( - galaxy_properties, self.__source, self.__header, *masks + galaxy_properties, self.__source, self.__source.header, *masks ) new_handler = self.__handler.make_derived(self.__source) return StructureCollection( filtered, - self.__header, self.__datasets, self.__hide_source, new_handler, @@ -731,7 +995,7 @@ def select( arg = columns # type: ignore kwargs = {} - if dataset == self.__header.file.data_type: + if dataset == self.__source.header.file.data_type: new_source = self.__source.select(arg, **kwargs) continue @@ -753,7 +1017,6 @@ def select( return StructureCollection( new_source, - self.__header, self.__datasets | new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -795,7 +1058,7 @@ def drop(self, **columns_to_drop): new_datasets = {} for dataset_name, columns in columns_to_drop.items(): - if dataset_name == self.__header.file.data_type: + if dataset_name == self.__source.header.file.data_type: new_source = self.__source.drop(columns) continue @@ -811,7 +1074,6 @@ def drop(self, **columns_to_drop): return StructureCollection( new_source, - self.__header, self.__datasets | new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -846,7 +1108,6 @@ def sort_by(self, column: str, invert: bool = False) -> StructureCollection: return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -950,13 +1211,14 @@ def with_units( return StructureCollection( new_source, - self.__header, new_datasets, self.__hide_source, self.__handler.make_derived(self.__source), ) - def take(self, n: int, at: str = "random"): + def take( + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" + ): """ Take some number of structures from the collection. See :py:meth:`opencosmo.Dataset.take`. @@ -968,25 +1230,37 @@ def take(self, n: int, at: str = "random"): at : str, optional The method to use to take the structures. One of "random", "first", or "last". Default is "random". + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. + Returns ------- StructureCollection A new collection with the structures taken from the original. """ - new_source = self.__source.take(n, at) + new_source = self.__source.take(n, at, mode) new_handler = self.__handler.make_derived(self.__source) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, new_handler, self.__derived_columns, ) - def take_range(self, start: int, end: int): + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ): """ Create a new collection from a row range in this collection. We use standard indexing conventions, so the rows included will be start -> end - 1. @@ -997,11 +1271,21 @@ def take_range(self, start: int, end: int): The first row to get. end : int The last row to get. + mode : str, "local" or "global", default = "local" + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. + + * ``"local"`` (default): the range is applied independently on + each rank. + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the collection is sorted, ranks will + coordinate to take from the globally-sorted collection. Returns ------- - table : astropy.table.Table - The table with only the rows from start to end. + collection : StructureCollection + The collection with only the rows from start to end. Raises ------ @@ -1010,10 +1294,9 @@ def take_range(self, start: int, end: int): or if end is greater than start. """ - new_source = self.__source.take_range(start, end) + new_source = self.__source.take_range(start, end, mode) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1043,7 +1326,6 @@ def take_rows(self, rows: np.ndarray | DataIndex): new_source = self.__source.take_rows(rows) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1054,6 +1336,7 @@ def with_new_columns( self, dataset: str, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | np.ndarray, ): """ @@ -1122,11 +1405,13 @@ def with_new_columns( if not isinstance(new_collection, StructureCollection): raise ValueError(f"{collection_name} is not a collection!") new_collection = new_collection.with_new_columns( - ".".join(path[1:]), descriptions=descriptions, **new_columns + ".".join(path[1:]), + descriptions=descriptions, + allow_overwrite=allow_overwrite, + **new_columns, ) return StructureCollection( self.__source, - self.__header, {**datasets, collection_name: new_collection}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1134,11 +1419,12 @@ def with_new_columns( if dataset == self.__source.dtype: new_source = self.__source.with_new_columns( - **new_columns, descriptions=descriptions + **new_columns, + descriptions=descriptions, + allow_overwrite=allow_overwrite, ) return StructureCollection( new_source, - self.__header, self.__datasets, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1156,7 +1442,9 @@ def with_new_columns( if not isinstance(ds, oc.Dataset): raise ValueError(f"{dataset} is not a dataset!") - new_ds = ds.with_new_columns(**new_columns, descriptions=descriptions) + new_ds = ds.with_new_columns( + **new_columns, descriptions=descriptions, allow_overwrite=allow_overwrite + ) new_derived_columns = ( set(new_ds.columns).difference(ds.columns).difference(new_im_cols) ) @@ -1164,7 +1452,6 @@ def with_new_columns( return StructureCollection( self.__source, - self.__header, {**datasets, dataset: new_ds}, self.__hide_source, self.__handler.make_derived(self.__source), @@ -1212,7 +1499,6 @@ def objects( for column in self.__derived_columns: name_parts = column.split(".") columns_to_collect[name_parts[0]][name_parts[1]] = [] - try: for row in self.__source.rows(metadata_columns=metadata_columns): row = dict(row) @@ -1308,7 +1594,6 @@ def with_datasets(self, datasets: str | Iterable[str]): new_datasets = {name: self.__datasets[name] for name in requested_datasets} return StructureCollection( self.__source, - self.__header, new_datasets, hide_source, self.__handler.make_derived(self.__source), @@ -1333,29 +1618,21 @@ def galaxies(self, *args, **kwargs): else: raise AttributeError("This collection does not contain galaxies!") - def make_schema(self, name: Optional[str] = None) -> Schema: + def make_schema(self, name: Optional[str] = None, **kwargs) -> Schema: children = {} source_name = self.__source.dtype datasets = self.__handler.resort(self.__source, self.__get_datasets()) + schema_kwargs: dict[str, Any] = ( + {"no_stack": True} if isinstance(self.__source, lc.Lightcone) else {} + ) - source_schema = self.__source.make_schema() - for colname, column in source_schema.children["data_linked"].columns.items(): - if "idx" in colname: - column.set_transformation(do_idx_update) - elif "start" in colname: - size_colname = colname.replace("start", "size") - size_data = ( - source_schema.children["data_linked"].columns[size_colname].data - ) - updater = partial(do_start_update, size=size_data) - column.set_transformation(updater) - - children[source_name] = source_schema + source_schema = self.__source.make_schema(**schema_kwargs) + children[source_name] = sio.rebuild_data_linked(source_schema) for name, dataset in datasets.items(): if name == "galaxies": name = "galaxy_properties" - ds_schema = dataset.make_schema() + ds_schema = dataset.make_schema(**schema_kwargs) if not isinstance(dataset, StructureCollection): children[name] = ds_schema continue diff --git a/src/opencosmo/column/__init__.py b/python/opencosmo/column/__init__.py similarity index 100% rename from src/opencosmo/column/__init__.py rename to python/opencosmo/column/__init__.py diff --git a/python/opencosmo/column/cache.py b/python/opencosmo/column/cache.py new file mode 100644 index 00000000..9e154a7c --- /dev/null +++ b/python/opencosmo/column/cache.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Iterable, Optional +from uuid import UUID +from weakref import finalize, ref + +import astropy.units as u +import numpy as np + +from opencosmo.index import DataIndex +from opencosmo.index.build import from_size +from opencosmo.index.get import get_data +from opencosmo.index.take import take +from opencosmo.index.unary import get_length, get_range +from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.io.writer import ColumnWriter + +if TYPE_CHECKING: + from opencosmo.index import DataIndex + +# (producer_uuid, column_name) — the unambiguous key for a cached column. +CacheKey = tuple[UUID, str] + +ColumnUpdater = Callable[[np.ndarray | u.Quantity], np.ndarray | u.Quantity] + + +def finish( + cached_data: dict[CacheKey, np.ndarray], + index: Optional[DataIndex], + cache_ref: ref[ColumnCache], +): + cache = cache_ref() + if cache is None: + return + + # pylint: disable=protected-access + if index is None: + index = from_size(len(cache)) + pairs_to_add = cache.registered_pairs.intersection(cached_data.keys()) - set( + cache.keys() + ) + data = {key[0]: {key[1]: get_data(cached_data[key], index)} for key in pairs_to_add} + if data: + cache.add_data(data) + + +def check_length(cache: ColumnCache, data: dict[UUID, dict[str, np.ndarray]]): + lengths = {len(arr) for uuid_data in data.values() for arr in uuid_data.values()} + if len(lengths) > 1: + raise ValueError( + "When adding data to the cache, all columns must be the same length" + ) + elif (length := len(cache)) > 0 and lengths and length != lengths.pop(): + raise ValueError( + "When adding data to the cache, the columns must be the same length as the columns currently in the cache" + ) + + +class ColumnCache: + """ + A column cache is used to persist data that is read from an hdf5 file. Caches can get data in one of two ways: + 1. They are explicitly given data that has been recently read from disk or + 2. They take data from a previous cache + + ColumnCaches break some of the rules that most other things follow in this library, notably that they have internal + state (which can change). This mutability is required for two reasons. + + 1. If the parent cache is garbage collected, the child cache needs to be able to copy over any data it needs + 2. If a new cache is created by adding columns, we need to signal the child to update their parent to the new + cache. This allows us to preserve the standard "operations create new objects" pattern that is present + throughout the library. + + Internal storage uses (producer_uuid, column_name) tuples as keys so that multiple + producers that happen to produce a column with the same name are kept separate. + """ + + def __init__( + self, + cached_data: dict[CacheKey, np.ndarray], + registered_column_groups: dict[int, set[CacheKey]], + column_descriptions: dict[str, str], + metadata_columns: set[str], + metadata_data: dict[str, np.ndarray], + derived_index: Optional[DataIndex], + parent: Optional[ref[ColumnCache]], + children: Optional[list[ref[ColumnCache]]], + ): + self.__cached_data = cached_data + self.__registered_column_groups = registered_column_groups + self.__metadata_columns = metadata_columns + self.__metadata_data = metadata_data + self.__descriptions = column_descriptions + self.__derived_index = derived_index + self.__parent = parent + if children is None: + children = [] + self.__children = children + self.__finalizer = None + + if parent is not None and (p := parent()) is not None: + self.__finalizer = finalize( + p, + finish, + p.__cached_data, + derived_index, + ref(self), + ) + self.__finalizer.atexit = False # type: ignore + + @classmethod + def empty(cls): + return ColumnCache({}, {}, {}, set(), {}, None, None, []) + + @property + def columns(self) -> set[str]: + return {name for (_, name) in self.__cached_data.keys()} + + def keys(self) -> set[CacheKey]: + return set(self.__cached_data.keys()) + + @property + def metadata_columns(self) -> set[str]: + return self.__metadata_columns + + @property + def descriptions(self) -> dict[str, str]: + return self.__descriptions + + @property + def registered_pairs(self) -> set[CacheKey]: + if not self.__registered_column_groups: + return set() + return set().union(*self.__registered_column_groups.values()) + + def create_child(self) -> ColumnCache: + return ColumnCache({}, {}, {}, self.__metadata_columns, {}, None, ref(self), []) + + def make_schema( + self, columns_to_uuid: dict[str, UUID], meta_columns: list[str] + ) -> tuple: + data = {} + metadata = {} + + cached = self.get_data({(uuid, name) for name, uuid in columns_to_uuid.items()}) + for name, coldata in _flatten(cached).items(): + if isinstance(coldata, u.Quantity): + column_data = coldata.value + unit_str = str(coldata.unit) + else: + column_data = coldata + unit_str = "" + attrs = { + "unit": unit_str, + "description": self.__descriptions.get(name, "None"), + } + writer = ColumnWriter.from_numpy_array(column_data, attrs=attrs) + data[name] = writer + + for name, coldata in self.get_metadata(meta_columns).items(): + if isinstance(coldata, u.Quantity): + column_data = coldata.value + unit_str = str(coldata.unit) + else: + column_data = coldata + unit_str = "" + attrs = { + "unit": unit_str, + "description": self.__descriptions.get(name, "None"), + } + writer = ColumnWriter.from_numpy_array(column_data, attrs=attrs) + metadata[name] = writer + + if not data and not metadata: + return ( + make_schema("data", FileEntry.EMPTY), + make_schema("metadata", FileEntry.EMPTY), + ) + + data_schema = ( + make_schema("data", FileEntry.COLUMNS, columns=data) + if data + else make_schema("data", FileEntry.EMPTY) + ) + metadata_schema = ( + make_schema("metadata", FileEntry.COLUMNS, columns=metadata) + if metadata + else make_schema("metadata", FileEntry.EMPTY) + ) + return data_schema, metadata_schema + + def __push_down(self, data: dict[CacheKey, np.ndarray]): + pairs_to_keep = self.registered_pairs.intersection(data.keys()).difference( + self.__cached_data.keys() + ) + cached_data = {key: data[key] for key in pairs_to_keep} + if self.__derived_index is not None: + cached_data = { + key: get_data(cd, self.__derived_index) + for key, cd in cached_data.items() + } + self.__cached_data |= cached_data + + def __push_up(self, data: dict[CacheKey, np.ndarray]): + assert len(self) == 0 or all(len(d) == len(self) for d in data.values()) + pairs_to_keep = self.registered_pairs.intersection(data.keys()).difference( + self.__cached_data.keys() + ) + self.__cached_data |= {key: data[key] for key in pairs_to_keep} + + def register_column_group(self, state_id: int, columns: dict[str, UUID]): + assert state_id not in self.__registered_column_groups + self.__registered_column_groups[state_id] = { + (uuid, name) for name, uuid in columns.items() + } + + def deregister_column_group(self, state_id: int): + assert state_id in self.__registered_column_groups + pairs = self.__registered_column_groups.pop(state_id) + remaining = ( + set().union(*self.__registered_column_groups.values()) + if self.__registered_column_groups + else set() + ) + to_drop = pairs.difference(remaining) + cached_data = { + key: self.__cached_data.pop(key) + for key in to_drop + if key in self.__cached_data + } + if not cached_data: + return + for child_ in self.__children: + if (child := child_()) is None: + continue + child.__push_down(cached_data) + + def __update_parent(self, parent: ColumnCache): + assert self.__parent is not None + assert self.__finalizer is not None + self.__finalizer.detach() + self.__parent = ref(parent) + self.__finalizer = finalize( + parent, finish, parent.__cached_data, self.__derived_index, ref(self) + ) + self.__finalizer.atexit = False # type: ignore + + def __len__(self): + if not self.__cached_data and self.__derived_index is None: + return 0 + elif self.__derived_index is not None: + return get_length(self.__derived_index) + elif self.__cached_data: + return len(next(iter(self.__cached_data.values()))) + elif self.__parent is not None and (p := self.__parent()) is not None: + return len(p) + return 0 + + def add_data( + self, + data: dict[UUID, dict[str, np.ndarray]], + descriptions: dict[str, str] = {}, + push_up: bool = True, + ): + """Add UUID-keyed column data to the cache.""" + + if not data: + return + check_length(self, data) + self.__descriptions |= descriptions + + flat: dict[CacheKey, np.ndarray] = { + (uuid, name): arr + for uuid, uuid_data in data.items() + for name, arr in uuid_data.items() + } + if ( + push_up + and self.__derived_index is None + and self.__parent is not None + and (p := self.__parent()) is not None + ): + p.__push_up(flat) + + self.__cached_data |= flat + + def add_metadata( + self, + data: dict[str, np.ndarray], + descriptions: dict[str, str] = {}, + ): + """Add metadata columns (name-keyed, no producer UUID) to the cache.""" + self.__metadata_columns = self.__metadata_columns.union(data.keys()) + self.__descriptions |= descriptions + self.__metadata_data |= data + + def drop(self, column_names: Iterable[str]) -> ColumnCache: + names_to_drop = set(column_names) + data = { + key: val + for key, val in self.__cached_data.items() + if key[1] not in names_to_drop + } + descriptions = { + name: desc + for name, desc in self.__descriptions.items() + if name not in names_to_drop + } + new_meta_columns = self.__metadata_columns.difference(names_to_drop) + new_meta_data = { + name: val + for name, val in self.__metadata_data.items() + if name not in names_to_drop + } + return ColumnCache( + data, {}, descriptions, new_meta_columns, new_meta_data, None, None, [] + ) + + def request( + self, pairs: set[CacheKey], index: Optional[DataIndex] + ) -> dict[CacheKey, np.ndarray]: + pairs_in_cache = pairs.intersection(self.__cached_data.keys()) + data = {key: self.__cached_data[key] for key in pairs_in_cache} + if index is not None: + data = {key: get_data(cd, index) for key, cd in data.items()} + + if self.__parent is None or pairs == pairs_in_cache: + return data + + parent = self.__parent() + if parent is None: + return data + + match (index, self.__derived_index): + case (None, None): + new_index = None + case (_, None): + new_index = index + case (None, _): + new_index = self.__derived_index + case _: + assert self.__derived_index is not None and index is not None + new_index = take(self.__derived_index, index) + + return data | parent.request(pairs - pairs_in_cache, new_index) + + def take(self, index: DataIndex) -> ColumnCache: + if len(self) == 0 and not self.columns: + return ColumnCache.empty() + if get_range(index)[1] > len(self): + raise ValueError( + "Tried to take more elements than the length of the cache!" + ) + new_cache = ColumnCache( + {}, {}, {}, self.__metadata_columns, {}, index, ref(self), [] + ) + self.__children.append(ref(new_cache)) + return new_cache + + def get_data(self, pairs: set[CacheKey]) -> dict[UUID, dict[str, np.ndarray]]: + """ + Retrieve data for the requested (uuid, name) pairs. Returns a + UUID-keyed dict so callers can look up each producer's contribution. + """ + pairs_in_cache = pairs.intersection(self.__cached_data.keys()) + missing_pairs = pairs - pairs_in_cache + flat = {key: self.__cached_data[key] for key in pairs_in_cache} + flat |= self.__get_derived_pairs(missing_pairs) + return _unflatten(flat) + + def get_metadata(self, column_names: Iterable[str]) -> dict[str, np.ndarray]: + """Retrieve name-keyed metadata columns.""" + names = set(column_names) + result = { + name: self.__metadata_data[name] + for name in names + if name in self.__metadata_data + } + if self.__parent is not None and (p := self.__parent()) is not None: + missing = names - set(result.keys()) + if missing: + result |= p.get_metadata(missing) + return result + + def __get_derived_pairs(self, pairs: set[CacheKey]) -> dict[CacheKey, np.ndarray]: + if self.__parent is None: + return {} + parent = self.__parent() + if parent is None: + return {} + result = parent.request(pairs, self.__derived_index) + self.__cached_data |= result + return result + + +def _flatten(uuid_data: dict[UUID, dict[str, np.ndarray]]) -> dict[str, np.ndarray]: + """Collapse UUID-keyed data to a flat name-keyed dict. Last writer wins.""" + return {name: arr for d in uuid_data.values() for name, arr in d.items()} + + +def _unflatten(flat: dict[CacheKey, np.ndarray]) -> dict[UUID, dict[str, np.ndarray]]: + """Group (uuid, name) → array back into {uuid: {name: array}}.""" + result: dict[UUID, dict[str, np.ndarray]] = {} + for (uuid, name), arr in flat.items(): + result.setdefault(uuid, {})[name] = arr + return result diff --git a/src/opencosmo/column/column.py b/python/opencosmo/column/column.py similarity index 58% rename from src/opencosmo/column/column.py rename to python/opencosmo/column/column.py index 60409ca0..d8dd419a 100644 --- a/src/opencosmo/column/column.py +++ b/python/opencosmo/column/column.py @@ -2,7 +2,7 @@ import operator as op from copy import copy -from functools import cached_property, partial, partialmethod, wraps +from functools import partial, partialmethod, wraps from inspect import signature from typing import ( TYPE_CHECKING, @@ -14,10 +14,10 @@ Self, Union, ) +from uuid import uuid4 import astropy.units as u # type: ignore import numpy as np -from astropy import table # type: ignore from opencosmo.column.evaluate import ( EvaluateStrategy, @@ -28,12 +28,39 @@ from opencosmo.units import UnitsError if TYPE_CHECKING: + from uuid import UUID + + from astropy import table + from opencosmo import Dataset + from opencosmo.index import DataIndex Comparison = Callable[[float, float], bool] +""" +The structures in this file are used both internally and user-facing. + +The Column class represents a reference to a single column in the dataset. If +`is_raw` is set to true, we are requiring that this column is actually instantiated +from data originally in the hdf5 file. This is not intended to be set by the user, +and is only used internally. + +A DerivedColumn represents a combination of columns that produces a single new column. +The columns it depends on may or may not be raw columns themselves, but eventually +the dependency graph always points back to raw columns, or columns that were provided +directly by the user as numpy arrays/astropy quantities. + +An evaluted column takes an arbitrary number of columns as input and returns an arbitrary +number of columns as output. These columns HAVE NOT actually been evaluated yet. +These combinations all form a dependency graph, which can easily be evaluated for validity. -def col(column_name: str) -> Column: +Raw columns and columns that are in-memory are allowed to be sources. All other columns +must take input. The dependency graph must be a DAG, where only those two types of columns +have no inputs. +""" + + +def col(name: str) -> Column: """ Create a reference to a column with a given name. These references can be combined to produce new columns or express queries that operate on the values in a given @@ -50,7 +77,7 @@ def col(column_name: str) -> Column: For more advanced usage, see :doc:`cols` """ - return Column(column_name) + return Column(name) def _require_scalar_quantity(func: Callable) -> Callable: @@ -68,7 +95,7 @@ def wrapper(self: Any, other: Any) -> Any: return wrapper -ColumnOrScalar = Union["Column", "DerivedColumn", int, float] +ColumnOrScalar = Union["ConstructedColumn", int, float, u.Quantity] def _log10( @@ -129,6 +156,57 @@ def _sqrt(left: np.ndarray | u.Unit, right: None): return left**0.5 +def _require_dimensionless(unit: u.UnitBase, func_name: str) -> None: + if not unit.is_equivalent(u.dimensionless_unscaled): + raise UnitsError( + f"{func_name} requires a dimensionless input, got unit '{unit}'" + ) + + +def _arcsin(left: Any, right: Any) -> Any: + if isinstance(left, u.UnitBase): + _require_dimensionless(left, "arcsin") + return u.rad + if isinstance(left, u.Quantity): + _require_dimensionless(left.unit, "arcsin") + return np.arcsin(left.value) * u.rad + return np.arcsin(left) + + +def _arccos(left: Any, right: Any) -> Any: + if isinstance(left, u.UnitBase): + _require_dimensionless(left, "arccos") + return u.rad + if isinstance(left, u.Quantity): + _require_dimensionless(left.unit, "arccos") + return np.arccos(left.value) * u.rad + return np.arccos(left) + + +def _arctan2(left: Any, right: Any) -> Any: + left_is_unit = isinstance(left, u.UnitBase) + right_is_unit = isinstance(right, u.UnitBase) + if left_is_unit or right_is_unit: + if not (left_is_unit and right_is_unit) or not left.is_equivalent(right): + raise UnitsError( + "arctan2 requires both inputs to have equivalent units or both to be unitless" + ) + return u.rad + left_is_qty = isinstance(left, u.Quantity) + right_is_qty = isinstance(right, u.Quantity) + if left_is_qty != right_is_qty: + raise UnitsError( + "arctan2 requires both inputs to have equivalent units or both to be unitless" + ) + if left_is_qty: + if not left.unit.is_equivalent(right.unit): + raise UnitsError( + f"arctan2 inputs have incompatible units: '{left.unit}' and '{right.unit}'" + ) + return np.arctan2(left.value, right.value) * u.rad + return np.arctan2(left, right) + + class Column: """ Represents a reference to a column with a given name. Column reference @@ -167,45 +245,34 @@ class Column: """ - def __init__(self, column_name: str): - self.column_name = column_name + def __init__(self, name: str): + self.name = name self.description = None @property def requires(self): - return set([self.column_name]) + return {self.name} - @property - def produces(self): - return None - - def get_units(self, column_units: dict[str, u.Unit]): - return column_units[self.column_name] - - def evaluate(self, data: dict[str, np.ndarray], *args): - return data[self.column_name] - - # mypy doesn't reason about eq and neq correctly def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.column_name, other, op.eq) + return ColumnMask(self, other, op.eq) def __ne__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore - return ColumnMask(self.column_name, other, op.ne) + return ColumnMask(self, other, op.ne) def __gt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.gt) + return ColumnMask(self, other, op.gt) def __ge__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.ge) + return ColumnMask(self, other, op.ge) def __lt__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.lt) + return ColumnMask(self, other, op.lt) def __le__(self, other: float | u.Quantity) -> ColumnMask: - return ColumnMask(self.column_name, other, op.le) + return ColumnMask(self, other, op.le) def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: - return ColumnMask(self.column_name, other, np.isin) + return ColumnMask(self, other, np.isin) @_require_scalar_quantity def __rmul__(self, other: Any) -> DerivedColumn: @@ -247,18 +314,16 @@ def __pow__(self, other: Any) -> DerivedColumn: case _: return NotImplemented - @_require_scalar_quantity def __add__(self, other: Any) -> DerivedColumn: match other: - case Column(): + case Column() | int() | float() | u.Quantity(): return DerivedColumn(self, other, op.add) case _: return NotImplemented - @_require_scalar_quantity def __sub__(self, other: Any) -> DerivedColumn: match other: - case Column(): + case Column() | int() | float() | u.Quantity(): return DerivedColumn(self, other, op.sub) case _: return NotImplemented @@ -294,26 +359,130 @@ def sqrt(self) -> DerivedColumn: """ return DerivedColumn(self, None, _sqrt) + def arcsin(self) -> DerivedColumn: + """ + Create a derived column containing the arcsine of this column (in radians). + The column must be dimensionless. + """ + return DerivedColumn(self, None, _arcsin) + + def arccos(self) -> DerivedColumn: + """ + Create a derived column containing the arccosine of this column (in radians). + The column must be dimensionless. + """ + return DerivedColumn(self, None, _arccos) + + def arctan2(self, other: ColumnOrScalar) -> DerivedColumn: + """ + Create a derived column containing arctan2(self, other) in radians. + Both columns must be dimensionless. + """ + return DerivedColumn(self, other, _arctan2) + class ConstructedColumn(Protocol): pass @property - def requires(self) -> set[str]: ... + def uuid(self) -> UUID: ... + + @property + def requires(self) -> set[UUID]: ... + + @property + def dep_map(self) -> dict[str, UUID] | None: ... + @property - def produces(self) -> Optional[set[str]]: ... + def produces(self) -> set[str]: ... + @property def description(self) -> Optional[str]: ... + def bind(self, name_to_uuid: dict[str, UUID]) -> Self: ... + + @property + def no_cache(self) -> bool: ... def evaluate( self, data: dict[str, np.ndarray], - chunk_sizes: Optional[np.ndarray], + index: DataIndex, ) -> np.ndarray | dict[str, np.ndarray]: ... def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: ... +class RawColumn: + def __init__(self, name, description, alias=None, _dep_uuid=None, _uuid=None): + self.__name = name + self.__description = description + self.__alias = alias + self.__uuid = _uuid if _uuid is not None else uuid4() + self.__dep_uuid: UUID | None = _dep_uuid + + @property + def uuid(self) -> UUID: + return self.__uuid + + @property + def name(self): + return self.__name + + def bind(self, name_to_uuid: dict[str, UUID]) -> RawColumn: + if self.__alias is None: + return self + dep_uuid = name_to_uuid[self.__name] + return RawColumn( + self.__name, + self.__description, + alias=self.__alias, + _dep_uuid=dep_uuid, + _uuid=self.__uuid, + ) + + @property + def requires(self) -> set[UUID]: + if self.__alias is None: + return set() + if self.__dep_uuid is None: + raise RuntimeError( + f"RawColumn alias '{self.__alias}' has not been bound yet." + ) + return {self.__dep_uuid} + + @property + def dep_map(self) -> dict[str, UUID]: + if self.__alias is None: + return {} + if self.__dep_uuid is None: + raise RuntimeError( + f"RawColumn alias '{self.__alias}' has not been bound yet." + ) + return {self.__name: self.__dep_uuid} + + @property + def no_cache(self): + return False + + @property + def alias(self) -> str | None: + return self.__alias + + @property + def produces(self) -> set[str]: + return set([self.__alias or self.__name]) + + @property + def description(self): + return self.__description + + def get_units(self, values: dict[str, u.Quantity]) -> dict[str, u.Unit]: + return values[self.__name] + + def evaluate(self, data: dict[str, np.ndarray], *args): + return data[self.__name] + + class DerivedColumn: """ A derived column represents a combination of multiple columns that already exist in @@ -336,39 +505,83 @@ def __init__( rhs: Optional[ColumnOrScalar], operation: Callable, description: Optional[str] = None, + output_name: Optional[str] = None, + _dep_map: dict[str, UUID] | None = None, + no_cache: bool = False, + _uuid: UUID | None = None, ): self.lhs = lhs self.rhs = rhs + + self.name = output_name self.operation = operation self.description = description if description is not None else "None" + self.__uuid = _uuid if _uuid is not None else uuid4() + self.__dep_map: dict[str, UUID] | None = _dep_map + self.__no_cache = no_cache - @cached_property - def requires(self): + @property + def uuid(self) -> UUID: + return self.__uuid + + @property + def dep_map(self) -> dict[str, UUID] | None: + return self.__dep_map + + def bind(self, name_to_uuid: dict[str, UUID]) -> DerivedColumn: """ - Return the raw data columns required to make this column + Resolve each dependency column name to the UUID of the producer that was + producing it at the time this column was registered with a dataset. + Returns a new bound DerivedColumn; does not mutate this instance. """ - vals = set() + required_names = self._traverse_names() + if missing := required_names.difference(name_to_uuid): + raise ValueError(f"Derived column depends on unknown columns {missing}") + + dep_map = {name: name_to_uuid[name] for name in self._traverse_names()} + return DerivedColumn( + self.lhs, + self.rhs, + self.operation, + self.description, + self.name, + _dep_map=dep_map, + _uuid=self.__uuid, + ) + + def _traverse_names(self) -> set[str]: + """Walk the expression tree and collect all Column leaf names.""" + vals: set[str] = set() match self.lhs: case Column(): - vals.add(self.lhs.column_name) + vals.add(self.lhs.name) case DerivedColumn(): - vals = vals | self.lhs.requires + vals |= self.lhs._traverse_names() match self.rhs: case Column(): - vals.add(self.rhs.column_name) + vals.add(self.rhs.name) case DerivedColumn(): - vals = vals | self.rhs.requires - + vals |= self.rhs._traverse_names() return vals + @property + def requires(self) -> set[UUID]: + if self.__dep_map is None: + raise RuntimeError(f"DerivedColumn '{self.name}' has not been bound yet.") + return set(self.__dep_map.values()) + @property def produces(self): - return None + return None if self.name is None else set([self.name]) + + @property + def no_cache(self): + return self.__no_cache def check_parent_existance(self, names: set[str]): match self.rhs: case Column(): - rhs_valid = self.rhs.column_name in names + rhs_valid = self.rhs.name in names case DerivedColumn(): rhs_valid = self.rhs.check_parent_existance(names) case _: @@ -376,7 +589,7 @@ def check_parent_existance(self, names: set[str]): match self.lhs: case Column(): - lhs_valid = self.lhs.column_name in names + lhs_valid = self.lhs.name in names case DerivedColumn(): lhs_valid = self.lhs.check_parent_existance(names) case _: @@ -387,7 +600,7 @@ def check_parent_existance(self, names: set[str]): def get_units(self, units: dict[str, u.Unit]): match self.lhs: case Column(): - lhs_unit = units[self.lhs.column_name] + lhs_unit = units[self.lhs.name] case DerivedColumn(): lhs_unit = self.lhs.get_units(units) case u.Quantity(): @@ -396,7 +609,7 @@ def get_units(self, units: dict[str, u.Unit]): lhs_unit = None match self.rhs: case Column(): - rhs_unit = units[self.rhs.column_name] + rhs_unit = units[self.rhs.name] case DerivedColumn(): rhs_unit = self.rhs.get_units(units) case u.Quantity(): @@ -475,21 +688,51 @@ def exp10(self, expected_unit_container: u.LogUnit = u.DexUnit): def sqrt(self): return DerivedColumn(self, None, _sqrt) + def arcsin(self) -> DerivedColumn: + return DerivedColumn(self, None, _arcsin) + + def arccos(self) -> DerivedColumn: + return DerivedColumn(self, None, _arccos) + + def arctan2(self, other: ColumnOrScalar) -> DerivedColumn: + return DerivedColumn(self, other, _arctan2) + + def __eq__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore + return ColumnMask(self, other, op.eq) + + def __ne__(self, other: float | u.Quantity) -> ColumnMask: # type: ignore + return ColumnMask(self, other, op.ne) + + def __gt__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.gt) + + def __ge__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.ge) + + def __lt__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.lt) + + def __le__(self, other: float | u.Quantity) -> ColumnMask: + return ColumnMask(self, other, op.le) + + def isin(self, other: Iterable[float | u.Quantity]) -> ColumnMask: + return ColumnMask(self, other, np.isin) + def evaluate(self, data: dict[str, np.ndarray], *args) -> np.ndarray: - lhs: np.typing.ArrayLike - rhs: Optional[np.typing.ArrayLike] + lhs: Any + rhs: Any match self.lhs: case DerivedColumn(): lhs = self.lhs.evaluate(data) case Column(): - lhs = data[self.lhs.column_name] + lhs = data[self.lhs.name] case _: lhs = self.lhs match self.rhs: case DerivedColumn(): rhs = self.rhs.evaluate(data) case Column(): - rhs = data[self.rhs.column_name] + rhs = data[self.rhs.name] case _: rhs = self.rhs @@ -508,6 +751,9 @@ def __init__( strategy: EvaluateStrategy = EvaluateStrategy.ROW_WISE, batch_size: int = -1, description: Optional[str] = None, + _dep_map: dict[str, UUID] | None = None, + no_cache: bool = False, + _uuid: UUID | None = None, **kwargs: Any, ): self.__func = func @@ -518,7 +764,47 @@ def __init__( self.__format = format self.__strategy = strategy self.__batch_size = batch_size + self.__no_cache = no_cache self.description = description + self.__uuid = _uuid if _uuid is not None else uuid4() + self.__dep_map = _dep_map + + @property + def uuid(self) -> UUID: + return self.__uuid + + @property + def dep_map(self) -> dict[str, UUID] | None: + return self.__dep_map + + def bind(self, name_to_uuid: dict[str, UUID]) -> EvaluatedColumn: + """ + Resolve each dependency column name to the UUID of the producer that was + producing it at the time this column was registered with a dataset. + Returns a new bound EvaluatedColumn; does not mutate this instance. + """ + dep_map = {name: name_to_uuid[name] for name in self.__requires} + return EvaluatedColumn( + self.__func, + self.__requires, + self.__produces, + self.__format, + self.__units, + self.__strategy, + self.__batch_size, + self.description, + _dep_map=dep_map, + no_cache=self.__no_cache, + _uuid=self.__uuid, + **self.__kwargs, + ) + + @property + def requires_names(self) -> set[str]: + """Return the required column names (as strings), for use in data lookup.""" + if self.__dep_map is not None: + return set(self.__dep_map.keys()) + return copy(self.__requires) def with_kwargs(self, **new_kwargs: Any): new_kwargs = self.__kwargs | new_kwargs @@ -531,12 +817,23 @@ def with_kwargs(self, **new_kwargs: Any): self.__strategy, self.__batch_size, self.description, + _dep_map=self.__dep_map, **new_kwargs, ) @property - def requires(self): - return copy(self.__requires) + def name(self): + return self.__func.__name__ + + @property + def requires(self) -> set[UUID]: + if self.__dep_map is None: + raise RuntimeError(f"EvaluatedColumn '{self.name}' has not been bound yet.") + return set(self.__dep_map.values()) + + @property + def no_cache(self): + return self.__no_cache @property def produces(self): @@ -565,13 +862,9 @@ def kwarg_names(self): def get_units(self, units: dict[str, np.ndarray]): return self.__units - def evaluate(self, data: dict[str, np.ndarray], chunk_sizes: Optional[np.ndarray]): + def evaluate(self, data: dict[str, np.ndarray], index: DataIndex | None): data = {name: data[name] for name in self.__requires} - if self.__format != "astropy": - data = { - name: val.value if isinstance(val, u.Quantity) else val - for name, val in data.items() - } + chunk_sizes = index[1] if isinstance(index, tuple) else None if self.batch_size > 0: length = len(next(iter(data.values()))) @@ -586,15 +879,35 @@ def evaluate(self, data: dict[str, np.ndarray], chunk_sizes: Optional[np.ndarray match strategy: case EvaluateStrategy.VECTORIZE: - return evaluate_vectorized(data, self.__func, self.__kwargs) + return evaluate_vectorized(data, self.__func, self.__kwargs, index) case EvaluateStrategy.ROW_WISE: - return evaluate_rows(data, self.__func, self.__kwargs) + return evaluate_rows(data, self.__func, self.__kwargs, self.__format) case EvaluateStrategy.CHUNKED: if chunk_sizes is None: raise ValueError( "Cannot evaluate in CHUNKED strategy with a non-chunked index" ) - return evaluate_chunks(data, self.__func, self.__kwargs, chunk_sizes) + return evaluate_chunks( + data, self.__func, self.__kwargs, chunk_sizes, self.__format + ) + + def evaluate_for_storage( + self, data: dict[str, np.ndarray], index: DataIndex | None + ) -> dict[str, np.ndarray]: + """ + Evaluate and return numpy-formatted output suitable for the column + cache. Input arrives in the numpy/astropy form used internally, so + it is first converted to the user's requested format before the + function runs, and the output is converted back to numpy. + """ + from opencosmo.dataset.formats import to_format_dict, to_numpy_dict + + required = {name: data[name] for name in self.__requires} + converted = to_format_dict(required, self.__format) + output = self.evaluate(converted, index) + if not isinstance(output, dict): + output = {next(iter(self.__produces)): output} + return to_numpy_dict(output) def evaluate_one(self, dataset: Dataset): match self.__strategy: @@ -639,36 +952,41 @@ class ColumnMask: def __init__( self, - column_name: str, - value: float | u.Quantity, + left: ColumnOrScalar, + right: ColumnOrScalar, operator: Callable[[table.Column, float | u.Quantity], np.ndarray], ): - self.column_name = column_name - self.value = value + self.left = left + self.right = right self.operator = operator - @property - def requires(self): - return {self.column_name} - - def apply(self, column: u.Quantity | np.ndarray) -> np.ndarray: - """ - mask the dataset based on the mask. - """ - # Astropy's errors are good enough here - if isinstance(column, table.Table): - column = column[self.column_name] - - if isinstance(self.value, u.Quantity) and isinstance(column, u.Quantity): - if self.value.unit != column.unit: - raise ValueError( - f"Incompatible units in fiter: {self.value.unit} and {column.unit}" - ) - - elif isinstance(column, u.Quantity): - return self.operator(column.value, self.value) + def apply(self, ds: Dataset): + match self.left: + case Column(): + left = ds.select(self.left.name).get_data() + case DerivedColumn(): + left = ds.select(data=self.left).get_data() + case _: + left = self.left - return self.operator(column, self.value) # type: ignore + right_selected = False + match self.right: + case Column(): + right = ds.select(self.right.name).get_data() + right_selected = True + case DerivedColumn(): + right = ds.select(data=self.right).get_data() + right_selected = True + case _: + right = self.right + if ( + isinstance(left, u.Quantity) + and not isinstance(right, u.Quantity) + and not right_selected + ): + return self.operator(left.value, right) + result = self.operator(left, right) + return result def __and__(self, other: Self | CompoundColumnMask): return CompoundColumnMask(self, other, lambda left, right: left & right) @@ -701,7 +1019,7 @@ def __and__(self, other: ColumnMask | Self): def __or__(self, other: ColumnMask | Self): return CompoundColumnMask(self, other, lambda left, right: left | right) - def apply(self, data): - left_mask = self.__left.apply(data) - right_mask = self.__right.apply(data) + def apply(self, ds: Dataset): + left_mask = self.__left.apply(ds) + right_mask = self.__right.apply(ds) return self.__op(left_mask, right_mask) diff --git a/python/opencosmo/column/evaluate.py b/python/opencosmo/column/evaluate.py new file mode 100644 index 00000000..d4822c84 --- /dev/null +++ b/python/opencosmo/column/evaluate.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np + +if TYPE_CHECKING: + from opencosmo import Dataset + + +class EvaluateStrategy(Enum): + VECTORIZE = "vectorize" + ROW_WISE = "row_wise" + CHUNKED = "chunked" + + +def evaluate_rows( + data: dict[str, Any], + func: Callable, + kwargs: dict[str, Any], + format: str, +): + from opencosmo.dataset.formats import stack_rows + + data_length = len(next(iter(data.values()))) + per_column: dict[str, list] = {} + for i in range(data_length): + iterable_inputs = {name: values[i] for name, values in data.items()} + output = func(**iterable_inputs, **kwargs) + if not isinstance(output, dict): + output = {func.__name__: output} + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: stack_rows(values, format) for name, values in per_column.items()} + + +def evaluate_chunks( + data: dict[str, Any], + func: Callable, + kwargs: dict[str, Any], + chunk_sizes: np.ndarray, + format: str, +): + from opencosmo.dataset.formats import concat_chunks + + chunk_splits = np.cumsum(chunk_sizes) + starts = np.concatenate([[0], chunk_splits[:-1]]) + per_column: dict[str, list] = {} + for start, end in zip(starts, chunk_splits): + chunk_input_data = { + name: arr[int(start) : int(end)] for name, arr in data.items() + } + output = func(**chunk_input_data, **kwargs) + if not isinstance(output, dict): + output = {func.__name__: output} + for name, value in output.items(): + per_column.setdefault(name, []).append(value) + return {name: concat_chunks(chunks, format) for name, chunks in per_column.items()} + + +def evaluate_vectorized(data, func, kwargs, index): + try: + return func(**data, **kwargs, index=index) + except TypeError: + return func(**data, **kwargs) + + +def do_first_evaluation( + func: Callable, + strategy: str, + format: str, + kwargs: dict[str, Any], + dataset: Dataset, +): + from opencosmo.dataset.formats import fetch_as_dict + + eval_strategy = EvaluateStrategy(strategy) + columns = list(dataset.columns) + match eval_strategy: + case EvaluateStrategy.VECTORIZE: + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) + return func(**values, **kwargs), eval_strategy + + case EvaluateStrategy.ROW_WISE: + values = fetch_as_dict(dataset.take(1), columns, format, unpack=False) + values = {name: container[0] for name, container in values.items()} + return func(**values, **kwargs), eval_strategy + + case EvaluateStrategy.CHUNKED: + index = dataset.index + assert isinstance(index, tuple) + first_chunk_size = index[1][0] + first_chunk = fetch_as_dict( + dataset.take(first_chunk_size, at="start"), columns, format + ) + return func(**first_chunk, **kwargs), eval_strategy diff --git a/src/opencosmo/column/select.py b/python/opencosmo/column/select.py similarity index 100% rename from src/opencosmo/column/select.py rename to python/opencosmo/column/select.py diff --git a/src/opencosmo/column/stock.py b/python/opencosmo/column/stock.py similarity index 100% rename from src/opencosmo/column/stock.py rename to python/opencosmo/column/stock.py diff --git a/src/opencosmo/cosmology.py b/python/opencosmo/cosmology.py similarity index 97% rename from src/opencosmo/cosmology.py rename to python/opencosmo/cosmology.py index b9d23548..3c07e294 100644 --- a/src/opencosmo/cosmology.py +++ b/python/opencosmo/cosmology.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import h5py - from opencosmo.parameters import CosmologyParameters + from opencosmo.dtypes import CosmologyParameters """ Reads cosmology from the header of the file and returns the diff --git a/src/opencosmo/dataset/__init__.py b/python/opencosmo/dataset/__init__.py similarity index 100% rename from src/opencosmo/dataset/__init__.py rename to python/opencosmo/dataset/__init__.py diff --git a/src/opencosmo/dataset/build.py b/python/opencosmo/dataset/build.py similarity index 97% rename from src/opencosmo/dataset/build.py rename to python/opencosmo/dataset/build.py index cfb8cd74..edcd1d62 100644 --- a/src/opencosmo/dataset/build.py +++ b/python/opencosmo/dataset/build.py @@ -7,8 +7,8 @@ import h5py import numpy as np +import opencosmo.dataset.state as state from opencosmo.dataset import Dataset -from opencosmo.dataset.state import DatasetState from opencosmo.spatial.healpix import HealPixIndex from opencosmo.spatial.tree import Tree from opencosmo.spatial.utils import combine_upwards @@ -51,12 +51,13 @@ def build_dataset_from_data( metadata_group = {} data_descriptions = descriptions.get("data", {}) - new_state = DatasetState.in_memory( + new_state = state.state_in_memory( data_group, metadata_group, header, header.file.unit_convention, region, + {}, data_descriptions, ) return Dataset(header, new_state, tree=tree) diff --git a/python/opencosmo/dataset/columns.py b/python/opencosmo/dataset/columns.py new file mode 100644 index 00000000..e8ffab6b --- /dev/null +++ b/python/opencosmo/dataset/columns.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING, Optional + +import astropy.units as u +import numpy as np + +from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn, RawColumn +from opencosmo.dataset.graph import validate_column_producers + +if TYPE_CHECKING: + from uuid import UUID + + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache + from opencosmo.units.handler import UnitHandler + + ColumnMap = dict[str, UUID] + + +def resort(columns: dict[str, np.ndarray], sorted_index: Optional[np.ndarray]): + if sorted_index is None or not columns: + return columns + reverse_sort = np.argsort(sorted_index) + return {name: data[reverse_sort] for name, data in columns.items()} + + +def validate_in_memory_columns( + columns: dict[str, np.ndarray], unit_handler: UnitHandler, ds_length: int +) -> UnitHandler: + new_units = {} + for colname, column in columns.items(): + if len(column) != ds_length: + raise ValueError(f"Column {colname} is not the same length as the dataset!") + if isinstance(column, u.Quantity): + new_units[colname] = column.unit + else: + new_units[colname] = None + + return unit_handler.with_static_columns(**new_units) + + +def __categorize_columns( + new_columns: dict, + descriptions: dict[str, str], + ds_length: int, +) -> tuple[list[ConstructedColumn], dict, dict, list[str], dict]: + """ + Classify incoming columns by type and build the producer list, in-memory column + dicts, and static unit map needed by add_columns. + + Returns: + new_derived_columns, new_in_memory_columns, new_in_memory_descriptions, + new_column_names, new_static_units + """ + new_derived_columns: list[ConstructedColumn] = [] + new_in_memory_columns: dict = {} + new_in_memory_descriptions: dict = {} + new_column_names: list[str] = [] + new_static_units: dict = {} + + for colname, column in new_columns.items(): + match column: + case DerivedColumn(): + column.name = colname + column.description = descriptions.get(colname, "None") + new_derived_columns.append(column) + new_column_names.extend(column.produces) + case EvaluatedColumn(): + column.description = descriptions.get(colname, "None") + new_derived_columns.append(column) + new_column_names.extend(column.produces) + case Column(): + producer = DerivedColumn( + lhs=column, + rhs=None, + operation=lambda x, _: x, + output_name=colname, + ) + new_derived_columns.append(producer) + new_column_names.extend(producer.produces) + case np.ndarray(): + if len(column) != ds_length: + raise ValueError( + f"New column {colname} does not have the same length as this dataset!" + ) + new_in_memory_descriptions[colname] = descriptions.get(colname, "None") + new_in_memory_columns[colname] = column + new_column_names.append(colname) + new_derived_columns.append(RawColumn(colname, None)) + new_static_units[colname] = ( + column.unit if isinstance(column, u.Quantity) else None + ) + case _: + raise ValueError(f"Got an invalid new column of type {type(column)}") + + return ( + new_derived_columns, + new_in_memory_columns, + new_in_memory_descriptions, + new_column_names, + new_static_units, + ) + + +def add_columns( + producers: list[ConstructedColumn], + unit_handler: UnitHandler, + cache: DataCache, + name_to_uuid: ColumnMap, + sorted_index: np.ndarray | None, + descriptions: dict[str, str], + new_columns: dict, + ds_length: int, + allow_overwrite: bool, +) -> tuple[list[ConstructedColumn], ColumnMap, UnitHandler]: + if ( + inter := set(name_to_uuid.keys()).intersection(new_columns.keys()) + ) and not allow_overwrite: + raise ValueError(f"Some columns are already in the dataset: {inter}") + + ( + new_derived_columns, + new_in_memory_columns, + new_in_memory_descriptions, + new_column_names, + new_static_units, + ) = __categorize_columns(new_columns, descriptions, ds_length) + + # Extend the name→UUID map with new producers' outputs so that columns added + # in the same with_new_columns call can reference each other. + # For overwritten columns, preserve the OLD UUID so that new producers can + # correctly declare a dependency on the existing data rather than on themselves. + extended_name_to_uuid = dict(name_to_uuid) + for producer in new_derived_columns: + if producer.produces: + for name in producer.produces: + if name not in name_to_uuid: + extended_name_to_uuid[name] = producer.uuid + + new_derived_columns = [ + producer.bind(extended_name_to_uuid) for producer in new_derived_columns + ] + + new_unit_handler = unit_handler.with_static_columns(**new_static_units) + new_producers = copy(producers) + new_derived_columns + + new_units = validate_column_producers(new_producers, new_unit_handler) + if new_units: + new_unit_handler = new_unit_handler.with_new_columns(**new_units) + + if new_in_memory_columns: + new_unit_handler = validate_in_memory_columns( + new_in_memory_columns, unit_handler, ds_length + ) + new_in_memory_columns = resort(new_in_memory_columns, sorted_index) + # Build UUID-keyed data: each in-memory column is owned by its RawColumn producer. + uuid_keyed = { + producer.uuid: {producer.name: new_in_memory_columns[producer.name]} + for producer in new_derived_columns + if isinstance(producer, RawColumn) + and producer.name in new_in_memory_columns + } + cache.add_data(uuid_keyed, descriptions=new_in_memory_descriptions) + + updated_columns = dict(name_to_uuid) + for producer in new_derived_columns: + for name in producer.produces: + updated_columns[name] = producer.uuid + + return new_producers, updated_columns, new_unit_handler diff --git a/src/opencosmo/dataset/dataset.py b/python/opencosmo/dataset/dataset.py similarity index 74% rename from src/opencosmo/dataset/dataset.py rename to python/opencosmo/dataset/dataset.py index d0baf19f..4dcba8e0 100644 --- a/src/opencosmo/dataset/dataset.py +++ b/python/opencosmo/dataset/dataset.py @@ -1,11 +1,11 @@ from __future__ import annotations -from functools import reduce from typing import ( TYPE_CHECKING, Callable, Generator, Iterable, + Literal, Mapping, Optional, TypeAlias, @@ -17,10 +17,16 @@ from astropy.table import QTable # type: ignore from deprecated.sphinx import deprecated +import opencosmo.dataset.state as st from opencosmo.column import Column from opencosmo.dataset.evaluate import build_evaluated_column, visit_dataset from opencosmo.dataset.formats import convert_data, verify_format -from opencosmo.index import into_array, mask, project +from opencosmo.dataset.take import ( + get_end_take_index, + get_random_take_index, + get_range_take_index, +) +from opencosmo.index import empty, get_range, into_array, mask, project from opencosmo.spatial import check from opencosmo.units.converters import get_scale_factor @@ -29,10 +35,10 @@ from opencosmo.column.column import Column, ColumnMask, ConstructedColumn from opencosmo.dataset.state import DatasetState + from opencosmo.dtypes import HaccSimulationParameters from opencosmo.header import OpenCosmoHeader from opencosmo.index import DataIndex from opencosmo.io.schema import Schema - from opencosmo.parameters import HaccSimulationParameters from opencosmo.spatial.protocols import Region from opencosmo.spatial.tree import Tree @@ -64,7 +70,7 @@ def __repr__(self): repr_ds = self.take(10, at="start") table_head = "First 10 rows:\n" - table_repr = repr_ds.data.__repr__() + table_repr = repr_ds.get_data().__repr__() # remove the first line table_repr = table_repr[table_repr.find("\n") + 1 :] head = f"OpenCosmo Dataset (length={length})\n" @@ -83,10 +89,10 @@ def __enter__(self): return self def __exit__(self, *exc_details): - return self.__state.__exit__(*exc_details) + return st.exit_state(self.__state, *exc_details) def close(self): - return self.__state.__exit__() + return st.exit_state(self.__state) @property def header(self) -> OpenCosmoHeader: @@ -207,10 +213,25 @@ def simulation(self) -> Optional[HaccSimulationParameters]: Returns ------- - parameters: Optional[opencosmo.parameters.hacc.HaccSimulationParameters] + parameters: Optional[opencosmo.dtypes.hacc.HaccSimulationParameters] """ return getattr(self.__header, "simulation", None) + @property + def sorted_by(self) -> Optional[str]: + """ + The column this dataset is sorted by. If not sorted, returns None. + + Returns + ------- + column: Optional[str] + """ + return self.__state.sort_key[0] if self.__state.sort_key is not None else None + + @property + def tree(self) -> Optional[Tree]: + return self.__tree + @property @deprecated( version="1.1.0", @@ -232,14 +253,19 @@ def data(self) -> QTable | u.Quantity: # Also the point is that there's MORE data than just the table return self.get_data("astropy") - def get_metadata(self, columns: str | list[str] = []): + def get_metadata(self, columns: str | list[str] = [], ignore_sort: bool = False): if isinstance(columns, str): columns = [columns] - return self.__state.get_metadata(columns) + return st.get_metadata(self.__state, columns, ignore_sort) def get_data( - self, format="astropy", unpack=True, metadata_columns=[], **kwargs + self, + format="astropy", + unpack=True, + metadata_columns=[], + wrap_single=False, + **kwargs, ) -> OpenCosmoData: """ Get the data in this dataset as an astropy table/column or as @@ -249,7 +275,7 @@ def get_data( on the data. The method supports output into several different formats, including - "astropy", "numpy", "pandas", "polars", and "pyarrow". Although astropy + "astropy", "numpy", "pandas", "polars", "jax", and "arrow". Although astropy and numpy are core dependencies of OpenCosmo, the remaining formats require you to have the relevant libraries installed in your python environment. This method will check that it can import the necessary @@ -259,13 +285,19 @@ def get_data( If the dataset only contains a single column, it will not be put in a table or dictionary. "astropy", "numpy" and "arrow" will return a single array - in this case, while "polars" and "pandas" will return a Series object. + in this case, while "polars" and "pandas" will return a Series object. Pass + :code:`wrap_single=True` to always return the format's multi-column container + (QTable, DataFrame, dict, ...) regardless of column count. Parameters ---------- output: str, default="astropy" The format to output the data in. - Currently supported are "astropy", "numpy", "pandas", "polars", "arrow" + Currently supported are "astropy", "numpy", "pandas", "polars", "arrow", "jax" + + wrap_single: bool, default=False + If True, always return the format's natural multi-column container even + when only one column is present. Returns ------- @@ -286,8 +318,11 @@ def get_data( else: unit_kwargs = {} - data = self.__state.get_data( - unit_kwargs=unit_kwargs, metadata_columns=metadata_columns + data = st.get_data( + self.__state, + unit_kwargs=unit_kwargs, + metadata_columns=metadata_columns, + **kwargs, ) # dict if unpack: data = { @@ -297,7 +332,7 @@ def get_data( for key, value in data.items() } - return convert_data(data, format) + return convert_data(data, format, wrap_single=wrap_single) def bound(self, region: Region, select_by: Optional[str] = None): """ @@ -336,7 +371,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): columns = check.find_coordinates_3d(self, self.dtype) check_region = region.into_base_convention( - self.__state.unit_handler, + self.__state.unit_handler, # type: ignore[arg-type] columns, self.__state.convention, { @@ -349,7 +384,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): check_region = region if not self.__state.region.intersects(check_region): - new_state = self.__state.take_rows(np.array([])) + new_state = st.take_rows(self.__state, empty()) return Dataset(self.__header, new_state, self.__tree) if not self.__state.region.contains(check_region): @@ -365,7 +400,7 @@ def bound(self, region: Region, select_by: Optional[str] = None): contained_index = project(self.__state.raw_index, contained_index) intersects_index = project(self.__state.raw_index, intersects_index) - check_state = self.__state.take_rows(intersects_index) + check_state = st.take_rows(self.__state, intersects_index) check_dataset = Dataset( self.__header, check_state, @@ -382,11 +417,13 @@ def bound(self, region: Region, select_by: Optional[str] = None): else: new_intersects_index = np.array([], dtype=np.int64) - new_index = np.concatenate( - [into_array(contained_index), into_array(new_intersects_index)] + new_index = np.sort( + np.concatenate( + [into_array(contained_index), into_array(new_intersects_index)] + ) ) - new_state = self.__state.take_rows(new_index).with_region(check_region) + new_state = st.with_region(st.take_rows(self.__state, new_index), check_region) return Dataset(self.__header, new_state, self.__tree) @@ -397,6 +434,7 @@ def evaluate( insert=True, format="astropy", batch_size: int = -1, + allow_overwrite: bool = False, _verify: bool = True, **evaluate_kwargs, ) -> Dataset | dict[str, np.ndarray]: @@ -411,18 +449,31 @@ def evaluate( instead of adding it as a column. The function should take in arguments with the same name as the columns in this dataset that - are needed for the computation, and should return a dictionary of output values. - The dataset will automatically selected the needed columns to avoid reading unnecessarily reading - data from disk + are needed for the computation, and should return a dictionary of output values. Any addition + arguments needed by the function can be passed as keyword arguments to :code:`evaluate`. - The new columns will have the same names as the keys of the output dictionary - See :ref:`Evaluating On Datasets` for more details. + The dataset will automatically selected the needed columns to avoid reading unnecessarily reading + data from disk. The new columns will have the same names as the keys of the output dictionary + See :ref:`Evaluating On Datasets` for more details. The keys of this dictionary must be different + from the names of the columns that are already in the dataset, unless allow_overwrite is set + to :code`True` If vectorize is set to True, the full columns will be pased to the dataset. Otherwise, - rows will be passed to the function one at a time. + rows will be passed to the function one at a time. If the function returns None, this method + will also return None as output. + + Keyword arguments can be used to pass in external values that are not columns in the dataset. + For example, we can compute each halo's gas fraction bias — how much gas it retains relative to + the cosmic baryon fraction — by passing the dataset's cosmology object as a keyword argument: + + .. code-block:: python + + def baryon_fraction_bias(sod_halo_mass_gas, sod_halo_mass, cosmology): + f_gas = sod_halo_mass_gas / sod_halo_mass + f_cosmic = cosmology.Ob0 / cosmology.Om0 + return {"sod_halo_baryon_bias": f_gas / f_cosmic} - If the function returns None, this method will also return None as output. For example, the function - could simply produce plots and save the to files. + ds = ds.evaluate(baryon_fraction_bias, cosmology=ds.cosmology, vectorize=True) Parameters ---------- @@ -438,8 +489,11 @@ def evaluate( as the function. Otherwise the data will be returned directly. format: str, default = astropy - Whether to provide data to your function as "astropy" quantities or "numpy" arrays/scalars. Default "astropy". Note that - this method does not support all the formats available in :py:meth:`get_data ` + The format in which to provide column data to your function. Supports the same formats + as :py:meth:`get_data ` ("astropy", "numpy", "pandas", + "polars", "arrow", "jax"). When :code:`insert=True`, the function's output is converted + back to numpy before being stored. + allow_overwrite: bool, default = False batch_size: int, default = -1 If set, feed data to the function in batches of the specified size. Default is -1, which disables batching. If @@ -455,6 +509,7 @@ def evaluate( result : Dataset | dict[str, np.ndarray | astropy.units.Quantity] The new dataset with the evaluated column(s) or the results as numpy arrays or astropy quantities """ + verify_format(format) evaluated_column = build_evaluated_column( self, func, vectorize, insert, format, batch_size, evaluate_kwargs ) @@ -464,7 +519,9 @@ def evaluate( return output return self.with_new_columns( - descriptions={}, **{func.__name__: evaluated_column} + descriptions={}, + allow_overwrite=allow_overwrite, + **{func.__name__: evaluated_column}, ) def filter(self, *masks: ColumnMask) -> Dataset: @@ -491,15 +548,11 @@ def filter(self, *masks: ColumnMask) -> Dataset: """ if not masks: return self - required_columns: set[str] = reduce( - lambda acc, r: acc | r.requires, masks, set() - ) - data = self.select(required_columns).get_data(unpack=False) - bool_mask = np.ones(len(data), dtype=bool) + bool_mask = np.ones(len(self), dtype=bool) for m in masks: - bool_mask &= m.apply(data) + bool_mask &= m.apply(self) - new_state = self.__state.with_mask(bool_mask) + new_state = st.take_rows(self.__state, np.where(bool_mask)[0]) return Dataset(self.__header, new_state, self.__tree) def rows( @@ -530,7 +583,7 @@ def rows( else: unit_kwargs = {} - for row in self.__state.rows(metadata_columns, unit_kwargs): + for row in st.iter_rows(self.__state, metadata_columns, unit_kwargs): output_data = row if not isinstance(output_data, dict): output_data = {self.columns[0]: row} @@ -591,10 +644,10 @@ def select( new_state = self.__state if derived_columns: - new_state = new_state.with_new_columns({}, **derived_columns) + new_state = st.with_new_columns(new_state, {}, False, **derived_columns) all_columns.update(derived_columns.keys()) - new_state = new_state.select(all_columns) + new_state = st.select(new_state, all_columns) return Dataset( self.__header, new_state, @@ -631,14 +684,14 @@ def drop(self, *columns: str | Iterable[str]) -> Dataset: col_group = {col_group} all_columns.update(col_group) - new_state = self.__state.select(all_columns, drop=True) + new_state = st.select(self.__state, all_columns, drop=True) return Dataset( self.__header, new_state, self.__tree, ) - def sort_by(self, column: str, invert: bool = False) -> Dataset: + def sort_by(self, column: Optional[str], invert: bool = False) -> Dataset: """ Sort this dataset by the values in a given column. By default sorting is in ascending order (least to greatest). Pass invert = True to sort in descending @@ -656,9 +709,9 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: Parameters ---------- - column : str + column : Optional[str] The column in the halo_properties or galaxy_properties dataset to - order the collection by. + order the collection by. Pass :code:`None` to remove sorting. invert : bool, default = False If False (the default), ordering will be from least to greatest. @@ -671,7 +724,7 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: """ - new_state = self.__state.sort_by(column, invert) + new_state = st.sort_by(self.__state, column, invert) return Dataset( self.__header, new_state, @@ -679,9 +732,7 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: ) def take( - self, - n: int, - at: str = "random", + self, n: int, at: str = "random", mode: Literal["local", "global"] = "local" ) -> Dataset: """ Create a new dataset from some number of rows from this dataset. @@ -697,7 +748,16 @@ def take( at : str Where to take the rows from. One of "start", "end", or "random". The default is "random". + mode : str, "local" or "global", default = "local" + Controls how ``n`` is interpreted when running under MPI. Has no + effect if you are not using MPI. + * ``"local"`` (default): ``n`` rows are taken independently on + each rank. + * ``"global"``: ``n`` is the total number of rows to select across + all ranks combined. Each rank receives the portion of those rows + that it owns. If the dataset is sorted, ranks will coordinate + to take from the globally-sorted dataset. Returns ------- @@ -711,16 +771,20 @@ def take( or if 'at' is invalid. """ - - new_state = self.__state.take(n, at) - - return Dataset( - self.__header, - new_state, - self.__tree, - ) - - def take_range(self, start: int, end: int) -> Dataset: + if at == "start": + return self.take_range(0, n, mode) + elif at == "end": + take_index = get_end_take_index(n, self, self.__state.sort_key, mode) + return self.take_rows(take_index) + elif at != "random": + raise ValueError(f"Unknown take type {at}") + + row_indices = get_random_take_index(n, len(self), mode) + return self.take_rows(row_indices) + + def take_range( + self, start: int, end: int, mode: Literal["local", "global"] = "local" + ) -> Dataset: """ Create a new dataset from a row range in this dataset. We use standard indexing conventions, so the rows included will be start -> end - 1. @@ -728,14 +792,25 @@ def take_range(self, start: int, end: int) -> Dataset: Parameters ---------- start : int - The beginning of the range + The beginning of the range. end : int - The end of the range + The end of the range (exclusive). + + mode : str, "local" or "global", default = "local" + Controls how ``start`` and ``end`` are interpreted when running + under MPI. Has no effect if you are not using MPI. + + * ``"local"`` (default): the range is applied independently on + each rank. + * ``"global"``: ``start`` and ``end`` index into the global row + space across all ranks combined. Each rank receives the portion + of that range it owns. If the dataset is sorted, ranks will + coordinate to take from the globally-sorted dataset. Returns ------- - table : astropy.table.Table - The table with only the rows from start to end. + dataset : Dataset + The new dataset with only the rows from start to end. Raises ------ @@ -744,14 +819,15 @@ def take_range(self, start: int, end: int) -> Dataset: or if end is greater than start. """ + if start < 0 or end < 0: + raise ValueError("start and end must be positive.") + if end < start: + raise ValueError("end must be greater than start.") - new_state = self.__state.take_range(start, end) - - return Dataset( - self.__header, - new_state, - self.__tree, + take_index = get_range_take_index( + self, self.__state.sort_key, start, end - start, mode ) + return self.take_rows(take_index) def take_rows(self, rows: np.ndarray | DataIndex): """ @@ -773,12 +849,20 @@ def take_rows(self, rows: np.ndarray | DataIndex): dataset. """ - new_state = self.__state.take_rows(rows) + + row_range = get_range(rows) + if row_range[0] < 0 or row_range[1] > len(self): + raise ValueError( + "Row indices must be between 0 and the length of this dataset - 1!" + ) + + new_state = st.take_rows(self.__state, rows) return Dataset(self.__header, new_state, self.__tree) def with_new_columns( self, descriptions: str | dict[str, str] = {}, + allow_overwrite: bool = False, **new_columns: ConstructedColumn | Column | np.ndarray | u.Quantity, ): """ @@ -789,6 +873,25 @@ def with_new_columns( quantities will not change under unit transformations. See :ref:`Adding Custom Columns` for examples. + If allow_overwrite is :code:`True`, the new column may have the same name as + a column that already exists in the dataset. This can be used to transform a column, + for example: + + .. code-block:: python + + log_mass = oc.col("fof_halo_mass").log10() + ds = ds.with_new_columns(fof_halo_mass=log_mass, allow_overwrite=True) + + The "fof_halo_mass" column will now be the log of the original "fof_halo_mass" column. + + Columns will be given the same name as the argument you use when you pass them into the function. + For example, we could do the same as above but name the column "log_fof_halo_mass" with + + .. code-block:: python + + log_mass = oc.col("fof_halo_mass").logo10() + ds = ds.with_new_columns(log_fof_halo_mass = log_mass) + Parameters ---------- @@ -796,8 +899,12 @@ def with_new_columns( A description for the new columns. These descriptions will be accessible through :py:attr:`Dataset.descriptions `. If a dictionary, should have keys matching the column names. + allow_overwrites : bool, default = False + If false, attempting to add a new column with the same name as an existing column will throw an error. + If true, overwrites are allowed. ** new_columns : opencosmo.DerivedColumn | np.ndarray | units.Quantity + The new columns to add. The name of the argument is the name the column will take. Returns ------- @@ -807,7 +914,9 @@ def with_new_columns( """ if isinstance(descriptions, str): descriptions = {key: descriptions for key in new_columns.keys()} - new_state = self.__state.with_new_columns(descriptions, **new_columns) + new_state = st.with_new_columns( + self.__state, descriptions, allow_overwrite, **new_columns + ) return Dataset(self.__header, new_state, self.__tree) def make_schema( @@ -825,7 +934,7 @@ def make_schema( The name of the dataset in the file. The default is "data". """ - schema = self.__state.make_schema(name) + schema = st.make_schema(self.__state, name) if self.__tree is not None: tree = self.__tree.apply_index(self.__state.raw_index) @@ -901,8 +1010,13 @@ def with_units( """ - new_state = self.__state.with_units( - convention, conversions, columns, self.cosmology, self.redshift + new_state = st.with_units( + self.__state, + convention, + conversions, + columns, + self.cosmology, + self.redshift, ) if convention is not None: new_header = self.__header.with_units(convention) diff --git a/src/opencosmo/dataset/evaluate.py b/python/opencosmo/dataset/evaluate.py similarity index 60% rename from src/opencosmo/dataset/evaluate.py rename to python/opencosmo/dataset/evaluate.py index 9910e22b..2f14e24c 100644 --- a/src/opencosmo/dataset/evaluate.py +++ b/python/opencosmo/dataset/evaluate.py @@ -2,18 +2,14 @@ from collections import defaultdict from inspect import Parameter, signature -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np from astropy.units import Quantity from opencosmo.column.column import EvaluatedColumn from opencosmo.column.evaluate import EvaluateStrategy, do_first_evaluation -from opencosmo.evaluate import ( - insert_data, - make_output_from_first_values, -) +from opencosmo.dataset.formats import concat_chunks, fetch_as_dict if TYPE_CHECKING: from opencosmo import Dataset @@ -27,10 +23,6 @@ def build_evaluated_column( dataset, func, vectorize, insert, format, batch_size, evaluate_kwargs ): - if format not in ["astropy", "numpy"]: - raise ValueError( - f"Evaluate only supports numpy and astropy format, got: {format}" - ) kwarg_columns = set(evaluate_kwargs.keys()).intersection(dataset.columns) if kwarg_columns: raise ValueError( @@ -69,11 +61,7 @@ def visit_dataset( ) -> dict[str, np.ndarray]: if column.batch_size > 0: return visit_dataset_batched(column, dataset) - data = dataset.select(column.requires).get_data(format=column.format) - try: - data = dict(data) - except (TypeError, ValueError): - data = {column.requires.pop(): data} + data = fetch_as_dict(dataset, column.requires_names, column.format) output = column.evaluate(data, dataset.index) if not isinstance(output, dict): assert len(column.produces) == 1 @@ -89,22 +77,21 @@ def visit_dataset_batched(column: EvaluatedColumn, dataset: Dataset): output = defaultdict(list) for start, end in np.lib.stride_tricks.sliding_window_view(ranges, 2): - batch_data = ( - dataset.select(column.requires) - .take_range(start, end) - .get_data(format=column.format, unpack=False) + batch_data = fetch_as_dict( + dataset.take_range(start, end), + column.requires_names, + column.format, + unpack=False, ) - try: - batch_data = dict(batch_data) - except TypeError: - batch_data = {column.requires.pop(): batch_data} batch_output = column.evaluate(batch_data, None) if batch_output is not None and not isinstance(batch_output, dict): batch_output = {column.produces.pop(): batch_output} for name, column_batch in batch_output.items(): output[name].append(column_batch) - full_output = {name: np.concat(out) for name, out in output.items()} + full_output = { + name: concat_chunks(out, column.format) for name, out in output.items() + } return full_output @@ -173,84 +160,6 @@ def verify_for_lazy_evaluation( return column -def __visit_rows_in_dataset( - function: Callable, - dataset: Dataset, - format: str, - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, Sequence] = {}, -): - first_row_values = dict(dataset.take(1, at="start").get_data()) - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - storage = __make_output(function, first_row_values | first_row_kwargs, len(dataset)) - for i, row in enumerate(dataset.rows(include_units=format == "astropy")): - if i == 0: - continue - iter_kwargs = {name: arr[i] for name, arr in iterable_kwargs.items()} - output = function(**row, **kwargs, **iter_kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __visit_rows_in_data( - function: Callable, - data: dict[str, np.ndarray], - format="astropy", - kwargs: dict[str, Any] = {}, - iterable_kwargs: dict[str, np.ndarray] = {}, -): - data = {key: d for key, d in data.items() if key in signature(function).parameters} - first_row_data = {name: arr[0] for name, arr in data.items()} - first_row_kwargs = kwargs | {name: arr[0] for name, arr in iterable_kwargs.items()} - n_rows = len(next(iter(data.values()))) - storage = __make_output(function, first_row_data | first_row_kwargs, n_rows) - if format == "numpy": - data = { - key: arr.value if isinstance(arr, Quantity) else arr - for key, arr in data.items() - } - - for i in range(1, n_rows): - row = { - name: arr[i] for name, arr in chain(data.items(), iterable_kwargs.items()) - } - output = function(**row, **kwargs) - if storage is not None: - insert_data(storage, i, output) - return storage - - -def __make_output( - function: Callable, - first_input_values: dict[str, Any], - n_rows: int, -) -> dict | None: - first_values = function(**first_input_values) - if first_values is None: - return None - if not isinstance(first_values, dict): - name = function.__name__ - first_values = {name: first_values} - - return make_output_from_first_values(first_values, n_rows) - - -def __visit_vectorize( - function: Callable, - data: dict[str, Iterable] | Iterable, - evaluator_kwargs: dict[str, Any] = {}, -): - pars = signature(function).parameters - - if not isinstance(data, dict) or (len(data) > 1 and len(pars) == 1): - return function(data, **evaluator_kwargs) - - input_data = {pname: data[pname] for pname in pars if pname in data} - - return function(**input_data, **evaluator_kwargs) - - def __verify( function: Callable, data_columns: Iterable[str], kwarg_names: Iterable[str] ): diff --git a/python/opencosmo/dataset/formats.py b/python/opencosmo/dataset/formats.py new file mode 100644 index 00000000..ff70f72b --- /dev/null +++ b/python/opencosmo/dataset/formats.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any, Iterable + +import astropy.units as u +import numpy as np +from astropy.table import Column, QTable + +if TYPE_CHECKING: + from opencosmo import Dataset + + +def verify_format(output_format: str): + match output_format: + case "astropy": + return + case "numpy": # these two are core dependencies + return + case "pandas": + import_name = "pandas" + case "arrow": + import_name = "pyarrow" + case "polars": + import_name = "polars" + case "jax": + import_name = "jax" + case _: + raise ValueError(f"Unknown data output format {output_format}") + + __verify_import(import_name, output_format) + + +def __verify_import(import_name: str, format_name: str): + try: + import_module(import_name) + except ImportError as e: + raise ImportError( + f"Data was requested in {format_name} format but could not import {import_name} package. Got '{e}'" + ) + + +def convert_data( + data: dict[str, np.ndarray], output_format: str, wrap_single: bool = False +): + """ + If `wrap_single` is True, the result is always the format's natural + multi-column container (QTable, DataFrame, dict[name, array], etc.) even + when there is only one column, instead of collapsing to a bare array / + Series. Used by callers (e.g. evaluate) that want a uniform + `container[colname]` access pattern regardless of column count. + """ + match output_format: + case "astropy": + return __convert_to_astropy(data, wrap_single) + case "numpy": + return convert_to_numpy(data, wrap_single) + case "pandas": + return __convert_to_pandas(data, wrap_single) + case "polars": + return __convert_to_polars(data, wrap_single) + case "arrow": + return __convert_to_arrow(data, wrap_single) + case "jax": + return __convert_to_jax(data, wrap_single) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def fetch_as_dict( + dataset: Dataset, + requires_names: Iterable[str], + output_format: str, + unpack: bool = True, +) -> dict[str, Any]: + """ + Fetch the requested columns and return them as a {name: container} dict in + the user's requested format. Routes through astropy so that Quantities (with + units) survive into the conversion step; other formats then receive plain + values via to_format_dict. + """ + requires_names = list(requires_names) + raw = dataset.select(requires_names).get_data(format="astropy", unpack=unpack) + if isinstance(raw, QTable): + raw = {name: raw[name] for name in raw.colnames} + elif not isinstance(raw, dict): + raw = {requires_names[0]: raw} + if output_format == "astropy": + return raw + return to_format_dict(raw, output_format) + + +def to_format_dict(data: dict[str, np.ndarray], output_format: str) -> dict: + """ + Convert each column of a numpy/astropy dict to the requested format, + preserving the dict shape. Unlike convert_data, this never wraps the + result in a higher-level container (DataFrame, QTable, ...). Used to + feed user-supplied evaluate functions when the upstream data is + numpy-shaped (e.g. when reading from the column cache). + """ + if output_format == "astropy": + return data + + def strip(value): + return value.value if isinstance(value, (u.Quantity, Column)) else value + + match output_format: + case "numpy": + return {k: strip(v) for k, v in data.items()} + case "jax": + import jax.numpy as jnp + + return {k: jnp.asarray(strip(v)) for k, v in data.items()} + case "pandas": + import pandas as pd + + return {k: pd.Series(strip(v)) for k, v in data.items()} + case "polars": + import polars as pl + + return {k: pl.Series(values=strip(v)) for k, v in data.items()} + case "arrow": + import pyarrow as pa # type: ignore + + return {k: pa.array(strip(v)) for k, v in data.items()} + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def to_numpy_dict(data: dict) -> dict[str, np.ndarray]: + """ + Convert each value in a dict-of-format-arrays back to a numpy array + suitable for the column cache. Astropy Quantities are preserved so + that downstream unit handling continues to work; other formats are + converted to plain numpy with no unit information. + """ + result: dict[str, np.ndarray] = {} + for name, value in data.items(): + if isinstance(value, (u.Quantity, np.ndarray)): + result[name] = value + else: + result[name] = np.asarray(value) + return result + + +def stack_rows(values: list, output_format: str): + """ + Stack a list of per-row values into a 1-D container in the target format. + Used by row-wise evaluation strategies to assemble output without + preallocation, which would break for formats with immutable arrays + (e.g. jax). + """ + match output_format: + case "astropy": + if values and isinstance(values[0], u.Quantity): + return u.Quantity(values) + return np.array(values) + case "numpy": + return np.array(values) + case "jax": + import jax.numpy as jnp + + return jnp.array(values) + case "pandas": + import pandas as pd + + return pd.Series(values) + case "polars": + import polars as pl + + return pl.Series(values=values) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.array(values) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def concat_chunks(chunks: list, output_format: str): + """ + Concatenate a list of per-chunk arrays into a single container in the + target format. + """ + match output_format: + case "astropy" | "numpy": + return np.concatenate(chunks) + case "jax": + import jax.numpy as jnp + + return jnp.concatenate(chunks) + case "pandas": + import pandas as pd + + return pd.concat(chunks, ignore_index=True) + case "polars": + import polars as pl + + return pl.concat(chunks) + case "arrow": + import pyarrow as pa # type: ignore + + return pa.concat_arrays(chunks) + case _: + raise ValueError(f"Unknown data output format {output_format}") + + +def __convert_to_astropy( + data: dict[str, np.ndarray], wrap_single: bool = False +) -> QTable: + if len(data) == 1 and not wrap_single: + return next(iter(data.values())) + if any( + (isinstance(d, u.Quantity) and d.isscalar) or not isinstance(d, np.ndarray) + for d in data.values() + ): + return data + + return QTable(data, copy=False) + + +def convert_to_numpy( + data: dict[str, np.ndarray], + wrap_single: bool = False, +) -> dict[str, np.ndarray] | np.ndarray: + converted_data = dict( + map( + lambda kv: ( + kv[0], + kv[1].value if isinstance(kv[1], (u.Quantity, Column)) else kv[1], + ), + data.items(), + ) + ) + if len(converted_data) == 1 and not wrap_single: + return next(iter(converted_data.values())) + return converted_data + + +def __convert_to_pandas(data: dict[str, np.ndarray], wrap_single: bool = False): + import pandas as pd + + numpy_data = convert_to_numpy(data, wrap_single) + if isinstance(numpy_data, np.ndarray): # only one column, wrap_single=False + return pd.Series(numpy_data, name=next(iter(data.keys()))) + + return pd.DataFrame(numpy_data, copy=True) + + +def __convert_to_arrow(data: dict[str, np.ndarray], wrap_single: bool = False): + import pyarrow as pa # type: ignore + + numpy_data = convert_to_numpy(data, wrap_single) + if isinstance(numpy_data, np.ndarray): + return pa.array(numpy_data) + + converted_data = map( + lambda kv: (kv[0], pa.array(kv[1])), + data.items(), + ) + return dict(converted_data) + + +def __convert_to_polars(data: dict[str, np.ndarray], wrap_single: bool = False): + import polars as pl + + numpy_data = convert_to_numpy(data, wrap_single) + if isinstance(numpy_data, np.ndarray): + return pl.Series(name=next(iter(data.keys())), values=numpy_data) + + return pl.from_dict(data) # type: ignore + + +def __convert_to_jax(data: dict[str, np.ndarray], wrap_single: bool = False): + import jax.numpy as jnp + + output_data = convert_to_numpy(data, wrap_single) + if isinstance(output_data, np.ndarray): + return jnp.asarray(output_data) + return {key: jnp.asarray(value) for key, value in output_data.items()} diff --git a/python/opencosmo/dataset/graph.py b/python/opencosmo/dataset/graph.py new file mode 100644 index 00000000..4c5e85a7 --- /dev/null +++ b/python/opencosmo/dataset/graph.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from functools import reduce +from typing import TYPE_CHECKING + +import rustworkx as rx + +from opencosmo.column.column import RawColumn + +if TYPE_CHECKING: + from uuid import UUID + + import astropy.units as u + + from opencosmo.column.column import ConstructedColumn + from opencosmo.units.handler import UnitHandler + + +def validate_column_producers( + producers: list[ConstructedColumn], unit_handler: UnitHandler +): + """ + Validate the network of column producers. + """ + dependency_graph = build_dependency_graph(producers) + + if cycle := rx.digraph_find_cycle(dependency_graph): + all_nodes: set[int] = reduce( + lambda known, edge: known.union(edge), cycle, set() + ) + names = [dependency_graph[i].produces for i in all_nodes] + raise ValueError(f"Found columns that depend on each other! Columns: {names}") + + for i in range(dependency_graph.num_nodes()): + if dependency_graph.in_degree(i): + continue + node = dependency_graph[i] + if not isinstance(node, RawColumn): + raise ValueError( + f"Tried to derive columns from unknown columns: {node.produces}" + ) + + return get_derived_units(dependency_graph, unit_handler.current_units) + + +def build_dependency_graph( + producers: list[ConstructedColumn], +) -> rx.PyDiGraph: + graph = rx.PyDiGraph() + uuid_to_node: dict[UUID, int] = {} + + for producer in producers: + node_idx = graph.add_node(producer) + uuid_to_node[producer.uuid] = node_idx + + for producer in producers: + produces_idx = uuid_to_node[producer.uuid] + if not producer.requires.issubset(uuid_to_node.keys()): + raise ValueError( + f"Producer {producer.produces} depends on an unknown producer UUID." + ) + new_edges = ( + (uuid_to_node[dep_uuid], produces_idx) for dep_uuid in producer.requires + ) + graph.add_edges_from_no_data(new_edges) + + return graph + + +def get_derived_units( + dependency_graph: rx.PyDiGraph, + units: dict[str, u.Unit], +): + new_units: dict[str, u.Unit | None] = {} + for node_idx in rx.topological_sort(dependency_graph): + node = dependency_graph[node_idx] + if isinstance(node, RawColumn): + continue + column_units = node.get_units(units | new_units) + if not isinstance(column_units, dict): + column_units = {prod: column_units for prod in node.produces} + new_units |= column_units + return new_units diff --git a/python/opencosmo/dataset/instantiate.py b/python/opencosmo/dataset/instantiate.py new file mode 100644 index 00000000..79d4ce75 --- /dev/null +++ b/python/opencosmo/dataset/instantiate.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import rustworkx as rx + +from opencosmo.column.column import EvaluatedColumn, RawColumn +from opencosmo.dataset.graph import build_dependency_graph + +if TYPE_CHECKING: + from uuid import UUID + + import numpy as np + + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache, DataHandler + from opencosmo.index import DataIndex + from opencosmo.units.handler import UnitHandler + + +def get_all_required_pairs( + columns_to_uuid: dict[str, UUID], dependency_graph: rx.PyDiGraph +) -> set[tuple[UUID, str]]: + """ + Return the full set of (producer_uuid, column_name) pairs needed to + produce the requested columns, including all transitive dependencies. + """ + uuid_to_node: dict[UUID, int] = { + dependency_graph[i].uuid: i for i in range(dependency_graph.num_nodes()) + } + required_nodes: set[int] = set() + for uuid in columns_to_uuid.values(): + if uuid in uuid_to_node: + node_idx = uuid_to_node[uuid] + required_nodes.add(node_idx) + required_nodes.update(rx.ancestors(dependency_graph, node_idx)) + + pairs: set[tuple[UUID, str]] = { + (uuid, name) for name, uuid in columns_to_uuid.items() + } + for node_idx in required_nodes: + producer = dependency_graph[node_idx] + for name in producer.produces: + pairs.add((producer.uuid, name)) + return pairs + + +def build_initial_uuid_data( + column_producers: list[ConstructedColumn], + raw_data: dict[str, np.ndarray], + cached_data: dict[UUID, dict[str, np.ndarray]], +) -> dict[UUID, dict[str, np.ndarray]]: + """ + Merge cached and freshly-fetched raw data into UUID-keyed storage. + Cached data is the starting point; raw data fills in any gaps. + """ + uuid_data: dict[UUID, dict[str, np.ndarray]] = {**cached_data} + for producer in column_producers: + if not isinstance(producer, RawColumn) or producer.uuid in uuid_data: + continue + output_name = producer.alias or producer.name + if output_name in raw_data: + uuid_data[producer.uuid] = {output_name: raw_data[output_name]} + return uuid_data + + +def build_derived_columns( + columns_to_uuid: dict[str, UUID], + uuid_data: dict[UUID, dict[str, np.ndarray]], + dependency_graph: rx.PyDiGraph, + index: DataIndex, + cache: DataCache, +) -> dict[UUID, dict[str, np.ndarray]]: + """ + Evaluate all derived producers needed to produce the requested columns, + in topological order. Each producer's inputs are resolved by UUID via + dep_map, so column-name shadowing cannot cause a derived column to + receive data from the wrong producer. + """ + uuid_to_node: dict[UUID, int] = { + dependency_graph[i].uuid: i for i in range(dependency_graph.num_nodes()) + } + + required_uuids: set[UUID] = set(columns_to_uuid.values()) + for producer_uuid in list(required_uuids): + if producer_uuid in uuid_to_node: + for node_idx in rx.ancestors(dependency_graph, uuid_to_node[producer_uuid]): + required_uuids.add(dependency_graph[node_idx].uuid) + + new_derived: dict[UUID, dict[str, np.ndarray]] = {} + to_cache: dict[UUID, dict[str, np.ndarray]] = {} + for node_idx in rx.topological_sort(dependency_graph): + producer = dependency_graph[node_idx] + if isinstance(producer, RawColumn): + continue + if producer.uuid not in required_uuids or producer.uuid in uuid_data: + continue + + all_data = uuid_data | new_derived + input_data = { + name: all_data[dep_uuid][name] + for name, dep_uuid in producer.dep_map.items() + } + if isinstance(producer, EvaluatedColumn): + output = producer.evaluate_for_storage(input_data, index) + else: + output = producer.evaluate(input_data, index) + if not isinstance(output, dict): + output = {next(iter(producer.produces)): output} + new_derived[producer.uuid] = output + if not producer.no_cache: + to_cache[producer.uuid] = output + + cache.add_data(to_cache, {}) + return new_derived + + +def __cache_raw_columns( + raw_columns: list[RawColumn], + raw_data: dict[str, Any], + working_columns: dict[str, UUID], + unit_handler: UnitHandler, + unit_kwargs: dict[str, Any], + cache: DataCache, +) -> dict[UUID, dict[str, Any]]: + """ + Write freshly-fetched raw columns to the cache and return the merged + (pre- and post-conversion) UUID-keyed data for merging into uuid_data. + + Pre-conversion data is pushed up to parent caches. Converted data is kept + local (push_up=False) to avoid propagating dataset-specific unit conversions + to parent caches, which could cause drift through repeated rounding. + """ + raw_by_uuid: dict[UUID, dict[str, Any]] = { + col.uuid: {col.alias or col.name: raw_data[col.alias or col.name]} + for col in raw_columns + if (col.alias or col.name) in raw_data + } + converted_by_uuid = unit_handler.apply_unit_conversions(raw_by_uuid, unit_kwargs) + + cacheable = set(working_columns.values()) + cache.add_data( + {uuid: data for uuid, data in raw_by_uuid.items() if uuid in cacheable}, + {}, + push_up=True, + ) + cache.add_data( + {uuid: data for uuid, data in converted_by_uuid.items() if uuid in cacheable}, + {}, + push_up=False, + ) + + return { + uuid: (data | converted_by_uuid.get(uuid, {})) + for uuid, data in raw_by_uuid.items() + } + + +def instantiate_dataset( + column_producers: list[ConstructedColumn], + columns_to_uuid: dict[str, UUID], + raw_data_handler: DataHandler, + cache: DataCache, + unit_handler: UnitHandler, + unit_kwargs: dict[str, Any], + sort_by: str | None = None, +): + # Extend working_columns with the sort column if it isn't already included. + working_columns = dict(columns_to_uuid) + if sort_by is not None and sort_by not in working_columns: + sort_name = sort_by + for producer in column_producers: + if sort_name in producer.produces: + working_columns[sort_name] = producer.uuid + break + + dependency_graph = build_dependency_graph(column_producers) + required_pairs = get_all_required_pairs(working_columns, dependency_graph) + + cached_data = cache.get_data(required_pairs) + + converted_cached = unit_handler.apply_unit_conversions(cached_data, unit_kwargs) + if converted_cached: + cache.add_data(converted_cached, {}, push_up=False) + for uuid, col_data in converted_cached.items(): + cached_data.setdefault(uuid, {}).update(col_data) + + # Determine which raw columns still need to be fetched from the handler. + cached_uuids = set(cached_data.keys()) + raw_columns = [ + col + for col in column_producers + if isinstance(col, RawColumn) + and col.uuid not in cached_uuids + and col.name in {name for (_, name) in required_pairs} + ] + raw_data = raw_data_handler.get_data(set(col.name for col in raw_columns)) + for column in raw_columns: + if column.alias is None: + continue + raw_data[column.alias] = raw_data[column.name] + + raw_data = unit_handler.apply_raw_units(raw_data, unit_kwargs) + + uuid_data = build_initial_uuid_data(column_producers, raw_data, cached_data) + new_derived = build_derived_columns( + working_columns, uuid_data, dependency_graph, raw_data_handler.index, cache + ) + + uuid_data |= new_derived + + uuid_data |= __cache_raw_columns( + raw_columns, raw_data, working_columns, unit_handler, unit_kwargs, cache + ) + + data = { + name: uuid_data[producer_uuid][name] + for name, producer_uuid in working_columns.items() + if producer_uuid in uuid_data and name in uuid_data[producer_uuid] + } + return data diff --git a/src/opencosmo/dataset/mpi.py b/python/opencosmo/dataset/mpi.py similarity index 69% rename from src/opencosmo/dataset/mpi.py rename to python/opencosmo/dataset/mpi.py index d5dda2f5..67ae301b 100644 --- a/src/opencosmo/dataset/mpi.py +++ b/python/opencosmo/dataset/mpi.py @@ -4,19 +4,23 @@ from warnings import warn from opencosmo.index.build import single_chunk +from opencosmo.plugins.contexts import HookPoint, PartitionCtx +from opencosmo.plugins.hook import query from opencosmo.spatial.protocols import TreePartition if TYPE_CHECKING: import h5py from mpi4py import MPI + from opencosmo.header import OpenCosmoHeader from opencosmo.spatial.tree import Tree def partition( comm: MPI.Comm, - length: int, - counts: h5py.Group, + header: OpenCosmoHeader, + index_group: h5py.Group, + data_group: h5py.Group, tree: Optional[Tree], min_level: Optional[int] = None, ) -> Optional[TreePartition]: @@ -25,8 +29,15 @@ def partition( spatial index. In principle this means the number of objects are similar between ranks. """ + partition_plugin_result = query( + HookPoint.Partition, + PartitionCtx(comm, header, index_group, data_group, tree, min_level), + ) + if partition_plugin_result is not None: + return partition_plugin_result + if tree is not None: - partitions = tree.partition(comm.Get_size(), counts, min_level) + partitions = tree.partition(comm.Get_size(), index_group, min_level) try: part = partitions[comm.Get_rank()] except IndexError: @@ -37,6 +48,7 @@ def partition( part = None return part + length = len(next(iter(data_group.values()))) nranks = comm.Get_size() rank = comm.Get_rank() if rank == nranks - 1: diff --git a/python/opencosmo/dataset/output.py b/python/opencosmo/dataset/output.py new file mode 100644 index 00000000..e09accd3 --- /dev/null +++ b/python/opencosmo/dataset/output.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from functools import reduce +from typing import TYPE_CHECKING, Optional + +import astropy.units as u + +from opencosmo.column.column import RawColumn +from opencosmo.io.schema import ( + FileEntry, + combine_with_cached_schema, + make_schema, +) +from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter, NumpySource + +if TYPE_CHECKING: + from uuid import UUID + + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache, DataHandler + from opencosmo.header import OpenCosmoHeader + from opencosmo.io.schema import Schema + from opencosmo.spatial.protocols import Region + + +def get_derived_column_names( + producers: list[ConstructedColumn], columns: set[str] +) -> set[str]: + all_derived: set[str] = reduce( + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + producers, + set(), + ) + return all_derived.intersection(columns) + + +def build_derived_writers( + producers: list[ConstructedColumn], + derived_data: dict, + data_schema: Schema, + cached_data_schema: Schema, +) -> None: + """Add ColumnWriter entries to data_schema for each non-raw, non-cached producer.""" + for producer in producers: + if isinstance(producer, RawColumn) or producer.produces.issubset( + cached_data_schema.columns.keys() + ): + continue + coldata = {name: derived_data[name] for name in producer.produces} + units = { + name: str(cd.unit) if isinstance(cd, u.Quantity) else "" + for name, cd in coldata.items() + } + coldata = { + name: cd.value if isinstance(cd, u.Quantity) else cd + for name, cd in coldata.items() + } + for name, cd in coldata.items(): + attrs = {"unit": units[name], "description": producer.description or "None"} + source = NumpySource(cd) + writer = ColumnWriter([source], ColumnCombineStrategy.CONCAT, attrs=attrs) + data_schema.columns[name] = writer + + +def make_dataset_schema( + producers: list[ConstructedColumn], + raw_data_handler: DataHandler, + cache: DataCache, + columns_to_uuid: dict[str, UUID], + meta_columns: list[str], + header: OpenCosmoHeader, + region: Region, + derived_data: dict, + name: Optional[str] = None, +) -> Schema: + columns = set(columns_to_uuid.keys()) + header = header.with_region(region) + raw_columns = columns.intersection(raw_data_handler.columns) + raw_meta_columns = raw_columns & set(meta_columns) + data_schema, metadata_schema = raw_data_handler.make_schema( + raw_columns, raw_meta_columns, header + ) + + cached_data_schema, cached_metadata_schema = cache.make_schema( + columns_to_uuid, meta_columns + ) + + data_producers = [ + prod for prod in producers if not prod.produces.issubset(meta_columns) + ] + build_derived_writers(data_producers, derived_data, data_schema, cached_data_schema) + + attributes = {} + if (load_conditions := raw_data_handler.load_conditions) is not None: + attributes["load/if"] = load_conditions + + data_schema = combine_with_cached_schema(data_schema, cached_data_schema) + metadata_schema = combine_with_cached_schema( + metadata_schema, cached_metadata_schema + ) + + children = {"data": data_schema} + if metadata_schema.type != FileEntry.EMPTY: + children[metadata_schema.name] = metadata_schema + if name is None: + name = "" + + return make_schema( + name, FileEntry.DATASET, children=children, attributes=attributes + ) diff --git a/python/opencosmo/dataset/state.py b/python/opencosmo/dataset/state.py new file mode 100644 index 00000000..b01ffa20 --- /dev/null +++ b/python/opencosmo/dataset/state.py @@ -0,0 +1,530 @@ +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass +from functools import reduce +from typing import TYPE_CHECKING, Any, Generator, Optional +from weakref import finalize + +import astropy.units as u +import numpy as np + +from opencosmo.column.cache import ColumnCache +from opencosmo.column.column import RawColumn +from opencosmo.column.select import get_column_selection +from opencosmo.dataset.columns import add_columns, resort +from opencosmo.dataset.instantiate import instantiate_dataset +from opencosmo.dataset.output import get_derived_column_names, make_dataset_schema +from opencosmo.handler.empty import EmptyHandler +from opencosmo.handler.hdf5 import Hdf5Handler +from opencosmo.index import single_chunk +from opencosmo.index.mask import into_array +from opencosmo.plugins.contexts import ( + DatasetInstantiateCtx, + HookPoint, + IndexUpdateCtx, + PostSortCtx, +) +from opencosmo.plugins.hook import fold +from opencosmo.units import UnitConvention +from opencosmo.units.handler import ( + make_unit_handler_from_hdf5, + make_unit_handler_from_units, +) + +if TYPE_CHECKING: + from uuid import UUID + + from astropy import table + from astropy.cosmology import Cosmology + + from opencosmo.column.column import ConstructedColumn + from opencosmo.handler.protocols import DataCache, DataHandler + from opencosmo.header import OpenCosmoHeader + from opencosmo.index import DataIndex + from opencosmo.io.iopen import DatasetTarget + from opencosmo.io.schema import Schema + from opencosmo.spatial.protocols import Region + from opencosmo.units.handler import UnitHandler + + +def deregister_state(id: int, cache: DataCache): + cache.deregister_column_group(id) + + +def sort_data( + data: dict[str, np.ndarray], sort_by: tuple[str, bool] | None, state: DatasetState +): + if sort_by is None: + return data + sort_column = data[sort_by[0]] + order = np.argsort(sort_column) + if sort_by[1]: + order = order[::-1] + + data = {key: value[order] for key, value in data.items()} + if sort_by[0] not in state.columns: + data.pop(sort_by[0]) + return fold(HookPoint.PostSort, PostSortCtx(state, data, np.argsort(order))).data + + +@dataclass(frozen=True) +class DatasetState: + """ + Main state container for the Dataset. Functions for manipulating it can be found below. The dataclass + itself only exposes basic lookup operations. + """ + + producers: dict[UUID, ConstructedColumn] + raw_data_handler: DataHandler + cache: DataCache + unit_handler: UnitHandler + header: OpenCosmoHeader + column_map: dict[str, UUID] + region: Region + open_kwargs: dict[str, Any] + sort_key: Optional[tuple[str, bool]] + metadata_columns: frozenset[str] + + def __post_init__(self): + self.cache.register_column_group(id(self), self.column_map) + finalize(self, deregister_state, id(self), self.cache) + + @property + def columns(self) -> list[str]: + return [c for c in self.column_map if c not in self.metadata_columns] + + @property + def meta_columns(self) -> list[str]: + return [c for c in self.column_map if c in self.metadata_columns] + + @property + def descriptions(self): + all_descriptions = {} + for producer in self.producers.values(): + update = {name: producer.description for name in producer.produces} + all_descriptions |= update + all_descriptions |= self.cache.descriptions + + return { + name: description + for name, description in all_descriptions.items() + if name in self.columns + } + + @property + def kwargs(self): + return self.open_kwargs + + @property + def raw_index(self): + if (si := get_sorted_index(self)) is not None: + ni = into_array(self.raw_data_handler.index) + return ni[si] + return self.raw_data_handler.index + + @property + def units(self): + units = self.unit_handler.current_units + return {name: units[name] for name in self.columns} + + @property + def convention(self): + return self.unit_handler.current_convention + + def __len__(state: DatasetState) -> int: + if isinstance(state.raw_data_handler, EmptyHandler): + return len(state.cache) + return len(state.raw_data_handler) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Factory functions (replace classmethods) +# --------------------------------------------------------------------------- + + +def state_from_target( + target: DatasetTarget, + unit_convention: UnitConvention, + region: Region, + open_kwargs: dict[str, Any], + index: Optional[DataIndex] = None, + metadata_group: Optional[str] = None, +) -> DatasetState: + data_group = target["dataset_group"] + if "load" in data_group.keys(): + load_conditions = dict(data_group["load/if"].attrs) + else: + load_conditions = None + + handler = Hdf5Handler.from_columns( + target["columns"], + index, + metadata_group, + load_conditions, + ) + unit_handler = make_unit_handler_from_hdf5( + target["columns"], target["header"], unit_convention + ) + meta_column_names = frozenset( + col.name.split("/")[-1] + for col in target["columns"] + if metadata_group and col.name.split("/")[-2] == metadata_group + ) + descriptions = handler.descriptions + + raw_producers = [ + RawColumn(cname, descriptions.get(cname, "None")) for cname in handler.columns + ] + column_map = {p.name: p.uuid for p in raw_producers} + producers: dict[UUID, ConstructedColumn] = {p.uuid: p for p in raw_producers} + cache = ColumnCache.empty() + return DatasetState( + producers=producers, + raw_data_handler=handler, + cache=cache, + unit_handler=unit_handler, + header=target["header"], + column_map=column_map, + region=region, + open_kwargs=open_kwargs, + sort_key=None, + metadata_columns=meta_column_names, + ) + + +def state_in_memory( + data_columns: dict, + metadata_columns: dict, + header: OpenCosmoHeader, + unit_convention: UnitConvention, + region: Region, + open_kwargs: dict[str, Any], + descriptions: Optional[dict[str, str]] = None, + index: Optional[DataIndex] = None, +) -> DatasetState: + descriptions = descriptions or {} + + all_columns = dict(data_columns) | dict(metadata_columns) + raw_producers = [ + RawColumn(cname, descriptions.get(cname, "None")) + for cname in all_columns.keys() + ] + column_map = {p.name: p.uuid for p in raw_producers} + producers: dict[UUID, ConstructedColumn] = {p.uuid: p for p in raw_producers} + + cache = ColumnCache.empty() + if all_columns: + uuid_data = {p.uuid: {p.name: all_columns[p.name]} for p in raw_producers} + cache.add_data(uuid_data, descriptions) + + units: dict[str, u.Unit] = {} + for name, column in all_columns.items(): + units[name] = None + if isinstance(column, u.Quantity): + units[name] = column.unit + + unit_handler = make_unit_handler_from_units(units, header, unit_convention) + + return DatasetState( + producers=producers, + raw_data_handler=EmptyHandler(), + cache=cache, + unit_handler=unit_handler, + header=header, + column_map=column_map, + region=region, + open_kwargs=open_kwargs, + sort_key=None, + metadata_columns=frozenset(metadata_columns.keys()), + ) + + +# --------------------------------------------------------------------------- +# Standalone functions (replace methods) +# --------------------------------------------------------------------------- + + +def exit_state(state: DatasetState, *exec_details): + return None + + +def get_data( + state: DatasetState, + ignore_sort: bool = False, + metadata_columns: list = [], + unit_kwargs: dict = {}, +) -> dict: + """ + Use a State to get the associated data. Most of the logic can be found in the + instantiate_dataset method. + """ + state = fold(HookPoint.DatasetInstantiate, DatasetInstantiateCtx(state)).state + data = instantiate_dataset( + list(state.producers.values()), + state.column_map, + state.raw_data_handler, + state.cache, + state.unit_handler, + unit_kwargs, + None if (ignore_sort or state.sort_key is None) else state.sort_key[0], + ) + + if missing := set(state.columns).difference(data.keys()): + raise RuntimeError( + f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" + ) + + if not ignore_sort: + data = sort_data(data, state.sort_key, state) + + new_order = list(state.columns) + for name in metadata_columns: + if name in state.metadata_columns: + new_order.append(name) + + return {name: data[name] for name in new_order} + + +def iter_rows( + state: DatasetState, + metadata_columns: list = [], + unit_kwargs: dict = {}, +) -> Generator: + """ + Iterate over the rows of a given DatasetState + """ + derived_to_collect = ( + set(state.columns) + .difference(state.cache.columns) + .difference(state.raw_data_handler.columns) + ) + derived_storage: dict[str, list[np.ndarray]] = { + name: [] for name in derived_to_collect + } + total_length = len(state) + chunk_ranges = [ + (i, min(i + 1000, total_length)) for i in range(0, total_length, 1000) + ] + if not chunk_ranges: + raise StopIteration + + try: + for start, end in chunk_ranges: + chunk = take_rows(state, single_chunk(start, end - start)) + data = get_data( + chunk, metadata_columns=metadata_columns, unit_kwargs=unit_kwargs + ) + for name in derived_to_collect: + derived_storage[name].append(data[name]) + + for i in range(len(chunk)): + yield {name: column[i] for name, column in data.items()} + all_derived = { + name: np.concatenate(arr) for name, arr in derived_storage.items() + } + derived_storage = resort(all_derived, get_sorted_index(state)) + if derived_storage: + uuid_keyed: dict = {} + for name, arr in derived_storage.items(): + uuid = state.column_map[name] + uuid_keyed.setdefault(uuid, {})[name] = arr + state.cache.add_data(uuid_keyed, {}) + except GeneratorExit: + pass + except BaseException: + raise + + +def get_metadata( + state: DatasetState, columns: list = [], ignore_sort: bool = False +) -> dict: + names = list(columns) if columns else list(state.metadata_columns) + data = instantiate_dataset( + list(state.producers.values()), + {name: state.column_map[name] for name in names}, + state.raw_data_handler, + state.cache, + state.unit_handler, + {}, + None, + ) + if ignore_sort: + return data + + sorted_index = get_sorted_index(state) + if sorted_index is not None: + data = {name: values[sorted_index] for name, values in data.items()} + return data + + +def make_schema(state: DatasetState, name: Optional[str] = None) -> Schema: + """ + Get metadata columns. + """ + producers = list(state.producers.values()) + columns = set(state.column_map.keys()).difference(state.metadata_columns) + derived_names = get_derived_column_names(producers, columns) + if derived_names: + selected = select(state, derived_names) + converted = with_units( + selected, state.unit_handler.base_convention, {}, {}, None, None + ) + derived_data = get_data(converted, ignore_sort=True) + else: + derived_data = {} + return make_dataset_schema( + producers, + state.raw_data_handler, + state.cache, + state.column_map, + state.meta_columns, + state.header, + state.region, + derived_data, + name, + ) + + +def with_new_columns( + state: DatasetState, + descriptions: dict[str, str] = {}, + allow_overwrite: bool = False, + **new_columns: ConstructedColumn | np.ndarray | u.Quantity, +) -> DatasetState: + """ + Add columns to a given state + """ + new_producers_list, new_column_map, new_unit_handler = add_columns( + list(state.producers.values()), + state.unit_handler, + state.cache, + state.column_map, + get_sorted_index(state), + descriptions, + new_columns, + len(state), + allow_overwrite=allow_overwrite, + ) + return dataclasses.replace( + state, + producers={p.uuid: p for p in new_producers_list}, + column_map=new_column_map, + unit_handler=new_unit_handler, + ) + + +def with_region(state: DatasetState, region: Region) -> DatasetState: + return dataclasses.replace(state, region=region) + + +def select(state: DatasetState, columns: set[str], drop: bool = False) -> DatasetState: + """ + Select a set of columns + """ + selections, missing = get_column_selection(state.columns, columns) + if missing: + raise ValueError( + f"Columns are included that are not in this dataset: {missing}" + ) + elif not selections and columns: + raise ValueError("No columns matched the provided wildcards!") + + if drop: + selections = set(state.columns) - selections + + new_column_map = {n: state.column_map[n] for n in selections} + new_column_map |= {n: state.column_map[n] for n in state.metadata_columns} + return dataclasses.replace(state, column_map=new_column_map) + + +def sort_by( + state: DatasetState, column_name: Optional[str], invert: bool +) -> DatasetState: + if column_name is None: + sort_key = None + elif column_name not in state.columns: + raise ValueError(f"This dataset has no column {column_name}") + else: + sort_key = (column_name, invert) + + return dataclasses.replace(state, sort_key=sort_key) + + +def get_sorted_index(state: DatasetState) -> np.ndarray | None: + if state.sort_key is not None: + column = get_data(select(state, {state.sort_key[0]}), ignore_sort=True)[ + state.sort_key[0] + ] + sorted_idx = np.argsort(column) + if state.sort_key[1]: + sorted_idx = sorted_idx[::-1] + else: + sorted_idx = None + + return sorted_idx + + +def take_rows(state: DatasetState, rows: DataIndex) -> DatasetState: + """ + Take a set of rows. The associated "take" functions in the + dataset all delegate to this function. + """ + if len(state) == 0: + return state + rows = fold(HookPoint.IndexUpdate, IndexUpdateCtx(state, rows)).index + sorted_idx = get_sorted_index(state) + if sorted_idx is not None: + rows = np.sort(sorted_idx[into_array(rows)]) + new_handler = state.raw_data_handler.take(rows) + new_cache = state.cache.take(rows) + return dataclasses.replace(state, raw_data_handler=new_handler, cache=new_cache) + + +def with_units( + state: DatasetState, + convention: Optional[str], + conversions: dict[u.Unit, u.Unit], + columns: dict[str, u.Unit], + cosmology: Cosmology, + redshift: float | table.Column, +) -> DatasetState: + """ + Update the units of a given state. + """ + if convention is None: + convention_ = state.unit_handler.current_convention + else: + convention_ = UnitConvention(convention) + + if ( + convention_ == UnitConvention.SCALEFREE + and UnitConvention(state.header.file.unit_convention) + != UnitConvention.SCALEFREE + ): + raise ValueError( + f"Cannot convert units with convention {state.header.file.unit_convention} to convention scalefree" + ) + column_keys = set(columns.keys()) + missing_columns = column_keys - set(state.columns) + if missing_columns: + raise ValueError(f"Dataset does not have columns {missing_columns}") + + new_handler = state.unit_handler.with_convention(convention_).with_conversions( + conversions, columns + ) + + if convention_ == state.unit_handler.current_convention: + cache = state.cache.create_child() + else: + all_derived_names: set[str] = set() + all_derived_names = reduce( + lambda acc, col: acc.union( + col.produces if not isinstance(col, RawColumn) else set() + ), + state.producers.values(), + all_derived_names, + ).intersection(state.columns) + columns_to_drop = all_derived_names.union(state.raw_data_handler.columns) + cache = state.cache.drop(columns_to_drop) + return dataclasses.replace(state, unit_handler=new_handler, cache=cache) diff --git a/python/opencosmo/dataset/take.py b/python/opencosmo/dataset/take.py new file mode 100644 index 00000000..645db2ac --- /dev/null +++ b/python/opencosmo/dataset/take.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Optional + +import numpy as np + +from opencosmo.index import empty, from_size, into_array, single_chunk +from opencosmo.mpi import get_comm_world, get_mpi, has_mpi + +if TYPE_CHECKING: + from opencosmo.collection.lightcone import Lightcone + from opencosmo.dataset.dataset import Dataset + from opencosmo.index import DataIndex, IndexArray + + +def get_random_take_index( + n: int, + ds_length: int, + mode: Literal["local", "global"], +) -> DataIndex: + if mode == "global" and has_mpi(): + return get_random_take_index_mpi(n, ds_length) + + if n > ds_length: + return from_size(ds_length) + + generator = np.random.default_rng() + rows = generator.choice(ds_length, n, replace=False) + return apply_sort_index(rows) + + +def apply_sort_index( + rows: DataIndex, sort_index: Optional[np.ndarray] = None +) -> np.ndarray: + """Return row positions in ascending physical order. + + With sort_index: treats rows as logical sorted-order positions and maps + them to physical row positions. Without sort_index: rows are already + physical positions and are simply sorted. + """ + arr = into_array(rows) + if sort_index is not None: + return np.sort(sort_index[arr]) + return np.sort(arr) + + +def _get_sort_index(ds: Dataset | Lightcone, sort_key: tuple[str, bool]) -> np.ndarray: + sort_col, sort_desc = sort_key + values = ds.select(sort_col).get_data("numpy", ignore_sort=True) + assert isinstance(values, np.ndarray) + if sort_desc: + values = -values + return np.argsort(values, kind="stable") + + +def get_rows_take_index( + ds: Dataset | Lightcone, rows: DataIndex, sort_key: Optional[tuple[str, bool]] +) -> DataIndex: + """Map user-provided logical (sorted-order) row positions to physical row positions.""" + if sort_key is None: + return rows + sort_index = _get_sort_index(ds, sort_key) + return apply_sort_index(rows, sort_index) + + +def get_range_take_index( + ds: Dataset | Lightcone, + sort_key: Optional[tuple[str, bool]], + start: int, + size: int, + mode: Literal["local", "global"], +): + if mode == "global" and has_mpi(): + return get_range_take_index_mpi(ds, sort_key, start, size) + + ds_len = len(ds) + if start + size > ds_len: + size = ds_len - start + + return single_chunk(start, size) + + +def get_end_take_index( + n: int, + ds: Dataset | Lightcone, + sort_key: Optional[tuple[str, bool]], + mode: Literal["local", "global"], +): + ds_length = len(ds) + if mode == "global" and has_mpi(): + comm = get_comm_world() + assert comm is not None + total_length = np.sum(comm.allgather(ds_length)) + if n > total_length: + return from_size(ds_length) + + return get_range_take_index_mpi(ds, sort_key, total_length - n, n) + + start = ds_length - n + if n > ds_length: + start = 0 + n = ds_length + + return single_chunk(start, n) + + +def get_range_take_index_mpi( + ds: Dataset | Lightcone, sort_key: Optional[tuple[str, bool]], start: int, size: int +): + comm = get_comm_world() + assert comm is not None + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total_length = int(np.sum(lengths)) + + if start > total_length: + return empty() + + if start + size > total_length: + size = total_length - start + + if sort_key is not None: + global_sort_order = get_global_sort_order(ds, sort_key) + + if comm.Get_rank() == 0: + assert global_sort_order is not None + n_ranks = comm.Get_size() + chunk_ranges = np.zeros(n_ranks + 1, dtype=np.int64) + chunk_ranges[1:] = np.cumsum(lengths) + + # Map each position in the global sort order to the rank that owns it. + rank_of_element = np.searchsorted( + chunk_ranges[1:], global_sort_order, side="right" + ) + # Count how many sorted elements before `start` belong to each rank; + # this is the local sorted-order start index for each rank's slice. + lo_per_rank = np.bincount(rank_of_element[:start], minlength=n_ranks) + # Count how many sorted elements in [start, start+size) belong to each rank. + count_per_rank = np.bincount( + rank_of_element[start : start + size], minlength=n_ranks + ) + else: + lo_per_rank = None + count_per_rank = None + + lo_per_rank = comm.bcast(lo_per_rank) + count_per_rank = comm.bcast(count_per_rank) + + rank = comm.Get_rank() + local_start = int(lo_per_rank[rank]) + local_size = int(count_per_rank[rank]) + + if local_size == 0: + return empty() + return single_chunk(local_start, local_size) + + # Handle the case without sorting: contiguous global range + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + local_start = max(0, start - offset) + local_end = min(int(lengths[rank]), start + size - offset) + + if local_end <= local_start: + return np.array([], dtype=np.int64) + return single_chunk(local_start, local_end - local_start) + + +def get_global_sort_order(ds: Dataset | Lightcone, sort_key: tuple[str, bool]): + comm = get_comm_world() + assert comm is not None + + assert sort_key is not None + sort_col, sort_desc = sort_key + raw = ds.select(sort_col).get_data("numpy", ignore_sort=True) + local_values = np.asarray( + raw.value if hasattr(raw, "value") else raw, dtype=np.float64 + ) + + lengths = np.array(comm.allgather(len(local_values)), dtype=np.int64) + total_length = int(np.sum(lengths)) + rank = comm.Get_rank() + + offsets = np.zeros(len(lengths), dtype=np.int64) + offsets[1:] = np.cumsum(lengths)[:-1] + recv = np.empty(total_length, dtype=np.float64) if rank == 0 else None + comm.Gatherv(local_values, [recv, lengths, offsets, get_mpi().DOUBLE], root=0) + + # Other ranks return None + if rank != 0: + return None + + # Determine global sort range + assert recv is not None + global_sorted_phys = np.argsort(recv, kind="stable") + if sort_desc: + global_sorted_phys = global_sorted_phys[::-1] + + return global_sorted_phys + + +def get_random_take_index_mpi(n: int, ds_length: int): + comm = get_comm_world() + assert comm is not None + lengths = comm.allgather(ds_length) + + if (total_length := np.sum(lengths)) < n: + return from_size(ds_length) + + if comm.Get_rank() == 0: + rng = np.random.default_rng() + rows = np.sort(rng.choice(total_length, n, replace=False)) + else: + rows = None + return get_local_rows_simple(rows, lengths, comm) + + +def get_local_rows_simple(rows: IndexArray | None, lengths: list[int], comm): + chunk_ranges = np.zeros(len(lengths) + 1, dtype=np.int64) + chunk_ranges[1:] = np.cumsum(lengths) + if comm.Get_rank() == 0: + assert rows is not None + chunk_ranges_in_index = np.searchsorted(rows, chunk_ranges) + chunk_ranges_in_index = comm.bcast(chunk_ranges_in_index) + else: + chunk_ranges_in_index = comm.bcast(None) + + rank_num = comm.Get_rank() + n_rows_local = chunk_ranges_in_index[rank_num + 1] - chunk_ranges_in_index[rank_num] + + local_rows = np.empty(n_rows_local, dtype=np.int64) + + if comm.Get_rank() == 0: + scatter_lengths = chunk_ranges_in_index[1:] - chunk_ranges_in_index[:-1] + buffer_offsets = np.zeros_like(scatter_lengths) + buffer_offsets[1:] = np.cumsum(scatter_lengths)[:-1] + + buffspec = [rows, scatter_lengths, buffer_offsets, get_mpi().INT64_T] + comm.Scatterv(buffspec, local_rows) + else: + comm.Scatterv([None, None, None, get_mpi().INT64_T], local_rows) + + return local_rows - chunk_ranges[rank_num] diff --git a/src/opencosmo/parameters/__init__.py b/python/opencosmo/dtypes/__init__.py similarity index 100% rename from src/opencosmo/parameters/__init__.py rename to python/opencosmo/dtypes/__init__.py diff --git a/src/opencosmo/parameters/cosmology.py b/python/opencosmo/dtypes/cosmology.py similarity index 100% rename from src/opencosmo/parameters/cosmology.py rename to python/opencosmo/dtypes/cosmology.py diff --git a/python/opencosmo/dtypes/diffsky.py b/python/opencosmo/dtypes/diffsky.py new file mode 100644 index 00000000..f40e75c5 --- /dev/null +++ b/python/opencosmo/dtypes/diffsky.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import dataclasses +from datetime import datetime # noqa +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +from pydantic import BaseModel, ConfigDict, field_serializer + +import opencosmo.dataset.state as st +from opencosmo.column.column import EvaluatedColumn, EvaluateStrategy +from opencosmo.index import into_array +from opencosmo.index.ops import reindex_column +from opencosmo.mpi import get_mpi +from opencosmo.plugins.contexts import HookPoint +from opencosmo.plugins.hook import hook +from opencosmo.spatial.tree import TreePartition + +if TYPE_CHECKING: + from astropy.table import Table + from mpi4py import MPI + + from opencosmo import Dataset + from opencosmo.dataset.state import DatasetState + from opencosmo.index import DataIndex + from opencosmo.plugins.contexts import ( + DatasetOpenCtx, + IndexUpdateCtx, + LightconeInstantiateCtx, + PartitionCtx, + PostSortCtx, + ) + +else: + MPI = get_mpi() + + +class DiffskyVersionInfo(BaseModel): + model_config = ConfigDict(frozen=True) + ACCESS_PATH: ClassVar[str] = "diffsky_versions" + diffmah: str + diffsky: str + diffstar: str + diffstarpop: Optional[str] = None + dsps: str + jax: str + numpy: str + + +class DiffskyCatalogInfo(BaseModel): + model_config = ConfigDict(frozen=True) + ACCESS_PATH: ClassVar[str] = "catalog_info" + README: Optional[str] = None + mock_version_name: str + zphot_table: Optional[tuple[float, ...]] = None + + @field_serializer("zphot_table") + def serialize_zphot_table(self, value): + if value is not None: + return list(value) + return None + + +# --- pure logic --- + + +def __offset(top_host_idx, offset): + output = top_host_idx + output[output >= 0] += offset + return {"top_host_idx": output} + + +def offset_top_host_idx(datasets: list[Dataset]): + lengths = [len(ds) for ds in datasets] + offsets = np.cumsum(lengths) + output_datasets = [datasets[0]] + for offset, ds in zip(offsets, datasets[1:]): + output_ds = ds.evaluate( + __offset, + offset=offset, + vectorize=True, + allow_overwrite=True, + ) + output_datasets.append(output_ds) # type: ignore + return output_datasets + + +def rebuild_top_host_idx(top_host_idx, index): + result = reindex_column(index, top_host_idx) + return {"top_host_idx": result} + + +def keep_top_host_idx(dataset: DatasetState, new_index: DataIndex): + index_array = into_array(new_index) + top_host_idx = st.get_data(st.select(dataset, {"top_host_idx"}))["top_host_idx"] + unique_in_sample = np.unique(top_host_idx[index_array]) + + missing_hosts = np.setdiff1d(unique_in_sample, index_array) + all_satellites = np.where(np.isin(top_host_idx, unique_in_sample))[0] + missing_satellites = np.setdiff1d(all_satellites, index_array) + + if len(missing_hosts) == 0 and len(missing_satellites) == 0: + return new_index + + all_missing = np.sort(np.concatenate((missing_hosts, missing_satellites))) + insert_idx = np.searchsorted(index_array, all_missing) + return np.insert(index_array, insert_idx, all_missing) + + +def _is_synthetic_galaxies_with_top_host_idx(obj) -> bool: + return ( + obj.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in obj.columns + ) + + +# --- hooks --- + + +@hook( + HookPoint.DatasetOpen, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.dataset), +) +def _attach_top_host_idx_column(ctx: DatasetOpenCtx) -> DatasetOpenCtx: + top_host_idx = EvaluatedColumn( + rebuild_top_host_idx, + requires={"top_host_idx"}, + produces={"top_host_idx"}, + format="numpy", + units={"top_host_idx": None}, + strategy=EvaluateStrategy.VECTORIZE, + no_cache=True, + ) + new_dataset = ctx.dataset.with_new_columns( + updated_host_idx=top_host_idx, allow_overwrite=True + ) + return dataclasses.replace(ctx, dataset=new_dataset) + + +@hook( + HookPoint.LightconeInstantiate, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.lightcone), +) +def _offset_top_host_idx(ctx: LightconeInstantiateCtx) -> LightconeInstantiateCtx: + cs = 0 + output = {} + + def _offset_top_host_idx(top_host_idx, offset): + top_host_idx[top_host_idx >= 0] += offset + return top_host_idx + + for key, ds in ctx.lightcone.items(): + output[key] = ds.evaluate( + _offset_top_host_idx, allow_overwrite=True, vectorize=True, offset=cs + ) + cs += len(ds) + + return dataclasses.replace(ctx, lightcone=output) # type: ignore[arg-type] + + +# Registers an IndexUpdate hook dynamically so that keep_top_host_idx only +# activates when the user explicitly requests it via open(..., keep_top_host=True). +@hook( + HookPoint.IndexUpdate, + when=lambda ctx: ( + "top_host_idx" in ctx.state.columns + and ctx.state.kwargs.get("keep_top_host", False) + ), +) +def _keep(ctx: IndexUpdateCtx) -> IndexUpdateCtx: + return dataclasses.replace(ctx, index=keep_top_host_idx(ctx.state, ctx.index)) + + +@hook( + HookPoint.PostSort, + when=lambda ctx: _is_synthetic_galaxies_with_top_host_idx(ctx.state), +) +def _remap_top_host_idx_after_sort(ctx: PostSortCtx) -> PostSortCtx: + data: Table = ctx.data # type: ignore[assignment] + mask = data["top_host_idx"] >= 0 + data["top_host_idx"][mask] = ctx.index[data["top_host_idx"][mask]] + return ctx + + +@hook( + HookPoint.Partition, + when=lambda ctx: ( + ctx.header.file.data_type == "synthetic_galaxies" + and "top_host_idx" in ctx.data_group.keys() + ), +) +def _partition_by_top_host_groups(ctx: PartitionCtx) -> Optional[TreePartition]: + top_host_idx = ctx.data_group["top_host_idx"][:] + n_rows = len(top_host_idx) + n_ranks = ctx.comm.Get_size() + rank = ctx.comm.Get_rank() + + ave, res = divmod(n_rows, n_ranks) + counts = np.array( + [ave + 1 if i < res else ave for i in range(n_ranks)], dtype=np.int64 + ) + displs = np.cumsum(np.concatenate(([0], counts[:-1]))) + start = displs[rank] + count = counts[rank] + row_indices = np.arange(start, start + count, dtype=np.int64) + chunk = top_host_idx[start : start + count] + + # find top hosts (self-referential) and orphans (top_host_idx == -1) + rank_top_hosts = row_indices[chunk == row_indices] + rank_orphans = row_indices[chunk == -1] + + # gather all rows belonging to this rank's top hosts + all_group_rows = np.where(np.isin(top_host_idx, rank_top_hosts))[0] + + index = np.union1d(all_group_rows, rank_orphans) + return TreePartition(idx=index, region=None, level=None) diff --git a/src/opencosmo/parameters/dtype.py b/python/opencosmo/dtypes/dtype.py similarity index 84% rename from src/opencosmo/parameters/dtype.py rename to python/opencosmo/dtypes/dtype.py index 558c5c22..ee2dc125 100644 --- a/src/opencosmo/parameters/dtype.py +++ b/python/opencosmo/dtypes/dtype.py @@ -2,15 +2,14 @@ from typing import TYPE_CHECKING -from opencosmo.parameters import hacc, lightcone +from opencosmo.dtypes import hacc, lightcone if TYPE_CHECKING: from pydantic import BaseModel - from opencosmo.parameters.file import FileParameters + from opencosmo.dtypes.file import FileParameters -# TODO: think I need to alter this def get_dtype_parameters( file_parameters: FileParameters, ) -> dict[str, dict[str, type[BaseModel]]]: diff --git a/src/opencosmo/parameters/file.py b/python/opencosmo/dtypes/file.py similarity index 100% rename from src/opencosmo/parameters/file.py rename to python/opencosmo/dtypes/file.py diff --git a/src/opencosmo/parameters/hacc.py b/python/opencosmo/dtypes/hacc.py similarity index 100% rename from src/opencosmo/parameters/hacc.py rename to python/opencosmo/dtypes/hacc.py diff --git a/src/opencosmo/parameters/lightcone.py b/python/opencosmo/dtypes/lightcone.py similarity index 100% rename from src/opencosmo/parameters/lightcone.py rename to python/opencosmo/dtypes/lightcone.py diff --git a/src/opencosmo/parameters/origin.py b/python/opencosmo/dtypes/origin.py similarity index 86% rename from src/opencosmo/parameters/origin.py rename to python/opencosmo/dtypes/origin.py index ad65f431..e62b59ca 100644 --- a/src/opencosmo/parameters/origin.py +++ b/python/opencosmo/dtypes/origin.py @@ -1,4 +1,4 @@ -from opencosmo.parameters import hacc +from opencosmo.dtypes import hacc from .cosmology import CosmologyParameters diff --git a/src/opencosmo/parameters/parameters.py b/python/opencosmo/dtypes/parameters.py similarity index 100% rename from src/opencosmo/parameters/parameters.py rename to python/opencosmo/dtypes/parameters.py diff --git a/src/opencosmo/parameters/units.py b/python/opencosmo/dtypes/units.py similarity index 100% rename from src/opencosmo/parameters/units.py rename to python/opencosmo/dtypes/units.py diff --git a/src/opencosmo/parameters/utils.py b/python/opencosmo/dtypes/utils.py similarity index 100% rename from src/opencosmo/parameters/utils.py rename to python/opencosmo/dtypes/utils.py diff --git a/src/opencosmo/file.py b/python/opencosmo/file.py similarity index 100% rename from src/opencosmo/file.py rename to python/opencosmo/file.py diff --git a/src/opencosmo/handler/empty.py b/python/opencosmo/handler/empty.py similarity index 84% rename from src/opencosmo/handler/empty.py rename to python/opencosmo/handler/empty.py index e46161da..09d12237 100644 --- a/src/opencosmo/handler/empty.py +++ b/python/opencosmo/handler/empty.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Self -from opencosmo.index import empty from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.index import empty + if TYPE_CHECKING: import numpy as np @@ -15,9 +16,6 @@ class EmptyHandler: def get_data(self, *args): return {} - def get_metadata(self, columns: Iterable[str]) -> dict[str, np.ndarray]: - return {} - def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None) -> Self: return self @@ -37,10 +35,6 @@ def columns(self) -> Iterable[str]: def load_conditions(self): return None - @property - def metadata_columns(self) -> Iterable[str]: - return set() - @property def descriptions(self): return {} diff --git a/src/opencosmo/handler/hdf5.py b/python/opencosmo/handler/hdf5.py similarity index 58% rename from src/opencosmo/handler/hdf5.py rename to python/opencosmo/handler/hdf5.py index b5a3e5d0..16b81af3 100644 --- a/src/opencosmo/handler/hdf5.py +++ b/python/opencosmo/handler/hdf5.py @@ -1,10 +1,13 @@ from __future__ import annotations from functools import cached_property -from itertools import chain from typing import TYPE_CHECKING, Iterable, Optional import numpy as np +from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.io.writer import ( + ColumnWriter, +) from opencosmo.index import ( SimpleIndex, @@ -14,18 +17,14 @@ into_array, take, ) -from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.io.writer import ( - ColumnWriter, -) if TYPE_CHECKING: import h5py - from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex from opencosmo.io.schema import Schema + from opencosmo.index import DataIndex + class Hdf5Handler: """ @@ -36,12 +35,10 @@ def __init__( self, columns: dict[str, h5py.Dataset], index: DataIndex, - metadata_columns: dict[str, h5py.Dataset], load_conditions: Optional[dict[str, bool]] = None, ): self.__index = index self.__columns = columns - self.__metadata_columns = metadata_columns self.__in_memory = next(iter(columns.values())).file.driver == "core" self.__load_conditions = load_conditions @@ -53,26 +50,24 @@ def from_columns( metadata_group: Optional[str] = None, load_conditions: Optional[dict[str, bool]] = None, ): - data_columns = filter(lambda col: col.name.split("/")[-2] == "data", columns) - metadata_columns: Iterable[h5py.Dataset] = [] + groups = {"data"} if metadata_group: - metadata_columns = filter( - lambda col: col.name.split("/")[-2] == metadata_group, columns - ) + groups.add(metadata_group) - data_columns_ = {col.name.split("/")[-1]: col for col in data_columns} - metadata_columns_ = {col.name.split("/")[-1]: col for col in metadata_columns} - lengths = set( - len(col) - for col in chain(data_columns_.values(), metadata_columns_.values()) - ) + all_columns = { + col.name.split("/")[-1]: col + for col in columns + if col.name.split("/")[-2] in groups + } + + lengths = set(len(col) for col in all_columns.values()) if len(lengths) > 1: raise ValueError("Not all columns are the same length!") if index is None: index = from_size(lengths.pop()) - return Hdf5Handler(data_columns_, index, metadata_columns_, load_conditions) + return Hdf5Handler(all_columns, index, load_conditions) def __len__(self): return get_length(self.__index) @@ -87,17 +82,13 @@ def load_conditions(self) -> Optional[dict[str, bool]]: def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None): if len(other) == 0: - return Hdf5Handler( - self.__columns, other, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, other, self.__load_conditions) if sorted is not None: return self.__take_sorted(other, sorted) new_index = take(self.__index, other) - return Hdf5Handler( - self.__columns, new_index, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, new_index, self.__load_conditions) def __take_sorted(self, other: DataIndex, sorted: np.ndarray): if get_length(sorted) != get_length(self.__index): @@ -107,9 +98,7 @@ def __take_sorted(self, other: DataIndex, sorted: np.ndarray): new_raw_index = into_array(self.__index)[new_indices] new_index = np.sort(new_raw_index) - return Hdf5Handler( - self.__columns, new_index, self.__metadata_columns, self.__load_conditions - ) + return Hdf5Handler(self.__columns, new_index, self.__load_conditions) @property def data(self): @@ -121,11 +110,7 @@ def index(self): @cached_property def columns(self): - return self.__columns.keys() - - @property - def metadata_columns(self): - return self.__metadata_columns.keys() + return list(self.__columns.keys()) @cached_property def descriptions(self): @@ -148,33 +133,32 @@ def __exit__(self, *exec_details): def make_schema( self, columns: Iterable[str], + metadata_columns: set[str] = set(), header: Optional[OpenCosmoHeader] = None, ) -> tuple[Schema, Optional[Schema]]: - column_writers = {} - for column_name in columns: - column_writers[column_name] = ColumnWriter.from_h5_dataset( - self.__columns[column_name], - self.__index, - attrs=dict(self.__columns[column_name].attrs), + columns = set(columns) + data_writers = {} + for column_name in columns - metadata_columns: + column = self.__columns[column_name] + data_writers[column_name] = ColumnWriter.from_h5_dataset( + column, self.__index, attrs=dict(column.attrs) ) - - data_schema = make_schema("data", FileEntry.COLUMNS, columns=column_writers) - - if self.metadata_columns: - assert len(self.__metadata_columns) > 0 - metadata_writers = {} - group_name = next(iter(self.__metadata_columns.values())).parent.name - group_name = group_name.split("/")[-1] - for column_name, column in self.__metadata_columns.items(): - metadata_writers[column_name] = ColumnWriter.from_h5_dataset( - column, self.__index, attrs=dict(column.attrs) - ) - metadata_schema = make_schema( - group_name, FileEntry.COLUMNS, columns=metadata_writers + data_schema = make_schema("data", FileEntry.COLUMNS, columns=data_writers) + + raw_meta = columns & metadata_columns + if not raw_meta: + return data_schema, make_schema("metadata", FileEntry.EMPTY) + + group_name = self.__columns[next(iter(raw_meta))].parent.name.split("/")[-1] + metadata_writers = {} + for column_name in raw_meta: + column = self.__columns[column_name] + metadata_writers[column_name] = ColumnWriter.from_h5_dataset( + column, self.__index, attrs=dict(column.attrs) ) - else: - metadata_schema = make_schema("metadata", FileEntry.EMPTY) - return data_schema, metadata_schema + return data_schema, make_schema( + group_name, FileEntry.COLUMNS, columns=metadata_writers + ) def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: """ """ @@ -187,18 +171,6 @@ def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: # Ensure order is preserved return {name: data[name] for name in columns} - def get_metadata(self, columns: Iterable[str]) -> Optional[dict[str, np.ndarray]]: - if len(self.__metadata_columns) == 0: - return None - if not columns: - columns = self.metadata_columns - - data = {} - for colname in columns: - data[colname] = get_data(self.__metadata_columns[colname], self.__index) - - return data - def take_range(self, start: int, end: int, indices: np.ndarray) -> np.ndarray: if start < 0 or end > len(indices): raise ValueError("Indices out of range") diff --git a/src/opencosmo/handler/protocols.py b/python/opencosmo/handler/protocols.py similarity index 52% rename from src/opencosmo/handler/protocols.py rename to python/opencosmo/handler/protocols.py index faf119dc..9a125546 100644 --- a/src/opencosmo/handler/protocols.py +++ b/python/opencosmo/handler/protocols.py @@ -3,47 +3,80 @@ from typing import TYPE_CHECKING, Iterable, Optional, Protocol, Self if TYPE_CHECKING: - import numpy as np + from uuid import UUID + import numpy as np from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex from opencosmo.io.schema import Schema + from opencosmo.index import DataIndex + class DataHandler(Protocol): def get_data(self, columns: Iterable[str]) -> dict[str, np.ndarray]: """ """ - def get_metadata(self, columns: Iterable[str]) -> dict[str, np.ndarray]: ... - def take(self, other: DataIndex, sorted: Optional[np.ndarray] = None) -> Self: ... def make_schema( - self, columns: Iterable[str], header: Optional[OpenCosmoHeader] = None + self, + columns: Iterable[str], + metadata_columns: set[str] = set(), + header: Optional[OpenCosmoHeader] = None, ) -> tuple[Schema, Schema]: ... @property def columns(self) -> Iterable[str]: ... @property - def metadata_columns(self) -> Iterable[str]: ... - @property def load_conditions(self) -> Optional[dict]: ... @property def index(self) -> DataIndex: ... -class DataCache(DataHandler, Protocol): +class DataCache(Protocol): def add_data( self, - data: dict[str, np.ndarray], + data: dict[UUID, dict[str, np.ndarray]], descriptions: dict[str, str], push_up: bool = True, ): ... + def add_metadata( + self, + data: dict[str, np.ndarray], + descriptions: dict[str, str] = {}, + ): ... + + def get_data( + self, pairs: set[tuple[UUID, str]] + ) -> dict[UUID, dict[str, np.ndarray]]: ... + + def get_metadata(self, column_names: Iterable[str]) -> dict[str, np.ndarray]: ... + def __len__(self) -> int: ... + def take(self, index: DataIndex) -> Self: ... + def drop(self, columns: Iterable[str]) -> Self: ... - def register_column_group(self, state_id: int, columns: set[str]) -> None: ... + + def register_column_group( + self, state_id: int, columns: dict[str, UUID] + ) -> None: ... + def deregister_column_group(self, state_id: int) -> None: ... + def create_child(self) -> Self: ... + + @property + def columns(self) -> set[str]: ... + + @property + def metadata_columns(self) -> set[str]: ... + + @property + def descriptions(self) -> dict[str, str]: ... + + def make_schema( + self, columns: dict[str, UUID], meta_columns: list[str] + ) -> tuple[Schema, Schema]: ... diff --git a/src/opencosmo/header.py b/python/opencosmo/header.py similarity index 99% rename from src/opencosmo/header.py rename to python/opencosmo/header.py index ba9b7b27..d8dd742a 100644 --- a/src/opencosmo/header.py +++ b/python/opencosmo/header.py @@ -10,17 +10,17 @@ import numpy as np from pydantic import BaseModel, ValidationError -from opencosmo.file import broadcast_read, file_reader, file_writer -from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter -from opencosmo.parameters import ( +from opencosmo.dtypes import ( FileParameters, dtype, origin, read_header_attributes, write_header_attributes, ) -from opencosmo.parameters.units import apply_units +from opencosmo.dtypes.units import apply_units +from opencosmo.file import broadcast_read, file_reader, file_writer +from opencosmo.io.schema import FileEntry, make_schema +from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter from opencosmo.units import UnitConvention if TYPE_CHECKING: diff --git a/src/opencosmo/index/__init__.py b/python/opencosmo/index/__init__.py similarity index 60% rename from src/opencosmo/index/__init__.py rename to python/opencosmo/index/__init__.py index 5d3f64f3..29fd1959 100644 --- a/src/opencosmo/index/__init__.py +++ b/python/opencosmo/index/__init__.py @@ -1,16 +1,18 @@ import numpy as np -from numpy.typing import NDArray -from .build import concatenate, empty, from_size, single_chunk +from .build import concatenate, empty, from_size, from_start_size_group, single_chunk from .get import get_data from .in_range import n_in_range from .mask import into_array, mask +from .ops import offset, rebuild_by_ranges, reindex_column from .project import project from .take import take from .unary import get_length, get_range -SimpleIndex = NDArray[np.int_] -ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] +IndexArray = np.ndarray[tuple[int], np.dtype[np.int64]] + +SimpleIndex = IndexArray +ChunkedIndex = tuple[IndexArray, IndexArray] DataIndex = SimpleIndex | ChunkedIndex @@ -32,4 +34,8 @@ "take", "get_length", "get_range", + "from_start_size_group", + "rebuild_by_ranges", + "reindex_column", + "offset", ] diff --git a/python/opencosmo/index/build.py b/python/opencosmo/index/build.py new file mode 100644 index 00000000..5d831f06 --- /dev/null +++ b/python/opencosmo/index/build.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from .mask import into_array + +if TYPE_CHECKING: + import h5py + + from . import ChunkedIndex, DataIndex, SimpleIndex + + +def from_size(size: int) -> ChunkedIndex: + return (np.array([0], dtype=np.int64), np.array([size], dtype=np.int64)) + + +def single_chunk(start: int, size: int) -> ChunkedIndex: + return (np.array([start], dtype=np.int64), np.array([size], np.int64)) + + +def empty() -> ChunkedIndex: + return (np.array([], dtype=np.int64), np.array([], dtype=np.int64)) + + +def from_range(start: int, end: int) -> ChunkedIndex: + size = end - start + return (np.array([start], dtype=np.int64), np.array([size], np.int64)) + + +def concatenate(*indices: DataIndex) -> SimpleIndex: + return np.concatenate(list(map(into_array, indices))) + + +def from_start_size_group(group: h5py) -> ChunkedIndex: + start = group["start"][:].astype(np.int64) + size = group["size"][:].astype(np.int64) + return (start, size) diff --git a/src/opencosmo/index/get.py b/python/opencosmo/index/get.py similarity index 76% rename from src/opencosmo/index/get.py rename to python/opencosmo/index/get.py index 6baa8eb7..2f5c2a18 100644 --- a/src/opencosmo/index/get.py +++ b/python/opencosmo/index/get.py @@ -9,22 +9,22 @@ from .unary import get_length if TYPE_CHECKING: - from numpy.typing import NDArray + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex -def get_data(data: h5py.Dataset | np.ndarray, index: np.ndarray | tuple): +def get_data(data: h5py.Dataset | np.ndarray, index: DataIndex): if get_length(index) == 0: - return np.array([]) + return np.array([], data.dtype) match index: case np.ndarray(): return get_data_simple(data, index) case (np.ndarray(), np.ndarray()): - return get_data_chunked(data, *index) + return get_data_chunked(data, index) case _: raise ValueError(f"Got invalid index of type {type(index)}") -def get_data_simple(data: h5py.Dataset | np.ndarray, index: NDArray[np.int_]): +def get_data_simple(data: h5py.Dataset | np.ndarray, index: SimpleIndex): if isinstance(data, np.ndarray): return data[index] @@ -41,14 +41,15 @@ def get_data_simple(data: h5py.Dataset | np.ndarray, index: NDArray[np.int_]): return buffer[index - min_] -def get_data_chunked( - data: h5py.Dataset | np.ndarray, starts: NDArray[np.int_], sizes: NDArray[np.int_] -): +def get_data_chunked(data: h5py.Dataset | np.ndarray, index: ChunkedIndex): """ We assume that starts are ordered, and chunks are non-overlapping """ + starts = index[0] + sizes = index[1] + unit = None if isinstance(data, u.Quantity): unit = data.unit @@ -66,7 +67,7 @@ def get_data_chunked( else: storage[dest_slice] = data[source_slice] - running_index += size + running_index += int(size) if unit is not None: storage *= unit diff --git a/python/opencosmo/index/in_range.py b/python/opencosmo/index/in_range.py new file mode 100644 index 00000000..165a6836 --- /dev/null +++ b/python/opencosmo/index/in_range.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from opencosmo._lib import index as idx + +if TYPE_CHECKING: + from opencosmo.index import DataIndex, IndexArray, SimpleIndex + + +def n_in_range( + index: DataIndex, + range_starts: int | IndexArray, + range_sizes: int | IndexArray, +) -> IndexArray: + range_starts = np.atleast_1d(range_starts) + range_sizes = np.atleast_1d(range_sizes) + match index: + case np.ndarray(): + return __n_in_range_simple(index, range_starts, range_sizes) + case (np.ndarray(), np.ndarray()): + return idx.n_in_range_chunked(index[0], index[1], range_starts, range_sizes) + case _: + raise ValueError(f"Unknown index type {type(index)}") + + +def __n_in_range_simple( + index: SimpleIndex, start: IndexArray, size: IndexArray +) -> IndexArray: + if len(start) != len(size): + raise ValueError("Start and size arrays must have the same length") + if np.any(size < 0): + raise ValueError("Sizes must greater than or equal to zero") + if len(index) == 0: + return np.zeros_like(start) + + ends = start + size + index_to_search = np.sort(index) + start_idxs = np.searchsorted(index_to_search, start, "left") + end_idxs = np.searchsorted(index_to_search, ends, "left") + return end_idxs - start_idxs diff --git a/src/opencosmo/index/mask.py b/python/opencosmo/index/mask.py similarity index 62% rename from src/opencosmo/index/mask.py rename to python/opencosmo/index/mask.py index 9dc95f26..5f85cc98 100644 --- a/src/opencosmo/index/mask.py +++ b/python/opencosmo/index/mask.py @@ -2,16 +2,18 @@ from typing import TYPE_CHECKING -import numba as nb import numpy as np if TYPE_CHECKING: from numpy.typing import NDArray + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex + +from opencosmo._lib import index as idxlib from opencosmo.index.unary import get_length -def mask(index, boolean_mask): +def mask(index: DataIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: match index: case np.ndarray(): return __mask_simple(index, boolean_mask) @@ -21,7 +23,7 @@ def mask(index, boolean_mask): raise TypeError(f"Unknown index type {type(index)}") -def __mask_simple(index: NDArray[np.int_], boolean_mask: NDArray[np.bool_]): +def __mask_simple(index: SimpleIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: if (lm := len(boolean_mask)) > len(index): raise ValueError( "Boolean mask must be less than or equal to the length of the index itself" @@ -30,12 +32,12 @@ def __mask_simple(index: NDArray[np.int_], boolean_mask: NDArray[np.bool_]): return index[:lm][boolean_mask] -def __mask_chunked(index: tuple, boolean_mask: NDArray[np.bool_]): +def __mask_chunked(index: ChunkedIndex, boolean_mask: NDArray[np.bool_]) -> SimpleIndex: array = into_array(index) return array[boolean_mask] -def into_array(index: np.ndarray | tuple): +def into_array(index: DataIndex) -> SimpleIndex: if get_length(index) == 0: return np.array([], dtype=np.int64) @@ -46,16 +48,6 @@ def into_array(index: np.ndarray | tuple): if len(index[0]) == 1: return np.arange(index[0][0], index[0][0] + index[1][0]) - return __chunked_into_array(*index) - - -@nb.njit -def __chunked_into_array(starts: NDArray[np.int_], sizes: NDArray[np.int_]): - output = np.zeros(np.sum(sizes), dtype=np.int64) - rs = 0 - for i in range(len(starts)): - output[rs : rs + sizes[i]] = np.arange( - starts[i], starts[i] + sizes[i], dtype=np.int64 - ) - rs += sizes[i] - return output + return idxlib.chunked_into_array(*index) + case _: + raise ValueError(f"Expected a DataIndex, got {type(index)}") diff --git a/python/opencosmo/index/ops.py b/python/opencosmo/index/ops.py new file mode 100644 index 00000000..bdb12a42 --- /dev/null +++ b/python/opencosmo/index/ops.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from opencosmo._lib import index as idxlib +from opencosmo.index import into_array + +if TYPE_CHECKING: + from opencosmo.index import ChunkedIndex, DataIndex + + +def reindex_column(index: DataIndex, column: np.ndarray): + column = column.astype(np.int64) + return idxlib.reindex_column(into_array(index), column) + + +def rebuild_by_ranges(index: DataIndex, ranges: ChunkedIndex): + match index: + case np.ndarray(): + return idxlib.rebuild_simple_by_ranges(index, *ranges) + case (np.ndarray(), np.ndarray()): + return idxlib.rebuild_chunked_by_ranges(*index, *ranges) + + +def offset(index: DataIndex, offset_amount: int): + if isinstance(index, np.ndarray): + return index + offset_amount + return (index[0] + offset_amount, index[1]) diff --git a/src/opencosmo/index/project.py b/python/opencosmo/index/project.py similarity index 67% rename from src/opencosmo/index/project.py rename to python/opencosmo/index/project.py index d2b6d63b..57865a08 100644 --- a/src/opencosmo/index/project.py +++ b/python/opencosmo/index/project.py @@ -4,13 +4,15 @@ import numpy as np +from opencosmo._lib import index as idxlib + from . import into_array if TYPE_CHECKING: from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex -def project(source: DataIndex, other: DataIndex): +def project(source: DataIndex, other: DataIndex) -> DataIndex: match (source, other): case (tuple(), np.ndarray()): return __project_simple_on_chunked(source, other) @@ -20,22 +22,30 @@ def project(source: DataIndex, other: DataIndex): return __project_simple_on_simple(source, other) case (np.ndarray(), tuple()): return __project_chunked_on_simple(source, other) + case _: + raise TypeError(f"Invalid index types: {type(source)}, {type(other)}") -def __project_simple_on_simple(source, other: DataIndex): +def __project_simple_on_simple(source: SimpleIndex, other: SimpleIndex) -> SimpleIndex: isin = np.isin(source, other) return np.where(isin)[0] -def __project_chunked_on_simple(source, other: DataIndex): - return project(source, into_array(other)) +def __project_chunked_on_simple( + source: SimpleIndex, other: ChunkedIndex +) -> SimpleIndex: + if len(other[0]) == 0: + return np.array([], dtype=np.int64) + return idxlib.project_chunked_on_simple(source, *other) -def __project_simple_on_chunked(source: ChunkedIndex, other: SimpleIndex): +def __project_simple_on_chunked(source: ChunkedIndex, other: SimpleIndex) -> DataIndex: return project(into_array(source), other) -def __project_chunked_on_chunked(source: ChunkedIndex, other: ChunkedIndex): +def __project_chunked_on_chunked( + source: ChunkedIndex, other: ChunkedIndex +) -> ChunkedIndex: source_ends = source[0] + source[1] other_ends = other[0] + other[1] diff --git a/python/opencosmo/index/take.py b/python/opencosmo/index/take.py new file mode 100644 index 00000000..b4d9e284 --- /dev/null +++ b/python/opencosmo/index/take.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from opencosmo._lib import index as idxlib + +if TYPE_CHECKING: + from opencosmo.index import ChunkedIndex, DataIndex, SimpleIndex + + +def take(from_: DataIndex, by: DataIndex) -> DataIndex: + match (from_, by): + case (np.ndarray(), np.ndarray()): + return __take_simple_from_simple(from_, by) + case (np.ndarray(), (np.ndarray(), np.ndarray())): + return idxlib.take_chunked_from_simple(from_, *by) + case ((np.ndarray(), np.ndarray()), np.ndarray()): + return __take_simple_from_chunked(from_, by) + case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): + return idxlib.take_chunked_from_chunked(*from_, *by) + case _: + raise TypeError(f"Invalid index types: {type(from_)}, {type(by)}") + + +def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex) -> SimpleIndex: + cumulative = np.insert(np.cumsum(from_[1]), 0, 0)[:-1] + + indices_into_chunks = np.argmax(by[:, np.newaxis] < cumulative, axis=1) - 1 + output = by - cumulative[indices_into_chunks] + from_[0][indices_into_chunks] + return output + + +def __take_simple_from_simple(from_: SimpleIndex, by: SimpleIndex) -> SimpleIndex: + return from_[by] diff --git a/python/opencosmo/index/unary.py b/python/opencosmo/index/unary.py new file mode 100644 index 00000000..8abe2526 --- /dev/null +++ b/python/opencosmo/index/unary.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from opencosmo._lib import index as idx + +if TYPE_CHECKING: + from opencosmo.index import DataIndex + +""" +Implementations for unary operations on indices +""" + + +def get_length(index: DataIndex) -> int: + match index: + case np.ndarray(): + return len(index) + case (np.ndarray(), np.ndarray()): + return int(np.sum(index[1])) + case _: + raise TypeError(f"Invalid index type {type(index)}") + + +def get_range(index: DataIndex) -> tuple[int, int]: + match index: + case np.ndarray(): + return idx.get_simple_range(index) + case (np.ndarray(), np.ndarray()): + return idx.get_chunked_range(*index) + case _: + raise ValueError(f"Unknown index type {type(index)}") diff --git a/src/opencosmo/io/__init__.py b/python/opencosmo/io/__init__.py similarity index 100% rename from src/opencosmo/io/__init__.py rename to python/opencosmo/io/__init__.py diff --git a/src/opencosmo/io/io.py b/python/opencosmo/io/io.py similarity index 100% rename from src/opencosmo/io/io.py rename to python/opencosmo/io/io.py diff --git a/src/opencosmo/io/iopen.py b/python/opencosmo/io/iopen.py similarity index 91% rename from src/opencosmo/io/iopen.py rename to python/opencosmo/io/iopen.py index 0cdfa1a8..d3c9f67f 100644 --- a/src/opencosmo/io/iopen.py +++ b/python/opencosmo/io/iopen.py @@ -11,11 +11,14 @@ import opencosmo as oc from opencosmo import collection as occ +from opencosmo.collection.structure import structure as sc from opencosmo.dataset import state as st from opencosmo.dataset.mpi import partition from opencosmo.header import OpenCosmoHeader, read_header from opencosmo.index.build import empty, from_range from opencosmo.mpi import get_comm_world +from opencosmo.plugins.contexts import DatasetOpenCtx, HookPoint +from opencosmo.plugins.hook import fold from opencosmo.spatial.builders import from_model from opencosmo.spatial.region import FullSkyRegion, HealpixRegion from opencosmo.spatial.tree import open_tree @@ -50,6 +53,7 @@ class DatasetTarget(TypedDict): header: OpenCosmoHeader dataset_group: h5py.Group columns: list[h5py.Dataset] + spatial_index: Optional[h5py.Group] class FileType(Enum): @@ -85,7 +89,8 @@ def open_files(paths: list[Path], open_kwargs: dict[str, Any]): if len(valid_targets) > 1: collection_type = __determine_multi_file_collection_type(valid_targets) return collection_type.open(valid_targets, **open_kwargs) - return __open_single_file(valid_targets[0]) + + return __open_single_file(valid_targets[0], open_kwargs) def __make_group_map(group: h5py.File | h5py.Group, prefix: str = ""): @@ -117,14 +122,18 @@ def __make_file_target(path: Path, open_kwargs: dict[str, Any]) -> Optional[File ) -def __open_single_file(target: FileTarget) -> oc.Dataset | oc.collection.Collection: +def __open_single_file( + target: FileTarget, open_kwargs: dict[str, Any] = {} +) -> oc.Dataset | oc.collection.Collection: """ Opens a single file, which may or may not contain several datasets """ if len(target["dataset_targets"]) == 1: # Just one dataset, easy - return open_single_dataset(target["dataset_targets"][0]) + return open_single_dataset( + target["dataset_targets"][0], open_kwargs=open_kwargs + ) elif target["dataset_targets"]: # Multiple datasets, but all grouped together @@ -136,7 +145,7 @@ def __open_single_file(target: FileTarget) -> oc.Dataset | oc.collection.Collect == FileType.STRUCTURE_COLLECTION ): # Structure collection - return occ.StructureCollection.open([target]) + return occ.StructureCollection.open([target], **open_kwargs) elif target["dataset_groups"]: # Sometimes, lightcones have multiple datasets per slice if all( @@ -145,6 +154,15 @@ def __open_single_file(target: FileTarget) -> oc.Dataset | oc.collection.Collect ): return occ.Lightcone.open([target]) + # Lightcone structure collection + elif ( + target["dataset_group_types"].get("halo_properties") == FileType.LIGHTCONE + or target["dataset_group_types"].get("galaxy_properties") + == FileType.LIGHTCONE + ): + result = sc.StructureCollection.open([target], **open_kwargs) + return result + datasets = { name: __open_dataset_targets_for_sim_collection( targets, target["dataset_group_types"][name] @@ -389,6 +407,7 @@ def __find_datasets_under_group( file_map.keys(), ) ) + for ds_group_name in known_dataset_groups: ds_group_parent = ds_group_name.rsplit("/", maxsplit=1)[0] ds_group_parent += "/" @@ -396,17 +415,22 @@ def __find_datasets_under_group( columns = [ nds_[1] for nds_ in filter( - lambda nds: ds_group_parent in nds[0] - and "header" not in nds[0] - and isinstance(nds[1], h5py.Dataset), + lambda nds: ( + ds_group_parent in nds[0] + and isinstance(nds[1], h5py.Dataset) + and "header" not in nds[0] + and f"{ds_group_parent}index" not in nds[0] + ), file_map.items(), ) ] + index_group = file_map.get(f"{ds_group_parent}index") target = DatasetTarget( header=header, dataset_group=file_map[ds_group_name].parent, columns=columns, + spatial_index=index_group, ) if evaluate_load_conditions(target, open_kwargs): known_datasets.append(target) @@ -462,6 +486,7 @@ def open_single_dataset( metadata_group: Optional[str] = None, bypass_lightcone: bool = False, bypass_mpi: bool = False, + open_kwargs: dict[str, Any] = {}, ): header = target["header"] ds_group = target["dataset_group"] @@ -469,27 +494,24 @@ def open_single_dataset( assert header is not None - index_columns = { - col.name: col for col in columns if col.name.split("/")[-3] == "index" - } try: box_size = header.with_units("scalefree").simulation["box_size"].value except AttributeError: box_size = None - try: + if target["spatial_index"] is not None: tree = open_tree( - index_columns, + target["spatial_index"], box_size, header.file.is_lightcone, ) - except (ValueError, AttributeError): + else: tree = None if header.file.region is not None: sim_region = from_model(header.file.region) elif header.file.is_lightcone and tree is not None: - pixels = tree.get_full_index(tree.max_level) + pixels = tree.get_partitions_with_data(tree.max_level) sim_region = HealpixRegion(pixels, nside=2**tree.max_level) elif header.file.data_type == "healpix_map": assert header.healpix_map["full_sky"] @@ -505,8 +527,7 @@ def open_single_dataset( if not bypass_mpi and (comm := get_comm_world()) is not None: assert partition is not None try: - idx_data = ds_group["index"] - part = partition(comm, ds_length, idx_data, tree) + part = partition(comm, header, ds_group["index"], ds_group["data"], tree) if part is None: index = empty() else: @@ -523,10 +544,11 @@ def open_single_dataset( rank = comm.Get_rank() index = from_range(chunk_boundaries[rank], chunk_boundaries[rank + 1]) - state = st.DatasetState.from_target( + state = st.state_from_target( target, UnitConvention.COMOVING, sim_region, + open_kwargs, index, metadata_group, ) @@ -536,11 +558,12 @@ def open_single_dataset( state, tree=tree, ) + dataset = fold(HookPoint.DatasetOpen, DatasetOpenCtx(dataset, open_kwargs)).dataset if header.file.data_type == "healpix_map": return __open_healpix_map(dataset, sim_region) elif header.file.is_lightcone and not bypass_lightcone: return occ.Lightcone.from_datasets( - {"data": dataset}, header.lightcone["z_range"] + {0: dataset}, header.lightcone["z_range"], **open_kwargs ) return dataset @@ -579,7 +602,7 @@ def __expand_lightcone_region(region, tree): pixels = pixels[:, None] * npix_ratio + np.arange(npix_ratio) pixels = pixels.flatten() - full_pixels = tree.get_full_index(tree.max_level) + full_pixels = tree.get_partitions_with_data(tree.max_level) full_pixels = np.intersect1d(pixels, full_pixels) return HealpixRegion(full_pixels, 2**tree.max_level) diff --git a/src/opencosmo/io/mpi.py b/python/opencosmo/io/mpi.py similarity index 99% rename from src/opencosmo/io/mpi.py rename to python/opencosmo/io/mpi.py index cc990ad3..f4122700 100644 --- a/src/opencosmo/io/mpi.py +++ b/python/opencosmo/io/mpi.py @@ -72,6 +72,7 @@ def write_parallel(file: Path, file_schema: Schema): results = comm.allgather(CombineState.VALID) except ValueError: results = comm.allgather(CombineState.INVALID) + raise except ZeroLengthError: results = comm.allgather(CombineState.ZERO_LENGTH) if any(rs == CombineState.INVALID for rs in results): @@ -297,6 +298,7 @@ def __replace_writers_with_updates(schema: Schema, comm: MPI.Comm): child_schema = schema.children.get(cn, make_schema(cn, FileEntry.EMPTY)) new_child_schema = __replace_writers_with_updates(child_schema, comm) schema.children[cn] = new_child_schema + return schema @@ -461,7 +463,8 @@ def __write_column( if writer is not None: data = writer.get_data(new_comm) else: - data = np.empty((0,), dtype=ds.dtype) + shape = (0,) + ds.shape[1:] + data = np.empty(shape, dtype=ds.dtype) ds.write_direct(data, dest_sel=np.s_[offset : offset + len(data)]) diff --git a/src/opencosmo/io/parquet.py b/python/opencosmo/io/parquet.py similarity index 100% rename from src/opencosmo/io/parquet.py rename to python/opencosmo/io/parquet.py diff --git a/src/opencosmo/io/protocols.py b/python/opencosmo/io/protocols.py similarity index 100% rename from src/opencosmo/io/protocols.py rename to python/opencosmo/io/protocols.py diff --git a/src/opencosmo/io/schema.py b/python/opencosmo/io/schema.py similarity index 88% rename from src/opencosmo/io/schema.py rename to python/opencosmo/io/schema.py index 416a1267..7e6fa4f6 100644 --- a/src/opencosmo/io/schema.py +++ b/python/opencosmo/io/schema.py @@ -28,6 +28,14 @@ class Schema(NamedTuple): attributes: dict[str, Any] +def dataset_schema_length(schema: Schema) -> Optional[int]: + if schema.type != FileEntry.DATASET: + return None + + column = next(iter(schema.children["data"].columns.values())) + return column.shape[0] + + def empty_schema(name: str, type_: FileEntry) -> Schema: return Schema(name, type_, {}, {}, {}) diff --git a/src/opencosmo/io/serial.py b/python/opencosmo/io/serial.py similarity index 75% rename from src/opencosmo/io/serial.py rename to python/opencosmo/io/serial.py index 15dab840..0653eae9 100644 --- a/src/opencosmo/io/serial.py +++ b/python/opencosmo/io/serial.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +from opencosmo.io.schema import dataset_schema_length + if TYPE_CHECKING: import h5py @@ -10,17 +12,26 @@ def allocate(group: h5py.File | h5py.Group, schema: Schema): for column_name, column_writer in schema.columns.items(): + if column_writer.shape[0] == 0: + continue group.require_dataset(column_name, column_writer.shape, column_writer.dtype) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue + child_group = group.require_group(child_name) allocate(child_group, child_schema) def write_columns(group: h5py.File | h5py.Group, schema: Schema): for column_path, column_writer in schema.columns.items(): + if column_writer.shape[0] == 0: + continue group[column_path][:] = column_writer.data group[column_path].attrs.update(column_writer.attrs) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue write_columns(group[child_name], child_schema) @@ -30,4 +41,6 @@ def write_metadata(group: h5py.File | h5py.Group, schema: Schema): metadata_group.attrs.update(metadata) for child_name, child_schema in schema.children.items(): + if dataset_schema_length(child_schema) == 0: + continue write_metadata(group[child_name], child_schema) diff --git a/src/opencosmo/io/updaters.py b/python/opencosmo/io/updaters.py similarity index 100% rename from src/opencosmo/io/updaters.py rename to python/opencosmo/io/updaters.py diff --git a/src/opencosmo/io/verify.py b/python/opencosmo/io/verify.py similarity index 88% rename from src/opencosmo/io/verify.py rename to python/opencosmo/io/verify.py index efbfc0f1..c34bd96b 100644 --- a/src/opencosmo/io/verify.py +++ b/python/opencosmo/io/verify.py @@ -146,7 +146,20 @@ def verify_structure_collection_data(schema: Schema): raise ValueError("No valid link holder found in schema!") for child_name, child_schema in schema.children.items(): - if child_name == link_holder: + if child_schema.type == FileEntry.LIGHTCONE and child_name == link_holder: + for grandchild_name, grandchild_schema in child_schema.children.items(): + has_link = any( + map( + lambda cn: "data_linked" in cn, + grandchild_schema.children.keys(), + ) + ) + if not has_link: + raise ValueError( + f'Source dataset {child_name}/{grandchild_name} does not have expected "data_linked" group' + ) + + elif child_name == link_holder: has_link = any( map(lambda cn: "data_linked" in cn, child_schema.children.keys()) ) @@ -160,5 +173,7 @@ def verify_structure_collection_data(schema: Schema): verify_dataset_data(child_schema) case FileEntry.STRUCTURE_COLLECTION: verify_structure_collection_data(child_schema) + case FileEntry.LIGHTCONE: + verify_lightcone_collection_schema(child_schema) case _: raise ValueError("Got an unknown child for structure collection!") diff --git a/src/opencosmo/io/writer.py b/python/opencosmo/io/writer.py similarity index 100% rename from src/opencosmo/io/writer.py rename to python/opencosmo/io/writer.py diff --git a/src/opencosmo/mpi.py b/python/opencosmo/mpi.py similarity index 100% rename from src/opencosmo/mpi.py rename to python/opencosmo/mpi.py diff --git a/python/opencosmo/plugins/__init__.py b/python/opencosmo/plugins/__init__.py new file mode 100644 index 00000000..a4d67ff5 --- /dev/null +++ b/python/opencosmo/plugins/__init__.py @@ -0,0 +1,25 @@ +from .contexts import ( + DatasetInstantiateCtx, + DatasetOpenCtx, + HookPoint, + IndexUpdateCtx, + LightconeInstantiateCtx, + LightconeOpenCtx, + PartitionCtx, + PostSortCtx, +) +from .hook import fold, hook, query + +__all__ = [ + "fold", + "hook", + "query", + "HookPoint", + "DatasetOpenCtx", + "DatasetInstantiateCtx", + "LightconeOpenCtx", + "LightconeInstantiateCtx", + "IndexUpdateCtx", + "PostSortCtx", + "PartitionCtx", +] diff --git a/python/opencosmo/plugins/contexts.py b/python/opencosmo/plugins/contexts.py new file mode 100644 index 00000000..19fee3ad --- /dev/null +++ b/python/opencosmo/plugins/contexts.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + import h5py + import numpy as np + from astropy.table import Table + from mpi4py import MPI + + from opencosmo import Dataset, Lightcone + from opencosmo.dataset.state import DatasetState + from opencosmo.header import OpenCosmoHeader + from opencosmo.index import DataIndex, IndexArray + from opencosmo.spatial.tree import Tree + + +class HookPoint(StrEnum): + DatasetOpen = "dataset_open" + DatasetInstantiate = "dataset_instantiate" + LightconeOpen = "lightcone_open" + LightconeInstantiate = "lightcone_instantiate" + IndexUpdate = "index_update" + PostSort = "post_sort" + Partition = "partition" + + +# --- fold hooks --- +# Plugins receive and return these contexts. Use dataclasses.replace() to +# produce a modified copy rather than mutating in place. + + +@dataclass(frozen=True) +class DatasetOpenCtx: + """Fired once per dataset after it is opened from disk. + + open_kwargs holds any keyword arguments the user passed to opencosmo.open(). + Plugins can inspect them to conditionally modify the dataset. + """ + + dataset: Dataset + open_kwargs: dict[str, Any] + + +@dataclass(frozen=True) +class DatasetInstantiateCtx: + """Fired each time get_data() is called on a DatasetState. + + Plugins may add or modify derived columns before the data is materialised. + """ + + state: DatasetState + + +@dataclass(frozen=True) +class LightconeOpenCtx: + """Fired once per Lightcone after it is opened from disk. + + open_kwargs mirrors DatasetOpenCtx.open_kwargs. + """ + + lightcone: Lightcone + open_kwargs: dict[str, Any] + + +@dataclass(frozen=True) +class LightconeInstantiateCtx: + """Fired each time get_data() is called on a Lightcone. + + Plugins may re-order or re-index sub-datasets before they are stacked. + """ + + lightcone: Lightcone + + +@dataclass(frozen=True) +class IndexUpdateCtx: + """Fired whenever a filter, take, or bound operation produces a new index. + + state is read-only context for the predicate and for reading column data. + index is the value being transformed; return a modified copy via + dataclasses.replace(ctx, index=new_index). + """ + + state: DatasetState + index: DataIndex + + +@dataclass(frozen=True) +class PostSortCtx: + """Fired after a sort operation reorders rows. + + state is read-only context for the predicate. + index is the reverse-sort permutation (i.e. np.argsort of the sort order), + which can be used to remap index-valued columns. + data is the value being transformed; return a modified copy via + dataclasses.replace(ctx, data=new_data). + """ + + state: DatasetState | Lightcone + data: Table | dict[str, np.ndarray] + index: IndexArray + + +# --- query hook --- +# Partition uses query() rather than fold(). The plugin returns a +# TreePartition directly (not a modified context), and the first non-None +# result wins. No plugin responding means the caller falls back to the +# default partitioning strategy. + + +@dataclass(frozen=True) +class PartitionCtx: + """Fired during MPI open to determine how rows are distributed across ranks. + + Plugins return a TreePartition, or None to defer to the default strategy. + This hook uses query() semantics: at most one plugin responds. + """ + + comm: MPI.Comm + header: OpenCosmoHeader + index_group: h5py.Group + data_group: h5py.Group + tree: Optional[Tree] = None + min_level: Optional[int] = None diff --git a/python/opencosmo/plugins/hook.py b/python/opencosmo/plugins/hook.py new file mode 100644 index 00000000..1e6e4b6d --- /dev/null +++ b/python/opencosmo/plugins/hook.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from functools import reduce +from typing import Any, Callable, TypeVar + +Ctx = TypeVar("Ctx") +R = TypeVar("R") + +type Predicate[Ctx] = Callable[[Ctx], bool] + + +@dataclass(frozen=True) +class Registration[Ctx]: + predicate: Predicate[Ctx] + transform: Callable[[Ctx], Any] + + +_registry: dict[str, list[Registration]] = defaultdict(list) + + +def hook(name: str, *, when: Predicate = lambda _: True): + """Decorator that registers a function as a hook implementation. + + The decorated function receives and returns a context dataclass. For fold + hooks, it should return a (possibly modified) copy of the context; use + dataclasses.replace() to do so functionally. For query hooks, it should + return either a result value or None to signal no match. + + Parameters + ---------- + name: + The hook point name. Use a constant from HookPoint. + when: + Predicate called with the context. The hook only fires when this + returns True. + """ + + def decorator(fn: Callable) -> Callable: + _registry[name].append(Registration(when, fn)) + return fn + + return decorator + + +def fold(name: str, ctx: Ctx) -> Ctx: + """Apply all registered hooks for *name* in order, threading ctx through. + + Each matching hook receives the output of the previous one. Non-matching + hooks (predicate returns False) are skipped and ctx passes through unchanged. + """ + return reduce( + lambda c, reg: reg.transform(c) if reg.predicate(c) else c, + _registry[name], + ctx, + ) + + +def query(name: str, ctx: Any) -> Any | None: + """Return the result of the first matching hook for *name*, or None. + + Intended for hook points where at most one plugin should respond — the + first hook whose predicate matches is called, and its return value is + returned immediately. Subsequent hooks are not evaluated. + """ + return next( + ( + result + for reg in _registry[name] + if reg.predicate(ctx) and (result := reg.transform(ctx)) is not None + ), + None, + ) diff --git a/src/opencosmo/spatial/__init__.py b/python/opencosmo/spatial/__init__.py similarity index 100% rename from src/opencosmo/spatial/__init__.py rename to python/opencosmo/spatial/__init__.py diff --git a/src/opencosmo/spatial/builders.py b/python/opencosmo/spatial/builders.py similarity index 100% rename from src/opencosmo/spatial/builders.py rename to python/opencosmo/spatial/builders.py diff --git a/src/opencosmo/spatial/check.py b/python/opencosmo/spatial/check.py similarity index 98% rename from src/opencosmo/spatial/check.py rename to python/opencosmo/spatial/check.py index 8c1473fe..9553cdf6 100644 --- a/src/opencosmo/spatial/check.py +++ b/python/opencosmo/spatial/check.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from opencosmo.dataset.dataset import Dataset - from opencosmo.parameters import FileParameters + from opencosmo.dtypes import FileParameters from opencosmo.spatial.protocols import Region ALLOWED_COORDINATES_3D = { diff --git a/src/opencosmo/spatial/healpix.py b/python/opencosmo/spatial/healpix.py similarity index 66% rename from src/opencosmo/spatial/healpix.py rename to python/opencosmo/spatial/healpix.py index 19a44adb..442b6db4 100644 --- a/src/opencosmo/spatial/healpix.py +++ b/python/opencosmo/spatial/healpix.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING +import healpy as hp import numpy as np +from astropy.coordinates import SkyCoord from opencosmo.index import into_array from opencosmo.spatial.region import HealpixRegion @@ -47,4 +49,18 @@ def query( raise ValueError("Didn't recieve a 2D region!") nside = 2**level intersects = region.get_healpix_intersections(nside) - return {level: (np.array([]), intersects)} + boundaries = ( + hp.boundaries(nside, intersects, nest=True) + .transpose( + 0, + 2, + 1, + ) + .reshape(-1, 3) + ) + coords = SkyCoord(*hp.vec2ang(boundaries, lonlat=True), unit="deg") + coord_is_contained = region.contains(coords) + pixel_is_contained = np.all(coord_is_contained.reshape(-1, 4), axis=1) + return { + level: (intersects[pixel_is_contained], intersects[~pixel_is_contained]) + } diff --git a/src/opencosmo/spatial/models.py b/python/opencosmo/spatial/models.py similarity index 100% rename from src/opencosmo/spatial/models.py rename to python/opencosmo/spatial/models.py diff --git a/src/opencosmo/spatial/octree.py b/python/opencosmo/spatial/octree.py similarity index 98% rename from src/opencosmo/spatial/octree.py rename to python/opencosmo/spatial/octree.py index ba87a290..0a00d47e 100644 --- a/src/opencosmo/spatial/octree.py +++ b/python/opencosmo/spatial/octree.py @@ -117,7 +117,9 @@ def __init__(self, root: Octant): self.root = root def get_partition_region(self, index: SimpleIndex, level: int): - octants = [get_octant(idx, level, 2 * self.root.halfwidth) for idx in index] + octants = [ + get_octant(int(idx), level, 2 * self.root.halfwidth) for idx in index + ] return get_region(octants) @classmethod diff --git a/src/opencosmo/spatial/protocols.py b/python/opencosmo/spatial/protocols.py similarity index 100% rename from src/opencosmo/spatial/protocols.py rename to python/opencosmo/spatial/protocols.py diff --git a/src/opencosmo/spatial/region.py b/python/opencosmo/spatial/region.py similarity index 100% rename from src/opencosmo/spatial/region.py rename to python/opencosmo/spatial/region.py diff --git a/src/opencosmo/spatial/relations.py b/python/opencosmo/spatial/relations.py similarity index 99% rename from src/opencosmo/spatial/relations.py rename to python/opencosmo/spatial/relations.py index 0aae3ff7..fd19481f 100644 --- a/src/opencosmo/spatial/relations.py +++ b/python/opencosmo/spatial/relations.py @@ -160,7 +160,7 @@ def __healpix_contains_other(region: HealpixRegion, other) -> bool: intersections = other.get_healpix_intersections(region.nside) except AttributeError: raise ValueError(f"Expected a 2D Sky Region but received {type(other)}") - return len(np.intersect1d(intersections, region.pixels)) == len(intersections) + return bool(np.all(np.isin(intersections, region.pixels))) # --------------------------------------------------------------------------- diff --git a/src/opencosmo/spatial/tree.py b/python/opencosmo/spatial/tree.py similarity index 84% rename from src/opencosmo/spatial/tree.py rename to python/opencosmo/spatial/tree.py index 5eadad05..bbcbfaed 100644 --- a/src/opencosmo/spatial/tree.py +++ b/python/opencosmo/spatial/tree.py @@ -13,7 +13,14 @@ MPI = None # type: ignore -from opencosmo.index import from_size, get_data, n_in_range +from opencosmo.index import ( + from_size, + from_start_size_group, + get_data, + into_array, + n_in_range, + project, +) from opencosmo.io.schema import FileEntry, make_schema from opencosmo.io.writer import ( ColumnCombineStrategy, @@ -31,7 +38,7 @@ def open_tree( - tree_columns: dict[str, h5py.Dataset], + tree_group: h5py.Group, box_size: Optional[int], is_lightcone: bool = False, ): @@ -55,9 +62,7 @@ def open_tree( else: spatial_index = OctTreeIndex.from_box_size(box_size) - tree_columns = {name.split("index/")[-1]: col for name, col in tree_columns.items()} - - return Tree(spatial_index, tree_columns) + return Tree(spatial_index, tree_group) def read_tree(file: h5py.File | h5py.Group, box_size: int): @@ -150,7 +155,9 @@ def partition_index(n_partitions: int, counts: h5py.Group, min_level: int): split_level_indices = full_region_indices - return np.array_split(split_level_indices, n_partitions), split_level + return np.array_split( + split_level_indices.astype(np.int64), n_partitions + ), split_level class Tree: @@ -177,7 +184,7 @@ def __init__( def max_level(self): return self.__max_level - def get_full_index(self, level: int): + def get_partitions_with_data(self, level: int): if level > self.max_level: raise ValueError( "Requested level is greater than the max level of this tree!" @@ -186,6 +193,30 @@ def get_full_index(self, level: int): return np.where(sizes > 0)[0] + def get_occupied_partitions(self, level: int, index: DataIndex): + if level > self.max_level: + raise ValueError( + "Level must be less than or equal to the max level of this tree" + ) + starts = self.__columns[f"level_{level}/start"][:] + return np.searchsorted(starts, into_array(index), side="right") - 1 + + def project_on_index( + self, level: int, index: DataIndex, partitions: Optional[DataIndex] + ): + if level > self.max_level: + raise ValueError( + "Level must be less than or equal to the max level of this tree" + ) + starts = self.__columns[f"level_{level}/start"] + sizes = self.__columns[f"level_{level}/size"] + if partitions is None: + partitions = from_size(len(starts)) + + return project( + index, (get_data(starts, partitions), get_data(sizes, partitions)) + ) + def partition( self, n_partitions: int, counts: h5py.Group, min_level: Optional[int] = None ) -> Sequence[TreePartition]: @@ -201,8 +232,7 @@ def partition( n_partitions, counts, min_level ) partitions = [] - start = self.__columns[f"level_{split_level}/start"][:] - size = self.__columns[f"level_{split_level}/size"][:] + start, size = from_start_size_group(self.__columns[f"level_{split_level}"]) for index_ in partition_indices: if len(index_) == 0: continue @@ -223,8 +253,9 @@ def query(self, region: Region) -> tuple[ChunkedIndex, ChunkedIndex]: intersects = [] for level, (cidx, iidx) in indices.items(): level_key = f"level_{level}" - level_starts = self.__columns[f"{level_key}/start"] - level_sizes = self.__columns[f"{level_key}/size"] + level_starts, level_sizes = from_start_size_group( + self.__columns[f"{level_key}"] + ) c_starts = get_data(level_starts, cidx) c_sizes = get_data(level_sizes, cidx) i_starts = get_data(level_starts, iidx) @@ -239,8 +270,9 @@ def query(self, region: Region) -> tuple[ChunkedIndex, ChunkedIndex]: return (contains_start, contains_size), (intersects_start, intersects_size) def apply_index(self, index: DataIndex, min_counts: int = 100) -> Tree: - max_level_starts = self.__columns[f"level_{self.__max_level}/start"][:] - max_level_sizes = self.__columns[f"level_{self.__max_level}/size"][:] + max_level_starts, max_level_sizes = from_start_size_group( + self.__columns[f"level_{self.__max_level}"] + ) n = n_in_range(index, max_level_starts, max_level_sizes) target = h5py.File(f"{uuid1()}.hdf5", "w", driver="core", backing_store=False) result = combine_upwards( diff --git a/src/opencosmo/spatial/utils.py b/python/opencosmo/spatial/utils.py similarity index 88% rename from src/opencosmo/spatial/utils.py rename to python/opencosmo/spatial/utils.py index 200c91b5..bb006236 100644 --- a/src/opencosmo/spatial/utils.py +++ b/python/opencosmo/spatial/utils.py @@ -21,8 +21,8 @@ def combine_upwards( raise ValueError("Recieved invalid number of counts!") group = target.require_group(f"level_{level}") - new_starts = np.insert(np.cumsum(counts, dtype=np.int32), 0, 0)[:-1] - counts = counts.astype(np.int32) # This should be fixed + new_starts = np.insert(np.cumsum(counts, dtype=np.int64), 0, 0)[:-1] + counts = counts.astype(np.int64) # This should be fixed group.create_dataset("start", data=new_starts) group.create_dataset("size", data=counts) diff --git a/src/opencosmo/units/__init__.py b/python/opencosmo/units/__init__.py similarity index 100% rename from src/opencosmo/units/__init__.py rename to python/opencosmo/units/__init__.py diff --git a/src/opencosmo/units/convention.py b/python/opencosmo/units/convention.py similarity index 100% rename from src/opencosmo/units/convention.py rename to python/opencosmo/units/convention.py diff --git a/src/opencosmo/units/converters.py b/python/opencosmo/units/converters.py similarity index 97% rename from src/opencosmo/units/converters.py rename to python/opencosmo/units/converters.py index 24853806..1388040c 100644 --- a/src/opencosmo/units/converters.py +++ b/python/opencosmo/units/converters.py @@ -6,6 +6,7 @@ import astropy.units as u from astropy.cosmology import units as cu +import opencosmo.dataset.state as st from opencosmo.units import UnitConvention if TYPE_CHECKING: @@ -173,12 +174,12 @@ def get_scale_factor(dataset: "DatasetState", cosmology, redshift): columns = set(dataset.columns) for column in KNOWN_SCALEFACTOR_COLUMNS: if column in columns: - col = dataset.select({column}).get_data()[column] + col = st.get_data(st.select(dataset, {column}))[column] return col for column in KNOWN_REDSHIFT_COLUMNS: if column in columns: - col = dataset.select({column}).get_data()[column] + col = st.get_data(st.select(dataset, {column}))[column] return 1 / (1 + col) return cosmology.scale_factor(redshift) diff --git a/src/opencosmo/units/get.py b/python/opencosmo/units/get.py similarity index 100% rename from src/opencosmo/units/get.py rename to python/opencosmo/units/get.py diff --git a/src/opencosmo/units/handler.py b/python/opencosmo/units/handler.py similarity index 86% rename from src/opencosmo/units/handler.py rename to python/opencosmo/units/handler.py index 3b9e53fd..de582254 100644 --- a/src/opencosmo/units/handler.py +++ b/python/opencosmo/units/handler.py @@ -13,6 +13,8 @@ ) if TYPE_CHECKING: + from uuid import UUID + import h5py import numpy as np from astropy.cosmology import Cosmology @@ -87,6 +89,16 @@ def current_convention(self): def base_convention(self): return self.__base_convention + @property + def columns_with_conversions(self): + return { + name + for name, applicator in self.__applicators.items() + if name in self.__column_conversions + or str(applicator.unit_in_convention(self.current_convention)) + in self.__conversions + } + @cached_property def base_units(self): return {key: app.base_unit for key, app in self.__applicators.items()} @@ -208,25 +220,26 @@ def apply_raw_units(self, data: dict[str, np.ndarray], unit_kwargs): return columns def apply_unit_conversions( - self, data: dict[str, u.Quantity | np.ndarray], unit_kwargs - ): - # Only apply the unit CONVERSIONS. Useful for cached data - # Does not return data that was not updated - output_data = {} + self, + data: dict[UUID, dict[str, u.Quantity | np.ndarray]], + unit_kwargs, + ) -> dict[UUID, dict[str, u.Quantity | np.ndarray]]: + # Only apply the unit CONVERSIONS. Useful for cached data. + # Does not return data that was not updated. if not self.__conversions and not self.__column_conversions: return {} - for colname, column in data.items(): - if not isinstance(column, u.Quantity): - continue - assert isinstance(column, u.Quantity) - column_conversion = self.__column_conversions.get(colname) - unit_conversion = self.__conversions.get(str(column.unit)) - if unit_conversion is not None and column_conversion is None: - output_data[colname] = column.to(unit_conversion) - elif column_conversion is not None: - output_data[colname] = column.to(column_conversion) - - return output_data + output: dict[UUID, dict[str, u.Quantity | np.ndarray]] = {} + for uuid, uuid_data in data.items(): + for colname, column in uuid_data.items(): + if not isinstance(column, u.Quantity): + continue + column_conversion = self.__column_conversions.get(colname) + unit_conversion = self.__conversions.get(str(column.unit)) + if unit_conversion is not None and column_conversion is None: + output.setdefault(uuid, {})[colname] = column.to(unit_conversion) + elif column_conversion is not None: + output.setdefault(uuid, {})[colname] = column.to(column_conversion) + return output def apply_units(self, data: dict[str, np.ndarray], unit_kwargs): if self.__current_convention == UnitConvention.UNITLESS: diff --git a/src/index.rs b/src/index.rs new file mode 100644 index 00000000..961fcb06 --- /dev/null +++ b/src/index.rs @@ -0,0 +1,487 @@ +use pyo3::prelude::*; +#[pymodule] +pub(crate) mod index { + use numpy::ndarray::s; + use numpy::ndarray::{Array1, ArrayView1}; + use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1}; + use pyo3::exceptions::{PyTypeError, PyValueError}; + use pyo3::prelude::*; + use pyo3::types::PyList; + use std::collections::HashMap; + use std::iter::zip; + + fn unpack_index_array<'py>(index: &Bound<'py, PyAny>) -> PyResult> { + let index_data = index + .cast::>() + .map_err(|_| PyTypeError::new_err("Indices should be a 1-d array of i64"))?; + + Ok(index_data.readonly()) + } + + fn unpack_chunked_index<'py>( + start: &Bound<'py, PyAny>, + size: &Bound<'py, PyAny>, + ) -> PyResult<(PyReadonlyArray1<'py, i64>, PyReadonlyArray1<'py, i64>)> { + let start_arr = unpack_index_array(start)?; + let size_arr = unpack_index_array(size)?; + if start_arr.len()? != size_arr.len()? { + return Err(PyValueError::new_err( + "start and size must be the same length in a chunked index!", + )); + } + Ok((start_arr, size_arr)) + } + + #[pyfunction(name = "get_simple_range")] + pub(crate) fn get_simple_range_py(index: &Bound<'_, PyAny>) -> PyResult<(i64, i64)> { + let index_arr = unpack_index_array(index)?; + get_simple_range(index_arr.as_array()) + } + fn get_simple_range(index: ArrayView1<'_, i64>) -> PyResult<(i64, i64)> { + if index.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (index[0], index[0]); + for &item in index { + if item < index_range.0 { + index_range = (item, index_range.1) + } + if item > index_range.1 { + index_range = (index_range.0, item) + } + } + Ok(index_range) + } + + #[pyfunction(name = "get_chunked_range")] + pub(crate) fn get_chunked_range_py( + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult<(i64, i64)> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + get_chunked_range(start_arr.as_array(), size_arr.as_array()) + } + + fn get_chunked_range( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> PyResult<(i64, i64)> { + if start.len() == 0 { + return Ok((0, 0)); + } + let mut index_range = (start[0], start[0] + size[0]); + for (&st, &si) in zip(start, size) { + let end = st + si; + if st < index_range.0 { + index_range = (st, index_range.1); + } + if end > index_range.1 { + index_range = (index_range.0, end); + } + } + Ok(index_range) + } + + #[pyfunction(name = "n_in_range_chunked")] + pub(crate) fn n_in_range_chunked_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + range_start: &Bound<'_, PyAny>, + range_size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let (range_start_arr, range_size_arr) = unpack_chunked_index(range_start, range_size)?; + let result = n_in_range_chunked( + start_arr.as_array(), + size_arr.as_array(), + range_start_arr.as_array(), + range_size_arr.as_array(), + ); + Ok(result.into_pyarray(py)) + } + + fn n_in_range_chunked( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + range_start: ArrayView1<'_, i64>, + range_size: ArrayView1<'_, i64>, + ) -> Array1 { + let mut output = Array1::::zeros(range_start.len()); + if start.len() == 0 { + return output; + } + let end = &start + &size; + for (i, (&rst, &rsi)) in zip(range_start, range_size).enumerate() { + let chunk_end = rst + rsi; + let total = zip(start, &end) + .filter(|(s, e)| !(**s > chunk_end || **e < rst)) + .map(|(&s, &e)| { + let mut cr = (s, e); + if chunk_end < e { + cr = (cr.0, chunk_end) + } + if rst > s { + cr = (rst, cr.1) + } + cr.1 - cr.0 + }) + .sum(); + output[i] = total; + } + output + } + #[pyfunction(name = "chunked_into_array")] + pub(crate) fn chunked_into_array_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let output = chunked_into_array(start_arr.as_array(), size_arr.as_array()); + Ok(output.into_pyarray(py)) + } + fn chunked_into_array(start: ArrayView1<'_, i64>, size: ArrayView1<'_, i64>) -> Array1 { + let total_length = size.sum(); + let mut output = Array1::::zeros(total_length as usize); + let mut rs: i64 = 0; + for (&st, &si) in zip(start, size) { + let range = Array1::from_iter(st..st + si); + output + .slice_mut(s![rs as usize..(rs + si) as usize]) + .assign(&range); + rs += si; + } + output + } + + #[pyfunction(name = "take_chunked_from_simple")] + fn take_chunked_from_simple_py<'py>( + py: Python<'py>, + simple: &Bound<'_, PyAny>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let simple_arr = unpack_index_array(simple)?; + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let result = take_chunked_from_simple( + simple_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + Ok(result?.into_pyarray(py)) + } + + fn take_chunked_from_simple( + simple: ArrayView1<'_, i64>, + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> Result, PyErr> { + let total_length = size.sum(); + let mut output = Array1::::zeros(total_length as usize); + let mut rs: i64 = 0; + for (&st, &si) in zip(start, size) { + let end = st + si; + if end as usize > simple.len() { + return Err(PyValueError::new_err( + "The chunked index is outside of the range of the simple index!", + )); + } + let to_insert = simple.slice(s![st as usize..end as usize]); + output + .slice_mut(s![rs as usize..(rs + si) as usize]) + .assign(&to_insert); + rs += si + } + Ok(output) + } + + #[pyfunction(name = "take_chunked_from_chunked")] + fn take_chunked_from_chunked_py<'py>( + py: Python<'py>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + take_start: &Bound<'_, PyAny>, + take_size: &Bound<'_, PyAny>, + ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let (take_start_arr, take_size_arr) = unpack_chunked_index(take_start, take_size)?; + let result = take_chunked_from_chunked( + start_arr.as_array(), + size_arr.as_array(), + take_start_arr.as_array(), + take_size_arr.as_array(), + )?; + Ok((result.0.into_pyarray(py), result.1.into_pyarray(py))) + } + fn find_chunk(prefix: &[i64], x: i64) -> usize { + let mut lo = 0usize; + let mut hi = prefix.len() - 1; + while lo + 1 < hi { + let mid = (lo + hi) / 2; + if prefix[mid] <= x { + lo = mid; + } else { + hi = mid; + } + } + lo + } + + fn take_chunked_from_chunked( + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + take_start: ArrayView1<'_, i64>, + take_size: ArrayView1<'_, i64>, + ) -> Result<(Array1, Array1), PyErr> { + if take_start.len() == 0 { + return Ok((Array1::::zeros(0), Array1::::zeros(0))); + } + let mut output_start: Vec = Vec::new(); + let mut output_size: Vec = Vec::new(); + + let mut prefix = vec![0i64; size.len() + 1]; + for i in 0..size.len() { + prefix[i + 1] = prefix[i] + size[i]; + } + let total = prefix[size.len()]; + + for (&tstart, &tsize) in zip(take_start, take_size) { + if tstart + tsize > total { + return Err(PyValueError::new_err( + "You can't take more elements than exist in an index!", + )); + } + let mut chunk_index = find_chunk(&prefix, tstart); + let mut cs = prefix[chunk_index]; + let mut start_in_chunk = tstart - cs; + let mut chunk_taken = 0i64; + + loop { + let size_in_chunk = size[chunk_index] - start_in_chunk; + let remaining = tsize - chunk_taken; + let (take, chunk_completed) = if size_in_chunk >= remaining { + (remaining, true) + } else { + (size_in_chunk, false) + }; + + output_start.push(start[chunk_index] + start_in_chunk); + output_size.push(take); + chunk_taken += take; + + if chunk_completed { + break; + } + cs += size[chunk_index]; + chunk_index += 1; + start_in_chunk = 0; + } + } + + Ok(( + Array1::from_vec(output_start), + Array1::from_vec(output_size), + )) + } + #[pyfunction(name = "reindex_column")] + fn reindex_columns_py<'py>( + py: Python<'py>, + index: &Bound<'_, PyAny>, + index_column: &Bound<'_, PyAny>, + ) -> PyResult>> { + let index_arr = unpack_index_array(index)?; + let index_column_arr = unpack_index_array(index_column)?; + Ok(reindex_column(index_arr.as_array(), index_column_arr.as_array()).into_pyarray(py)) + } + + fn reindex_column( + index: ArrayView1<'_, i64>, + index_column: ArrayView1<'_, i64>, + ) -> Array1 { + let mut index_map: HashMap = HashMap::new(); + for (i, &index_entry) in index.iter().enumerate() { + index_map.insert(index_entry, i as i64); + } + + let mut output: Vec = Vec::with_capacity(index_column.len()); + for val in index_column.iter() { + let val_index_opt = index_map.get(val); + if let Some(&val_index) = val_index_opt { + output.push(val_index) + } else { + output.push(-1) + } + } + Array1::from_vec(output) + } + + #[pyfunction(name = "rebuild_chunked_by_ranges")] + fn rebuild_chunked_by_ranges_py<'py>( + py: Python<'py>, + starts: &Bound<'_, PyAny>, + sizes: &Bound<'_, PyAny>, + range_starts: &Bound<'_, PyAny>, + range_sizes: &Bound<'_, PyAny>, + ) -> PyResult> { + let (start_arr, size_arr) = unpack_chunked_index(starts, sizes)?; + let (range_starts_arr, range_sizes_arr) = unpack_chunked_index(range_starts, range_sizes)?; + let mut output = rebuild_chunked_by_ranges( + start_arr.as_array(), + size_arr.as_array(), + range_starts_arr.as_array(), + range_sizes_arr.as_array(), + ); + PyList::new( + py, + output + .drain(0..) + .map(|(st, si)| (st.into_pyarray(py), si.into_pyarray(py))), + ) + } + + fn rebuild_chunked_by_ranges( + starts: ArrayView1<'_, i64>, + sizes: ArrayView1<'_, i64>, + range_starts: ArrayView1<'_, i64>, + range_sizes: ArrayView1<'_, i64>, + ) -> Vec<(Array1, Array1)> { + let n_datasets = range_starts.len(); + let mut outputs: Vec<(Vec, Vec)> = + (0..n_datasets).map(|_| (Vec::new(), Vec::new())).collect(); + + if starts.len() == 0 || n_datasets == 0 { + return outputs + .into_iter() + .map(|(st, si)| (Array1::from_vec(st), Array1::from_vec(si))) + .collect(); + } + + let mut i = 0usize; + let mut j = 0usize; + + while i < starts.len() && j < n_datasets { + let chunk_start = starts[i]; + let chunk_end = chunk_start + sizes[i]; + let ds_start = range_starts[j]; + let ds_end = ds_start + range_sizes[j]; + + if chunk_end <= ds_start { + i += 1; + } else if chunk_start >= ds_end { + j += 1; + } else { + let overlap_start = chunk_start.max(ds_start); + let overlap_end = chunk_end.min(ds_end); + outputs[j].0.push(overlap_start - ds_start); + outputs[j].1.push(overlap_end - overlap_start); + + if chunk_end <= ds_end { + i += 1; + } else { + j += 1; + } + } + } + + outputs + .into_iter() + .map(|(st, si)| (Array1::from_vec(st), Array1::from_vec(si))) + .collect() + } + #[pyfunction(name = "rebuild_simple_by_ranges")] + fn rebuild_simple_by_ranges_py<'py>( + py: Python<'py>, + index: &Bound<'_, PyAny>, + range_starts: &Bound<'_, PyAny>, + range_sizes: &Bound<'_, PyAny>, + ) -> PyResult> { + let index_arr = unpack_index_array(index)?; + let (start_arr, size_arr) = unpack_chunked_index(range_starts, range_sizes)?; + let mut output = rebuild_simple_by_ranges( + index_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + PyList::new(py, output.drain(0..).map(|a| a.into_pyarray(py))) + } + + #[pyfunction(name = "project_chunked_on_simple")] + fn project_chunked_on_simple_py<'py>( + py: Python<'py>, + simple: &Bound<'_, PyAny>, + start: &Bound<'_, PyAny>, + size: &Bound<'_, PyAny>, + ) -> PyResult>> { + let simple_arr = unpack_index_array(simple)?; + let (start_arr, size_arr) = unpack_chunked_index(start, size)?; + let result = project_chunked_on_simple( + simple_arr.as_array(), + start_arr.as_array(), + size_arr.as_array(), + ); + Ok(result.into_pyarray(py)) + } + + fn project_chunked_on_simple( + simple: ArrayView1<'_, i64>, + start: ArrayView1<'_, i64>, + size: ArrayView1<'_, i64>, + ) -> Array1 { + let mut output: Vec = Vec::new(); + let n_chunks = start.len(); + if simple.is_empty() || n_chunks == 0 { + return Array1::from_vec(output); + } + + let mut chunk_idx = 0usize; + for (i, &val) in simple.iter().enumerate() { + // Advance past chunks whose end is at or before val. + while chunk_idx < n_chunks && val >= start[chunk_idx] + size[chunk_idx] { + chunk_idx += 1; + } + if chunk_idx >= n_chunks { + break; + } + // val is before the current chunk; skip. + if val < start[chunk_idx] { + continue; + } + output.push(i as i64); + } + Array1::from_vec(output) + } + + fn rebuild_simple_by_ranges( + index: ArrayView1<'_, i64>, + range_starts: ArrayView1<'_, i64>, + range_sizes: ArrayView1<'_, i64>, + ) -> Vec> { + let n_ranges = range_starts.len(); + let mut outputs: Vec> = (0..n_ranges).map(|_| Vec::new()).collect(); + + if n_ranges == 0 || index.len() == 0 { + return outputs.into_iter().map(Array1::from_vec).collect(); + } + + let mut j = 0usize; + for &idx in index { + // Advance past all ranges whose end is at or before idx. + // Using a while loop handles the case where idx skips multiple ranges + // and prevents an OOB panic when advancing past the last range. + while j < n_ranges && idx >= range_starts[j] + range_sizes[j] { + j += 1; + } + if j >= n_ranges { + break; + } + // idx falls before the start of range j (gap in a non-contiguous layout). + if idx < range_starts[j] { + continue; + } + outputs[j].push(idx - range_starts[j]); + } + + outputs.into_iter().map(Array1::from_vec).collect() + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..6cebfacf --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,21 @@ +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +/// +/// +use pyo3::prelude::*; +mod index; + +#[pymodule] +mod _lib { + use pyo3::prelude::*; + + #[pymodule_export] + use crate::index::index; + + /// Formats the sum of two numbers as string. + #[pyfunction] + fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) + } +} diff --git a/src/opencosmo/collection/lightcone/coordinates.py b/src/opencosmo/collection/lightcone/coordinates.py deleted file mode 100644 index 47e21b67..00000000 --- a/src/opencosmo/collection/lightcone/coordinates.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING - -import astropy.units as u -import numpy as np - -if TYPE_CHECKING: - import opencosmo as oc - - -def make_radec_columns(dataset: oc.Lightcone): - if "ra" in dataset.columns and "dec" in dataset.columns: - return dataset - elif "theta" in dataset.columns and "phi" in dataset.columns: - return dataset.evaluate( - radec_from_thetaphi, vectorize=True, insert=True, format="numpy" - ) - else: - warnings.warn( - "Could not find coordinates in this catalog. Spatial queries will not be available" - ) - - -def radec_from_thetaphi(theta, phi): - theta_deg = theta * 180 / np.pi - phi_deg = phi * 180 / np.pi - return {"ra": phi_deg * u.deg, "dec": (90.0 - theta_deg) * u.deg} diff --git a/src/opencosmo/collection/structure/io.py b/src/opencosmo/collection/structure/io.py deleted file mode 100644 index b2ff41b0..00000000 --- a/src/opencosmo/collection/structure/io.py +++ /dev/null @@ -1,223 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from functools import reduce -from typing import TYPE_CHECKING, Optional - -import numpy as np - -from opencosmo import io -from opencosmo.collection import lightcone as lc -from opencosmo.collection.structure import structure as sc - -if TYPE_CHECKING: - import h5py - - from opencosmo import dataset as d - from opencosmo.io.iopen import FileTarget - -ALLOWED_LINKS = { # h5py.Files that can serve as a link holder and - "halo_properties": ["halo_particles", "halo_profiles", "galaxy_properties"], - "galaxy_properties": ["galaxy_particles"], -} - - -def remove_empty(dataset): - metadata = dataset.get_metadata() - mask = np.ones(len(dataset), dtype=bool) - for name, col in metadata.items(): - if "size" in name: - mask &= col != 0 - elif "idx" in name: - mask &= col != -1 - - if not mask.all(): - dataset = dataset.take_rows(np.where(mask)[0]) - return dataset - - -def validate_linked_groups(groups: dict[str, h5py.Group]): - if "halo_properties" in groups: - if "data_linked" not in groups["halo_properties"].keys(): - raise ValueError( - "File appears to be a structure collection, but does not have links!" - ) - elif "galaxy_properties" in groups: - if "data_linked" not in groups["galaxy_properties"].keys(): - raise ValueError( - "File appears to be a structure collection, but does not have links!" - ) - if len(groups) == 1: - raise ValueError("Structure collections must have more than one dataset") - - -def build_structure_collection(targets: list[FileTarget], ignore_empty: bool): - link_sources: dict[str, list[io.iopen.DatasetTarget]] = defaultdict(list) - link_targets: dict[str, dict[str, list[d.Dataset | sc.StructureCollection]]] = ( - defaultdict(lambda: defaultdict(list)) - ) - dataset_targets: list[io.iopen.DatasetTarget] = reduce( - lambda acc, t: acc + t["dataset_targets"], targets, [] - ) - for target in dataset_targets: - if target["header"].file.data_type == "halo_properties": - link_sources["halo_properties"].append(target) - elif target["header"].file.data_type == "galaxy_properties": - link_sources["galaxy_properties"].append(target) - elif str(target["header"].file.data_type).startswith("halo"): - dataset = io.iopen.open_single_dataset( - target, bypass_lightcone=True, bypass_mpi=True - ) - name = target["dataset_group"].name.split("/")[-1] - if not name: - name = target["header"].file.data_type - elif name.startswith("halo_properties"): - name = name[16:] - link_targets["halo_targets"][name].append(dataset) - elif str(target["header"].file.data_type).startswith("galaxy"): - dataset = io.iopen.open_single_dataset( - target, bypass_lightcone=True, bypass_mpi=True - ) - name = target["dataset_group"].name.split("/")[-1] - if not name: - name = target["header"].file.data_type - elif name.startswith("galaxy_properties"): - name = name[18:] - link_targets["galaxy_targets"][name].append(dataset) - else: - raise ValueError( - f"Unknown data type for structure collection {target['header'].data_type}" - ) - - if ( - len(link_sources["halo_properties"]) > 1 - or len(link_sources["galaxy_properties"]) > 1 - ): - raise NotImplementedError( - "Opening structure collections that span multiple redshifts is not currently supported" - ) - # Potentially a lightcone structure collection - collections = {} - sources_by_step, targets_by_step = __sort_by_step(link_sources, link_targets) - if set(sources_by_step.keys()) != set(targets_by_step.keys()): - raise ValueError("Datasets are not the same across all lightcone steps!") - for step, sources in sources_by_step.items(): - halo_properties = sources.get("halo_properties") - galaxy_properties = sources.get("galaxy_properties") - targets = targets_by_step[step] - collection = __build_structure_collection( - halo_properties, galaxy_properties, targets, ignore_empty - ) - collections[step] = collection - - expected_datasets = set(next(iter(collections.values())).keys()) - for collection in collections.values(): - if set(collection.keys()) != expected_datasets: - raise ValueError( - "All structure collections in a lightcone must have the same set of datasets" - ) - return lc.Lightcone(collections) - - halo_properties_target = None - galaxy_properties_target = None - if link_sources["halo_properties"]: - halo_properties_target = link_sources["halo_properties"][0] - if link_sources["galaxy_properties"]: - galaxy_properties_target = link_sources["galaxy_properties"][0] - - input_link_targets: dict[str, dict[str, d.Dataset | sc.StructureCollection]] = ( - defaultdict(dict) - ) - for source_type, source_targets in link_targets.items(): - if any(len(ts) > 1 for ts in source_targets.values()): - raise ValueError("Found more than one linked file of a given type!") - input_link_targets[source_type] = { - key: t[0] for key, t in source_targets.items() - } - - return __build_structure_collection( - halo_properties_target, - galaxy_properties_target, - input_link_targets, - ignore_empty, - ) - - -def __sort_by_step(link_sources: dict[str, list[io.iopen.DatasetTarget]], link_targets): - sources_by_step: dict[int, dict[str, io.iopen.DatasetTarget]] = defaultdict(dict) - targets_by_step: dict[int, dict[str, dict[str, d.Dataset]]] = defaultdict( - lambda: defaultdict(dict) - ) - for source_name, sources in link_sources.items(): - for source in sources: - if not source["header"].file.is_lightcone: - raise ValueError( - "Recived multiple source datasets of a single type, but not all are lightcone datasets!" - ) - if source["header"].file.step is None: - raise ValueError("No step in source!") - - sources_by_step[source["header"].file.step][source_name] = source - for target_type, targets_ in link_targets.items(): - for target_name, targets in targets_.items(): - for target in targets: - if not target.header.file.is_lightcone: - raise ValueError( - "Recived multiple datasets of a single type, but not all are lightcone datasets!" - ) - targets_by_step[target.header.file.step][target_type][target_name] = ( - target - ) - - return sources_by_step, targets_by_step - - -def __build_structure_collection( - halo_properties_target: Optional[io.iopen.DatasetTarget], - galaxy_properties_target: Optional[io.iopen.DatasetTarget], - link_targets: dict[str, dict[str, d.Dataset | sc.StructureCollection]], - ignore_empty: bool, -): - if galaxy_properties_target is not None and "galaxy_targets" in link_targets: - # Galaxy properties and galaxy particles - source_dataset = io.iopen.open_single_dataset( - galaxy_properties_target, - metadata_group="data_linked", - bypass_lightcone=True, - bypass_mpi=halo_properties_target is not None, - ) - if ignore_empty and halo_properties_target is None: - source_dataset = remove_empty(source_dataset) - collection = sc.StructureCollection( - source_dataset, - source_dataset.header, - link_targets["galaxy_targets"], - ) - if halo_properties_target is not None: - link_targets["halo_targets"]["galaxy_properties"] = collection - else: - return collection - - if ( - halo_properties_target is not None - and galaxy_properties_target is not None - and "galaxy_targets" not in link_targets - ): - # Halo properties and galaxy properties, but no galaxy particles - galaxy_properties = io.iopen.open_single_dataset( - galaxy_properties_target, bypass_lightcone=True, bypass_mpi=True - ) - link_targets["halo_targets"]["galaxy_properties"] = galaxy_properties - - if halo_properties_target is not None and link_targets["halo_targets"]: - source_dataset = io.iopen.open_single_dataset( - halo_properties_target, metadata_group="data_linked", bypass_lightcone=True - ) - if ignore_empty: - source_dataset = remove_empty(source_dataset) - - return sc.StructureCollection( - source_dataset, - source_dataset.header, - link_targets["halo_targets"], - ) diff --git a/src/opencosmo/column/cache.py b/src/opencosmo/column/cache.py deleted file mode 100644 index f388a489..00000000 --- a/src/opencosmo/column/cache.py +++ /dev/null @@ -1,326 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Iterable, Optional -from weakref import finalize, ref - -import astropy.units as u -import numpy as np - -from opencosmo.index import DataIndex -from opencosmo.index.get import get_data -from opencosmo.index.take import take -from opencosmo.index.unary import get_length, get_range -from opencosmo.io.schema import FileEntry, make_schema -from opencosmo.io.writer import ColumnWriter - -if TYPE_CHECKING: - from opencosmo.index import DataIndex - - -ColumnUpdater = Callable[[np.ndarray | u.Quantity], np.ndarray | u.Quantity] - - -def finish( - cached_data: dict[str, np.ndarray], - index: Optional[DataIndex], - cache_ref: ref[ColumnCache], -): - cache = cache_ref() - if cache is None: - return - - columns_to_add = ( - cache.registered_columns.intersection(cached_data.keys()) - cache.columns - ) - data = {name: cached_data[name] for name in columns_to_add} - if index is not None: - data = {name: get_data(cd, index) for name, cd in data.items()} - if data: - cache.add_data(data) - - -def check_length(cache: ColumnCache, data: dict[str, np.ndarray]): - lengths = set(len(d) for d in data.values()) - if len(lengths) > 1: - raise ValueError( - "When adding data to the cache, all columns must be the same length" - ) - elif (length := len(cache)) > 0 and length != lengths.pop(): - raise ValueError( - "When adding data to the cache, the columns must be the same length as the columns currently in the cache" - ) - - -class ColumnCache: - """ - A column cache is used to persist data that is read from an hdf5 file. Caches can get data in one of two ways: - 1. They are explicitly given data that has been recently read from disk or - 2. They take data from a previous cache - - ColumnCaches break some of the rules that most other things follow in this library, notably that they have internal - state (which can change). This mutability is required for two reasons. - - 1. If the parent cache is garbage collected, the child cache needs to be able to copy over any data it needs - 2. If a new cache is created by adding columns, we need to signal the child to update their parent to the new - cache. This allows us to preserve the standard "operations create new objects" pattern that is present - throughout the library. - - """ - - def __init__( - self, - cached_data: dict[str, np.ndarray], - registered_column_groups: dict[int, set[str]], - column_descriptions: dict[str, str], - metadata_columns: set[str], - derived_index: Optional[DataIndex], - parent: Optional[ref[ColumnCache]], - children: Optional[list[ref[ColumnCache]]], - ): - self.__cached_data = cached_data - self.__registered_column_groups = registered_column_groups - self.__metadata_columns = metadata_columns - self.__descriptions = column_descriptions - self.__derived_index = derived_index - self.__parent = parent - if children is None: - children = [] - self.__children = children - self.__finalizer = None - - if parent is not None and (p := parent()) is not None: - self.__finalizer = finalize( - p, - finish, - p.__cached_data, - derived_index, - ref(self), - ) - self.__finalizer.atexit = False # type: ignore - - @classmethod - def empty(cls): - return ColumnCache({}, {}, {}, set(), None, None, []) - - @property - def columns(self): - return set(self.__cached_data.keys()) - - @property - def metadata_columns(self): - return self.__metadata_columns - - @property - def descriptions(self): - return self.__descriptions - - @property - def registered_columns(self): - return set().union(*list(self.__registered_column_groups.values())) - - def create_child(self): - return ColumnCache({}, {}, {}, self.__metadata_columns, None, ref(self), []) - - def make_schema(self, columns: Iterable[str]): - data = {} - metadata = {} - columns = set(columns) - cached_data = self.get_data(columns) - if not cached_data: - return ( - make_schema("data", FileEntry.EMPTY), - make_schema("metadata", FileEntry.EMPTY), - ) - - for name, coldata in cached_data.items(): - if isinstance(coldata, u.Quantity): - column_data = coldata.value - unit_str = str(coldata.unit) - else: - column_data = coldata - unit_str = "" - attrs = {"unit": unit_str} - attrs["description"] = self.descriptions.get(name, "None") - writer = ColumnWriter.from_numpy_array(column_data, attrs=attrs) - if name in self.metadata_columns: - metadata[name] = writer - else: - data[name] = writer - - data_schema = make_schema("data", FileEntry.COLUMNS, columns=data) - if metadata: - metadata_schema = make_schema( - "metadata", FileEntry.COLUMNS, columns=metadata - ) - else: - metadata_schema = make_schema("metadata", FileEntry.EMPTY) - - return data_schema, metadata_schema - - def __push_down(self, data: dict[str, np.ndarray]): - columns_to_keep = self.registered_columns.intersection(data.keys()).difference( - self.__cached_data.keys() - ) - cached_data = {colname: data[colname] for colname in columns_to_keep} - if self.__derived_index is not None: - cached_data = { - colname: get_data(coldata, self.__derived_index) - for colname, coldata in cached_data.items() - } - - self.__cached_data |= cached_data - - def __push_up(self, data: dict[str, np.ndarray]): - assert len(self) == 0 or all(len(d) == len(self) for d in data.values()) - columns_to_keep = self.registered_columns.intersection(data.keys()).difference( - self.__cached_data.keys() - ) - self.__cached_data |= {key: data[key] for key in columns_to_keep} - - def register_column_group(self, state_id: int, columns: set[str]): - assert state_id not in self.__registered_column_groups - self.__registered_column_groups[state_id] = columns - - def deregister_column_group(self, state_id: int): - assert state_id in self.__registered_column_groups - columns = self.__registered_column_groups.pop(state_id) - remaining_columns = set().union(*list(self.__registered_column_groups.values())) - - to_drop = columns.difference(remaining_columns) - cached_data = { - name: self.__cached_data.pop(name) - for name in to_drop - if name in self.__cached_data - } - if not cached_data: - return - for child_ in self.__children: - if (child := child_()) is None: - continue - child.__push_down(cached_data) - - def __update_parent(self, parent: ColumnCache): - assert self.__parent is not None - assert self.__finalizer is not None - self.__finalizer.detach() - self.__parent = ref(parent) - self.__finalizer = finalize( - parent, finish, parent.__cached_data, self.__derived_index, ref(self) - ) - self.__finalizer.atexit = False # type: ignore - - def __len__(self): - if not self.__cached_data and self.__derived_index is None: - return 0 - elif self.__derived_index is not None: - return get_length(self.__derived_index) - elif self.__cached_data: - return len(next(iter(self.__cached_data.values()))) - elif self.__parent is not None and (p := self.__parent()) is not None: - return len(p) - return 0 - - def add_data( - self, - data: dict[str, np.ndarray], - descriptions: dict[str, str] = {}, - metadata_columns: Optional[set[str]] = None, - push_up=True, - ): - """ - The in-place equivalent of with_data. Should not be used outside the context of this - file. - - """ - check_length(self, data) - if metadata_columns is not None: - assert metadata_columns.issubset(data.keys()) - self.__metadata_columns = self.__metadata_columns.union(metadata_columns) - - self.__descriptions |= descriptions - if ( - push_up - and self.__derived_index is None - and self.__parent is not None - and (p := self.__parent()) is not None - ): - p.__push_up(data) - - self.__cached_data = self.__cached_data | data - - def drop(self, columns: Iterable[str]): - columns = set(columns) - columns_to_drop = set(self.__cached_data.keys()).intersection(columns) - data = { - name: data - for name, data in self.__cached_data.items() - if name not in columns_to_drop - } - descriptions = { - name: desc - for name, desc in self.__descriptions.items() - if name not in columns_to_drop - } - new_meta_columns = self.__metadata_columns.difference(columns) - - return ColumnCache(data, {}, descriptions, new_meta_columns, None, None, []) - - def request(self, column_names: Iterable[str], index: Optional[DataIndex]): - column_names = set(column_names) - columns_in_cache = column_names.intersection(self.__cached_data.keys()) - - data = {name: self.__cached_data[name] for name in columns_in_cache} - if index is not None: - data = {name: get_data(cd, index) for name, cd in data.items()} - - if self.__parent is None or column_names == columns_in_cache: - return data - - parent = self.__parent() - if parent is None: - return data - - match (index, self.__derived_index): - case (None, None): - new_index = None - case (_, None): - new_index = index - case (None, _): - new_index = self.__derived_index - case _: - assert self.__derived_index is not None and index is not None - new_index = take(self.__derived_index, index) - - return data | parent.request(column_names, new_index) - - def take(self, index: DataIndex): - if len(self) == 0 and not self.columns: - return ColumnCache.empty() - if get_range(index)[1] > len(self): - raise ValueError( - "Tried to take more elements than the length of the cache!" - ) - new_cache = ColumnCache( - {}, {}, {}, self.__metadata_columns, index, ref(self), [] - ) - self.__children.append(ref(new_cache)) - return new_cache - - def get_data(self, columns: Iterable[str]): - columns = set(columns) - columns_in_cache = columns.intersection(self.__cached_data.keys()) - missing_columns = columns - columns_in_cache - output = {c: self.__cached_data[c] for c in columns_in_cache} - output |= self.__get_derived_columns(missing_columns) - return output - - def __get_derived_columns(self, column_names: set[str]): - if self.__parent is None: - return {} - parent = self.__parent() - if parent is None: - return {} - result = parent.request(column_names, self.__derived_index) - - self.__cached_data = self.__cached_data | result - return result diff --git a/src/opencosmo/column/evaluate.py b/src/opencosmo/column/evaluate.py deleted file mode 100644 index 7a04a119..00000000 --- a/src/opencosmo/column/evaluate.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable - -import astropy.units as u -import numpy as np - -from opencosmo.evaluate import insert_data - -if TYPE_CHECKING: - from opencosmo import Dataset - - -class EvaluateStrategy(Enum): - VECTORIZE = "vectorize" - ROW_WISE = "row_wise" - CHUNKED = "chunked" - - -def evaluate_rows(data: dict[str, np.ndarray], func: Callable, kwargs: dict[str, Any]): - data_length = len(next(iter(data.values()))) - storage = {} - for i in range(data_length): - iterable_inputs = {name: values[i] for name, values in data.items()} - output = func(**iterable_inputs, **kwargs) - if not isinstance(output, dict): - output = {func.__name__: output} - if i == 0: - storage = __make_row_based_output_from_first_values(output, data_length) - continue - insert_data(storage, i, output) - return storage - - -def __make_row_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - try: - shape = (data_length,) + value.shape - except AttributeError: - shape = (data_length,) - try: - dtype = value.dtype - except AttributeError: - dtype = type(value) - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0] = value - storage[name] = column_storage - - return storage - - -def evaluate_chunks( - data: dict[str, np.ndarray], - func: Callable, - kwargs: dict[str, Any], - chunk_sizes: np.ndarray, -): - data_length = len(next(iter(data.values()))) - - chunk_splits = np.cumsum(chunk_sizes) - storage = {} - input_data = {name: np.split(arr, chunk_splits) for name, arr in data.items()} - for i in range(len(chunk_splits)): - chunk_input_data = {name: split[i] for name, split in input_data.items()} - output = func(**chunk_input_data, **kwargs) - if not isinstance(output, dict): - output = {func.__name__: output} - if i == 0: - storage = __make_chunked_based_output_from_first_values(output, data_length) - continue - for name, values in output.items(): - storage[name][chunk_splits[i - 1] : chunk_splits[i]] = values - return storage - - -def __make_chunked_based_output_from_first_values(values, data_length): - storage = {} - for name, value in values.items(): - shape = (data_length,) + value.shape[1:] - dtype = value.dtype - column_storage = np.zeros(shape, dtype=dtype) - if isinstance(value, u.Quantity): - column_storage *= value.unit - column_storage[0 : len(value)] = value - storage[name] = column_storage - - return storage - - -def evaluate_vectorized(data, func, kwargs): - return func(**data, **kwargs) - - -def do_first_evaluation( - func: Callable, - strategy: str, - format: str, - kwargs: dict[str, Any], - dataset: Dataset, -): - eval_strategy = EvaluateStrategy(strategy) - match eval_strategy: - case EvaluateStrategy.VECTORIZE: - values = dataset.take(1).get_data(format, unpack=False) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} - - return func(**values, **kwargs), eval_strategy - - case EvaluateStrategy.ROW_WISE: - values = dataset.take(1).get_data(format, unpack=True) - try: - values = dict(values) - except TypeError: - values = {dataset.columns[0]: values} - return func(**values, **kwargs), eval_strategy - - case EvaluateStrategy.CHUNKED: - index = dataset.index - assert isinstance(index, tuple) - first_chunk_size = index[1][0] - first_chunk = dataset.take(first_chunk_size, at="start").get_data(format) - first_chunk = dict(first_chunk) - return func(**first_chunk, **kwargs), eval_strategy diff --git a/src/opencosmo/dataset/derived.py b/src/opencosmo/dataset/derived.py deleted file mode 100644 index bedafb42..00000000 --- a/src/opencosmo/dataset/derived.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -from functools import reduce -from itertools import product -from typing import TYPE_CHECKING, Mapping, Optional - -import rustworkx as rx - -if TYPE_CHECKING: - import astropy.units as u - import numpy as np - - from opencosmo.column.column import ConstructedColumn - from opencosmo.handler.protocols import DataCache, DataHandler - from opencosmo.index import DataIndex - from opencosmo.units.handler import UnitHandler - - -def build_dependency_graph( - derived_columns: Mapping[str, ConstructedColumn], - names_to_keep: Optional[set[str]] = None, -): - dependency_graph = rx.PyDiGraph() - all_requires: set[str] = reduce( - lambda known, dc: known.union(dc.requires), derived_columns.values(), set() - ) - nodeidx = dependency_graph.add_nodes_from(all_requires) - nodemap = {name: idx for (name, idx) in zip(all_requires, nodeidx)} - - for target, derived_column in derived_columns.items(): - requires = derived_column.requires - produces = derived_column.produces - if produces is None: - produces = set((target,)) - to_add = list(filter(lambda p: p not in nodemap, produces)) - new_map = dependency_graph.add_nodes_from(to_add) - nodemap.update({name: idx for (name, idx) in zip(to_add, new_map)}) - - requires_idx = tuple(nodemap[r] for r in requires) - produces_idx = tuple(nodemap[r] for r in produces) - - dependency_graph.add_edges_from_no_data(product(requires_idx, produces_idx)) - - if names_to_keep is not None: - nodes_to_keep = reduce( - lambda acc, name: acc.union(rx.ancestors(dependency_graph, nodemap[name])), - names_to_keep, - {nodemap[name] for name in names_to_keep}, - ) - names_to_keep = {dependency_graph[n] for n in nodes_to_keep} - dependency_graph = dependency_graph.subgraph(list(nodes_to_keep)) - derived_columns = { - name: dc for name, dc in derived_columns.items() if name in names_to_keep - } - - return dependency_graph - - -def replace_multi_producers( - graph: rx.PyDiGraph, - derived_columns: Mapping[str, ConstructedColumn], - columns_to_get: Optional[set[str]] = None, -): - """ - Some derived columns actually produce multiple outputs. At this stage, the dependency - graph is working solely with actual column names, meaning if any of those columns is - produced by one of these "multi-produces" they will not be in the derived_columns - dictionary and therefore cannot be instantiated. This function replaces such - columns with the name of the derived_column that produces them. - """ - - node_map = {name: i for i, name in enumerate(graph.nodes())} - missing = set(derived_columns.keys()).difference(node_map.keys()) - if not missing: - return graph - for missing_column in missing: - missing_column_produces = derived_columns[missing_column].produces - if missing_column_produces is None or not missing_column_produces.intersection( - columns_to_get or missing_column_produces - ): - continue - outputs = [ - node_map[name] for name in missing_column_produces if name in node_map - ] - graph.contract_nodes(outputs, missing_column) - return graph - - -def validate_derived_columns( - derived_columns: dict[str, ConstructedColumn], - known_raw_columns: set[str], - units: dict[str, u.Unit], -): - """ - Validate the network of derived columns. This - """ - dependency_graph = build_dependency_graph(derived_columns) - if cycle := rx.digraph_find_cycle(dependency_graph): - all_nodes: set[int] = reduce( - lambda known, edge: known.union(edge), cycle, set() - ) - names = [dependency_graph[i] for i in all_nodes] - raise ValueError( - f"Found derived columns that depend on each other! Columns: {names}" - ) - - sources = set( - filter( - lambda i: not dependency_graph.in_degree(i), - range(dependency_graph.num_nodes()), - ) - ) - source_names = map(lambda i: dependency_graph[i], sources) - if missing := set(source_names).difference(known_raw_columns): - raise ValueError(f"Tried to derive columns from unknown columns: {missing}") - - dependency_graph = replace_multi_producers(dependency_graph, derived_columns) - validate_dependency_graph( - dependency_graph, known_raw_columns, set(derived_columns.keys()) - ) - - return validate_derived_units(dependency_graph, derived_columns, units) - - -def validate_dependency_graph( - dependency_graph: rx.PyDiGraph, - known_raw_columns: set[str], - derived_columns: set[str], -): - expected = set(dependency_graph.nodes()) - known = known_raw_columns.union(derived_columns) - missing = expected.difference(known) - assert len(missing) == 0 - - -def validate_derived_units( - dependency_graph: rx.PyDiGraph, - derived_columns: dict[str, ConstructedColumn], - units: dict[str, u.Unit], -): - output_units: dict[str, Optional[u.Unit]] = {} - for node in rx.topological_sort(dependency_graph): - node_name = dependency_graph[node] - if node_name in units: - continue - new_units = derived_columns[node_name].get_units(units) - if not isinstance(new_units, dict): - new_units = {node_name: new_units} - units |= new_units - output_units |= new_units - return output_units - - -def build_derived_columns( - all_derived_columns: dict[str, ConstructedColumn], - derived_columns_to_get: set[str], - cache: DataCache, - data_handler: DataHandler, - unit_handler: UnitHandler, - unit_kwargs: dict, - index: DataIndex, -) -> dict[str, np.ndarray]: - """ - Build any derived columns that are present in this dataset. Also returns any columns that - had to be instantiated in order to build these derived columns. - """ - if not derived_columns_to_get: - return {} - - column_names: set[str] = reduce( - lambda known, dc: known.union(dc[1].produces) - if dc[1].produces is not None - else known.union((dc[0],)), - all_derived_columns.items(), - set(), - ) - - dependency_graph = build_dependency_graph( - all_derived_columns, derived_columns_to_get - ) - cached_data = cache.get_data(dependency_graph.nodes()) - cached_data |= unit_handler.apply_unit_conversions(cached_data, unit_kwargs) - - additional_derived = column_names.difference(cached_data.keys()) - - if not additional_derived: - return cached_data - - columns_to_fetch = ( - set(dependency_graph.nodes()) - .intersection(data_handler.columns) - .difference(cached_data.keys()) - ) - - raw_data = data_handler.get_data(columns_to_fetch) - data = cached_data | unit_handler.apply_units(raw_data, unit_kwargs) - - dependency_graph = replace_multi_producers( - dependency_graph, all_derived_columns, derived_columns_to_get - ) - new_derived: dict[str, np.ndarray] = {} - - for colidx in rx.topological_sort(dependency_graph): - colname = dependency_graph[colidx] - if colname in data: - continue - derived_column = all_derived_columns[colname] - produces = derived_column.produces - if produces is None: - produces = set((colname,)) - if all(name in data for name in produces): - continue - output = derived_column.evaluate( - data, index[1] if isinstance(index, tuple) else None - ) - if isinstance(output, dict): - data |= output - new_derived |= output - else: - data[colname] = output - new_derived[colname] = output - - if new_derived: - cache.add_data(new_derived, {}) - - return data | new_derived diff --git a/src/opencosmo/dataset/formats.py b/src/opencosmo/dataset/formats.py deleted file mode 100644 index 321916d2..00000000 --- a/src/opencosmo/dataset/formats.py +++ /dev/null @@ -1,113 +0,0 @@ -from __future__ import annotations - -from importlib import import_module - -import astropy.units as u -import numpy as np -from astropy.table import Column, QTable - - -def verify_format(output_format: str): - match output_format: - case "astropy": - return - case "numpy": # these two are core dependencies - return - case "pandas": - import_name = "pandas" - case "arrow": - import_name = "pyarrow" - case "polars": - import_name = "polars" - case _: - raise ValueError(f"Unknown data output format {output_format}") - - __verify_import(import_name, output_format) - - -def __verify_import(import_name: str, format_name: str): - try: - import_module(import_name) - except ImportError as e: - raise ImportError( - f"Data was requested in {format_name} format but could not import {import_name} package. Got '{e}'" - ) - - -def convert_data(data: dict[str, np.ndarray], output_format: str): - match output_format: - case "astropy": - return __convert_to_astropy(data) - case "numpy": - return __convert_to_numpy(data) - case "pandas": - return __convert_to_pandas(data) - case "polars": - return __convert_to_polars(data) - case "arrow": - return __convert_to_arrow(data) - case _: - raise ValueError(f"Unknown data output format {output_format}") - - -def __convert_to_astropy(data: dict[str, np.ndarray]) -> QTable: - if len(data) == 1: - return next(iter(data.values())) - if any( - (isinstance(d, u.Quantity) and d.isscalar) or not isinstance(d, np.ndarray) - for d in data.values() - ): - return data - - return QTable(data, copy=False) - - -def __convert_to_numpy( - data: dict[str, np.ndarray], -) -> dict[str, np.ndarray] | np.ndarray: - converted_data = dict( - map( - lambda kv: ( - kv[0], - kv[1].value if isinstance(kv[1], (u.Quantity, Column)) else kv[1], - ), - data.items(), - ) - ) - if len(converted_data) == 1: - return next(iter(converted_data.values())) - return converted_data - - -def __convert_to_pandas(data: dict[str, np.ndarray]): - import pandas as pd - - numpy_data = __convert_to_numpy(data) - if isinstance(numpy_data, np.ndarray): # only one column - return pd.Series(numpy_data, name=next(iter(data.keys()))) - - return pd.DataFrame(numpy_data, copy=True) - - -def __convert_to_arrow(data: dict[str, np.ndarray]): - import pyarrow as pa # type: ignore - - numpy_data = __convert_to_numpy(data) - if isinstance(numpy_data, np.ndarray): - return pa.array(numpy_data) - - converted_data = map( - lambda kv: (kv[0], pa.array(kv[1])), - data.items(), - ) - return dict(converted_data) - - -def __convert_to_polars(data: dict[str, np.ndarray]): - import polars as pl - - numpy_data = __convert_to_numpy(data) - if isinstance(numpy_data, np.ndarray): - return pl.Series(name=next(iter(data.keys())), values=numpy_data) - - return pl.from_dict(data) # type: ignore diff --git a/src/opencosmo/dataset/im.py b/src/opencosmo/dataset/im.py deleted file mode 100644 index a9ee1f86..00000000 --- a/src/opencosmo/dataset/im.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import astropy.units as u -import numpy as np - -if TYPE_CHECKING: - from opencosmo.units.handler import UnitHandler - - -def resort(columns: dict[str, np.ndarray], sorted_index: Optional[np.ndarray]): - if sorted_index is None or not columns: - return columns - reverse_sort = np.argsort(sorted_index) - return {name: data[reverse_sort] for name, data in columns.items()} - - -def validate_in_memory_columns( - columns: dict[str, np.ndarray], unit_handler: UnitHandler, ds_length: int -): - new_units = {} - for colname, column in columns.items(): - if len(column) != ds_length: - raise ValueError(f"Column {colname} is not the same length as the dataset!") - if isinstance(column, u.Quantity): - new_units[colname] = column.unit - else: - new_units[colname] = None - - return unit_handler.with_static_columns(**new_units) diff --git a/src/opencosmo/dataset/state.py b/src/opencosmo/dataset/state.py deleted file mode 100644 index b47535d9..00000000 --- a/src/opencosmo/dataset/state.py +++ /dev/null @@ -1,687 +0,0 @@ -from __future__ import annotations - -from copy import copy -from functools import reduce -from typing import TYPE_CHECKING, Optional -from weakref import finalize - -import astropy.units as u -import numpy as np - -from opencosmo.column.cache import ColumnCache -from opencosmo.column.column import Column, DerivedColumn, EvaluatedColumn -from opencosmo.column.select import get_column_selection -from opencosmo.dataset.derived import ( - build_derived_columns, - validate_derived_columns, -) -from opencosmo.dataset.im import resort, validate_in_memory_columns -from opencosmo.handler.empty import EmptyHandler -from opencosmo.handler.hdf5 import Hdf5Handler -from opencosmo.index.build import single_chunk -from opencosmo.index.mask import into_array -from opencosmo.index.unary import get_range -from opencosmo.io.schema import FileEntry, combine_with_cached_schema, make_schema -from opencosmo.io.writer import ColumnCombineStrategy, ColumnWriter, NumpySource -from opencosmo.units import UnitConvention -from opencosmo.units.handler import ( - make_unit_handler_from_hdf5, - make_unit_handler_from_units, -) - -if TYPE_CHECKING: - from astropy import table - from astropy.cosmology import Cosmology - from numpy.typing import NDArray - - from opencosmo.column.column import ConstructedColumn - from opencosmo.handler.protocols import DataCache, DataHandler - from opencosmo.header import OpenCosmoHeader - from opencosmo.index import DataIndex - from opencosmo.io.iopen import DatasetTarget - from opencosmo.spatial.protocols import Region - from opencosmo.units.handler import UnitHandler - - -def deregister_state(id: int, cache: DataCache): - cache.deregister_column_group(id) - - -class DatasetState: - """ - Holds mutable state required by the dataset. Cleans up the dataset to mostly focus - on very high-level operations. Not a user facing class. - """ - - def __init__( - self, - raw_data_handler: DataHandler, - cache: DataCache, - derived_columns: dict[str, ConstructedColumn], - unit_handler: UnitHandler, - header: OpenCosmoHeader, - columns: set[str], - region: Region, - sort_by: Optional[tuple[str, bool]], - ): - self.__raw_data_handler = raw_data_handler - self.__cache = cache - self.__derived_columns = derived_columns - self.__unit_handler = unit_handler - self.__header = header - self.__columns = columns - self.__region = region - self.__sort_by = sort_by - self.__cache.register_column_group(id(self), self.__columns) - finalize(self, deregister_state, id(self), self.__cache) - - def __rebuild(self, **updates): - new = { - "raw_data_handler": self.__raw_data_handler, - "cache": self.__cache, - "derived_columns": self.__derived_columns, - "unit_handler": self.__unit_handler, - "header": self.__header, - "columns": self.__columns, - "region": self.__region, - "sort_by": self.__sort_by, - } | updates - return DatasetState(**new) - - def __exit__(self, *exec_details): - return None - - @classmethod - def from_target( - cls, - target: DatasetTarget, - unit_convention: UnitConvention, - region: Region, - index: Optional[DataIndex] = None, - metadata_group: Optional[str] = None, - in_memory: bool = False, - ): - data_group = target["dataset_group"] - if "load" in data_group.keys(): - load_conditions = dict(data_group["load/if"].attrs) - else: - load_conditions = None - - handler = Hdf5Handler.from_columns( - target["columns"], - index, - metadata_group, - load_conditions, - ) - unit_handler = make_unit_handler_from_hdf5( - target["columns"], target["header"], unit_convention - ) - - columns = set(handler.columns) - cache = ColumnCache.empty() - return DatasetState( - handler, - cache, - {}, - unit_handler, - target["header"], - columns, - region, - None, - ) - - @classmethod - def in_memory( - cls, - data_columns: dict, - metadata_columns: dict, - header: OpenCosmoHeader, - unit_convention: UnitConvention, - region: Region, - descriptions: Optional[dict[str, str]] = {}, - index: Optional[DataIndex] = None, - ): - cache = ColumnCache.empty() - cache.add_data( - data_columns | metadata_columns, descriptions, set(metadata_columns.keys()) - ) - units: dict[str, u.Unit] = {} - for name, column in data_columns.items(): - units[name] = None - if isinstance(column, u.Quantity): - units[name] = column.unit - - unit_handler = make_unit_handler_from_units(units, header, unit_convention) - - return DatasetState( - EmptyHandler(), - cache, - {}, - unit_handler, - header, - set(data_columns.keys()), - region, - None, - ) - - def __len__(self): - if isinstance(self.__raw_data_handler, EmptyHandler): - return len(self.__cache) - return len(self.__raw_data_handler) - - @property - def descriptions(self): - all_descriptions = ( - {name: col.description for name, col in self.__derived_columns.items()} - | self.__raw_data_handler.descriptions - | self.__cache.descriptions - ) - return { - name: description - for name, description in all_descriptions.items() - if name in self.columns - } - - @property - def raw_index(self): - if (si := self.get_sorted_index()) is not None: - ni = into_array(self.__raw_data_handler.index) - return ni[si] - - return self.__raw_data_handler.index - - @property - def unit_handler(self): - return self.__unit_handler - - @property - def units(self): - units = self.__unit_handler.current_units - return {name: units[name] for name in self.columns} - - @property - def convention(self): - return self.__unit_handler.current_convention - - @property - def region(self): - return self.__region - - @property - def header(self): - return self.__header - - @property - def columns(self) -> list[str]: - return list(self.__columns) - - @property - def meta_columns(self) -> list[str]: - columns = set(self.__cache.metadata_columns).union( - self.__raw_data_handler.metadata_columns - ) - return list(columns) - - def get_data( - self, - ignore_sort: bool = False, - metadata_columns: list = [], - unit_kwargs: dict = {}, - ) -> table.QTable: - """ - Get the data for a given handler. - """ - data = self.__build_derived_columns(unit_kwargs) - cached_data = self.__cache.get_data(self.columns) - converted_cached_data = self.__unit_handler.apply_unit_conversions( - cached_data, unit_kwargs - ) - - data |= cached_data - if converted_cached_data: - self.__cache.add_data(converted_cached_data, {}, push_up=False) - data |= converted_cached_data - - raw_columns = ( - set(self.columns) - .intersection(self.__raw_data_handler.columns) - .difference(data.keys()) - ) - if ( - self.__sort_by is not None - and self.__sort_by[0] in self.__raw_data_handler.columns - ): - raw_columns.add(self.__sort_by[0]) - - if raw_columns: - raw_data = self.__raw_data_handler.get_data(raw_columns) - raw_data = self.__unit_handler.apply_raw_units(raw_data, unit_kwargs) - if raw_data: - self.__cache.add_data(raw_data, {}) - - updated_data = self.__unit_handler.apply_unit_conversions( - raw_data, unit_kwargs - ) - if updated_data: - self.__cache.add_data(updated_data, {}, push_up=False) - - new_data = raw_data | updated_data - data |= new_data - - if missing := set(self.columns).difference(data.keys()): - raise RuntimeError( - f"Some columns are missing from the output! This is likely a bug. Please report it on GitHub. Missing: {missing}" - ) - - # keep ordering - - if metadata_columns: - metadata = self.__cache.get_data(metadata_columns) - additional_metadata_columns_to_fetch = set(metadata_columns).difference( - metadata.keys() - ) - metadata |= ( - self.__raw_data_handler.get_metadata( - additional_metadata_columns_to_fetch - ) - or {} - ) - - data.update(metadata) - - if not ignore_sort and self.__sort_by is not None: - sort_by = data[self.__sort_by[0]] - order = np.argsort(sort_by) - if self.__sort_by[1]: - order = order[::-1] - - data = {key: value[order] for key, value in data.items()} - - new_order = [c for c in self.columns] - if metadata_columns: - new_order.extend(metadata_columns) - - return {name: data[name] for name in new_order} - - def rows(self, metadata_columns: list = [], unit_kwargs: dict = {}): - derived_to_collect = ( - set(self.__derived_columns.keys()) - .intersection(self.columns) - .difference(self.__cache.columns) - ) - derived_storage: dict[str, list[np.ndarray]] = { - name: [] for name in derived_to_collect - } - total_length = len(self) - chunk_ranges = [ - (i, min(i + 1000, total_length)) for i in range(0, total_length, 1000) - ] - if not chunk_ranges: - raise StopIteration - - try: - for start, end in chunk_ranges: - chunk = self.take_range(start, end) - data = chunk.get_data( - metadata_columns=metadata_columns, unit_kwargs=unit_kwargs - ) - for name in derived_to_collect: - derived_storage[name].append(data[name]) - - for i in range(len(chunk)): - yield {name: column[i] for name, column in data.items()} - all_derived = { - name: np.concatenate(arr) for name, arr in derived_storage.items() - } - derived_storage = resort(all_derived, self.get_sorted_index()) - if derived_storage: - self.__cache.add_data(data, {}) - except GeneratorExit: - pass - except BaseException: - raise - - def get_metadata(self, columns=[]): - metadata = self.__raw_data_handler.get_metadata(columns) - sorted_index = self.get_sorted_index() - if sorted_index is not None: - metadata = {name: values[sorted_index] for name, values in metadata.items()} - return metadata - - def with_mask(self, mask: NDArray[np.bool_]): - index = np.where(mask)[0] - new_raw_handler = self.__raw_data_handler.take(index) - new_cache = self.__cache.take(index) - return self.__rebuild( - cache=new_cache, - raw_data_handler=new_raw_handler, - ) - - def make_schema(self, name: Optional[str] = None): - header = self.__header.with_region(self.__region) - raw_columns = self.__columns.intersection(self.__raw_data_handler.columns) - - data_schema, metadata_schema = self.__raw_data_handler.make_schema( - raw_columns, header - ) - derived_names = set(self.__derived_columns.keys()).intersection(self.columns) - derived_data = ( - self.select(derived_names) - .with_units(self.__unit_handler.base_convention, {}, {}, None, None) - .get_data(ignore_sort=True) - ) - cached_data_schema, cached_metadata_schema = self.__cache.make_schema( - self.columns + self.meta_columns - ) - - for colname in derived_names: - if colname in cached_data_schema.columns: - continue - coldata = derived_data[colname] - unit = "" - if isinstance(coldata, u.Quantity): - unit = str(coldata.unit) - coldata = derived_data[colname].value - - attrs = { - "unit": unit, - "description": str(self.__derived_columns[colname].description), - } - source = NumpySource(coldata) - writer = ColumnWriter([source], ColumnCombineStrategy.CONCAT, attrs=attrs) - data_schema.columns[colname] = writer - - attributes = {} - if (load_conditions := self.__raw_data_handler.load_conditions) is not None: - attributes["load/if"] = load_conditions - - data_schema = combine_with_cached_schema(data_schema, cached_data_schema) - - metadata_schema = combine_with_cached_schema( - metadata_schema, cached_metadata_schema - ) - children = {"data": data_schema} - - if metadata_schema.type != FileEntry.EMPTY: - children[metadata_schema.name] = metadata_schema - if name is None: - name = "" - - return make_schema( - name, FileEntry.DATASET, children=children, attributes=attributes - ) - - def with_new_columns( - self, - descriptions: dict[str, str] = {}, - **new_columns: ConstructedColumn | np.ndarray | u.Quantity, - ): - """ - Add a set of derived columns to the dataset. A derived column is a column that - has been created based on the values in another column. - """ - - existing_columns = set(self.columns) - - if inter := existing_columns.intersection(new_columns.keys()): - raise ValueError(f"Some columns are already in the dataset: {inter}") - - new_derived_columns = {} - new_in_memory_columns = {} - new_in_memory_descriptions = {} - - for colname, column in new_columns.items(): - match column: - case DerivedColumn() | EvaluatedColumn() | Column(): - column.description = descriptions.get(colname, "None") - new_derived_columns[colname] = column - case np.ndarray(): - if len(column) != len(self): - raise ValueError( - f"New column {colname} does not have the same length as this dataset!" - ) - new_in_memory_descriptions[colname] = descriptions.get( - colname, "None" - ) - new_in_memory_columns[colname] = column - case _: - raise ValueError( - f"Got an invalid new column of type {type(column)}" - ) - - new_unit_handler = self.__unit_handler - new_derived = copy(self.__derived_columns) - new_column_names: set[str] = set(self.columns) - if new_in_memory_columns: - new_unit_handler = validate_in_memory_columns( - new_in_memory_columns, self.__unit_handler, len(self) - ) - new_in_memory_columns = resort( - new_in_memory_columns, self.get_sorted_index() - ) - self.__cache.add_data( - new_in_memory_columns, descriptions=new_in_memory_descriptions - ) - new_column_names |= set(new_in_memory_columns.keys()) - - if new_derived_columns: - new_units = validate_derived_columns( - self.__derived_columns | new_derived_columns, - existing_columns.union(new_in_memory_columns.keys()).difference( - self.__derived_columns.keys() - ), - new_unit_handler.base_units, - ) - new_derived |= new_derived_columns - for colname, derived in new_derived.items(): - if (prod := derived.produces) is not None: - new_column_names |= prod - else: - new_column_names.add(colname) - - new_unit_handler = new_unit_handler.with_new_columns(**new_units) - - return self.__rebuild( - cache=self.__cache, - derived_columns=new_derived, - columns=new_column_names, - unit_handler=new_unit_handler, - ) - - def __build_derived_columns(self, unit_kwargs: dict) -> table.Table: - """ - Build any derived columns that are present in this dataset - """ - if not self.__derived_columns: - return {} - - all_derived_columns: set[str] = reduce( - lambda acc, dc: acc.union( - dc[1].produces if dc[1].produces is not None else {dc[0]} - ), - self.__derived_columns.items(), - set(), - ) - derived_names = all_derived_columns.intersection(self.columns) - if self.__sort_by is not None and self.__sort_by[0] in all_derived_columns: - derived_names.add(self.__sort_by[0]) - - dc = build_derived_columns( - self.__derived_columns, - derived_names, - self.__cache, - self.__raw_data_handler, - self.__unit_handler, - unit_kwargs, - self.__raw_data_handler.index, - ) - return dc - - def with_region(self, region: Region): - """ - Return the same dataset but with a different region - """ - return self.__rebuild(region=region) - - def select(self, columns: set[str], drop=False): - """ - Select a subset of columns from the dataset. It is possible for a user to select - a derived column in the dataset, but not the columns it is derived from. - This class tracks any columns which are required to materialize the dataset but - are not in the final selection in self.__hidden. When the dataset is - materialized, the columns in self.__hidden are removed before the data is - returned to the user. - - """ - - selections, missing = get_column_selection(self.columns, columns) - if missing: - raise ValueError( - f"Columns are included that are not in this dataset: {missing}" - ) - elif not selections and columns: - raise ValueError("No columns matched the provided wildcards!") - - if drop: - selections = set(self.columns) - selections - - return self.__rebuild(columns=selections) - - def sort_by(self, column_name: str, invert: bool): - if column_name not in self.columns: - raise ValueError(f"This dataset has no column {column_name}") - - return self.__rebuild(sort_by=(column_name, invert)) - - def get_sorted_index(self): - if self.__sort_by is not None: - column = self.select({self.__sort_by[0]}).get_data(ignore_sort=True)[ - self.__sort_by[0] - ] - sorted = np.argsort(column) - if self.__sort_by[1]: - sorted = sorted[::-1] - - else: - sorted = None - - return sorted - - def take(self, n: int, at: str): - """ - Take rows from the dataset. - """ - - take_index: DataIndex - - if at == "start": - return self.take_range(0, n) - elif at == "end": - return self.take_range(len(self) - n, len(self)) - elif at == "random": - row_indices = np.random.choice(len(self), n, replace=False) - row_indices.sort() - - sorted = self.get_sorted_index() - if sorted is None: - take_index = row_indices - else: - take_index = np.sort(sorted[row_indices]) - - new_handler = self.__raw_data_handler.take(take_index) - new_cache = self.__cache.take(take_index) - - return self.__rebuild( - raw_data_handler=new_handler, - cache=new_cache, - ) - - def take_range(self, start: int, end: int): - """ - Take a range of rows form the dataset. - """ - if start < 0 or end < 0: - raise ValueError("start and end must be positive.") - if end < start: - raise ValueError("end must be greater than start.") - if end > len(self): - raise ValueError("end must be less than the length of the dataset.") - - sorted = self.get_sorted_index() - take_index: DataIndex - if sorted is None: - take_index = single_chunk(start, end - start) - else: - take_index = np.sort(sorted[start:end]) - - new_raw_handler = self.__raw_data_handler.take(take_index) - new_im = self.__cache.take(take_index) - return self.__rebuild( - raw_data_handler=new_raw_handler, - cache=new_im, - ) - - def take_rows(self, rows: DataIndex): - if len(self) == 0: - return self - row_range = get_range(rows) - - if row_range[1] > len(self) or row_range[0] < 0: - raise ValueError( - "Row indices must be between 0 and the length of this dataset!" - ) - sorted = self.get_sorted_index() - new_handler = self.__raw_data_handler.take(rows, sorted) - new_cache = self.__cache.take(rows) - - return self.__rebuild( - raw_data_handler=new_handler, - cache=new_cache, - ) - - def with_units( - self, - convention: Optional[str], - conversions: dict[u.Unit, u.Unit], - columns: dict[str, u.Unit], - cosmology: Cosmology, - redshift: float | table.Column, - ): - """ - Change the unit convention - """ - - if convention is None: - convention_ = self.__unit_handler.current_convention - - else: - convention_ = UnitConvention(convention) - - if ( - convention_ == UnitConvention.SCALEFREE - and UnitConvention(self.header.file.unit_convention) - != UnitConvention.SCALEFREE - ): - raise ValueError( - f"Cannot convert units with convention {self.header.file.unit_convention} to convention scalefree" - ) - column_keys = set(columns.keys()) - missing_columns = column_keys - set(self.columns) - if missing_columns: - raise ValueError(f"Dataset does not have columns {missing_columns}") - - new_handler = self.__unit_handler.with_convention(convention_).with_conversions( - conversions, columns - ) - - if convention_ == self.__unit_handler.current_convention: - cache = self.__cache.create_child() - else: - all_derived_names: set[str] = reduce( - lambda acc, next: acc.union(next[1].produces or {next[0]}), - self.__derived_columns.items(), - set(), - ) - columns_to_drop = all_derived_names.union(self.__raw_data_handler.columns) - cache = self.__cache.drop(columns_to_drop) - return self.__rebuild(unit_handler=new_handler, cache=cache) diff --git a/src/opencosmo/evaluate.py b/src/opencosmo/evaluate.py deleted file mode 100644 index 74db256d..00000000 --- a/src/opencosmo/evaluate.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Any - -import astropy.units as u -import numpy as np - -""" -General helper routines for evaluating expressions on datasets and collections -""" - - -def insert_data( - storage: dict[str, np.ndarray], index: int, values_to_insert: dict[str, Any] -): - if isinstance(values_to_insert, dict): - for name, value in values_to_insert.items(): - storage[name][index] = value - return storage - - name = next(iter(storage.keys())) - storage[name][index] = values_to_insert - - -def make_output_from_first_values(first_values: dict, n_rows: int): - storage = {} - new_first_values = {} - for name, value in first_values.items(): - shape: tuple[int, ...] = (n_rows,) - dtype = type(value) - if not isinstance(value, np.ndarray): - new_first_values[name] = value - elif isinstance(value, u.Quantity) and value.isscalar: - dtype = value.value.dtype - new_first_values[name] = value - elif isinstance(value, np.ndarray) and len(value) == 1: - dtype = value.dtype - new_first_values[name] = value[0] - else: - dtype = value.dtype - shape = shape + value.shape - new_first_values[name] = value - - storage[name] = np.zeros(shape, dtype=dtype) - for name, value in new_first_values.items(): - if isinstance(value, u.Quantity): - storage[name] = storage[name] * value.unit - - storage[name][0] = value - return storage diff --git a/src/opencosmo/index/build.py b/src/opencosmo/index/build.py deleted file mode 100644 index 4038bce8..00000000 --- a/src/opencosmo/index/build.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from .mask import into_array - -if TYPE_CHECKING: - from . import DataIndex - - -def from_size(size: int): - return (np.array([0], dtype=np.int64), np.array([size], dtype=np.int64)) - - -def single_chunk(start: int, size: int): - return (np.array([start], dtype=np.int64), np.array([size], np.int64)) - - -def empty(): - return (np.array([], dtype=np.int64), np.array([], dtype=np.int64)) - - -def from_range(start: int, end: int): - size = end - start - return (np.array([start], dtype=np.int64), np.array([size], np.int64)) - - -def concatenate(*indices: DataIndex): - np.concatenate(list(map(into_array, indices))) diff --git a/src/opencosmo/index/in_range.py b/src/opencosmo/index/in_range.py deleted file mode 100644 index 7f067802..00000000 --- a/src/opencosmo/index/in_range.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numba as nb -import numpy as np - -if TYPE_CHECKING: - from numpy.typing import NDArray - - -def n_in_range( - index: NDArray[np.int_] | tuple, - range_starts: int | NDArray[np.int_], - range_sizes: int | NDArray[np.int_], -): - range_starts = np.atleast_1d(range_starts) - range_sizes = np.atleast_1d(range_sizes) - match index: - case np.ndarray(): - return __n_in_range_simple(index, range_starts, range_sizes) - case (np.ndarray(), np.ndarray()): - return __n_in_range_chunked(*index, range_starts, range_sizes) - case _: - raise ValueError(f"Unknown index type {type(index)}") - - -def __n_in_range_simple( - index: NDArray[np.int_], start: NDArray[np.int_], size: NDArray[np.int_] -) -> NDArray[np.int_]: - if len(start) != len(size): - raise ValueError("Start and size arrays must have the same length") - if np.any(size < 0): - raise ValueError("Sizes must greater than or equal to zero") - if len(index) == 0: - return np.zeros_like(start) - - ends = start + size - index_to_search = np.sort(index) - start_idxs = np.searchsorted(index_to_search, start, "left") - end_idxs = np.searchsorted(index_to_search, ends, "left") - return end_idxs - start_idxs - - -@nb.njit -def __n_in_range_chunked( - starts: NDArray[np.int_], - sizes: NDArray[np.int_], - range_starts: NDArray[np.int_], - range_sizes: NDArray[np.int_], -) -> NDArray[np.int_]: - """ - Return the number of elements in this index that fall within - a specified data range. Used to mask spatial index. - - - As with numpy, this is the half-open range [start, end) - """ - if len(range_starts) != len(range_sizes): - raise ValueError("Start and size arrays must have the same length") - if np.any(range_sizes < 0): - raise ValueError("Sizes must greater than or equal to zero") - if len(starts) == 0: - return np.zeros_like(range_starts) - - index_ranges = np.vstack((starts, starts + sizes)) - output = np.zeros_like(range_starts) - for i in range(len(range_starts)): - index_chunk_ranges = np.clip( - index_ranges, a_min=range_starts[i], a_max=range_starts[i] + range_sizes[i] - ) - output[i] = np.sum(index_chunk_ranges[1] - index_chunk_ranges[0]) - - return output diff --git a/src/opencosmo/index/take.py b/src/opencosmo/index/take.py deleted file mode 100644 index a83f137a..00000000 --- a/src/opencosmo/index/take.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -import numba as nb # type: ignore -import numpy as np - -SimpleIndex = np.ndarray -ChunkedIndex = tuple[np.ndarray, np.ndarray] - - -def take(from_, by): - match (from_, by): - case (np.ndarray(), np.ndarray()): - return __take_simple_from_simple(from_, by) - case (np.ndarray(), (np.ndarray(), np.ndarray())): - return __take_chunked_from_simple(from_, by) - case ((np.ndarray(), np.ndarray()), np.ndarray()): - return __take_simple_from_chunked(from_, by) - case ((np.ndarray(), np.ndarray()), (np.ndarray(), np.ndarray())): - return __take_chunked_from_chunked(from_, by) - - -def __take_simple_from_chunked(from_: ChunkedIndex, by: SimpleIndex): - cumulative = np.insert(np.cumsum(from_[1]), 0, 0)[:-1] - - indices_into_chunks = np.argmax(by[:, np.newaxis] < cumulative, axis=1) - 1 - output = by - cumulative[indices_into_chunks] + from_[0][indices_into_chunks] - return output - - -def __take_simple_from_simple(from_: np.ndarray, by: np.ndarray): - return from_[by] - - -def __take_chunked_from_simple(from_: SimpleIndex, by: ChunkedIndex): - output = np.zeros(by[1].sum(), dtype=int) - output = __cfs_helper(from_, *by, output) - return output - - -@nb.njit -def __cfs_helper(arr, starts, sizes, storage): - rs = 0 - for i in range(len(starts)): - cstart = starts[i] - csize = sizes[i] - storage[rs : rs + csize] = arr[cstart : cstart + csize] - rs += csize - return storage - - -@nb.njit -def __cfc_helper(from_starts, from_sizes, by_starts, by_sizes): - pass - - -@nb.njit -def prefix_sum(arr): - out = np.empty(len(arr) + 1, dtype=arr.dtype) - total = 0 - out[0] = 0 - for i in range(len(arr)): - total += arr[i] - out[i + 1] = total - return out - - -@nb.njit -def find_chunk(prefix, x): - """ - Returns index i such that prefix[i] <= x < prefix[i+1]. - """ - lo = 0 - hi = len(prefix) - 1 # prefix has length N+1 - while lo + 1 < hi: - mid = (lo + hi) // 2 - if prefix[mid] <= x: - lo = mid - else: - hi = mid - return lo - - -@nb.njit -def resolve_spanning_numba( - start1, size1, start2, size2, out_start, out_size, out_owner -): - """ - Resolves index2 slices into data-level chunks. - Returns the number of output segments written. - """ - prefix = prefix_sum(size1) - out_pos = 0 - - for j in range(len(start2)): - logical = start2[j] - remaining = size2[j] - - while remaining > 0: - # Find which chunk in index1 we are inside - i1 = find_chunk(prefix, logical) - - # Where inside that chunk? - offset = logical - prefix[i1] - - # How many logical units remain in this chunk? - chunk_left = size1[i1] - offset - - # How much we take - take = chunk_left if chunk_left < remaining else remaining - - # Emit result - out_start[out_pos] = start1[i1] + offset - out_size[out_pos] = take - out_owner[out_pos] = j - - out_pos += 1 - - # Advance - logical += take - remaining -= take - - return out_pos - - -def __take_chunked_from_chunked(from_: ChunkedIndex, by: ChunkedIndex): - if len(from_[0]) == 0 and from_[0][0] == 0: - return by - - max_out = len(by[1]) * len(from_[1]) - out_start = np.empty(max_out, dtype=np.int64) - out_size = np.empty(max_out, dtype=np.int64) - out_owner = np.empty(max_out, dtype=np.int64) - - n = resolve_spanning_numba( - from_[0], from_[1], by[0], by[1], out_start, out_size, out_owner - ) - out_start = np.resize(out_start, (n,)) - out_size = np.resize(out_size, (n,)) - return out_start, out_size diff --git a/src/opencosmo/index/unary.py b/src/opencosmo/index/unary.py deleted file mode 100644 index c260b158..00000000 --- a/src/opencosmo/index/unary.py +++ /dev/null @@ -1,61 +0,0 @@ -import numba as nb -import numpy as np -from numpy.typing import NDArray - -""" -Implementations for unary operations on indices -""" - -SimpleIndex = NDArray[np.int_] -ChunkedIndex = tuple[NDArray[np.int_], NDArray[np.int_]] - - -def get_length(index: SimpleIndex | ChunkedIndex): - match index: - case np.ndarray(): - return len(index) - case (np.ndarray(), np.ndarray()): - return int(np.sum(index[1])) - case _: - raise TypeError(f"Invalid index type {type(index)}") - - -def get_range(index: SimpleIndex | ChunkedIndex): - match index: - case np.ndarray(): - return __get_simple_range(index) - case (np.ndarray(), np.ndarray()): - return __get_chunked_range(*index) - case _: - raise ValueError(f"Unknown index type {type(index)}") - - -@nb.njit -def __get_simple_range(index: SimpleIndex): - if len(index) == 0: - return (0, 0) - - min = index[0] - max = index[0] - for val in index[1:]: - if val < min: - min = val - if val > max: - max = val - return (min, max) - - -@nb.njit -def __get_chunked_range(starts: NDArray[np.int_], sizes: NDArray[np.int_]): - if len(starts) == 0: - return (0, 0) - min = starts[0] - max = min + sizes[0] - for i in range(1, len(starts)): - start = starts[i] - end = start + sizes[i] - if start < min: - min = start - if end > max: - max = end - return (min, max) diff --git a/src/opencosmo/parameters/diffsky.py b/src/opencosmo/parameters/diffsky.py deleted file mode 100644 index 35502e6e..00000000 --- a/src/opencosmo/parameters/diffsky.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from datetime import datetime # noqa -from typing import ClassVar, Optional - -from pydantic import BaseModel, ConfigDict, field_serializer - - -class DiffskyVersionInfo(BaseModel): - model_config = ConfigDict(frozen=True) - ACCESS_PATH: ClassVar[str] = "diffsky_versions" - diffmah: str - diffsky: str - diffstar: str - diffstarpop: Optional[str] = None - dsps: str - jax: str - numpy: str - - -class DiffskyCatalogInfo(BaseModel): - model_config = ConfigDict(frozen=True) - ACCESS_PATH: ClassVar[str] = "catalog_info" - README: Optional[str] = None - mock_version_name: str - zphot_table: Optional[tuple[float, ...]] = None - - @field_serializer("zphot_table") - def serialize_zphot_table(self, value): - if value is not None: - return list(value) - return None diff --git a/test/parallel/test_dataset_mpi.py b/test/parallel/test_dataset_mpi.py new file mode 100644 index 00000000..a1ae6970 --- /dev/null +++ b/test/parallel/test_dataset_mpi.py @@ -0,0 +1,309 @@ +import numpy as np +import pytest +from opencosmo.mpi import get_comm_world +from pytest_mpi import parallel_assert + +import opencosmo as oc + + +@pytest.fixture +def input_path(snapshot_path): + return snapshot_path / "haloproperties.hdf5" + + +@pytest.mark.parallel(nprocs=4) +def test_take_global(input_path): + comm = get_comm_world() + + ds = oc.open(input_path) + total_length = sum(comm.allgather(len(ds))) + n_to_take = np.random.randint(total_length // 4, int(total_length * 0.75)) + n_to_take = comm.bcast(n_to_take) + + ds = ds.take(n_to_take, mode="global") + all_lengths = comm.allgather(len(ds)) + + parallel_assert(sum(all_lengths) == n_to_take) + + +# ── take_range global, unsorted ────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_start(input_path): + """First n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + + ds_taken = ds.take_range(0, n, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max(0, min(int(lengths[rank]), n - offset)) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_end(input_path): + """Last n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + ds_taken = ds.take_range(global_start, total, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +# ── take_range global, sorted ───────────────────────────────────────────────── +# +# The sorted tests verify value-level correctness: after a global range take on +# a sorted dataset, every selected value must satisfy the global threshold +# implied by the range position. We derive the expected threshold by gathering +# all values from all ranks before the take, then checking the invariant after. + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_start(input_path): + """Global start on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + # Gather original values to compute the expected threshold before the take. + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + ds_taken = ds.take_range(0, n, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_end(input_path): + """Global end on sorted data selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take_range(total - n, total, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +# ── take global end ─────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_global_end(input_path): + """take(n, at='end', mode='global') selects the last n rows across all ranks.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + ds_taken = ds.take(n, at="end", mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_take_global_end_sorted(input_path): + """take(n, at='end', mode='global') on sorted data selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take(n, at="end", mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +# ── take_range global, middle ───────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_middle(input_path): + """A middle window of global rows lands on the correct ranks with the right counts.""" + comm = get_comm_world() + ds = oc.open(input_path) + + lengths = np.array(comm.allgather(len(ds)), dtype=np.int64) + total = int(np.sum(lengths)) + global_start = total // 4 + global_end = 3 * total // 4 + + ds_taken = ds.take_range(global_start, global_end, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), global_end - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(ds_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(ds_taken)}", + ) + parallel_assert(sum(comm.allgather(len(ds_taken))) == global_end - global_start) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_middle(input_path): + """A middle window on sorted data selects the correct globally-ranked values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(ds))) + global_start = total // 4 + global_end = 3 * total // 4 + size = global_end - global_start + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + sorted_all = np.sort(all_original) + lower_threshold = sorted_all[global_start] + upper_threshold = sorted_all[global_end - 1] + + ds_taken = ds.take_range(global_start, global_end, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == size) + parallel_assert( + np.all(all_selected >= lower_threshold), + "some selected values fall below the lower global threshold", + ) + parallel_assert( + np.all(all_selected <= upper_threshold), + "some selected values exceed the upper global threshold", + ) + + +# ── inverted sort ───────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_inverted_start(input_path): + """Inverted sort: global start selects the n globally largest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + ds_taken = ds.take_range(0, n, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_take_range_global_sorted_inverted_end(input_path): + """Inverted sort: global end selects the n globally smallest values.""" + comm = get_comm_world() + ds = oc.open(input_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(ds))) + n = total // 3 + + original = ds.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + ds_taken = ds.take_range(total - n, total, mode="global") + + selected = ds_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index c5d9023d..f85f0bb2 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -2,15 +2,16 @@ import shutil import astropy.units as u +import h5py import numpy as np import pytest from astropy.coordinates import SkyCoord from healpy import pix2ang from mpi4py import MPI +from opencosmo.mpi import get_comm_world from pytest_mpi.parallel_assert import parallel_assert import opencosmo as oc -from opencosmo.mpi import get_comm_world IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -122,7 +123,9 @@ def test_healpix_index_chain_failure(haloproperties_600_path): @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parallel(nprocs=4) def test_healpix_write(haloproperties_600_path, per_test_dir): + comm = get_comm_world() ds = oc.open(haloproperties_600_path) + assert "redshift" in ds.columns pixel = np.random.choice(ds.region.pixels) center = pix2ang(ds.region.nside, pixel, True, True) @@ -133,12 +136,18 @@ def test_healpix_write(haloproperties_600_path, per_test_dir): oc.write(per_test_dir / "lightcone_test.hdf5", ds) new_ds = oc.open(per_test_dir / "lightcone_test.hdf5") - radius2 = 2 * u.deg + radius2 = 1 * u.deg region2 = oc.make_cone(center, radius2) new_ds = new_ds.bound(region2) ds = ds.bound(region2) - assert set(ds.get_data()["fof_halo_tag"]) == set(new_ds.get_data()["fof_halo_tag"]) + rank_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + new_rank_tags = new_ds.select("fof_halo_tag").get_data("numpy", unpack=False) + + all_tags = np.concatenate(comm.allgather(rank_tags)) + all_new_tags = np.concatenate(comm.allgather(new_rank_tags)) + + parallel_assert(np.all(np.sort(all_tags) == np.sort(all_new_tags))) @pytest.mark.filterwarnings("ignore::UserWarning") @@ -230,9 +239,9 @@ def test_box_search_chain_failure(haloproperties_600_path): @pytest.mark.parallel(nprocs=4) def test_box_search_write(haloproperties_600_path, per_test_dir): """Written box-search result supports a narrower refinement search on re-open.""" + comm = get_comm_world() ds = oc.open(haloproperties_600_path) - # Each rank works with a pixel it owns so the search is guaranteed to find data. pixel = np.random.choice(ds.region.pixels) ra_center, dec_center = pix2ang(ds.region.nside, pixel, lonlat=True, nest=True) @@ -249,7 +258,12 @@ def test_box_search_write(haloproperties_600_path, per_test_dir): ds = ds.box_search(p1_inner, p2_inner) new_ds = new_ds.box_search(p1_inner, p2_inner) - assert set(ds.get_data()["fof_halo_tag"]) == set(new_ds.get_data()["fof_halo_tag"]) + original_halo_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + written_halo_tags = ds.select("fof_halo_tag").get_data("numpy", unpack=False) + + all_original_tags = np.concat(comm.allgather(original_halo_tags)) + all_written_tags = np.concat(comm.allgather(written_halo_tags)) + parallel_assert(np.all(all_original_tags == all_written_tags)) @pytest.mark.parallel(nprocs=4) @@ -403,13 +417,13 @@ def test_diffsky_stack_with_synths(core_path_487, core_path_475, per_test_dir): def test_write_some_missing(core_path_487, core_path_475, per_test_dir): comm = MPI.COMM_WORLD ds = oc.open(core_path_487, core_path_475, synth_cores=False) + assert "early_index" in ds.columns if comm.Get_rank() == 0: ds = ds.with_redshift_range(0, 0.02) assert len(ds.keys()) == 1 original_data = ds.select("early_index").get_data("numpy") original_data_length = comm.allgather(len(original_data)) - ds = ds.with_new_columns(gal_id=np.arange(len(ds))) oc.write(per_test_dir / "lightcone.hdf5", ds) ds = oc.open(per_test_dir / "lightcone.hdf5", synth_cores=True) written_data = ds.select("early_index").get_data("numpy") @@ -434,13 +448,9 @@ def test_write_diffsky_some_missing_no_stack( ds.pop(475) assert len(ds.keys()) == 1 - all_lengths = comm.allgather(len(ds)) - all_ends = np.insert(np.cumsum(all_lengths), 0, 0) - rank = comm.Get_rank() - ds = ds.with_new_columns(gal_id=np.arange(all_ends[rank], all_ends[rank + 1])) - - columns_to_check = comm.bcast(np.random.choice(ds.columns, 10, replace=False)) - columns_to_check = np.insert(columns_to_check, 0, "gal_id") + # columns_to_check = comm.bcast(np.random.choice(ds.columns, 10, replace=False)) + # columns_to_check = np.insert(columns_to_check, 0, "gal_id") + columns_to_check = list(ds.columns) original_data = ds.select(columns_to_check).get_data("numpy") @@ -457,9 +467,8 @@ def test_write_diffsky_some_missing_no_stack( columns_to_check.sort() for column_name in columns_to_check: - if column_name == "gal_id": + if column_name in ["gal_id", "top_host_idx"]: continue - column_name = str(column_name) column_data_original = np.concat(comm.allgather(original_data.pop(column_name))) column_data_written = np.concat(comm.allgather(written_data.pop(column_name))) parallel_assert( @@ -470,6 +479,87 @@ def test_write_diffsky_some_missing_no_stack( ) +@pytest.mark.parallel(nprocs=4) +def test_open_parallel_top_host(core_path_487, core_path_475): + with h5py.File(core_path_487) as f: + core_map = _get_expected_core_tags(f["cores"]) + with h5py.File(core_path_475) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data() + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + +@pytest.mark.parallel(nprocs=4) +def test_open_write_parallel_top_host(core_path_487, core_path_475, per_test_dir): + with h5py.File(core_path_475) as f: + core_map = _get_expected_core_tags(f["cores"]) + + with h5py.File(core_path_487) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + + oc.write(per_test_dir / "test.hdf5", ds) + with h5py.File(per_test_dir / "test.hdf5") as f: + written_core_map = _get_expected_core_tags(f["475_475"]) + assert core_map == written_core_map + + data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + +@pytest.mark.parallel(nprocs=4) +def test_open_write_parallel_top_after_filter( + core_path_487, core_path_475, per_test_dir +): + with h5py.File(core_path_475) as f: + core_map = _get_expected_core_tags(f["cores"]) + + with h5py.File(core_path_487) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True).take(10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + oc.write(per_test_dir / "test.hdf5", ds) + + data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + +@pytest.mark.parallel(nprocs=4) +def test_keep_top_host_filter(core_path_487, core_path_475): + with h5py.File(core_path_487) as f: + core_map = _get_expected_core_tags(f["cores"]) + with h5py.File(core_path_475) as f: + core_map |= _get_expected_core_tags(f["cores"]) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + data = ds.take(10).select("top_host_idx", "core_tag").get_data() + + _assert_top_host_idx_correct(data, core_map) + _assert_all_group_members_present(data, core_map) + + @pytest.mark.parallel(nprocs=4) def test_write_some_missing_no_stack( haloproperties_600_path, haloproperties_601_path, per_test_dir @@ -521,3 +611,491 @@ def test_lightcone_stacking( assert np.all(np.isin(all_fof_tags, all_fof_tags_new)) assert ds_new.z_range == ds.z_range assert next(iter(ds_new.values())).header.lightcone["z_range"] == ds_new.z_range + + +# ── take global ─────────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global(haloproperties_600_path, haloproperties_601_path): + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + total_length = sum(comm.allgather(len(lc))) + n_to_take = np.random.randint(total_length // 4, int(total_length * 0.75)) + n_to_take = comm.bcast(n_to_take) + + lc = lc.take(n_to_take, mode="global") + all_lengths = comm.allgather(len(lc)) + + parallel_assert(sum(all_lengths) == n_to_take) + + +# ── take_range global, unsorted ─────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_start(haloproperties_600_path, haloproperties_601_path): + """First n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + + lc_taken = lc.take_range(0, n, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max(0, min(int(lengths[rank]), n - offset)) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_end(haloproperties_600_path, haloproperties_601_path): + """Last n global rows land on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + lc_taken = lc.take_range(global_start, total, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_middle(haloproperties_600_path, haloproperties_601_path): + """A middle window of global rows lands on the correct ranks with the right counts.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + global_start = total // 4 + global_end = 3 * total // 4 + + lc_taken = lc.take_range(global_start, global_end, mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), global_end - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == global_end - global_start) + + +# ── take_range global, sorted ───────────────────────────────────────────────── +# +# The sorted tests verify value-level correctness: after a global range take on +# a sorted lightcone, every selected value must satisfy the global threshold +# implied by the range position. + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_start( + haloproperties_600_path, haloproperties_601_path +): + """Global start on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_end( + haloproperties_600_path, haloproperties_601_path +): + """Global end on sorted data selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(total - n, total, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_middle( + haloproperties_600_path, haloproperties_601_path +): + """A middle window on sorted data selects the correct globally-ranked values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + global_start = total // 4 + global_end = 3 * total // 4 + size = global_end - global_start + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + sorted_all = np.sort(all_original) + lower_threshold = sorted_all[global_start] + upper_threshold = sorted_all[global_end - 1] + + lc_taken = lc.take_range(global_start, global_end, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == size) + parallel_assert( + np.all(all_selected >= lower_threshold), + "some selected values fall below the lower global threshold", + ) + parallel_assert( + np.all(all_selected <= upper_threshold), + "some selected values exceed the upper global threshold", + ) + + +# ── take global end ─────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_end(haloproperties_600_path, haloproperties_601_path): + """take(n, at='end', mode='global') selects the last n rows across all ranks.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path) + + lengths = np.array(comm.allgather(len(lc)), dtype=np.int64) + total = int(np.sum(lengths)) + n = total // 3 + global_start = total - n + + lc_taken = lc.take(n, at="end", mode="global") + + rank = comm.Get_rank() + offset = int(np.sum(lengths[:rank])) + expected_local = max( + 0, + min(int(lengths[rank]), total - offset) - max(0, global_start - offset), + ) + + parallel_assert( + len(lc_taken) == expected_local, + f"rank {rank}: expected {expected_local} rows, got {len(lc_taken)}", + ) + parallel_assert(sum(comm.allgather(len(lc_taken))) == n) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_end_sorted(haloproperties_600_path, haloproperties_601_path): + """take(n, at='end', mode='global') on sorted data selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take(n, at="end", mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +# ── take(at="start") global, sorted ────────────────────────────────────────── +# +# take(n, at="start") is a distinct branch from take_range(0, n) in the +# Lightcone implementation; the sort-order → physical conversion lives +# separately in each branch and must be tested independently. + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_start_sorted(haloproperties_600_path, haloproperties_601_path): + """take(n, at='start', mode='global') on sorted data selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass" + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take(n, at="start", mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +# ── inverted sort ───────────────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_inverted_start( + haloproperties_600_path, haloproperties_601_path +): + """Inverted sort: global start selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass", invert=True + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_sorted_inverted_end( + haloproperties_600_path, haloproperties_601_path +): + """Inverted sort: global end selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_601_path, haloproperties_600_path).sort_by( + "fof_halo_mass", invert=True + ) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(total - n, total, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +# ── single-step lightcone ───────────────────────────────────────────────────── + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_global_single_step(haloproperties_600_path): + """Global take on a single-step lightcone produces the correct total count.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path) + + total = sum(comm.allgather(len(lc))) + n_to_take = np.random.randint(total // 4, int(total * 0.75)) + n_to_take = comm.bcast(n_to_take) + + lc_taken = lc.take(n_to_take, mode="global") + parallel_assert(sum(comm.allgather(len(lc_taken))) == n_to_take) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_single_step_sorted(haloproperties_600_path): + """Sorted global take_range on a single-step lightcone selects the n globally smallest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path).sort_by("fof_halo_mass") + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected <= threshold), + "some selected values exceed the global n-th smallest threshold", + ) + + +@pytest.mark.parallel(nprocs=4) +def test_lc_take_range_global_single_step_sorted_inverted(haloproperties_600_path): + """Inverted sort on a single-step lightcone: global start selects the n globally largest values.""" + comm = get_comm_world() + lc = oc.open(haloproperties_600_path).sort_by("fof_halo_mass", invert=True) + + total = sum(comm.allgather(len(lc))) + n = total // 3 + + original = lc.select("fof_halo_mass").get_data("numpy") + all_original = np.concatenate(comm.allgather(original)) + threshold = np.sort(all_original)[::-1][n - 1] + + lc_taken = lc.take_range(0, n, mode="global") + + selected = lc_taken.select("fof_halo_mass").get_data("numpy") + all_selected = np.concatenate(comm.allgather(selected)) + + parallel_assert(len(all_selected) == n) + parallel_assert( + np.all(all_selected >= threshold), + "some selected values fall below the global n-th largest threshold", + ) + + +def _get_expected_core_tags(group): + raw_top_host = group["data"]["top_host_idx"][:] + core_tag = group["data"]["core_tag"][:] + top_host_core_tag = core_tag[raw_top_host] + return dict(zip(core_tag, top_host_core_tag)) + + +def _assert_top_host_idx_correct(data, core_map): + """ + Verify top_host_idx is correctly remapped in `data` (a numpy dict with + "top_host_idx" and "core_tag" keys). Works for both local (per-rank) and + global (gathered) data. + + Synthetic cores (core_tag == -1) are checked separately: each must point + to its own row index. Real cores are checked against core_map. + """ + synth_mask = data["core_tag"] == -1 + synth_indices = np.where(synth_mask)[0] + assert np.all(data["top_host_idx"][synth_mask] == synth_indices) + + # Restrict to real cores for the map check, but dereference top_host_idx + # against the full data so that indices into synthetic rows resolve correctly. + real_mask = ~synth_mask + real_top_host_idx = data["top_host_idx"][real_mask] + real_core_tag = data["core_tag"][real_mask] + + has_top_host = real_top_host_idx >= 0 + found_top_host_core_tag = data["core_tag"][real_top_host_idx[has_top_host]] + found_core_map = dict(zip(real_core_tag[has_top_host], found_top_host_core_tag)) + + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in real_core_tag + } + assert should_have_core_map == found_core_map + + comm = get_comm_world() + all_data_core_maps = comm.allgather(found_core_map) + seen = set() + for m in all_data_core_maps: + assert len(seen.intersection(m.keys())) == 0 + seen |= m.keys() + + +def _assert_all_group_members_present(data, core_map): + """ + Verify that for every top_host represented in the data, all rows from the + full dataset that point to that top_host are also present. + """ + host_to_members: dict = {} + for ct, host_ct in core_map.items(): + host_to_members.setdefault(host_ct, set()).add(ct) + + present_core_tags = set(data["core_tag"]) + top_host_core_tags = set(data["core_tag"][data["top_host_idx"]]) + + for top_host_ct in top_host_core_tags: + expected_members = host_to_members.get(top_host_ct, set()) + missing = expected_members - present_core_tags + assert not missing, ( + f"top_host {top_host_ct}: {len(missing)} member(s) missing from result" + ) diff --git a/test/parallel/test_structure_collection_mpi.py b/test/parallel/test_structure_collection_mpi.py new file mode 100644 index 00000000..fd4ea956 --- /dev/null +++ b/test/parallel/test_structure_collection_mpi.py @@ -0,0 +1,213 @@ +import os +import shutil + +import numpy as np +import pytest +from mpi4py import MPI +from opencosmo.mpi import get_comm_world + +import opencosmo as oc + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +@pytest.fixture +def halos_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "haloproperties.hdf5" + particles = lightcone_path / "step_600" / "haloparticles.hdf5" + profiles = lightcone_path / "step_600" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_600" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def halos_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "haloproperties.hdf5" + particles = lightcone_path / "step_601" / "haloparticles.hdf5" + profiles = lightcone_path / "step_601" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_601" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def per_test_dir( + tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest +): + """ + Creates a unique directory for each test and deletes it after the test finishes. + + Uses tmp_path_factory so you can control base temp location via pytest's + tempdir handling, and also so it can be used from broader-scoped fixtures + if needed. + """ + # request.node.nodeid is unique across parameterizations; sanitize for filesystem + nodeid = ( + request.node.nodeid.replace("/", "_") + .replace("::", "__") + .replace("[", "_") + .replace("]", "_") + ) + + path = tmp_path_factory.mktemp(nodeid) + comm = MPI.COMM_WORLD + path_to_return = comm.bcast(path) + + try: + yield path_to_return + finally: + # Close out storage pressure immediately after each test + if IN_GITHUB_ACTIONS: + shutil.rmtree(path, ignore_errors=True) + + +def verify_halo(halo): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + if "galaxy" not in halo: + return + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +@pytest.mark.parallel(nprocs=4) +def test_open_lightcone_structure_with_galaxies( + halos_600_path, galaxies_600_path, halos_601_path, galaxies_601_path +): + ds = oc.open( + *halos_600_path, *galaxies_600_path, *halos_601_path, *galaxies_601_path + ) + ds = ds.filter(oc.col("sod_halo_mass") > 1e14).take(10) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + verify_halo(halo) + + +@pytest.mark.parallel(nprocs=4) +def test_write_lightcone_structure(halos_600_path, halos_601_path, per_test_dir): + comm = get_comm_world() + ds = ( + oc.open( + *halos_600_path, + *halos_601_path, + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + all_halos_read = set( + np.concatenate( + comm.allgather( + ds["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + oc.write(per_test_dir / "halos.hdf5", ds) + ds_new = oc.open(per_test_dir / "halos.hdf5") + + all_halos = set( + np.concatenate( + comm.allgather( + ds_new["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + verify_halo(halo) + + assert halo_tags_start.issubset(all_halos) + assert halo_tags_end.issubset(all_halos) + assert all_halos_read == all_halos + + +@pytest.mark.parallel(nprocs=4) +def test_write_lightcone_structure_with_galaxies( + halos_600_path, halos_601_path, galaxies_600_path, galaxies_601_path, per_test_dir +): + comm = get_comm_world() + ds = ( + oc.open( + *halos_600_path, *halos_601_path, *galaxies_600_path, *galaxies_601_path + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + all_halos_read = set( + np.concatenate( + comm.allgather( + ds["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + oc.write(per_test_dir / "halos.hdf5", ds) + ds_new = oc.open(per_test_dir / "halos.hdf5") + + all_halos = set( + np.concatenate( + comm.allgather( + ds_new["halo_properties"].select("fof_halo_tag").get_data("numpy") + ) + ) + ) + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + verify_halo(halo) + + assert halo_tags_start.issubset(all_halos) + assert halo_tags_end.issubset(all_halos) + assert all_halos_read == all_halos diff --git a/test/spatial/test_2d.py b/test/spatial/test_2d.py index 363017e7..5613e4b6 100644 --- a/test/spatial/test_2d.py +++ b/test/spatial/test_2d.py @@ -1,9 +1,9 @@ import astropy.units as u import numpy as np from astropy.coordinates import SkyCoord +from opencosmo.spatial.relations import contains_2d, intersects_2d import opencosmo as oc -from opencosmo.spatial.relations import contains_2d, intersects_2d # --------------------------------------------------------------------------- # Cone search (bound with ConeRegion) diff --git a/test/test_cache.py b/test/test_cache.py index 41932644..592c4f00 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -1,96 +1,115 @@ -import numpy as np +from uuid import uuid4 +import numpy as np from opencosmo.column.cache import ColumnCache +def _make_mapping(names): + """Create a {name: UUID} mapping for testing.""" + return {name: uuid4() for name in names} + + +def _add_raw_data(cache, name_to_uuid, raw_data): + """Add name-keyed data to cache via UUID-keyed interface.""" + uuid_data = {name_to_uuid[name]: {name: raw_data[name]} for name in raw_data} + cache.add_data(uuid_data) + + +def _get_flat_data(cache, name_to_uuid): + """Get data from cache, flattened to a plain name-keyed dict.""" + pairs = {(uuid, name) for name, uuid in name_to_uuid.items()} + uuid_data = cache.get_data(pairs) + return {name: arr for d in uuid_data.values() for name, arr in d.items()} + + def test_cache_take(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 100, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) - cached_data = cache2.get_data("abcdefg") + cache2.register_column_group(0, name_to_uuid) + cached_data = _get_flat_data(cache2, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2]) + assert np.all(column == raw_data[colname][index2]) def test_cache_passthrough(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) + cache2.register_column_group(0, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) assert set(cache3.columns) == set("abcdefg") assert len(cache2.columns) == 0 def test_cache_passthrough_delete(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) - _ = cache2.get_data("abcdefg") + cache2.register_column_group(0, name_to_uuid) + _ = _get_flat_data(cache2, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) assert set(cache3.columns) == set() del cache2 assert set(cache3.columns) == set("abcdefg") - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) def test_cache_passthrough_twice_delete(): - data = {} - for name in "abcdefg": - data[name] = np.random.randint(0, 1000, 10_000) + raw_data = {name: np.random.randint(0, 1000, 10_000) for name in "abcdefg"} + name_to_uuid = _make_mapping("abcdefg") + cache = ColumnCache.empty() - cache.register_column_group(0, set("abcdefg")) - cache.add_data(data) + cache.register_column_group(0, name_to_uuid) + _add_raw_data(cache, name_to_uuid, raw_data) index2 = np.sort(np.random.choice(10_000, 1000, replace=False)) cache2 = cache.take(index2) - cache2.register_column_group(0, set("abcdefg")) + cache2.register_column_group(0, name_to_uuid) index3 = np.sort(np.random.choice(1000, 100, replace=False)) cache3 = cache2.take(index3) - cache3.register_column_group(0, set("abcdefg")) + cache3.register_column_group(0, name_to_uuid) assert set(cache2.columns) == set() del cache assert set(cache2.columns) == set("abcdefg") - cached_data = cache3.get_data("abcdefg") + cached_data = _get_flat_data(cache3, name_to_uuid) for colname, column in cached_data.items(): - assert np.all(column == data[colname][index2[index3]]) + assert np.all(column == raw_data[colname][index2[index3]]) assert set(cache3.columns) == set("abcdefg") diff --git a/test/test_collection.py b/test/test_collection.py index 75488920..8cbebec9 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -108,6 +108,20 @@ def test_open_structures(halo_paths, galaxy_paths): assert isinstance(c3, oc.StructureCollection) +def test_call_lightcone_fails(halo_paths, galaxy_paths): + ds = oc.open(*halo_paths, *galaxy_paths) + with pytest.raises(AttributeError): + ds.get_pixels() + with pytest.raises(AttributeError): + ds.cone_search(None, None) + with pytest.raises(AttributeError): + ds.box_search(None, None) + with pytest.raises(AttributeError): + ds.pixel_search(None) + with pytest.raises(AttributeError): + ds.with_redshift_range(0.0, 1.0) + + def test_multi_filter(multi_path): collection = oc.open(multi_path) collection = collection.filter(oc.col("sod_halo_mass") > 0) @@ -530,6 +544,36 @@ def eval_fn(fof_halo_mass, fof_halo_tag): ) +def test_structure_collection_evaluate_overwrite(halo_paths): + collection = oc.open(*halo_paths).take(20) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + collection.evaluate(fof_halo_mass, dataset="halo_properties") + + collection_overwritten = collection.evaluate( + fof_halo_mass, + dataset="halo_properties", + allow_overwrite=True, + vectorize=True, + ) + original = ( + collection["halo_properties"] + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + overwritten = ( + collection_overwritten["halo_properties"] + .select("fof_halo_mass") + .get_data("numpy") + ) + assert np.all( + np.isclose(overwritten, original["fof_halo_mass"] * original["fof_halo_com_vx"]) + ) + + def test_visit_dataset_in_structure_collection_nochunk(halo_paths): collection = oc.open(*halo_paths) @@ -565,6 +609,22 @@ def offset( assert np.all(offset_vec == offset_loop) +def test_evaluate_on_dataset_nested_path(halo_paths, galaxy_paths): + collection = oc.open(*halo_paths, *galaxy_paths).take(10) + + def particle_id(x, y, z): + return np.arange(len(x)) + + collection = collection.evaluate_on_dataset( + particle_id, dataset="galaxies.star_particles", vectorize=True, insert=True + ) + assert "particle_id" in collection["galaxies"]["star_particles"].columns + for halo in collection.halos(["galaxies"]): + for galaxy in halo["galaxies"].galaxies(["star_particles"]): + pid = galaxy["star_particles"].select("particle_id").get_data("numpy") + assert np.all(pid == np.arange(len(pid))) + + def test_visit_galaxies_in_halo_collection(halo_paths, galaxy_paths): collection = oc.open(*halo_paths, *galaxy_paths).take(10) @@ -1027,6 +1087,30 @@ def fof_px(fof_halo_mass, fof_halo_com_vx, random_value, other_value): ) +def test_simulation_collection_evaluate_overwrite(multi_path): + collection = oc.open(multi_path) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + collection.evaluate(fof_halo_mass, vectorize=True, insert=True) + + collection_overwritten = collection.evaluate( + fof_halo_mass, vectorize=True, insert=True, allow_overwrite=True + ) + for ds_name, ds in collection.items(): + original = ds.select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + overwritten = ( + collection_overwritten[ds_name].select("fof_halo_mass").get_data("numpy") + ) + assert np.all( + np.isclose( + overwritten, original["fof_halo_mass"] * original["fof_halo_com_vx"] + ) + ) + + def test_simulation_collection_add(multi_path): collection = oc.open(multi_path) ds_name = next(iter(collection.keys())) @@ -1227,7 +1311,31 @@ def test_data_cached_after_objects(halo_paths): pass dataset = ds["dm_particles"] - cache = dataset._Dataset__state._DatasetState__cache - data = cache.get_data(("gpe",)) - assert data.get("gpe") is not None + state = dataset._Dataset__state + cache = state.cache + columns = state.column_map # dict[str, UUID] + gpe_uuid = columns["gpe"] + uuid_data = cache.get_data({(gpe_uuid, "gpe")}) + assert uuid_data.get(gpe_uuid, {}).get("gpe") is not None assert dataset.descriptions["gpe"] != "None" + + +def test_modify_metadata_column(halo_paths): + ds = oc.open(*halo_paths) + galaxyproperties_start = ds["halo_properties"].get_metadata( + "galaxyproperties_start" + ) + updated_galprops = oc.col("galaxyproperties_start") + 1000 + + ds = ds.with_new_columns( + "halo_properties", galaxyproperties_start=updated_galprops, allow_overwrite=True + ) + updated_galaxyproperties_start = ds["halo_properties"].get_metadata( + "galaxyproperties_start" + ) + assert np.all( + (galaxyproperties_start["galaxyproperties_start"] + 1000) + == updated_galaxyproperties_start["galaxyproperties_start"] + ) + assert "galaxyproperties_start" not in ds["halo_properties"].columns + assert "galaxyproperties_start" not in ds["halo_properties"].get_data("numpy") diff --git a/test/test_dataset.py b/test/test_dataset.py index f7e90135..d2095fa4 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -543,12 +543,15 @@ def test_rows_cache(input_path): for i, row in enumerate(dataset.rows()): assert row["fof_px"] == row["fof_halo_mass"] * row["fof_halo_com_vx"] - cache = dataset._Dataset__state._DatasetState__cache - cached_data = cache.get_data(["fof_halo_mass", "fof_halo_com_vx", "fof_px"]) - assert np.all( - cached_data["fof_px"] - == cached_data["fof_halo_mass"] * cached_data["fof_halo_com_vx"] - ) + # After iterating rows(), derived columns should be cached. + state = dataset._Dataset__state + cache = state.cache + columns = state.column_map + assert "fof_px" in cache.columns + pairs = {(columns["fof_px"], "fof_px")} + uuid_data = cache.get_data(pairs) + flat = {name: arr for d in uuid_data.values() for name, arr in d.items()} + assert len(flat["fof_px"]) == 100 IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -575,7 +578,7 @@ def test_cache(input_path): def test_cache_select(input_path): dataset = oc.open(input_path) _ = dataset.get_data() - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert set(dataset.columns) == cache.columns columns = np.random.choice(dataset.columns, 5, replace=False) dataset2 = dataset.select(columns) @@ -592,7 +595,7 @@ def test_cache_filter(input_path): dataset = oc.open(input_path) dataset = dataset.filter(oc.col("fof_halo_mass") > 1e14) - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert ( len(cache.columns) > 0 or cache._ColumnCache__parent() is not None @@ -610,7 +613,7 @@ def test_cache_column_conversion(input_path): dataset2 = dataset.with_units(conversions={u.Mpc: u.lyr}) - cache = dataset2._Dataset__state._DatasetState__cache + cache = dataset2._Dataset__state.cache data2 = dataset2.get_data() assert len(cache.columns) == len(dataset2.columns) @@ -623,7 +626,7 @@ def test_cache_change_units(input_path): dataset.get_data() dataset2 = dataset.with_units("scalefree") - cache = dataset2._Dataset__state._DatasetState__cache + cache = dataset2._Dataset__state.cache assert ( len(cache.columns) == 0 and cache._ColumnCache__parent is None @@ -635,16 +638,24 @@ def test_cache_conversion_propogation(input_path): dataset2 = dataset.with_units(conversions={u.Mpc: u.lyr}, fof_halo_center_x=u.km) dataset2.get_data() - cache = dataset._Dataset__state._DatasetState__cache - cache2 = dataset2._Dataset__state._DatasetState__cache + state = dataset._Dataset__state + state2 = dataset2._Dataset__state + cache = state.cache + cache2 = state2.cache + col_to_uuid = state.column_map + pairs = {(uuid, name) for name, uuid in col_to_uuid.items()} assert len(cache.columns) == len(dataset2.columns) # just to be safe - cached_columns = cache.get_data(dataset.columns) - cached_columns2 = cache2.get_data(dataset.columns) - for col in cached_columns.values(): + flat = { + name: arr for d in cache.get_data(pairs).values() for name, arr in d.items() + } + flat2 = { + name: arr for d in cache2.get_data(pairs).values() for name, arr in d.items() + } + for col in flat.values(): if isinstance(col, u.Quantity): assert col.unit not in [u.lyr, u.km] - for col in cached_columns2.values(): + for col in flat2.values(): if isinstance(col, u.Quantity): assert col.unit != u.Mpc @@ -706,3 +717,39 @@ def test_description_with_insert_multiple(input_path, tmp_path): descriptions = ds.descriptions assert descriptions["random_data"] == "random data for a test" assert descriptions["halo_px"] == "halo x momentum" + + +def test_with_new_columns_overwrite(input_path): + ds = oc.open(input_path) + log_mass = oc.col("fof_halo_mass").log10() + + with pytest.raises(ValueError): + ds.with_new_columns(fof_halo_mass=log_mass) + + ds_overwritten = ds.with_new_columns(fof_halo_mass=log_mass, allow_overwrite=True) + assert "fof_halo_mass" in ds_overwritten.columns + + original = ds.select("fof_halo_mass").get_data("numpy") + overwritten = ds_overwritten.select("fof_halo_mass").get_data("numpy") + assert np.all(np.isclose(overwritten, np.log10(original))) + + +def test_evaluate_overwrite(input_path): + ds = oc.open(input_path) + + def fof_halo_mass(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + with pytest.raises(ValueError): + ds.evaluate(fof_halo_mass, vectorize=True, insert=True) + + ds_overwritten = ds.evaluate( + fof_halo_mass, vectorize=True, insert=True, allow_overwrite=True + ) + assert "fof_halo_mass" in ds_overwritten.columns + + data = ds.select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + overwritten = ds_overwritten.select("fof_halo_mass").get_data("numpy") + assert np.all( + np.isclose(overwritten, data["fof_halo_mass"] * data["fof_halo_com_vx"]) + ) diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 890ad4ba..3c491951 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -2,12 +2,14 @@ import shutil import astropy.units as u +import h5py +import healpy as hp import numpy as np import pytest +from opencosmo.spatial.region import HealpixRegion import opencosmo as oc from opencosmo.column import add_mag_cols -from opencosmo.spatial.region import HealpixRegion @pytest.fixture @@ -149,9 +151,13 @@ def offset( def test_open_write_with_synthetics(core_path_475, core_path_487, per_test_dir): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + n = 10_000 ds = oc.open(core_path_487, core_path_475, synth_cores=True) ds = ds.filter(oc.col("lsst_g") < 20).take(n).with_new_columns(galid=np.arange(n)) + original_data = ds.get_data() assert len(original_data) == n assert len(ds) == n @@ -163,13 +169,30 @@ def test_open_write_with_synthetics(core_path_475, core_path_487, per_test_dir): ds = oc.open(per_test_dir / "test.hdf5", synth_cores=True) assert len(ds) == n - written_data = ds.get_data() + # Check top_host_idx correctness on the spatially-sorted written data before + # any reordering, since top_host_idx values are indices into that ordering. + reindex_data = ( + oc.open(per_test_dir / "test.hdf5") + .select("top_host_idx", "core_tag") + .get_data("numpy") + ) + assert_top_host_idx_correct(reindex_data, core_map) + written_data = ds.get_data() written_data.sort("galid") - columns_to_check = np.random.choice(ds.columns, size=20, replace=False) + other_columns = [c for c in ds.columns if c != "top_host_idx"] + columns_to_check = np.random.choice(other_columns, size=20, replace=False) + for column in columns_to_check: assert np.all(original_data[column] == written_data[column]) + # Check the synthetic cores still point to themselves + synth_core_check_data = ds.select("top_host_idx", "core_tag").get_data("numpy") + synth_core_rows = np.where(synth_core_check_data["core_tag"] == -1)[0] + assert np.all( + synth_core_check_data["top_host_idx"][synth_core_rows] == synth_core_rows + ) + def test_open_write_with_multiple_synthetics( core_path_475, core_path_487, per_test_dir @@ -277,10 +300,203 @@ def test_add_mag_units_unitless(core_path_475, core_path_487): def test_region(core_path_475, core_path_487): ds = oc.open(core_path_475, core_path_487) assert isinstance(ds.region, HealpixRegion) - print(ds.region.pixels) - assert len(ds.region.pixels) == 502 + assert len(ds.region.pixels) == 1610 + + +def test_lc_collection_pixel_search_with_synthetics(core_path_475, core_path_487): + ds = oc.open(core_path_487, core_path_475, synth_cores=True) + pixels = ds.get_pixels(64) + + all_coordinates = ds.select("ra", "dec").get_data("numpy") + all_pixels = hp.ang2pix( + 64, all_coordinates["ra"], all_coordinates["dec"], nest=True, lonlat=True + ) + + assert np.all(np.unique(all_pixels) == pixels) + + pixels_to_search = np.sort(np.random.choice(pixels, 20, replace=False)) + ds_bound = ds.pixel_search(pixels_to_search) + for ds_ in ds_bound.values(): + assert isinstance(ds_, oc.Lightcone) + + bound_coordinates = ds_bound.select("ra", "dec").get_data("numpy") + bound_pixels = np.unique( + hp.ang2pix( + 64, + bound_coordinates["ra"], + bound_coordinates["dec"], + lonlat=True, + nest=True, + ) + ) + assert np.all(pixels_to_search == bound_pixels) + expected_index = np.isin(all_pixels, pixels_to_search) + found_halo_tags = ds_bound.select("gal_id").get_data() + expected_halo_tags = ds.select("gal_id").get_data()[expected_index] + assert np.all(found_halo_tags == expected_halo_tags) def test_open_bad_data(core_path_475, core_path_487, invalid_data_path): with pytest.raises(ValueError, match=str(invalid_data_path)): oc.open(core_path_475, core_path_487, invalid_data_path) + + +def get_expected_core_tags(path): + with h5py.File(path) as f: + raw_top_host = f["cores"]["data"]["top_host_idx"][:] + core_tag = f["cores"]["data"]["core_tag"][:] + + top_host_core_tag = core_tag[raw_top_host] + return dict(zip(core_tag, top_host_core_tag)) + + +def assert_all_group_members_present(data, core_map): + """ + Verify that for every top_host represented in the data, all rows from the + full dataset that point to that top_host are also present. + """ + host_to_members: dict = {} + for ct, host_ct in core_map.items(): + host_to_members.setdefault(host_ct, set()).add(ct) + + present_core_tags = set(data["core_tag"]) + top_host_core_tags = set(data["core_tag"][data["top_host_idx"]]) + + for top_host_ct in top_host_core_tags: + expected_members = host_to_members.get(top_host_ct, set()) + missing = expected_members - present_core_tags + assert not missing, ( + f"top_host {top_host_ct}: {len(missing)} member(s) missing from result" + ) + + +def assert_top_host_idx_correct(data, core_map): + """ + Verify that top_host_idx correctly references the expected host after any + transformation. For each row with a valid index, checks that the referenced + row has the core_tag the original file associates with that host. Also + asserts that every galaxy whose host is present in the data has a valid index. + + data: numpy dict with "top_host_idx" and "core_tag" columns. + core_map: dict mapping core_tag -> host core_tag from the raw HDF5 files. + """ + has_top_host = data["top_host_idx"] >= 0 + found_top_host_core_tag = data["core_tag"][data["top_host_idx"][has_top_host]] + found_core_map = dict(zip(data["core_tag"][has_top_host], found_top_host_core_tag)) + + filtered_core_map = { + key: val for key, val in core_map.items() if key in found_core_map + } + assert filtered_core_map == found_core_map + + should_have_core_map = { + key: val + for key, val in core_map.items() + if val in data["core_tag"] and key in data["core_tag"] + } + assert should_have_core_map == found_core_map + + +def test_reindex_top_host_take_none(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + +def test_reindex_top_host_take_random(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).take(300) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + +def test_reindex_top_host_take_range_sorted(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).sort_by("logsm_obs").take(300) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + +def test_reindex_top_host_take_range(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).take_range(100, 400) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + +def test_reindex_top_host_filter(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487).filter(oc.col("logsm_obs") > 10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + + +def test_keep_top_host_take_random(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take(20) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_take_start(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take(20, at="start") + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_take_range(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.take_range(20, 60) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_take_range_after_sort(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True).sort_by("logsm_obs") + ds = ds.take_range(20, 60) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) + + +def test_keep_top_host_filter(core_path_475, core_path_487): + core_map = get_expected_core_tags(core_path_475) + core_map |= get_expected_core_tags(core_path_487) + + ds = oc.open(core_path_475, core_path_487, keep_top_host=True) + ds = ds.filter(oc.col("logsm_obs") < 10) + data = ds.select("top_host_idx", "core_tag").get_data("numpy") + assert_top_host_idx_correct(data, core_map) + assert np.all(data["top_host_idx"] >= 0) + assert_all_group_members_present(data, core_map) diff --git a/test/test_evaluate_formats.py b/test/test_evaluate_formats.py new file mode 100644 index 00000000..6b1363c9 --- /dev/null +++ b/test/test_evaluate_formats.py @@ -0,0 +1,337 @@ +import jax.numpy as jnp +import numpy as np +import pandas as pd +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import opencosmo as oc + + +@pytest.fixture +def input_path(snapshot_path): + return snapshot_path / "haloproperties.hdf5" + + +FORMATS = ["jax", "pandas", "polars", "arrow"] +SCALARS = { + "jax": (jnp.ndarray, np.floating, float), + "pandas": (pd.Series, np.floating, float), + "polars": (pl.Series, float, int), + "arrow": (pa.Scalar, float, int), +} + + +def _vectorized_func(format): + """Multiply two columns using each format's native multiplication path.""" + + if format == "arrow": + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return pc.multiply(fof_halo_mass, fof_halo_com_vx) + + else: + + def fof_px(fof_halo_mass, fof_halo_com_vx): + return fof_halo_mass * fof_halo_com_vx + + return fof_px + + +def _row_func(format): + """Multiply two scalars; works for any format because each row is a scalar.""" + + def fof_px(fof_halo_mass, fof_halo_com_vx): + if isinstance(fof_halo_mass, pa.Scalar): + return fof_halo_mass.as_py() * fof_halo_com_vx.as_py() + return float(fof_halo_mass) * float(fof_halo_com_vx) + + return fof_px + + +def _expected(input_path): + data = ( + oc.open(input_path) + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + return data["fof_halo_mass"] * data["fof_halo_com_vx"] + + +def _to_numpy(value): + if isinstance(value, jnp.ndarray): + return np.asarray(value) + if isinstance(value, (pd.Series, pl.Series)): + return value.to_numpy() + if isinstance(value, pa.Array): + return value.to_numpy(zero_copy_only=False) + return np.asarray(value) + + +# --------------------------------------------------------------------------- +# insert = False (return result directly) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_noinsert(input_path, format): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_noinsert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + result = ds.evaluate( + _row_func(format), vectorize=False, insert=False, format=format + ) + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_noinsert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + result = ds.evaluate( + _vectorized_func(format), + insert=False, + batch_size=batch_size, + format=format, + ) + expected = _expected(input_path) + assert np.allclose(_to_numpy(result["fof_px"]), expected) + + +# --------------------------------------------------------------------------- +# insert = True (converted to numpy, stored in cache) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_vectorized_insert(input_path, format): + ds = oc.open(input_path) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_row_wise_insert(input_path, format): + ds = oc.open(input_path).take(500, at="start") + ds = ds.evaluate(_row_func(format), vectorize=False, insert=True, format=format) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + selected = ( + oc.open(input_path) + .take(500, at="start") + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_evaluate_batched_insert(input_path, format): + ds = oc.open(input_path) + batch_size = 10_000 + ds = ds.evaluate( + _vectorized_func(format), + insert=True, + batch_size=batch_size, + format=format, + ) + assert "fof_px" in ds.columns + data = ds.select("fof_px").get_data("numpy") + expected = _expected(input_path) + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Output-type assertions on the not-insert path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "format,expected_type", + [ + ("jax", jnp.ndarray), + ("pandas", pd.Series), + ("polars", pl.Series), + ("arrow", pa.Array), + ], +) +def test_evaluate_noinsert_returns_native_container(input_path, format, expected_type): + ds = oc.open(input_path) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert isinstance(result["fof_px"], expected_type) + + +# --------------------------------------------------------------------------- +# StructureCollection paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def halo_paths(snapshot_path): + files = ["haloproperties.hdf5", "haloparticles.hdf5"] + return [snapshot_path / f for f in files] + + +def _mean_x(format): + """Per-structure function: mean of dm_particles 'x' coord. Returns a scalar + in the user's format. fof_halo_center_x is a scalar Quantity that has had + its unit stripped for non-astropy formats, so it's plain float.""" + + def offset(halo_properties, dm_particles): + x = dm_particles["x"] + if format == "arrow": + mean_x = pc.mean(x).as_py() + elif format == "polars": + mean_x = x.mean() + elif format == "pandas": + mean_x = float(x.mean()) + elif format == "jax": + mean_x = float(jnp.mean(x)) + else: + mean_x = float(np.mean(x)) + return mean_x - float(halo_properties["fof_halo_center_x"]) + + return offset + + +def _arange_like(format): + """Per-structure function with dataset=`dm_particles`: must return an array + in the user's format with the same length as the input dataset.""" + + def particle_id(x, y, z): + n = len(x) + if format == "jax": + return jnp.arange(n) + if format == "pandas": + return pd.Series(np.arange(n)) + if format == "polars": + return pl.Series(values=np.arange(n)) + if format == "arrow": + return pa.array(np.arange(n)) + return np.arange(n) + + return particle_id + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + collection = collection.evaluate( + _mean_x(format), **spec, format=format, insert=True + ) + data = collection["halo_properties"].select("offset").get_data("numpy") + assert len(data) == 50 + assert np.any(data != 0) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_properties_noinsert(halo_paths, format): + collection = oc.open(*halo_paths).take(50) + spec = { + "dm_particles": ["x"], + "halo_properties": ["fof_halo_center_x"], + } + result = collection.evaluate(_mean_x(format), **spec, format=format, insert=False) + assert "offset" in result + assert len(result["offset"]) == 50 + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_into_dataset(halo_paths, format): + collection = oc.open(*halo_paths).take(20) + collection = collection.evaluate( + _arange_like(format), + dataset="dm_particles", + format=format, + insert=True, + ) + for halo in collection.halos(["dm_particles"]): + pid = halo["dm_particles"].select("particle_id").get_data("numpy") + assert np.all(pid == np.arange(len(pid))) + + +@pytest.mark.parametrize("format", FORMATS) +def test_collection_evaluate_on_dataset(halo_paths, format): + """Routes through Dataset.evaluate via the collection wrapper.""" + collection = oc.open(*halo_paths).take(50) + selected = ( + collection["halo_properties"] + .select(["fof_halo_mass", "fof_halo_com_vx"]) + .get_data("numpy") + ) + collection = collection.evaluate_on_dataset( + _vectorized_func(format), + dataset="halo_properties", + vectorize=True, + format=format, + insert=True, + ) + data = collection["halo_properties"].select("fof_px").get_data("numpy") + expected = selected["fof_halo_mass"] * selected["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +# --------------------------------------------------------------------------- +# Lightcone paths +# --------------------------------------------------------------------------- + + +@pytest.fixture +def lc_paths(lightcone_path): + return [ + lightcone_path / "step_600" / "haloproperties.hdf5", + lightcone_path / "step_601" / "haloproperties.hdf5", + ] + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_insert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + ds = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=True, format=format + ) + for name in ds.keys(): + data = ds[name].select("fof_px").get_data("numpy") + original = ( + ds[name].select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy") + ) + expected = original["fof_halo_mass"] * original["fof_halo_com_vx"] + assert np.allclose(data, expected) + + +@pytest.mark.parametrize("format", FORMATS) +def test_lightcone_evaluate_noinsert(lc_paths, format): + ds = oc.open(*lc_paths).take(100) + result = ds.evaluate( + _vectorized_func(format), vectorize=True, insert=False, format=format + ) + assert "fof_px" in result + assert len(result["fof_px"]) == len(ds) diff --git a/test/test_filters.py b/test/test_filters.py index 8c64d8fa..6c246f20 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -5,6 +5,7 @@ import opencosmo as oc from opencosmo import col +from opencosmo.column import offset_3d @pytest.fixture @@ -31,6 +32,7 @@ def test_multi_filters_single_column(input_path, max_mass): ds = ds.filter(col("sod_halo_mass") > 0, col("sod_halo_mass") < max_mass) data = ds.get_data() + assert data["sod_halo_mass"].min() > 0 assert data["sod_halo_mass"].max() < max_mass @@ -138,7 +140,6 @@ def test_or_filter(input_path): ds = ds.filter(high_mass | low_mass) data = ds.select("fof_halo_mass").get_data("numpy") - assert len(data) > 0 assert np.all((data > 1e14) | (data < 1e12)) assert np.any(data > 1e14) assert np.any(data < 1e12) @@ -169,3 +170,27 @@ def test_filter_tree(input_path): assert np.all(data["sod_halo_cdelta"] < 5) assert np.any(data["fof_halo_mass"] > 1e14) assert np.any(data["fof_halo_mass"] < 1e12) + + +def test_filter_by_derived(input_path): + col = offset_3d("fof_halo_center", "fof_halo_com") / oc.col("sod_halo_radius") + ds = oc.open(input_path) + + ds = ds.filter(col > 0.1) + data = ds.select(xoff=col).get_data() + assert np.all(data > 0.1) + + +def test_column_comparison(input_path): + ds = oc.open(input_path) + + ds = ds.filter(oc.col("fof_halo_center_x") < oc.col("fof_halo_center_y")) + data = ds.select("fof_halo_center_x", "fof_halo_center_y").get_data() + assert np.all(data["fof_halo_center_x"] < data["fof_halo_center_y"]) + + +def test_filter_bad_units(input_path): + ds = oc.open(input_path) + + with pytest.raises(u.UnitConversionError): + ds = ds.filter(oc.col("fof_halo_center_x") < 10 * u.kg) diff --git a/test/test_formats.py b/test/test_formats.py index 533fd383..d5b46d63 100644 --- a/test/test_formats.py +++ b/test/test_formats.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np import pandas as pd import polars as pl @@ -28,6 +29,12 @@ def test_return_pyarrow(input_path): assert all(isinstance(v, pa.Array) for v in data.values()) +def test_return_jax(input_path): + data = oc.open(input_path).get_data("jax") + assert isinstance(data, dict) + assert all(isinstance(v, jnp.ndarray) for v in data.values()) + + def test_return_pandas_single(input_path): dataset = oc.open(input_path) column = np.random.choice(dataset.columns) @@ -48,3 +55,10 @@ def test_return_pyarrow_single(input_path): column = np.random.choice(dataset.columns) data = dataset.select(column).get_data("arrow") assert isinstance(data, pa.Array) + + +def test_return_jax_single(input_path): + dataset = oc.open(input_path) + column = np.random.choice(dataset.columns) + data = dataset.select(column).get_data("jax") + assert isinstance(data, jnp.ndarray) diff --git a/test/test_healpixmap.py b/test/test_healpixmap.py index d94ad778..85178398 100644 --- a/test/test_healpixmap.py +++ b/test/test_healpixmap.py @@ -4,9 +4,10 @@ import numpy as np import pytest from astropy.coordinates import SkyCoord +from healsparse import HealSparseMap +from opencosmo.spatial.healpix import HealpixRegion import opencosmo as oc -from opencosmo.spatial.healpix import HealpixRegion @pytest.fixture @@ -67,7 +68,7 @@ def test_healpix_downgrade(healpix_map_path): original_data = ds.get_data("healpix") downgraded_data = downgraded_ds.get_data("healpix") - downgraded_data_2 = original_data["tsz"].reshape((-1, 4)).sum(axis=1) / 4 + downgraded_data_2 = original_data.reshape((-1, 4)).sum(axis=1) / 4 center = (0 * u.deg, 0 * u.deg) radius = 1 * u.deg @@ -81,23 +82,21 @@ def test_healpix_downgrade(healpix_map_path): hp.query_disc(downgraded_nside, [1, 0, 0], 1 * (np.pi / 180.0)) ) - assert len(original_data["tsz"]) == original_npix - assert len(downgraded_data["tsz"]) == downgraded_npix + assert len(original_data) == original_npix + assert len(downgraded_data) == downgraded_npix - assert len(data_region_original["tsz"].valid_pixels) == npix_region - assert len(data_region_downgraded["tsz"].valid_pixels) == npix_region_downgraded + assert len(data_region_original.valid_pixels) == npix_region + assert len(data_region_downgraded.valid_pixels) == npix_region_downgraded assert np.all( np.isclose( downgraded_data_2, - downgraded_data["tsz"], + downgraded_data, atol=1.0e-13, ) ) - assert np.isclose( - np.mean(original_data["tsz"]), np.mean(downgraded_data["tsz"]), atol=1.0e-13 - ) + assert np.isclose(np.mean(original_data), np.mean(downgraded_data), atol=1.0e-13) def test_healpix_downgrade_doesnt_have_file_handle(healpix_map_path): @@ -113,14 +112,14 @@ def test_healpix_downgrade_doesnt_have_file_handle(healpix_map_path): # First dataset backed by actual file, data should be cached dataset = next(iter(ds.values())) - cache = dataset._Dataset__state._DatasetState__cache + cache = dataset._Dataset__state.cache assert len(cache.columns) > 1 # New dataset entirely in-memory, so no cache output.get_data() downgraded_dataset = next(iter(output.values())) - cache = downgraded_dataset._Dataset__state._DatasetState__cache - handler = downgraded_dataset._Dataset__state._DatasetState__raw_data_handler + cache = downgraded_dataset._Dataset__state.cache + handler = downgraded_dataset._Dataset__state.raw_data_handler assert len(cache.columns) == len(dataset.columns) assert len(handler) == 0 @@ -223,7 +222,13 @@ def test_healpix_write_after_downgrade(healpix_map_path, tmp_path): oc.write(tmp_path / "map_test.hdf5", ds) new_ds = oc.open(tmp_path / "map_test.hdf5") - print(new_ds) + + original_data = ds.get_data("healpix") + written_data = new_ds.get_data("healpix") + + assert np.all(original_data["ksz"] == written_data["ksz"]) + assert np.all(original_data["tsz"] == written_data["tsz"]) + assert np.all(ds.pixels == new_ds.pixels) def test_healpix_write_after_take_range(healpix_map_path, tmp_path): @@ -258,9 +263,8 @@ def test_healpix_collection_drop(healpix_map_path): to_drop = set(["tsz"]) ds = ds.drop(to_drop) - columns_found = set(ds.get_data().keys()) - assert not columns_found.intersection(to_drop) + assert isinstance(ds.get_data(), HealSparseMap) def test_healpix_collection_take(healpix_map_path): @@ -269,10 +273,10 @@ def test_healpix_collection_take(healpix_map_path): ds_start = ds.take(n_to_take, "start") ds_end = ds.take(n_to_take, "end") ds_random = ds.take(n_to_take, "random") - tags = ds.select("tsz").get_data()["tsz"].valid_pixels - tags_start = ds_start.select("tsz").get_data()["tsz"].valid_pixels - tags_end = ds_end.select("tsz").get_data()["tsz"].valid_pixels - tags_random = ds_random.select("tsz").get_data()["tsz"].valid_pixels + tags = ds.select("tsz").get_data().valid_pixels + tags_start = ds_start.select("tsz").get_data().valid_pixels + tags_end = ds_end.select("tsz").get_data().valid_pixels + tags_random = ds_random.select("tsz").get_data().valid_pixels assert np.all(tags[:n_to_take] == tags_start) assert np.all(tags[-n_to_take:] == tags_end) assert len(tags_random) == n_to_take and len(set(tags_random)) == len(tags_random) @@ -284,8 +288,8 @@ def test_healpix_collection_range(healpix_map_path): end = int(0.75 * len(ds)) ds_range = ds.take_range(start, end) - halo_tags = ds.select("tsz").get_data("healsparse")["tsz"].valid_pixels[start:end] - range_halo_tags = ds_range.select("tsz").get_data("healsparse")["tsz"].valid_pixels + halo_tags = ds.select("tsz").get_data("healsparse").valid_pixels[start:end] + range_halo_tags = ds_range.select("tsz").get_data("healsparse").valid_pixels assert np.all(halo_tags == range_halo_tags) @@ -301,6 +305,14 @@ def test_healpix_collection_select(healpix_map_path): assert columns_found == to_select +def test_healpix_collection_take_healpix(healpix_map_path): + ds = oc.open(healpix_map_path) + ds = ds.take_range(500, 1000) + data = ds.get_data("healpix") + for value in data.values(): + assert np.all(np.where(value.mask)[0] == np.arange(500, 1000)) + + def test_healpix_collection_select_healsparse(healpix_map_path): ds = oc.open(healpix_map_path) to_select = set(["tsz", "ksz"]) @@ -325,8 +337,14 @@ def test_healpix_collection_derive(healpix_map_path): ds = oc.open(healpix_map_path) sz_sqrd = oc.col("tsz") ** 2 + oc.col("ksz") ** 2 ds = ds.with_new_columns(weird_sz=sz_sqrd) - weird = ds.select("weird_sz").get_data() - assert isinstance(weird, dict) + data = ds.get_data() + assert isinstance(data, dict) + assert np.all(data["weird_sz"].valid_pixels == data["tsz"].valid_pixels) + + tsz = data["tsz"][data["tsz"].valid_pixels] + ksz = data["ksz"][data["ksz"].valid_pixels] + weird_sz = data["weird_sz"][data["weird_sz"].valid_pixels] + assert np.all(weird_sz == tsz**2 + ksz**2) def test_healpix_collection_add(healpix_map_path): @@ -334,7 +352,7 @@ def test_healpix_collection_add(healpix_map_path): map_data = ds.get_data("healsparse")["tsz"].valid_pixels data = np.zeros(len(map_data)) ds = ds.with_new_columns(random=data) - stored_data = ds.select("random").get_data("healpix")["random"] + stored_data = ds.select("random").get_data("healpix") assert np.all(stored_data == data) @@ -343,9 +361,7 @@ def test_healpix_collection_add_sparse(healpix_map_path): map_data = ds.get_data("healsparse")["tsz"].valid_pixels data = np.zeros(len(map_data)) ds = ds.with_new_columns(random=data) - stored_data = ( - ds.select("random").get_data("healsparse")["random"].get_values_pix(map_data) - ) + stored_data = ds.select("random").get_data("healsparse").get_values_pix(map_data) assert np.all(stored_data == data) @@ -367,9 +383,9 @@ def offset(tsz, ksz): ds_vec = ds.evaluate(offset, vectorize=True, insert=True) ds_iter = ds.evaluate(offset, insert=True) - offset_vec = ds_vec.select("offset").get_data("healsparse")["offset"] + offset_vec = ds_vec.select("offset").get_data("healsparse") offset_vec = offset_vec.get_values_pix(offset_vec.valid_pixels) - offset_iter = ds_iter.select("offset").get_data("healsparse")["offset"] + offset_iter = ds_iter.select("offset").get_data("healsparse") offset_iter = offset_iter.get_values_pix(offset_iter.valid_pixels) assert np.all(offset_vec == offset_iter) diff --git a/test/test_lightcone.py b/test/test_lightcone.py index b4efc42d..97c5fc33 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -1,4 +1,5 @@ import astropy.units as u +import healpy as hp import numpy as np import pytest from astropy.cosmology import units as cu @@ -19,7 +20,7 @@ def haloproperties_601_path(lightcone_path): @pytest.fixture def all_files(): - return ["haloparticles.hdf5", "haloproperties.hdf5", "sodpropertybins.hdf5"] + return ["haloparticles.hdf5", "haloproperties.hdf5", "haloprofiles.hdf5"] @pytest.fixture @@ -40,8 +41,9 @@ def test_create_theta_phi_coords(haloproperties_600_path, haloproperties_601_pat ra = (data["phi"] * u.rad).to(u.deg) dec = ((np.pi / 2 - data["theta"]) * u.rad).to(u.deg) - assert np.allclose(data["ra"], ra, rtol=1e-2) - assert np.allclose(data["dec"], dec, rtol=1e-2) + + assert np.allclose(data["ra"], ra, atol=0.0001, rtol=1e-2) + assert np.allclose(data["dec"], dec, atol=0.0001, rtol=1e-2) def test_lightcone_physical_units(haloproperties_600_path): @@ -69,6 +71,37 @@ def test_lc_collection_restrict_z(haloproperties_600_path, haloproperties_601_pa assert np.sum(masked_redshifts) == len(redshifts) +def test_lc_collection_pixel_search(haloproperties_600_path, haloproperties_601_path): + ds = oc.open(haloproperties_601_path, haloproperties_600_path) + pixels = ds.get_pixels(64) + + all_coordinates = ds.select("theta", "phi").get_data("numpy") + all_pixels = hp.ang2pix( + 64, all_coordinates["theta"], all_coordinates["phi"], nest=True + ) + + assert np.all(np.unique(all_pixels) == pixels) + + pixels_to_search = np.sort(np.random.choice(pixels, 20, replace=False)) + ds_bound = ds.pixel_search(pixels_to_search) + + bound_coordinates = ds_bound.select("ra", "dec").get_data("numpy") + bound_pixels = np.unique( + hp.ang2pix( + 64, + bound_coordinates["ra"], + bound_coordinates["dec"], + lonlat=True, + nest=True, + ) + ) + assert np.all(pixels_to_search == bound_pixels) + expected_index = np.isin(all_pixels, pixels_to_search) + found_halo_tags = ds_bound.select("fof_halo_tag").get_data() + expected_halo_tags = ds.select("fof_halo_tag").get_data()[expected_index] + assert np.all(found_halo_tags == expected_halo_tags) + + def test_lc_collection_write( haloproperties_600_path, haloproperties_601_path, tmp_path ): @@ -179,7 +212,7 @@ def test_lc_collection_range( def test_lc_collection_take_rows(haloproperties_600_path, haloproperties_601_path): ds = oc.open(haloproperties_600_path, haloproperties_601_path) - n_to_take = int(0.75 * len(ds)) + n_to_take = int(0.25 * len(ds)) rows = np.random.choice(len(ds), n_to_take, replace=False) rows.sort() ds_rows = ds.take_rows(rows) @@ -197,19 +230,17 @@ def test_lc_collection_take_rows(haloproperties_600_path, haloproperties_601_pat assert np.all( data["fof_halo_mass"][sorted_index][rows] == sorted_tags["fof_halo_mass"] ) - toolkit_sorted_tags_mass = dict( - zip(sorted_tags["fof_halo_tag"], sorted_tags["fof_halo_mass"]) - ) - sorted_tags_mass = dict( - zip( - data["fof_halo_tag"][sorted_index][rows], - data["fof_halo_mass"][sorted_index][rows], - ) - ) - # Exact order is not deterministic, because many low_mass halos have the same mass, - # So we just make sure the tag->mass mapping is the same in the two datasets. + # Verify each returned (tag, mass) pair is internally consistent with the original + # dataset. We do not compare against a reference sort because sort stability + # determines which specific halo is selected among ties, and we don't require a + # particular choice — only that the returned tag actually belongs to a halo with + # the returned mass. + tag_order = np.argsort(data["fof_halo_tag"]) + all_tags_sorted = data["fof_halo_tag"][tag_order] + all_mass_by_tag = data["fof_halo_mass"][tag_order] - assert toolkit_sorted_tags_mass == sorted_tags_mass + positions = np.searchsorted(all_tags_sorted, sorted_tags["fof_halo_tag"]) + assert np.all(all_mass_by_tag[positions] == sorted_tags["fof_halo_mass"]) def test_lc_collection_derive( @@ -543,8 +574,3 @@ def test_lightcone_stacking_nostack( def test_lightcone_structure_collection_open(structure_600): c = oc.open(*structure_600) assert isinstance(c, oc.StructureCollection) - - -def test_lightcone_structure_collection_open_multiple(structure_600, structure_601): - with pytest.raises(NotImplementedError): - _ = oc.open(*structure_600, *structure_601) diff --git a/test/test_select.py b/test/test_select.py index 5f1eac2c..afb7f575 100644 --- a/test/test_select.py +++ b/test/test_select.py @@ -88,9 +88,7 @@ def test_select_doesnt_alter_raw(input_path): selected = dataset.select(selected_cols) selected_data = selected.get_data() - raw_data = ( - dataset._Dataset__state._DatasetState__raw_data_handler._Hdf5Handler__columns - ) + raw_data = dataset._Dataset__state.raw_data_handler._Hdf5Handler__columns assert all(isinstance(raw_data[col], h5py.Dataset) for col in cols) assert all(data[col].unit == selected_data[col].unit for col in selected_cols) assert not all(np.all(data[col].value == raw_data[col][:]) for col in selected_cols) @@ -131,3 +129,10 @@ def test_select_with_derived(input_path): assert com_columns.issubset(data.columns) assert np.all(data["gal_px"] == data["gal_mass"] * data["gal_com_vx"]) + + +def test_select_single_check_units(input_path): + ds = oc.open(input_path) + ds = ds.select("fof_halo_center_x", center_x=oc.col("fof_halo_center_x")) + data = ds.get_data() + assert np.all(data["fof_halo_center_x"] == data["center_x"]) diff --git a/test/test_structure_collection.py b/test/test_structure_collection.py new file mode 100644 index 00000000..a08866d8 --- /dev/null +++ b/test/test_structure_collection.py @@ -0,0 +1,202 @@ +import numpy as np +import pytest + +import opencosmo as oc + + +@pytest.fixture +def halos_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "haloproperties.hdf5" + particles = lightcone_path / "step_600" / "haloparticles.hdf5" + profiles = lightcone_path / "step_600" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_600_path(lightcone_path): + properties = lightcone_path / "step_600" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_600" / "galaxyparticles.hdf5" + return [properties, particles] + + +@pytest.fixture +def halos_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "haloproperties.hdf5" + particles = lightcone_path / "step_601" / "haloparticles.hdf5" + profiles = lightcone_path / "step_601" / "haloprofiles.hdf5" + return [properties, particles, profiles] + + +@pytest.fixture +def galaxies_601_path(lightcone_path): + properties = lightcone_path / "step_601" / "galaxyproperties.hdf5" + particles = lightcone_path / "step_601" / "galaxyparticles.hdf5" + return [properties, particles] + + +def verify_halo(halo): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + if "galaxy" not in halo: + return + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +def test_open_lightcone_structure(halos_600_path, halos_601_path): + ds = oc.open(*halos_600_path, *halos_601_path) + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = ( + halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + ) + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + + +def test_open_lightcone_galaxy_structure_collection( + galaxies_600_path, + galaxies_601_path, +): + ds = oc.open(*galaxies_600_path, *galaxies_601_path) + for galaxy in ds.take(100).galaxies(): + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +def test_open_lightcone_structure_with_galaxies( + halos_600_path, galaxies_600_path, halos_601_path, galaxies_601_path +): + ds = oc.open( + *halos_600_path, *galaxies_600_path, *halos_601_path, *galaxies_601_path + ) + ds = ds.filter(oc.col("sod_halo_mass") > 1e14).take(10) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10).halos(): + gravity_particle_tags = ( + halo["dm_particles"].select("fof_halo_tag").get_data("numpy") + ) + assert np.all(gravity_particle_tags == halo["halo_properties"]["fof_halo_tag"]) + halo_bin_tags = ( + halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy") + ) + assert np.all(halo_bin_tags == halo["halo_properties"]["fof_halo_tag"]) + for galaxy in halo["galaxies"].galaxies(): + assert ( + galaxy["galaxy_properties"]["fof_halo_tag"] + == halo["halo_properties"]["fof_halo_tag"] + ) + + if "star_particles" not in galaxy: + continue + tags = galaxy["star_particles"].select("gal_tag").get_data("numpy") + assert np.all(tags == galaxy["galaxy_properties"]["gal_tag"]) + + +def test_write_lightcone_structure(halos_600_path, halos_601_path, tmp_path): + ds = ( + oc.open(*halos_600_path, *halos_601_path) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + oc.write(tmp_path / "halos.hdf5", ds) + ds_new = oc.open(tmp_path / "halos.hdf5") + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_start + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_end + verify_halo(halo) + + +def test_write_lightcone_structure_with_galaxies( + halos_600_path, halos_601_path, galaxies_600_path, galaxies_601_path, tmp_path +): + ds = ( + oc.open( + *halos_600_path, *halos_601_path, *galaxies_600_path, *galaxies_601_path + ) + .filter(oc.col("fof_halo_mass") > 1e14) + .take(1000) + ) + halo_tags_start = set() + halo_tags_end = set() + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos(): + halo_tags_start.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + + for halo in ds.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos(): + halo_tags_end.add(halo["halo_properties"]["fof_halo_tag"]) + verify_halo(halo) + oc.write(tmp_path / "halos.hdf5", ds) + ds_new = oc.open(tmp_path / "halos.hdf5") + + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="start").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_start + verify_halo(halo) + for halo in ( + ds_new.filter(oc.col("sod_halo_mass") > 1e14).take(10, at="end").halos() + ): + assert halo["halo_properties"]["fof_halo_tag"] in halo_tags_end + verify_halo(halo) + + +def test_data_link_sort_write_lightcone(halos_600_path, halos_601_path, tmp_path): + collection = oc.open(*halos_600_path, *halos_601_path) + collection = collection.filter(oc.col("sod_halo_mass") > 10**14).sort_by( + "fof_halo_mass" + ) + output = tmp_path / "halos.hdf5" + oc.write(output, collection) + new_collection = oc.open(output).take(10) + assert np.all( + collection["halo_properties"].select("sod_halo_mass").get_data("numpy") > 10**14 + ) + for halo in new_collection.objects(("halo_profiles",)): + assert np.all( + halo["halo_properties"]["fof_halo_tag"] + == halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy")[0] + ) + + +def test_redshift_bound(halos_600_path, halos_601_path, tmp_path): + collection = oc.open(*halos_600_path, *halos_601_path) + collection = collection.with_redshift_range(0.038, 0.039) + + collection = collection.filter(oc.col("sod_halo_mass") > 10**14) + for halo in collection.halos(): + redshift = halo["halo_properties"]["redshift"] + assert redshift > 0.038 and redshift < 0.039 + verify_halo(halo) diff --git a/test/test_take.py b/test/test_take.py index f87a0bce..8ae08696 100644 --- a/test/test_take.py +++ b/test/test_take.py @@ -58,5 +58,29 @@ def test_take_chain(input_path): def test_take_too_many(input_path): ds = oc.open(input_path) length = len(ds.get_data()) - with pytest.raises(ValueError): - ds.take(length + 1) + + new_ds = ds.take(length + 1) + assert len(new_ds) == len(ds) + + +def test_take_end_too_many(input_path): + ds = oc.open(input_path) + length = len(ds) + + new_ds = ds.take(length + 1, at="end") + assert len(new_ds) == length + + +def test_take_end_sorted(input_path): + ds = oc.open(input_path) + cols = ds.columns + sort_col = cols[0] + n = 10 + + all_values = ds.select(sort_col).get_data("numpy") + threshold = np.sort(all_values)[-n] + + taken = ds.sort_by(sort_col).take(n, at="end").select(sort_col).get_data("numpy") + + assert len(taken) == n + assert np.all(taken >= threshold) diff --git a/uv.lock b/uv.lock index 123aae5f..53519f59 100644 --- a/uv.lock +++ b/uv.lock @@ -1,15 +1,19 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12, <3.15" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", - "python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version < '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", - "(python_full_version < '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version < '3.13' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", ] [[package]] @@ -557,6 +561,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "jax" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/49/b082387119c4a6bc7596296bbdc6bce034628cdd2845ebb27304cbca3624/jax-0.10.1.tar.gz", hash = "sha256:11672410faf8752429eb9a131de203dc488a2a3a012d509baa2b39878008810d", size = 2718178, upload-time = "2026-05-20T14:54:09.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl", hash = "sha256:47f3192c76e9e3358de1b106a8af5e943fccb10510903f25d96ea53652729134", size = 3150973, upload-time = "2026-05-20T14:51:30.066Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/b4/fdb6e989b142d8a8d2093f342cbc5323fe0d4a7217fd899c8ddf9e108a5a/jaxlib-0.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7a4399c7429c87ee6f7ab1e5712c5548ecfabde974ee9f4a12957fea4e35efb8", size = 60816137, upload-time = "2026-05-20T14:52:53.028Z" }, + { url = "https://files.pythonhosted.org/packages/a7/29/0b4eaaca005708751ff301903a2ca760dcc34175dadb7536498c57a7de85/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:12126603ba472300c62480f86d20972d563143041039d9dea349ad510aa6123c", size = 80294034, upload-time = "2026-05-20T14:52:56.941Z" }, + { url = "https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:f3cdf5b7f48470ab5455ab79aab746419694ccb6b52651cc2ce5fb27def03588", size = 85828774, upload-time = "2026-05-20T14:53:01.749Z" }, + { url = "https://files.pythonhosted.org/packages/ec/8f/993ea419eca6f34fe12613e22a03b93f40e5b1e8e0df18d4060e1313a1fc/jaxlib-0.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:0acf3f8e7dca9074c0327f0f61502845792ca9f82fab23b841b00daa78e85488", size = 64830187, upload-time = "2026-05-20T14:53:05.932Z" }, + { url = "https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4167213aa00f14bb0d8fbd90f9ded75e976f71ce8baf8c3c44e04c8fb80ea0c1", size = 60815855, upload-time = "2026-05-20T14:53:11.718Z" }, + { url = "https://files.pythonhosted.org/packages/5b/76/9a1971bc9edb8728a7ba86d2693127ee46add9811230b8452321415fd4e9/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:e53223d8f6861d33c02dd02343fa464700401ff5363784d37c33407218e328b8", size = 80293947, upload-time = "2026-05-20T14:53:15.408Z" }, + { url = "https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:bb073a1224e659e01e8d32d47c000edb52ec2aa8ba97ec22b2228b3a46e5c167", size = 85829861, upload-time = "2026-05-20T14:53:19.773Z" }, + { url = "https://files.pythonhosted.org/packages/a9/df/48659e2ee57705c63a51525f810fe3e0c87af4ca9f89d4738281a872d58e/jaxlib-0.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:6449f1d4a22324f5f02c843360475783f9fd1d353fe711806cbf4e927d1360ae", size = 64828863, upload-time = "2026-05-20T14:53:24.562Z" }, + { url = "https://files.pythonhosted.org/packages/c7/34/3f7c95ee1b2555d611f836988a49b522c04b8d186e0528f91d45118089bf/jaxlib-0.10.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:72a7db1242d6b7773340e430c244e73a8d13f82ce83933c0529a8b4b1520dcf3", size = 60934063, upload-time = "2026-05-20T14:53:28.549Z" }, + { url = "https://files.pythonhosted.org/packages/3a/84/855a2395d299e00e1d9ffd0aecd516b52b81780f3e4ef537527be6d8c1fc/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:649c26ca92e9bbffb3c35226b37893916f20e6b89fb3911ce79b39e5cfb27b46", size = 80402500, upload-time = "2026-05-20T14:53:32.444Z" }, + { url = "https://files.pythonhosted.org/packages/be/35/153e91a9c770a981d525d845b3f4cdb71a1e119681594a33908e9536bdff/jaxlib-0.10.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:cfe75d8a17e0d33a7bed27f32d7a5344a66e8d4af7073973f396e14ff4a9c503", size = 85941210, upload-time = "2026-05-20T14:53:36.982Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a1/c4d4c0530313c50dd1ba07fff480cfd0c5f18c5ec49742f4a52a6edfd95f/jaxlib-0.10.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9a657554dbe56f3691d377cb99d00b29df14e5638e579e57965f9f69c32c9315", size = 60826189, upload-time = "2026-05-20T14:53:40.73Z" }, + { url = "https://files.pythonhosted.org/packages/98/71/529b2439b88491e0806ee9fa6191ecf90f4447dfe092e80bd19577c85260/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:93cd9989404f86a21b50c7ff4e850a31f55509664312080978bde97a664d92a9", size = 80300654, upload-time = "2026-05-20T14:53:44.751Z" }, + { url = "https://files.pythonhosted.org/packages/be/08/0bc4132fb2fe224ee9d83fd60e3650fd4d893b4e5148707a98df8f0333a4/jaxlib-0.10.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:26b94e9640b01968cc14b8353dd6b6540d723f30579c78b1f46a477fc4aa196d", size = 85841241, upload-time = "2026-05-20T14:53:49.13Z" }, + { url = "https://files.pythonhosted.org/packages/5c/de/423d748ce3367bd5ea20d8cc34a7ceb6420da4d41c20b247f54194700d04/jaxlib-0.10.1-cp314-cp314-win_amd64.whl", hash = "sha256:375820799bbf7d515dd4e4d40f3334566b73d3fe64d340afbd6aa897d5d7c486", size = 67301520, upload-time = "2026-05-20T14:53:52.989Z" }, + { url = "https://files.pythonhosted.org/packages/2d/bf/3f7ce089d62f7ac85ea678925471f7ec88038899e67ab02079c1e7d8ad4e/jaxlib-0.10.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:bc52f16bdd61b299efeed7ea8e743e91a118059a463da2ab97196fc05b69ddb7", size = 60935393, upload-time = "2026-05-20T14:53:56.886Z" }, + { url = "https://files.pythonhosted.org/packages/65/6a/38cf1d4ff8c8f74ec8d567e0aa6e3d2082ab6fc580545f6bb51368da20a9/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:55b0a473fbd57d31dc3935c4cd5c0c38af5f7c1f41300df27923cf46676972ca", size = 80405741, upload-time = "2026-05-20T14:54:01.94Z" }, + { url = "https://files.pythonhosted.org/packages/16/9e/d3cff171aaf13a09aab26a44d1a27dcbf0d6311e4d855f5d99685965ace3/jaxlib-0.10.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:4b0cb8ef960a3723037db63f28ffa20083d90fff6d30085e99d7c63cfa08e4c0", size = 85942183, upload-time = "2026-05-20T14:54:05.91Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -715,30 +765,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/c8/d148e041732d631fc76036f8b30fae4e77b027a1e95b7a84bb522481a940/librt-0.8.1-cp314-cp314t-win_arm64.whl", hash = "sha256:bf512a71a23504ed08103a13c941f763db13fb11177beb3d9244c98c29fb4a61", size = 48755, upload-time = "2026-02-17T16:12:47.943Z" }, ] -[[package]] -name = "llvmlite" -version = "0.47.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/88/a8952b6d5c21e74cbf158515b779666f692846502623e9e3c39d8e8ba25f/llvmlite-0.47.0.tar.gz", hash = "sha256:62031ce968ec74e95092184d4b0e857e444f8fdff0b8f9213707699570c33ccc", size = 193614, upload-time = "2026-03-31T18:29:53.497Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/48/4b7fe0e34c169fa2f12532916133e0b219d2823b540733651b34fdac509a/llvmlite-0.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:306a265f408c259067257a732c8e159284334018b4083a9e35f67d19792b164f", size = 37232769, upload-time = "2026-03-31T18:28:43.735Z" }, - { url = "https://files.pythonhosted.org/packages/e6/4b/e3f2cd17822cf772a4a51a0a8080b0032e6d37b2dbe8cfb724eac4e31c52/llvmlite-0.47.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5853bf26160857c0c2573415ff4efe01c4c651e59e2c55c2a088740acfee51cd", size = 56275178, upload-time = "2026-03-31T18:28:48.342Z" }, - { url = "https://files.pythonhosted.org/packages/b6/55/a3b4a543185305a9bdf3d9759d53646ed96e55e7dfd43f53e7a421b8fbae/llvmlite-0.47.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:003bcf7fa579e14db59c1a1e113f93ab8a06b56a4be31c7f08264d1d4072d077", size = 55128632, upload-time = "2026-03-31T18:28:52.901Z" }, - { url = "https://files.pythonhosted.org/packages/2f/f5/d281ae0f79378a5a91f308ea9fdb9f9cc068fddd09629edc0725a5a8fde1/llvmlite-0.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:f3079f25bdc24cd9d27c4b2b5e68f5f60c4fdb7e8ad5ee2b9b006007558f9df7", size = 38138692, upload-time = "2026-03-31T18:28:57.147Z" }, - { url = "https://files.pythonhosted.org/packages/77/6f/4615353e016799f80fa52ccb270a843c413b22361fadda2589b2922fb9b0/llvmlite-0.47.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a3c6a735d4e1041808434f9d440faa3d78d9b4af2ee64d05a66f351883b6ceec", size = 37232771, upload-time = "2026-03-31T18:29:01.324Z" }, - { url = "https://files.pythonhosted.org/packages/31/b8/69f5565f1a280d032525878a86511eebed0645818492feeb169dfb20ae8e/llvmlite-0.47.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2699a74321189e812d476a43d6d7f652f51811e7b5aad9d9bba842a1c7927acb", size = 56275178, upload-time = "2026-03-31T18:29:05.748Z" }, - { url = "https://files.pythonhosted.org/packages/d6/da/b32cafcb926fb0ce2aa25553bf32cb8764af31438f40e2481df08884c947/llvmlite-0.47.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c6951e2b29930227963e53ee152441f0e14be92e9d4231852102d986c761e40", size = 55128632, upload-time = "2026-03-31T18:29:11.235Z" }, - { url = "https://files.pythonhosted.org/packages/46/9f/4898b44e4042c60fafcb1162dfb7014f6f15b1ec19bf29cfea6bf26df90d/llvmlite-0.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2e9adf8698d813a9a5efb2d4370caf344dbc1e145019851fee6a6f319ba760e", size = 38138695, upload-time = "2026-03-31T18:29:15.43Z" }, - { url = "https://files.pythonhosted.org/packages/1c/d4/33c8af00f0bf6f552d74f3a054f648af2c5bc6bece97972f3bfadce4f5ec/llvmlite-0.47.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:de966c626c35c9dff5ae7bf12db25637738d0df83fc370cf793bc94d43d92d14", size = 37232773, upload-time = "2026-03-31T18:29:19.453Z" }, - { url = "https://files.pythonhosted.org/packages/64/1d/a760e993e0c0ba6db38d46b9f48f6c7dceb8ac838824997fb9e25f97bc04/llvmlite-0.47.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ddbccff2aeaff8670368340a158abefc032fe9b3ccf7d9c496639263d00151aa", size = 56275176, upload-time = "2026-03-31T18:29:24.149Z" }, - { url = "https://files.pythonhosted.org/packages/84/3b/e679bc3b29127182a7f4aa2d2e9e5bea42adb93fb840484147d59c236299/llvmlite-0.47.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4a7b778a2e144fc64468fb9bf509ac1226c9813a00b4d7afea5d988c4e22fca", size = 55128631, upload-time = "2026-03-31T18:29:29.536Z" }, - { url = "https://files.pythonhosted.org/packages/be/f7/19e2a09c62809c9e63bbd14ce71fb92c6ff7b7b3045741bb00c781efc3c9/llvmlite-0.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:694e3c2cdc472ed2bd8bd4555ca002eec4310961dd58ef791d508f57b5cc4c94", size = 39153826, upload-time = "2026-03-31T18:29:33.681Z" }, - { url = "https://files.pythonhosted.org/packages/40/a1/581a8c707b5e80efdbbe1dd94527404d33fe50bceb71f39d5a7e11bd57b7/llvmlite-0.47.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:92ec8a169a20b473c1c54d4695e371bde36489fc1efa3688e11e99beba0abf9c", size = 37232772, upload-time = "2026-03-31T18:29:37.952Z" }, - { url = "https://files.pythonhosted.org/packages/11/03/16090dd6f74ba2b8b922276047f15962fbeea0a75d5601607edb301ba945/llvmlite-0.47.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa1cbd800edd3b20bc141521f7fd45a6185a5b84109aa6855134e81397ffe72b", size = 56275178, upload-time = "2026-03-31T18:29:42.58Z" }, - { url = "https://files.pythonhosted.org/packages/f5/cb/0abf1dd4c5286a95ffe0c1d8c67aec06b515894a0dd2ac97f5e27b82ab0b/llvmlite-0.47.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6725179b89f03b17dabe236ff3422cb8291b4c1bf40af152826dfd34e350ae8", size = 55128632, upload-time = "2026-03-31T18:29:46.939Z" }, - { url = "https://files.pythonhosted.org/packages/4f/79/d3bbab197e86e0ff4f9c07122895b66a3e0d024247fcff7f12c473cb36d9/llvmlite-0.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6842cf6f707ec4be3d985a385ad03f72b2d724439e118fcbe99b2929964f0453", size = 39153839, upload-time = "2026-03-31T18:29:51.004Z" }, -] - [[package]] name = "markupsafe" version = "3.0.3" @@ -856,6 +882,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/49/d651878698a0b67f23aa28e17f45a6d6dd3d3f933fa29087fa4ce5947b5a/matplotlib-3.10.8-cp314-cp314t-win_arm64.whl", hash = "sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f", size = 8192560, upload-time = "2025-12-10T22:56:38.008Z" }, ] +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" }, + { url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, + { url = "https://files.pythonhosted.org/packages/f5/f0/0cfadd537c5470378b1b32bd859cf2824972174b51b873c9d95cfd7475a5/ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7", size = 212222, upload-time = "2025-11-17T22:31:53.742Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/9acc86985bfad8f2c2d30291b27cd2bb4c74cea08695bd540906ed744249/ml_dtypes-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460", size = 160793, upload-time = "2025-11-17T22:31:55.358Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" }, + { url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" }, + { url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" }, + { url = "https://files.pythonhosted.org/packages/e1/8b/200088c6859d8221454825959df35b5244fa9bdf263fd0249ac5fb75e281/ml_dtypes-0.5.4-cp313-cp313-win_amd64.whl", hash = "sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328", size = 212224, upload-time = "2025-11-17T22:32:01.349Z" }, + { url = "https://files.pythonhosted.org/packages/8f/75/dfc3775cb36367816e678f69a7843f6f03bd4e2bcd79941e01ea960a068e/ml_dtypes-0.5.4-cp313-cp313-win_arm64.whl", hash = "sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175", size = 160798, upload-time = "2025-11-17T22:32:02.864Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" }, + { url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" }, + { url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" }, + { url = "https://files.pythonhosted.org/packages/8c/27/12607423d0a9c6bbbcc780ad19f1f6baa2b68b18ce4bddcdc122c4c68dc9/ml_dtypes-0.5.4-cp313-cp313t-win_amd64.whl", hash = "sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6", size = 225612, upload-time = "2025-11-17T22:32:08.615Z" }, + { url = "https://files.pythonhosted.org/packages/e5/80/5a5929e92c72936d5b19872c5fb8fc09327c1da67b3b68c6a13139e77e20/ml_dtypes-0.5.4-cp313-cp313t-win_arm64.whl", hash = "sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1", size = 164145, upload-time = "2025-11-17T22:32:09.782Z" }, + { url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" }, + { url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" }, + { url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" }, + { url = "https://files.pythonhosted.org/packages/e9/93/2bfed22d2498c468f6bcd0d9f56b033eaa19f33320389314c19ef6766413/ml_dtypes-0.5.4-cp314-cp314-win_amd64.whl", hash = "sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56", size = 221032, upload-time = "2025-11-17T22:32:15.763Z" }, + { url = "https://files.pythonhosted.org/packages/76/a3/9c912fe6ea747bb10fe2f8f54d027eb265db05dfb0c6335e3e063e74e6e8/ml_dtypes-0.5.4-cp314-cp314-win_arm64.whl", hash = "sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049", size = 163353, upload-time = "2025-11-17T22:32:16.932Z" }, + { url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" }, + { url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" }, + { url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" }, + { url = "https://files.pythonhosted.org/packages/84/44/f4d18446eacb20ea11e82f133ea8f86e2bf2891785b67d9da8d0ab0ef525/ml_dtypes-0.5.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1", size = 236612, upload-time = "2025-11-17T22:32:22.579Z" }, + { url = "https://files.pythonhosted.org/packages/ad/3f/3d42e9a78fe5edf792a83c074b13b9b770092a4fbf3462872f4303135f09/ml_dtypes-0.5.4-cp314-cp314t-win_arm64.whl", hash = "sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d", size = 168825, upload-time = "2025-11-17T22:32:23.766Z" }, +] + [[package]] name = "more-itertools" version = "11.0.1" @@ -986,34 +1048,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, ] -[[package]] -name = "numba" -version = "0.65.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "llvmlite" }, - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/49/61/7299643b9c18d669e04be7c5bcb64d985070d07553274817b45b049e7bfe/numba-0.65.0.tar.gz", hash = "sha256:edad0d9f6682e93624c00125a471ae4df186175d71fd604c983c377cdc03e68b", size = 2764131, upload-time = "2026-04-01T03:52:01.946Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6c/2f/8bd31a1ea43c01ac215283d83aa5f8d5acbe7a36c85b82f1757bfe9ccb31/numba-0.65.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b27ee4847e1bfb17e9604d100417ee7c1d10f15a6711c6213404b3da13a0b2aa", size = 2680705, upload-time = "2026-04-01T03:51:32.597Z" }, - { url = "https://files.pythonhosted.org/packages/73/36/88406bd58600cc696417b8e5dd6a056478da808f3eaf48d18e2421e0c2d9/numba-0.65.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a52d92ffd297c10364bce60cd1fcb88f99284ab5df085f2c6bcd1cb33b529a6f", size = 3801411, upload-time = "2026-04-01T03:51:34.321Z" }, - { url = "https://files.pythonhosted.org/packages/0c/61/ce753a1d7646dd477e16d15e89473703faebb8995d2f71d7ad69a540b565/numba-0.65.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da8e371e328c06d0010c3d8b44b21858652831b85bcfba78cb22c042e22dbd8e", size = 3501622, upload-time = "2026-04-01T03:51:36.348Z" }, - { url = "https://files.pythonhosted.org/packages/7d/86/db87a5393f1b1fabef53ac3ba4e6b938bb27e40a04ad7cc512098fcae032/numba-0.65.0-cp312-cp312-win_amd64.whl", hash = "sha256:59bb9f2bb9f1238dfd8e927ba50645c18ae769fef4f3d58ea0ea22a2683b91f5", size = 2749979, upload-time = "2026-04-01T03:51:37.88Z" }, - { url = "https://files.pythonhosted.org/packages/8b/f8/eee0f1ff456218db036bfc9023995ec1f85a9dc8f2422f1594f6a87829e0/numba-0.65.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c6334094563a456a695c812e6846288376ca02327cf246cdcc83e1bb27862367", size = 2680679, upload-time = "2026-04-01T03:51:39.491Z" }, - { url = "https://files.pythonhosted.org/packages/1b/8f/3d116e4b8e92f6abace431afa4b2b944f4d65bdee83af886f5c4b263df95/numba-0.65.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b8a9008411615c69d083d1dcf477f75a5aa727b30beb16e139799e2be945cdfd", size = 3809537, upload-time = "2026-04-01T03:51:41.42Z" }, - { url = "https://files.pythonhosted.org/packages/b5/2c/6a3ca4128e253cb67affe06deb47688f51ce968f5111e2a06d010e6f1fa6/numba-0.65.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af96c0cba53664efcb361528b8c75e011a6556c859c7e08424c2715201c6cf7a", size = 3508615, upload-time = "2026-04-01T03:51:43.444Z" }, - { url = "https://files.pythonhosted.org/packages/96/0e/267f9a36fb282c104a971d7eecb685b411c47dce2a740fe69cf5fc2945d9/numba-0.65.0-cp313-cp313-win_amd64.whl", hash = "sha256:6254e73b9c929dc736a1fbd3d6f5680789709a5067cae1fa7198707385129c04", size = 2749938, upload-time = "2026-04-01T03:51:45.218Z" }, - { url = "https://files.pythonhosted.org/packages/56/a4/90edb01e9176053578e343d7a7276bc28356741ee67059aed8ed2c1a4e59/numba-0.65.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:ee336b398a6fca51b1f626034de99f50cb1bd87d537a166275158a3cee744b82", size = 2680878, upload-time = "2026-04-01T03:51:46.91Z" }, - { url = "https://files.pythonhosted.org/packages/24/8d/e12d6ff4b9119db3cbf7b2db1ce257576441bd3c76388c786dea74f20b02/numba-0.65.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:05c0a9fdf75d85f57dee47b719e8d6415707b80aae45d75f63f9dc1b935c29f7", size = 3778456, upload-time = "2026-04-01T03:51:48.552Z" }, - { url = "https://files.pythonhosted.org/packages/17/89/abcd83e76f6a773276fe76244140671bcc5bf820f6e2ae1a15362ae4c8c9/numba-0.65.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:583680e0e8faf124d362df23b4b593f3221a8996341a63d1b664c122401bec2f", size = 3478464, upload-time = "2026-04-01T03:51:50.527Z" }, - { url = "https://files.pythonhosted.org/packages/73/5b/fbce55ce3d933afbc7ade04df826853e4a846aaa47d58d2fbb669b8f2d08/numba-0.65.0-cp314-cp314-win_amd64.whl", hash = "sha256:add297d3e1c08dd884f44100152612fa41e66a51d15fdf91307f9dde31d06830", size = 2752012, upload-time = "2026-04-01T03:51:52.691Z" }, - { url = "https://files.pythonhosted.org/packages/1e/ab/af705f4257d9388fb2fd6d7416573e98b6ca9c786e8b58f02720978557bd/numba-0.65.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:194a243ba53a9157c8538cbb3166ec015d785a8c5d584d06cdd88bee902233c7", size = 2683961, upload-time = "2026-04-01T03:51:54.281Z" }, - { url = "https://files.pythonhosted.org/packages/ff/e5/8267b0adb0c01b52b553df5062fbbb42c30ed5362d08b85cc913a36f838f/numba-0.65.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c7fa502960f7a2f3f5cb025bc7bff888a3551277b92431bfdc5ba2f11a375749", size = 3816373, upload-time = "2026-04-01T03:51:56.18Z" }, - { url = "https://files.pythonhosted.org/packages/b0/f5/b8397ca360971669a93706b9274592b6864e4367a37d498fbbcb62aa2d48/numba-0.65.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5046c63f783ca3eb6195f826a50797465e7c4ce811daa17c9bea47e310c9b964", size = 3532782, upload-time = "2026-04-01T03:51:58.387Z" }, - { url = "https://files.pythonhosted.org/packages/f5/21/1e73fa16bf0393ebb74c5bb208d712152ffdfc84600a8e93a3180317856e/numba-0.65.0-cp314-cp314t-win_amd64.whl", hash = "sha256:46fd679ae4f68c7a5d5721efbd29ecee0b0f3013211591891d79b51bfdf73113", size = 2757611, upload-time = "2026-04-01T03:52:00.083Z" }, -] - [[package]] name = "numpy" version = "2.4.4" @@ -1077,7 +1111,7 @@ wheels = [ [[package]] name = "opencosmo" -version = "1.2.5" +version = "1.2.6" source = { editable = "." } dependencies = [ { name = "astropy" }, @@ -1087,7 +1121,6 @@ dependencies = [ { name = "hdf5plugin" }, { name = "healpy" }, { name = "healsparse" }, - { name = "numba" }, { name = "numpy" }, { name = "pydantic" }, { name = "rustworkx" }, @@ -1100,6 +1133,7 @@ io = [ [package.dev-dependencies] dev = [ + { name = "jax" }, { name = "mypy" }, { name = "pip" }, { name = "pre-commit" }, @@ -1138,7 +1172,6 @@ requires-dist = [ { name = "hdf5plugin", specifier = ">=5.0.0,!=5.1,<7.0" }, { name = "healpy", specifier = ">=1.19.0,<2.0.0" }, { name = "healsparse", specifier = ">=1.11,<2.0" }, - { name = "numba", specifier = ">=0.64.0,<1.0" }, { name = "numpy", specifier = ">=2.0,<2.5" }, { name = "pyarrow", marker = "extra == 'io'", specifier = ">=21.0.0" }, { name = "pydantic", specifier = ">=2.10.6,<3.0.0" }, @@ -1148,6 +1181,7 @@ provides-extras = ["io"] [package.metadata.requires-dev] dev = [ + { name = "jax", specifier = ">=0.10.1" }, { name = "mypy", specifier = ">=1.15,<2.0.0" }, { name = "pip", specifier = ">=25.1.1" }, { name = "pre-commit", specifier = ">=4.2.0,<5.0.0" }, @@ -1173,6 +1207,15 @@ test = [ ] test-mpi = [{ name = "mpi-pytest", specifier = ">=2025.4.0,<2026.0.0" }] +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" }, +] + [[package]] name = "packaging" version = "26.0"