diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d4b97e8..dcf32247 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,11 @@ jobs: - 'packages/essreduce/**' - 'pyproject.toml' - 'pixi.lock' + essnmx: + - 'packages/essnmx/**' + - 'packages/essreduce/**' + - 'pyproject.toml' + - 'pixi.lock' formatting: name: Formatting and static analysis diff --git a/README.md b/README.md index d8936b56..63e3c3be 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ Monorepo for ESS neutron scattering data reduction packages, managed with [pixi] |---------|-------------| | [essreduce](packages/essreduce/) | Common data reduction tools (core) | | [essimaging](packages/essimaging/) | Neutron imaging (ODIN, TBL, YMIR) | +| [essnmx](packages/essnmx/) | Data reduction for NMX at the European Spallation Source. | ## Dependency graph diff --git a/packages/essnmx/.copier-answers.ess.yml b/packages/essnmx/.copier-answers.ess.yml new file mode 100644 index 00000000..6425ea7c --- /dev/null +++ b/packages/essnmx/.copier-answers.ess.yml @@ -0,0 +1,3 @@ +# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY +_commit: 34ca4ba +_src_path: https://github.com/scipp/ess_template diff --git a/packages/essnmx/.copier-answers.yml b/packages/essnmx/.copier-answers.yml new file mode 100644 index 00000000..aecca8d4 --- /dev/null +++ b/packages/essnmx/.copier-answers.yml @@ -0,0 +1,13 @@ +# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY +_commit: 0dae45f +_src_path: gh:scipp/copier_template +description: Data reduction for NMX at the European Spallation Source. +max_python: '3.13' +min_python: '3.11' +namespace_package: ess +nightly_deps: scipp,sciline,scippnexus,plopp +orgname: scipp +prettyname: ESSnmx +projectname: essnmx +related_projects: Scipp,Sciline,Plopp,ScippNexus +year: 2025 diff --git a/packages/essnmx/.gitignore b/packages/essnmx/.gitignore new file mode 100644 index 00000000..c6e47900 --- /dev/null +++ b/packages/essnmx/.gitignore @@ -0,0 +1,49 @@ +# Build artifacts +build +dist +html +.tox +*.egg-info +# we lock dependencies with pip-compile, not uv +uv.lock + +*.sw? + +# Environments +venv +.venv + +# Caches +*.DS_Store +.clangd/ +*.ipynb_checkpoints +__pycache__/ +.vs/ +.virtual_documents +.hypothesis +.pytest_cache +.mypy_cache +docs/generated/ +.ruff_cache + +# Editor settings +.idea/ +.vscode/ + +# Data files +*.data +*.dat +*.csv +*.xye +*.h5 +*.hdf5 +*.hdf +*.nxs +*.raw +*.cif +*.rcif +*.ort +*.zip +*.sqw +*.nxspe +*.mtz diff --git a/packages/essnmx/.python-version b/packages/essnmx/.python-version new file mode 100644 index 00000000..2c073331 --- /dev/null +++ b/packages/essnmx/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/packages/essnmx/CODE_OF_CONDUCT.md b/packages/essnmx/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..1c3746e4 --- /dev/null +++ b/packages/essnmx/CODE_OF_CONDUCT.md @@ -0,0 +1,134 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +scipp[at]ess.eu. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations + diff --git a/packages/essnmx/CONTRIBUTING.md b/packages/essnmx/CONTRIBUTING.md new file mode 100644 index 00000000..753fb793 --- /dev/null +++ b/packages/essnmx/CONTRIBUTING.md @@ -0,0 +1,20 @@ +## Contributing to ESSnmx + +Welcome to the developer side of ESSnmx! + +Contributions are always welcome. +This includes reporting bugs or other issues, submitting pull requests, requesting new features, etc. + +If you need help with using ESSnmx or contributing to it, have a look at the GitHub [discussions](https://github.com/scipp/essnmx/discussions) and start a new [Q&A discussion](https://github.com/scipp/essnmx/discussions/categories/q-a) if you can't find what you are looking for. + +For bug reports and other problems, please open an [issue](https://github.com/scipp/essnmx/issues/new) in GitHub. + +You are welcome to submit pull requests at any time. +But to avoid having to make large modifications during review or even have your PR rejected, please first open an issue first to discuss your idea! + +Check out the subsections of the [Developer documentation](https://scipp.github.io/essnmx/developer/index.html) for details on how ESSnmx is developed. + +## Code of conduct + +This project is a community effort, and everyone is welcome to contribute. +Everyone within the community is expected to abide by our [code of conduct](https://github.com/scipp/essnmx/blob/main/CODE_OF_CONDUCT.md). diff --git a/packages/essnmx/LICENSE b/packages/essnmx/LICENSE new file mode 100644 index 00000000..7d62083d --- /dev/null +++ b/packages/essnmx/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2025, Scipp contributors (https://github.com/scipp) +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/packages/essnmx/MANIFEST.in b/packages/essnmx/MANIFEST.in new file mode 100644 index 00000000..1aba38f6 --- /dev/null +++ b/packages/essnmx/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/packages/essnmx/README.md b/packages/essnmx/README.md new file mode 100644 index 00000000..a95ecc59 --- /dev/null +++ b/packages/essnmx/README.md @@ -0,0 +1,16 @@ +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md) +[![PyPI badge](http://img.shields.io/pypi/v/essnmx.svg)](https://pypi.python.org/pypi/essnmx) +[![Anaconda-Server Badge](https://anaconda.org/conda-forge/essnmx/badges/version.svg)](https://anaconda.org/conda-forge/essnmx) +[![License: BSD 3-Clause](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](LICENSE) + +# ESSnmx + +## About + +Data reduction for NMX at the European Spallation Source. + +## Installation + +```sh +python -m pip install essnmx +``` diff --git a/packages/essnmx/docs/_static/anaconda-icon.js b/packages/essnmx/docs/_static/anaconda-icon.js new file mode 100644 index 00000000..024350ec --- /dev/null +++ b/packages/essnmx/docs/_static/anaconda-icon.js @@ -0,0 +1,13 @@ +FontAwesome.library.add( + (faListOldStyle = { + prefix: "fa-custom", + iconName: "anaconda", + icon: [ + 67.65, // viewBox width + 67.500267, // viewBox height + [], // ligature + "e001", // unicode codepoint - private use area + "M 33.900391 0 C 32.600392 0 31.299608 0.09921885 30.099609 0.19921875 A 39.81 39.81 0 0 1 35.199219 4.3007812 L 36.5 5.5 L 35.199219 6.8007812 A 34.65 34.65 0 0 0 32 10.199219 L 32 10.300781 A 6.12 6.12 0 0 0 31.5 10.900391 A 19.27 19.27 0 0 1 33.900391 10.800781 A 23 23 0 0 1 33.900391 56.800781 A 22.39 22.39 0 0 1 21.900391 53.400391 A 45.33 45.33 0 0 1 16.699219 53.699219 A 19.27 19.27 0 0 1 14.300781 53.599609 A 78.24 78.24 0 0 0 15 61.699219 A 33.26 33.26 0 0 0 33.900391 67.5 A 33.75 33.75 0 0 0 33.900391 0 z M 23 1.8007812 A 33.78 33.78 0 0 0 15.599609 5.4003906 A 47 47 0 0 1 20.699219 6.5996094 A 52.38 52.38 0 0 1 23 1.8007812 z M 26.5 2 A 41.8 41.8 0 0 0 23.699219 7.5996094 C 25.199217 8.1996088 26.69922 8.8000007 28.199219 9.5 C 28.799218 8.7000008 29.300391 8.0999999 29.400391 8 C 30.10039 7.2000008 30.800001 6.399218 31.5 5.6992188 A 58.59 58.59 0 0 0 26.5 2 z M 13.199219 8.1992188 A 48.47 48.47 0 0 0 13.099609 14.800781 A 44.05 44.05 0 0 1 18.300781 14.5 A 39.43 39.43 0 0 1 19.699219 9.5996094 A 46.94 46.94 0 0 0 13.199219 8.1992188 z M 10.099609 9.8007812 A 33.47 33.47 0 0 0 4.9003906 16.5 C 6.6003889 16 8.3992205 15.599218 10.199219 15.199219 C 10.099219 13.399221 10.099609 11.600779 10.099609 9.8007812 z M 22.599609 10.599609 C 22.19961 11.799608 21.8 13.100392 21.5 14.400391 A 29.18 29.18 0 0 1 26.199219 12.099609 A 27.49 27.49 0 0 0 22.599609 10.599609 z M 17.699219 17.5 C 16.19922 17.5 14.80078 17.599219 13.300781 17.699219 A 33.92 33.92 0 0 0 14.099609 22.099609 A 20.36 20.36 0 0 1 17.699219 17.5 z M 10.599609 17.900391 A 43.62 43.62 0 0 0 3.3007812 19.900391 L 3.0996094 20 L 3.1992188 20.199219 A 30.3 30.3 0 0 0 6.5 27.300781 L 6.5996094 27.5 L 6.8007812 27.400391 A 50.41 50.41 0 0 1 11.699219 24.300781 L 11.900391 24.199219 L 11.900391 24 A 38.39 38.39 0 0 1 10.800781 18.099609 L 10.800781 17.900391 L 10.599609 17.900391 z M 1.8007812 22.800781 L 1.5996094 23.400391 A 33.77 33.77 0 0 0 0 32.900391 L 0 33.5 L 0.40039062 33.099609 A 24.93 24.93 0 0 1 4.8007812 28.900391 L 5 28.800781 L 4.9003906 28.599609 A 54.49 54.49 0 0 1 2 23.300781 L 1.8007812 22.800781 z M 12.300781 26.300781 L 11.800781 26.599609 C 10.500783 27.399609 9.2003893 28.19961 7.9003906 29.099609 L 7.6992188 29.199219 L 8 29.400391 C 8.8999991 30.600389 9.8007822 31.900001 10.800781 33 L 11.099609 33.5 L 11.099609 32.900391 A 23.54 23.54 0 0 1 12.099609 26.900391 L 12.300781 26.300781 z M 6.0996094 30.5 L 5.9003906 30.699219 A 47 47 0 0 0 0.80078125 35.599609 L 0.59960938 35.800781 L 0.80078125 36 A 58.38 58.38 0 0 0 6.4003906 40.199219 L 6.5996094 40.300781 L 6.6992188 40.099609 A 45.3 45.3 0 0 1 9.6992188 35.5 L 9.8007812 35.300781 L 9.6992188 35.199219 A 52 52 0 0 1 6.1992188 30.800781 L 6.0996094 30.5 z M 11.300781 36.400391 L 11 36.900391 C 10.100001 38.200389 9.2003898 39.600001 8.4003906 41 L 8.3007812 41.199219 L 8.5 41.300781 C 9.8999986 42.10078 11.400392 42.800001 12.900391 43.5 L 13.400391 43.699219 L 13.199219 43.199219 A 23.11 23.11 0 0 1 11.400391 37 L 11.300781 36.400391 z M 0.099609375 37.699219 L 0.19921875 38.300781 A 31.56 31.56 0 0 0 2.9003906 47.699219 L 3.0996094 48.199219 L 3.3007812 47.699219 A 55.47 55.47 0 0 1 5.6992188 42.099609 L 5.8007812 41.800781 L 5.5996094 41.699219 A 57.36 57.36 0 0 1 0.59960938 38.099609 L 0.099609375 37.699219 z M 7.4003906 42.800781 L 7.3007812 43 A 53.76 53.76 0 0 0 4.5 50 L 4.4003906 50.199219 L 4.5996094 50.300781 A 39.14 39.14 0 0 0 12.199219 51.699219 L 12.5 51.699219 L 12.5 51.5 A 36.79 36.79 0 0 1 13 45.699219 L 13 45.5 L 12.800781 45.400391 A 49.67 49.67 0 0 1 7.5996094 42.900391 L 7.4003906 42.800781 z M 14.5 45.900391 L 14.5 46.199219 A 45.53 45.53 0 0 0 14.099609 51.5 L 14.099609 51.699219 L 14.300781 51.699219 C 15.10078 51.699219 15.89922 51.800781 16.699219 51.800781 A 12.19 12.19 0 0 0 19.400391 51.800781 L 20 51.800781 L 19.5 51.400391 A 20.73 20.73 0 0 1 14.900391 46.199219 L 14.900391 46.099609 L 14.5 45.900391 z M 5.1992188 52.099609 L 5.5 52.599609 A 34.87 34.87 0 0 0 12.599609 60.400391 L 13 60.800781 L 13 60.099609 A 51.43 51.43 0 0 1 12.5 53.5 L 12.5 53.300781 L 12.300781 53.300781 A 51.94 51.94 0 0 1 5.8007812 52.199219 L 5.1992188 52.099609 z" + ], + }) +); diff --git a/packages/essnmx/docs/_static/favicon.ico b/packages/essnmx/docs/_static/favicon.ico new file mode 100644 index 00000000..af3f7a31 Binary files /dev/null and b/packages/essnmx/docs/_static/favicon.ico differ diff --git a/packages/essnmx/docs/_static/logo-dark.svg b/packages/essnmx/docs/_static/logo-dark.svg new file mode 100644 index 00000000..fc42a002 --- /dev/null +++ b/packages/essnmx/docs/_static/logo-dark.svg @@ -0,0 +1,162 @@ + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + diff --git a/packages/essnmx/docs/_static/logo.svg b/packages/essnmx/docs/_static/logo.svg new file mode 100644 index 00000000..cb6f702b --- /dev/null +++ b/packages/essnmx/docs/_static/logo.svg @@ -0,0 +1,166 @@ + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/packages/essnmx/docs/_templates/class-template.rst b/packages/essnmx/docs/_templates/class-template.rst new file mode 100644 index 00000000..0200267d --- /dev/null +++ b/packages/essnmx/docs/_templates/class-template.rst @@ -0,0 +1,31 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :special-members: __getitem__ + + {% block methods %} + .. automethod:: __init__ + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/packages/essnmx/docs/_templates/doc_version.html b/packages/essnmx/docs/_templates/doc_version.html new file mode 100644 index 00000000..64fad220 --- /dev/null +++ b/packages/essnmx/docs/_templates/doc_version.html @@ -0,0 +1,2 @@ + +Current ESSnmx version: {{ version }} (older versions). diff --git a/packages/essnmx/docs/_templates/module-template.rst b/packages/essnmx/docs/_templates/module-template.rst new file mode 100644 index 00000000..6fee8d77 --- /dev/null +++ b/packages/essnmx/docs/_templates/module-template.rst @@ -0,0 +1,66 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Module Attributes') }} + + .. autosummary:: + :toctree: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :toctree: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + :template: class-template.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :toctree: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. rubric:: Modules + +.. autosummary:: + :toctree: + :template: module-template.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/packages/essnmx/docs/about/.DS_Store b/packages/essnmx/docs/about/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/packages/essnmx/docs/about/.DS_Store differ diff --git a/packages/essnmx/docs/about/NMX_work_flow.png b/packages/essnmx/docs/about/NMX_work_flow.png new file mode 100644 index 00000000..51edeaf2 Binary files /dev/null and b/packages/essnmx/docs/about/NMX_work_flow.png differ diff --git a/packages/essnmx/docs/about/data_workflow_overview.md b/packages/essnmx/docs/about/data_workflow_overview.md new file mode 100644 index 00000000..998db451 --- /dev/null +++ b/packages/essnmx/docs/about/data_workflow_overview.md @@ -0,0 +1,128 @@ +# Data Workflow Overview +This is an overall description of the data workflow for the NMX instrument at ESS. + +The [NMX](https://europeanspallationsource.se/instruments/nmx) Macromolecular Diffractometer is a time-of-flight (TOF) +quasi-Laue diffractometer optimised for small samples and large unit cells +dedicated to the structure determination of biological macromolecules by crystallography. + +The main scientific driver is to locate the hydrogen atoms relevant to the function of the macromolecule. + +## Data reduction + +![Workflow Overview](NMX_work_flow.png) + +### From single event data to binned image-like data (scipp) +The first step in the data reduction is to reduce the data from single event data to image-like data.
+Therefore the [essNMX](https://github.com/scipp/essnmx) package is used. + +The time of arrival (TOA) should be converted into time of flight (TOF). + +Then the single events get binned into pixels and then histogramed in the TOF dimension.
+This result can be exported to an HDF5 file +along with additional metadata and instrument coordinates (pixel IDs). + +See [workflow example](../user-guide/workflow) for more details. + +### Spot finding and integration (DIALS) +For the next five steps of the data reduction from spot finding to spot integration, +we use a program called [DIALS](https://dials.github.io/index.html) [^1]. +[^1]: DIAL as a toolkit, DOI: 10.1002/pro.4224 + +#### 1. Import Image-like Files + +First, we use [dials.import](https://dials.github.io/documentation/programs/dials_import.html) to convert image data files into a format compatible with dials. + +It processes the metadata and filenames of each image to establish relationships between different sets of images.
+Once all images are processed, the program generates an experiment object file, outlining the connections between the files.
+The images to be processed are designated as command-line arguments.
+Occasionally, there may be a restriction on the maximum number of arguments allowed on the command line, and the number of files could surpass this limit.
+In such cases, image filenames can be entered through stdin, as demonstrated in the examples below.
+The Format class for NMX is at modules/dxtbx/src/dxtbx/format/FormatNMX.py where beam-line-specific parameters and file format information are stored. + +```console +dials.import *.nxs +``` + +#### 2. Search for Strong Pixels + +The next step is to [search for strong pixels](https://dials.github.io/documentation/programs/dials_find_spots.html).
+In this step, the intensity of each pixel or a pixel group is compared with its local surroundings.
+With the information of strong pixels, strong spots are defined.
+To find these spots, the centroids and intensities will be calculated.
+The results can be visualised in the image viewer or the [dial browser](https://toastisme.github.io/dials_browser_experiment_viewer/). + + +```console +dials.find_spots imported.expt find_spots.phil +``` + +#### 3. Index Instrument Geometry + +In the [indexing](https://dials.github.io/documentation/programs/dials_index.html) step the unit cell is determined.
+A list of indexed reflexes and an instrument model including a crystal model is returned.
+One-dimensional and three-dimensional fast Fourier transform-based methods are available. + +As input parameters the ``imported.exp`` and ``strong.refl`` files are used.
+More parameters such as ``unit cell`` and ``spacegroup`` can be given. + +```console +dials.index imported.expt strong.refl space_group=P1 unit_cell=a,b,c,alpha,beta,gamma +``` + +#### 4. Refine the Diffraction Geometry + +The result of indexing the instrument geometry is then used to get refined diffraction geometry [^2]. +[^2]: https://dials.github.io/documentation/programs/dials_refine + +```console +dials.refine indexed.refl indexed.expt detector.panels=hierarchical +``` + +#### 5. Integrate Reflexes + +The last step in DIALS is to integrate each reflex.[^3] +[^3]: https://dials.github.io/documentation/programs/dials_integrate.html + +Currently, different approach is used to integrate the dimension of the image and the dimension of TOF.
+In the dimension of the image, a simple summation is used +and in the TOF dimension, a profile-fitting approach is used. + +```console +dev.dials.simple_tof_integrate refined.expt refined.refl +``` + +### Scaling (LSCALE/pyscale) +Currently [LSCALE](https://doi.org/10.1107/S0021889898015350) can be used in a docker container which makes it indented from the OS.
+LSCALE is a program for scaling and normalisation of Laue intensity data.
+The source code is available on [Zenodo](https://zenodo.org/records/4381992).
+Since LSCALE is not maintained anymore we are currently developing a Python-based alternative to LSCALE called pyscale[^4]. +[^4]: ``pyscale`` is under development and lives in a private repository. Please ask for access to the repository to the [owner](https://github.com/mlund) if needed. + +**To start docker desktop** +```console +docker run -it -v $HOME:/mnt/host -v /tmp/.X11-unix:/tmp/.X11-unix -e DISPLAY=host.docker.internal:0 lscale +``` +**Command to run ``lscale``** +```console +lscale < lscale.com > lscale.out +``` + +### Merge Intensities and Derive Structure Factors (CCP4, AIMLESS and CTRUNCATE) +[AIMLESS](https://www.ccp4.ac.uk/html/aimless.html) and [CTRUNCATE](https://www.ccp4.ac.uk/html/ctruncate.html) are sub-programs of [CCP4](https://www.ccp4.ac.uk/html/). + +[AIMLESS](https://www.ccp4.ac.uk/html/aimless.html) can scale multiple observations of reflections together.
+It can also merge multiple observations into an average intensity. + +[CTRUNCATE](https://www.ccp4.ac.uk/html/ctruncate.html) converts measured intensities into structure factors.
+CTRUNCATE includes corrections for weak reflections to avoid negative intensities due to background corrections. + +This step can be done via GUI interfaces of ``CCP4``. +1. Start ``CCP4`` GUI +2. Go to ``all programs`` +3. Select ``Aimless`` +4. Select ``scaled *mtz file`` + +Parameters can be modified.
+Standard parameters are fine in most cases. + +The ``mtz`` file can be used in a standard protein crystallographic program to solve and refine the structure. diff --git a/packages/essnmx/docs/about/index.md b/packages/essnmx/docs/about/index.md new file mode 100644 index 00000000..61b745df --- /dev/null +++ b/packages/essnmx/docs/about/index.md @@ -0,0 +1,34 @@ +# About + +```{toctree} +--- +maxdepth: 3 +--- + +data_workflow_overview +``` + +## Development + +ESSnmx is an open source project by the [European Spallation Source ERIC](https://ess.eu/) (ESS). + +## License + +ESSnmx is available as open source under the [BSD-3 license](https://opensource.org/license/BSD-3-Clause). + +## Citing ESSnmx + +Please cite the following: + +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.14733001.svg)](https://doi.org/10.5281/zenodo.14733001) + +To cite a specific version of ESSnmx, select the desired version on Zenodo to get the corresponding DOI. + +## Older versions of the documentation + +Older versions of the documentation pages can be found under the assets of each [release](https://github.com/scipp/essnmx/releases). +Simply download the archive, unzip and view locally in a web browser. + +## Source code and development + +ESSnmx is hosted and developed [on GitHub](https://github.com/scipp/essnmx). diff --git a/packages/essnmx/docs/api-reference/index.md b/packages/essnmx/docs/api-reference/index.md new file mode 100644 index 00000000..40f2176b --- /dev/null +++ b/packages/essnmx/docs/api-reference/index.md @@ -0,0 +1,39 @@ +# API Reference + +## Classes + +```{eval-rst} +.. currentmodule:: ess.nmx + +.. autosummary:: + :toctree: ../generated/classes + :template: class-template.rst + :recursive: + +``` + +## Top-level functions + +```{eval-rst} +.. autosummary:: + :toctree: ../generated/functions + :recursive: + +``` + +## Submodules + +```{eval-rst} +.. autosummary:: + :toctree: ../generated/modules + :template: module-template.rst + :recursive: + + data + mcstas + types + mtz_io + scaling + configurations + +``` diff --git a/packages/essnmx/docs/conf.py b/packages/essnmx/docs/conf.py new file mode 100644 index 00000000..b427d470 --- /dev/null +++ b/packages/essnmx/docs/conf.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +import doctest +import os +import sys +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as get_version + +from sphinx.util import logging + +sys.path.insert(0, os.path.abspath(".")) + +logger = logging.getLogger(__name__) + +# General information about the project. +project = 'ESSnmx' +copyright = '2025 Scipp contributors' +author = 'Scipp contributors' + +html_show_sourcelink = True + +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.githubpages', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_autodoc_typehints', + 'sphinx_copybutton', + 'sphinx_design', + 'sphinxcontrib.autodoc_pydantic', + 'nbsphinx', + 'myst_parser', +] + +try: + import sciline.sphinxext.domain_types # noqa: F401 + + extensions.append("sciline.sphinxext.domain_types") + # See https://github.com/tox-dev/sphinx-autodoc-typehints/issues/457 + suppress_warnings = ["config.cache"] +except ModuleNotFoundError: + pass + + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "fieldlist", + "html_admonition", + "html_image", + "replacements", + "smartquotes", + "strikethrough", + "substitution", + "tasklist", +] + +myst_heading_anchors = 3 + +autodoc_type_aliases = { + "array_like": "array_like", +} + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipp": ("https://scipp.github.io/", None), +} + +# autodocs includes everything, even irrelevant API internals. autosummary +# looks more suitable in the long run when the API grows. +# For a nice example see how xarray handles its API documentation. +autosummary_generate = True + +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_use_param = True +napoleon_use_rtype = False +napoleon_preprocess_types = True +napoleon_type_aliases = { + # objects without namespace: numpy + "ndarray": "~numpy.ndarray", +} +typehints_defaults = "comma" +typehints_use_rtype = False + + +sciline_domain_types_prefix = "ess.nmx" +sciline_domain_types_aliases = { + "scipp._scipp.core.DataArray": "scipp.DataArray", + "scipp._scipp.core.Dataset": "scipp.Dataset", + "scipp._scipp.core.DType": "scipp.DType", + "scipp._scipp.core.Unit": "scipp.Unit", + "scipp._scipp.core.Variable": "scipp.Variable", + "scipp.core.data_group.DataGroup": "scipp.DataGroup", +} + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = [".rst", ".md"] +html_sourcelink_suffix = "" # Avoid .ipynb.txt extensions in sources + +# The master toctree document. +master_doc = "index" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# + +try: + release = get_version("essnmx") + version = ".".join(release.split(".")[:3]) # CalVer +except PackageNotFoundError: + logger.info( + "Warning: determining version from package metadata failed, falling back to " + "a dummy version number." + ) + release = version = "0.0.0-dev" + +warning_is_error = True + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = "en" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +# -- Options for HTML output ---------------------------------------------- + +html_theme = "pydata_sphinx_theme" +html_theme_options = { + "primary_sidebar_end": ["edit-this-page", "sourcelink"], + "secondary_sidebar_items": [], + "navbar_persistent": ["search-button"], + "show_nav_level": 1, + # Adjust this to ensure external links are moved to "Move" menu + "header_links_before_dropdown": 4, + "pygment_light_style": "github-light-high-contrast", + "pygment_dark_style": "github-dark-high-contrast", + "logo": { + "image_light": "_static/logo.svg", + "image_dark": "_static/logo-dark.svg", + }, + "external_links": [ + {"name": "Plopp", "url": "https://scipp.github.io/plopp"}, + {"name": "Sciline", "url": "https://scipp.github.io/sciline"}, + {"name": "Scipp", "url": "https://scipp.github.io"}, + {"name": "ScippNexus", "url": "https://scipp.github.io/scippnexus"}, + ], + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/scipp/essnmx", + "icon": "fa-brands fa-github", + "type": "fontawesome", + }, + { + "name": "PyPI", + "url": "https://pypi.org/project/essnmx/", + "icon": "fa-brands fa-python", + "type": "fontawesome", + }, + { + "name": "Conda", + "url": "https://anaconda.org/conda-forge/essnmx", + "icon": "fa-custom fa-anaconda", + "type": "fontawesome", + }, + ], + "footer_start": ["copyright", "sphinx-version"], + "footer_end": ["doc_version", "theme-version"], +} +html_context = { + "doc_path": "docs", +} +html_sidebars = { + "**": ["sidebar-nav-bs", "page-toc"], +} + +html_title = "ESSnmx" +html_logo = "_static/logo.svg" +html_favicon = "_static/favicon.ico" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] +html_css_files = [] +html_js_files = ["anaconda-icon.js"] + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = "essnmxdoc" + +# -- Options for Matplotlib in notebooks ---------------------------------- + +nbsphinx_execute_arguments = [ + "--Session.metadata=scipp_sphinx_build=True", +] + +# -- Options for doctest -------------------------------------------------- + +# sc.plot returns a Figure object and doctest compares that against the +# output written in the docstring. But we only want to show an image of the +# figure, not its `repr`. +# In addition, there is no need to make plots in doctest as the documentation +# build already tests if those plots can be made. +# So we simply disable plots in doctests. +doctest_global_setup = """ +import numpy as np + +try: + import scipp as sc + + def do_not_plot(*args, **kwargs): + pass + + sc.plot = do_not_plot + sc.Variable.plot = do_not_plot + sc.DataArray.plot = do_not_plot + sc.DataGroup.plot = do_not_plot + sc.Dataset.plot = do_not_plot +except ImportError: + # Scipp is not needed by docs if it is not installed. + pass +""" + +# Using normalize whitespace because many __str__ functions in scipp produce +# extraneous empty lines and it would look strange to include them in the docs. +doctest_default_flags = ( + doctest.ELLIPSIS + | doctest.IGNORE_EXCEPTION_DETAIL + | doctest.DONT_ACCEPT_TRUE_FOR_1 + | doctest.NORMALIZE_WHITESPACE +) + +# -- Options for linkcheck ------------------------------------------------ + +linkcheck_ignore = [ + # Specific lines in Github blobs cannot be found by linkcheck. + r'https?://github\.com/.*?/blob/[a-f0-9]+/.+?#', + # Linkcheck seems to be denied access by some DOI resolvers. + # Since DOIs are supposed to be permanent, we don't need to check them.' + r'https?://doi\.org/', + r'https?://dx\.doi\.org/', + r'https://www\.ccp4\.ac\.uk/*', # Seems to be denied by the server. + # Manually checked and working +] diff --git a/packages/essnmx/docs/developer/coding-conventions.md b/packages/essnmx/docs/developer/coding-conventions.md new file mode 100644 index 00000000..4fafc18d --- /dev/null +++ b/packages/essnmx/docs/developer/coding-conventions.md @@ -0,0 +1,117 @@ +# Coding conventions + +## Code formatting + +There are no explicit code formatting conventions since we use `ruff` to enforce a format. + +## Docstring format + +We use the [NumPy docstring format](https://www.sphinx-doc.org/en/master/usage/extensions/example_numpy.html). +We use `sphinx-autodocs-typehints` to automatically insert type hints into the docstrings. +Our format thus deviates from the default NumPy example given by the link above. +Docstrings should therefore be laid out as follows, including spacing and punctuation: + +```python + +def foo(x: int, y: float) -> float: + """Short description. + + Long description. + + With multiple paragraphs. + + Warning + ------- + Be careful! + + Parameters + ---------- + x: + First input. + y: + Second input. + + Returns + ------- + : + The result. + + Raises + ------ + ValueError + If the input is bad. + IndexError + If some lookup failed. + + See Also + -------- + scitacean.bar: + A bit less foo. + + Examples + -------- + This is how to use it: + + >>> foo(1, 2) + 3 + + And also: + + >>> foo(1, 3) + 6 + """ +``` + +The order of sections is fixed as shown in the example. + +- **Short description** (*required*) A single sentence describing the purpose of the function / class. +- **Long description** (*optional*) One or more paragraphs of detailed explanations. + Can include additional sections like `Warning` or `Hint`. +- **Parameters** (*required for functions*) List of all function arguments including their name but not their type. + Listing arguments like this can seem ridiculous if the explanation is as devoid of content as in the example. + But it is still required in order for sphinx to show the types. +- **Returns** (*required for functions*) Description of the return value. + Required for the same reason as the parameter list. + + For a single return value, neither a name nor type should be given. + But a colon is required as in the example above in order to produce proper formatting. + + For multiple return values, to produce proper formatting, + both name and type must be given even though the latter repeats the type annotation: + + ```python + + """ + Returns + ------- + n: int + The first return value. + z: float + The second return value. + """ + ``` + +- **Raises** (*optional*) We generally do not document what exceptions can be raised from a function. + But if there are some important cases, this section can list those exceptions with an explanation + of when the exception is raised. + The exception type is required. + Note that there are no colons here. +- **See Also** (*optional*) List of related functions and/or classes. + The function/class name should include the module it is in but without reST markup. + For simple cases, the explanation can be left out. + In this case, the colon should be omitted as well and multiple entries must be separated by commas. +- **Examples** (*optional*) Example code given using `>>>` as the Python prompt. + May include text before, after, and between code blocks. + Note the spacing in the example. + +Some functions can be sufficiently described by a single sentence. +In this case, the 'Parameters' and 'Returns' sections may be omitted and the docstring should be laid out on a single line. +If it does not fit on a single line, it is too complicated. +For example + +```python +def bar(self) -> int: + """Returns the number of dimensions.""" +``` + +Note that the argument types are not shown in the rendered documentation. diff --git a/packages/essnmx/docs/developer/dependency-management.md b/packages/essnmx/docs/developer/dependency-management.md new file mode 100644 index 00000000..172722dc --- /dev/null +++ b/packages/essnmx/docs/developer/dependency-management.md @@ -0,0 +1,13 @@ +# Dependency management + +essnmx is a library, so the package dependencies are never pinned. +Lower bounds are fine and individual versions can be excluded. +See, e.g., [Should You Use Upper Bound Version Constraints](https://iscinumpy.dev/post/bound-version-constraints/) for an explanation. + +Development dependencies (as opposed to dependencies of the deployed package that users need to install) are pinned to an exact version in order to ensure reproducibility. +This also includes dependencies used for the various CI builds. +This is done by specifying packages (and potential version constraints) in `requirements/*.in` files and locking those dependencies using [pip-compile-multi](https://pip-compile-multi.readthedocs.io/en/latest/index.html) to produce `requirements/*.txt` files. +Those files are then used by [tox](https://tox.wiki/en/latest/) to create isolated environments and run tests, build docs, etc. + +`tox` can be cumbersome to use for local development. +Therefore `requirements/dev.txt` can be used to create a virtual environment with all dependencies. diff --git a/packages/essnmx/docs/developer/getting-started.md b/packages/essnmx/docs/developer/getting-started.md new file mode 100644 index 00000000..a7667511 --- /dev/null +++ b/packages/essnmx/docs/developer/getting-started.md @@ -0,0 +1,91 @@ +# Getting started + +## Setting up + +### Dependencies + +Development dependencies are specified in `requirements/dev.txt` and can be installed using (see [Dependency Management](./dependency-management.md) for more information) + +```sh +pip install -r requirements/dev.txt +``` + +Additionally, building the documentation requires [pandoc](https://pandoc.org/) which is not on PyPI and needs to be installed through other means, e.g. with your OS package manager. + +### Install the package + +Install the package in editable mode using + +```sh +pip install -e . +``` + +### Set up git hooks + +The CI pipeline runs a number of code formatting and static analysis tools. +If they fail, a build is rejected. +To avoid that, you can run the same tools locally. +This can be done conveniently using [pre-commit](https://pre-commit.com/): + +```sh +pre-commit install +``` + +Alternatively, if you want a different workflow, take a look at ``tox.ini`` or ``.pre-commit.yaml`` to see what tools are run and how. + +## Running tests + +`````{tab-set} +````{tab-item} tox +Run the tests using + +```sh +tox -e py311 +``` + +(or just `tox` if you want to run all environments). + +```` +````{tab-item} Manually +Run the tests using + +```sh +python -m pytest +``` +```` +````` + +## Building the docs + +`````{tab-set} +````{tab-item} tox +Build the documentation using + +```sh +tox -e docs +``` + +This builds the docs and also runs `doctest`. +`linkcheck` can be run separately using + +```sh +tox -e linkcheck +``` +```` + +````{tab-item} Manually + +Build the documentation using + +```sh +python -m sphinx -v -b html -d .tox/docs_doctrees docs html +``` + +Additionally, test the documentation using + +```sh +python -m sphinx -v -b doctest -d .tox/docs_doctrees docs html +python -m sphinx -v -b linkcheck -d .tox/docs_doctrees docs html +``` +```` +````` \ No newline at end of file diff --git a/packages/essnmx/docs/developer/index.md b/packages/essnmx/docs/developer/index.md new file mode 100644 index 00000000..d183a4ca --- /dev/null +++ b/packages/essnmx/docs/developer/index.md @@ -0,0 +1,17 @@ +# Development + +```{include} ../../CONTRIBUTING.md +``` + +## Table of contents + +```{toctree} +--- +maxdepth: 2 +--- + +getting-started +coding-conventions +dependency-management +test-dataset +``` diff --git a/packages/essnmx/docs/developer/test-dataset.ipynb b/packages/essnmx/docs/developer/test-dataset.ipynb new file mode 100644 index 00000000..96c7e252 --- /dev/null +++ b/packages/essnmx/docs/developer/test-dataset.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Dataset\n", + "\n", + "This page has the instruction of how the test-datasets were generated and how they are used in the tests.\n", + "\n", + "## Scaling workflow - MTZ files\n", + "\n", + "MTZ test datasets are create with ``gemmi`` and random generator.\n", + "\n", + "We have multiple test MTZ files since multiple files are expected in usual cases.\n", + "\n", + "These files do not have any physical meaning and they are meant to be useful for testing the workflow.\n", + "\n", + "Here is the code cell to create the test files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gemmi\n", + "import pandas as pd\n", + "import numpy as np\n", + "from ess.nmx.mtz_io import (\n", + " DEFAULT_INTENSITY_COLUMN_NAME,\n", + " DEFAULT_WAVELENGTH_COLUMN_NAME,\n", + " DEFAULT_STD_DEV_COLUMN_NAME,\n", + " DEFAULT_SPACE_GROUP_DESC,\n", + ")\n", + "\n", + "# Negative intensities will happen due to corrections\n", + "# and high intensities are also expected in some cases\n", + "INTENSITY_RANGE = (-20.0, 200.0)\n", + "HKL_RANGE = (-100, 100)\n", + "MANDATORY_FIELDS = (\n", + " \"H\",\n", + " \"K\",\n", + " \"L\",\n", + " DEFAULT_WAVELENGTH_COLUMN_NAME, # LAMBDA\n", + " DEFAULT_INTENSITY_COLUMN_NAME, # I\n", + " DEFAULT_STD_DEV_COLUMN_NAME, # SIGI\n", + ")\n", + "global_rng = np.random.default_rng(0)\n", + "HKL_CANDIDATES = tuple(\n", + " zip(*[global_rng.integers(*HKL_RANGE, size=100) for _ in range(3)], strict=False)\n", + ")\n", + "\n", + "def create_mtz_data_frame(random_seed: int) -> pd.DataFrame:\n", + " rng = np.random.default_rng(random_seed)\n", + " intensities = np.sort(rng.normal(50, 20, size=10_000))[::-1] + (\n", + " rng.uniform(*INTENSITY_RANGE, size=10_000)\n", + " * rng.choice([0] * 99 + [1], size=10_000)\n", + " )\n", + " std_devs = np.multiply(intensities, rng.uniform(0.1, 0.15, size=10_000))\n", + " wavelengths = np.sort(rng.uniform(2.8, 3.2, size=10_000))[::-1]\n", + "\n", + " df = pd.DataFrame(\n", + " {\n", + " DEFAULT_INTENSITY_COLUMN_NAME: intensities,\n", + " DEFAULT_STD_DEV_COLUMN_NAME: std_devs,\n", + " DEFAULT_WAVELENGTH_COLUMN_NAME: wavelengths,\n", + " }\n", + " )\n", + "\n", + " df[[\"H\", \"K\", \"L\"]] = pd.Series(\n", + " rng.choice(HKL_CANDIDATES, size=10_000).tolist()\n", + " ).to_list()\n", + "\n", + " return df\n", + "\n", + "\n", + "def dataframe_to_mtz(df: pd.DataFrame) -> gemmi.Mtz:\n", + " \"\"\"Create a random MTZ file with a single dataset.\n", + "\n", + " Columns:\n", + " - H, K, L: Miller indices\n", + " - LAMBDA: Wavelength\n", + " - I: Intensity\n", + " - SIGI: Standard deviation of intensity\n", + "\n", + " \"\"\"\n", + " assert set(df.columns) == set(MANDATORY_FIELDS)\n", + "\n", + " mtz = gemmi.Mtz()\n", + " mtz.add_dataset(\"HKL\")\n", + " column_type_map = { # Column types: https://www.ccp4.ac.uk/html/mtzformat.html#coltypes\n", + " \"H\": \"H\",\n", + " \"K\": \"H\",\n", + " \"L\": \"H\",\n", + " DEFAULT_WAVELENGTH_COLUMN_NAME: \"R\",\n", + " DEFAULT_INTENSITY_COLUMN_NAME: \"J\",\n", + " DEFAULT_STD_DEV_COLUMN_NAME: \"Q\",\n", + " }\n", + "\n", + " for col_name in df.columns:\n", + " mtz.add_column(col_name, type=column_type_map[col_name], expand_data=True)\n", + "\n", + " mtz.spacegroup = gemmi.SpaceGroup(DEFAULT_SPACE_GROUP_DESC)\n", + " mtz.set_data(df.values)\n", + " return mtz\n", + "\n", + "\n", + "for seed in range(1, 6):\n", + " sample_df = create_mtz_data_frame(seed)\n", + " sample_mtz = dataframe_to_mtz(sample_df)\n", + " # sample_mtz.write_to_file(f\"sample_{seed}.mtz\") # Uncomment to save the MTZ file\n", + "\n", + "sample_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the files were created, they were compressed into one file\n", + "and uploaded in the server where pooch can access to.\n", + "\n", + "Here is the script for compressing the files.\n", + "\n", + "```bash\n", + "tar -czvf mtz_random_samples.tar.gz sample_*.mtz\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nmx-dev-310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/packages/essnmx/docs/index.md b/packages/essnmx/docs/index.md new file mode 100644 index 00000000..2445e0fd --- /dev/null +++ b/packages/essnmx/docs/index.md @@ -0,0 +1,49 @@ +:::{image} _static/logo.svg +:class: only-light +:alt: ESSnmx +:width: 60% +:align: center +::: +:::{image} _static/logo-dark.svg +:class: only-dark +:alt: ESSnmx +:width: 60% +:align: center +::: + +```{raw} html + +``` + +```{role} transparent +``` + +# {transparent}`ESSnmx` + +
+ Data reduction for NMX at the European Spallation Source. +

+
+ +:::{include} user-guide/installation.md +:heading-offset: 1 +::: + +## Get in touch + +- If you have questions that are not answered by these documentation pages, ask on [discussions](https://github.com/scipp/essnmx/discussions). Please include a self-contained reproducible example if possible. +- Report bugs (including unclear, missing, or wrong documentation!), suggest features or view the source code [on GitHub](https://github.com/scipp/essnmx). + +```{toctree} +--- +hidden: +--- + +user-guide/index +api-reference/index +developer/index +about/index +``` diff --git a/packages/essnmx/docs/user-guide/index.md b/packages/essnmx/docs/user-guide/index.md new file mode 100644 index 00000000..83cdd88e --- /dev/null +++ b/packages/essnmx/docs/user-guide/index.md @@ -0,0 +1,13 @@ +# User Guide + +```{toctree} +--- +maxdepth: 1 +--- + +workflow +mcstas_workflow +mcstas_workflow_chunk +scaling_workflow +installation +``` diff --git a/packages/essnmx/docs/user-guide/installation.md b/packages/essnmx/docs/user-guide/installation.md new file mode 100644 index 00000000..f29e78c9 --- /dev/null +++ b/packages/essnmx/docs/user-guide/installation.md @@ -0,0 +1,16 @@ +# Installation + +To install ESSnmx and all of its dependencies, use + +`````{tab-set} +````{tab-item} pip +```sh +pip install essnmx +``` +```` +````{tab-item} conda +```sh +conda install -c conda-forge essnmx +``` +```` +````` diff --git a/packages/essnmx/docs/user-guide/mcstas_workflow.ipynb b/packages/essnmx/docs/user-guide/mcstas_workflow.ipynb new file mode 100644 index 00000000..4924246f --- /dev/null +++ b/packages/essnmx/docs/user-guide/mcstas_workflow.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# McStas Workflow\n", + "In this example, we will use McStas 3 simulation file.\n", + "\n", + "## Build Pipeline (Collect Parameters and Providers)\n", + "Import the providers from ``load_mcstas_nexus`` to use the ``McStas`` simulation data workflow.
\n", + "``MaximumProbability`` can be manually provided to derive more realistic number of events.
\n", + "It is because ``weights`` are given as probability, not number of events in a McStas file.
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.mcstas import NMXMcStasWorkflow\n", + "from ess.nmx.data import get_small_mcstas\n", + "\n", + "from ess.nmx.mcstas.types import *\n", + "from ess.nmx.mcstas.reduction import merge_panels\n", + "from ess.nmx.mcstas.nexus import export_as_nexus\n", + "\n", + "wf = NMXMcStasWorkflow()\n", + "# Replace with the path to your own file\n", + "wf[FilePath] = get_small_mcstas()\n", + "wf[MaximumCounts] = 10000\n", + "wf[TimeBinSteps] = 50" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To see what the workflow can produce, display it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We want to reduce all three panels, so we map the relevant part of the workflow over a list of the three panels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# DetectorIndex selects what detector panels to include in the run\n", + "# in this case we select all three panels.\n", + "wf[NMXReducedDataGroup] = (\n", + " wf[NMXReducedDataGroup]\n", + " .map({DetectorIndex: sc.arange('panel', 3, unit=None)})\n", + " .reduce(index=\"panel\", func=merge_panels)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wf.visualize(NMXReducedDataGroup, graph_attr={\"rankdir\": \"TD\"}, compact=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Desired Types" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cyclebane.graph import NodeName, IndexValues\n", + "\n", + "# Data from all selected detectors binned by panel, pixel and timeslice\n", + "targets = [NodeName(NMXReducedDataGroup, IndexValues((\"panel\",), (i,))) for i in range(3)]\n", + "dg = merge_panels(*wf.compute(targets).values())\n", + "dg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dg['counts']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export Results\n", + "\n", + "``NMXReducedData`` object has a method to export the data into nexus or h5 file.\n", + "\n", + "You can save the result as ``test.nxs``, for example:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "export_as_nexus(dg, \"test.nxs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instrument View\n", + "\n", + "Pixel positions are not used for later steps,\n", + "but it is included in the coordinates for instrument view.\n", + "\n", + "All pixel positions are relative to the sample position,\n", + "therefore the sample is at (0, 0, 0).\n", + "\n", + "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import scippneutron as scn\n", + "\n", + "da = dg[\"counts\"]\n", + "da.coords[\"position\"] = dg[\"position\"]\n", + "# Plot one out of 100 pixels to reduce size of docs output\n", + "view = scn.instrument_view(da[\"id\", ::100].sum('t'), pixel_size=0.0075)\n", + "view" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nmx-dev-313", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/packages/essnmx/docs/user-guide/mcstas_workflow_chunk.ipynb b/packages/essnmx/docs/user-guide/mcstas_workflow_chunk.ipynb new file mode 100644 index 00000000..7257ed80 --- /dev/null +++ b/packages/essnmx/docs/user-guide/mcstas_workflow_chunk.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# McStas Workflow - Chunk by Chunk\n", + "In this example, we will process McStas events chunk by chunk, panel by panel." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Base Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.mcstas import NMXMcStasWorkflow\n", + "from ess.nmx.data import get_small_mcstas\n", + "from ess.nmx.mcstas.types import *\n", + "\n", + "wf = NMXMcStasWorkflow()\n", + "# Replace with the path to your own file\n", + "wf[FilePath] = get_small_mcstas()\n", + "wf[MaximumCounts] = 10_000\n", + "wf[TimeBinSteps] = 50\n", + "wf.visualize(NMXReducedDataGroup, graph_attr={\"rankdir\": \"TD\"}, compact=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Raw Data Metadata\n", + "\n", + "`time-of-flight` coordinate and `McStasWeight2CountScaleFactor` should not be different from chunk to chunk.\n", + "\n", + "Therefore we need to compute `TimeBinStep` and `McStasWeight2CoutScaleFactor` before we compute `NMXReducedData`.\n", + "\n", + "It can be done by `ess.reduce.streaming.StreamProcessor`.\n", + "\n", + "In this example, `MinimumTimeOfArrival`, `MaximumTimeOfArrival` and `MaximumProbability` will be renewed every time a chunk is added to the streaming processor.\n", + "\n", + "`(Min/Max)Accumulator` remembers the previous minimum/maximum value and compute new minimum/maximum value with the new chunk.\n", + "\n", + "``raw_event_data_chunk_generator`` yields a chunk of raw event probability from mcstas h5 file.\n", + "\n", + "This example below process the data chunk by chunk with size: ``CHUNK_SIZE``." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from ess.reduce.streaming import StreamProcessor, MaxAccumulator, MinAccumulator\n", + "\n", + "# Stream processor building helper\n", + "scalefactor_stream_processor = partial(\n", + " StreamProcessor,\n", + " dynamic_keys=(RawEventProbability,),\n", + " target_keys=(NMXRawDataMetadata,),\n", + " accumulators={\n", + " MaximumProbability: MaxAccumulator,\n", + " MaximumTimeOfArrival: MaxAccumulator,\n", + " MinimumTimeOfArrival: MinAccumulator,\n", + " },\n", + ")\n", + "metadata_wf = wf.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata_wf.visualize(NMXRawDataMetadata, graph_attr={\"rankdir\": \"TD\"}, compact=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.mcstas.load import (\n", + " raw_event_data_chunk_generator,\n", + " mcstas_weight_to_probability_scalefactor,\n", + ")\n", + "from ess.nmx.mcstas.streaming import calculate_number_of_chunks\n", + "from ipywidgets import IntProgress\n", + "\n", + "CHUNK_SIZE = 10 # Number of event rows to process at once\n", + "# Increase this number to speed up the processing\n", + "NUM_DETECTORS = 3\n", + "\n", + "# Loop over the detectors\n", + "file_path = metadata_wf.compute(FilePath)\n", + "raw_data_metadatas = {}\n", + "\n", + "for detector_i in range(0, NUM_DETECTORS):\n", + " temp_wf = metadata_wf.copy()\n", + " temp_wf[DetectorIndex] = detector_i\n", + " detector_name = temp_wf.compute(DetectorName)\n", + " max_chunk_id = calculate_number_of_chunks(\n", + " temp_wf.compute(FilePath), detector_name=detector_name, chunk_size=CHUNK_SIZE\n", + " )\n", + " cur_detector_progress_bar = IntProgress(\n", + " min=0, max=max_chunk_id, description=f\"Detector {detector_i}\"\n", + " )\n", + " display(cur_detector_progress_bar)\n", + "\n", + " # Build the stream processor\n", + " processor = scalefactor_stream_processor(temp_wf)\n", + " for da in raw_event_data_chunk_generator(\n", + " file_path=file_path, detector_name=detector_name, chunk_size=CHUNK_SIZE\n", + " ):\n", + " if any(da.sizes.values()) == 0:\n", + " continue\n", + " else:\n", + " results = processor.add_chunk({RawEventProbability: da})\n", + " cur_detector_progress_bar.value += 1\n", + " display(results[NMXRawDataMetadata])\n", + " raw_data_metadatas[detector_i] = results[NMXRawDataMetadata]\n", + "\n", + "# We take the min/maximum values of the scale factor\n", + "min_toa = min(meta.min_toa for meta in raw_data_metadatas.values())\n", + "max_toa = max(meta.max_toa for meta in raw_data_metadatas.values())\n", + "max_probability = max(meta.max_probability for meta in raw_data_metadatas.values())\n", + "\n", + "toa_bin_edges = sc.linspace(dim='t', start=min_toa, stop=max_toa, num=51)\n", + "scale_factor = mcstas_weight_to_probability_scalefactor(\n", + " max_counts=wf.compute(MaximumCounts), max_probability=max_probability\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Metadata\n", + "\n", + "Other metadata does not require any chunk-based computation.\n", + "\n", + "Therefore we export the metadata first and append detector data later." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Final Output\n", + "\n", + "Now with all the metadata, we can compute the final output chunk by chunk.\n", + "\n", + "We will also compute static parameters in advance so that stream processor does not compute them every time another chunk is added.\n", + "\n", + "We will as well export the reduced data detector by detector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.mcstas.xml import McStasInstrument\n", + "\n", + "final_wf = wf.copy()\n", + "# Set the scale factor and time bin edges\n", + "final_wf[McStasWeight2CountScaleFactor] = scale_factor\n", + "final_wf[TimeBinSteps] = toa_bin_edges\n", + "\n", + "# Set the crystal rotation manually for now ...\n", + "final_wf[CrystalRotation] = sc.vector([0, 0, 0.0], unit='deg')\n", + "# Set static info\n", + "final_wf[McStasInstrument] = wf.compute(McStasInstrument)\n", + "final_wf.visualize(NMXReducedDataGroup, compact=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.mcstas.nexus import NXLauetofWriter\n", + "\n", + "\n", + "def temp_generator(file_path, detector_name):\n", + " max_chunk_id = calculate_number_of_chunks(\n", + " file_path, detector_name=detector_name, chunk_size=CHUNK_SIZE\n", + " )\n", + " cur_detector_progress_bar = IntProgress(\n", + " min=0, max=max_chunk_id, description=f\"Detector {detector_i}\"\n", + " )\n", + " display(cur_detector_progress_bar)\n", + " for da in raw_event_data_chunk_generator(\n", + " file_path=file_path, detector_name=detector_name, chunk_size=CHUNK_SIZE\n", + " ):\n", + " yield da\n", + " cur_detector_progress_bar.value += 1\n", + "\n", + "\n", + "# When a panel is added to the writer,\n", + "# the writer will start processing the data from the generator\n", + "# and store the results in memory\n", + "# The writer will then write the data to the file\n", + "# when ``save`` is called\n", + "writer = NXLauetofWriter(\n", + " chunk_generator=temp_generator,\n", + " chunk_insert_key=RawEventProbability,\n", + " workflow=final_wf,\n", + " output_filename=\"test.h5\",\n", + " overwrite=True,\n", + " extra_meta={\"McStasWeight2CountScaleFactor\": scale_factor},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for detector_i in range(3):\n", + " display(writer.add_panel(detector_id=detector_i))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nmx-dev-310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/packages/essnmx/docs/user-guide/scaling_workflow.ipynb b/packages/essnmx/docs/user-guide/scaling_workflow.ipynb new file mode 100644 index 00000000..e5ecbb75 --- /dev/null +++ b/packages/essnmx/docs/user-guide/scaling_workflow.ipynb @@ -0,0 +1,403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scaling\n", + "\n", + "## MTZ IO\n", + "\n", + "``ess.nmx`` has ``MTZ`` IO helper functions.\n", + "They can be used as providers in a workflow of scaling routine.\n", + "\n", + "They are wrapping ``MTZ`` IO functions of ``gemmi``." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gemmi\n", + "from ess.nmx.mtz_io import (\n", + " read_mtz_file,\n", + " mtz_to_pandas,\n", + " MTZFilePath,\n", + " get_unique_space_group,\n", + " MtzDataFrame,\n", + " merge_mtz_dataframes,\n", + ")\n", + "from ess.nmx.data import get_small_random_mtz_samples\n", + "\n", + "\n", + "small_mtz_sample = get_small_random_mtz_samples()[0]\n", + "mtz = read_mtz_file(MTZFilePath(small_mtz_sample))\n", + "df = mtz_to_pandas(mtz)\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Pipeline\n", + "\n", + "Scaling routine includes:\n", + "- Reducing individual MTZ dataset\n", + "- Merging MTZ dataset \n", + "- Reducing merged MTZ dataset\n", + "\n", + "These operations are done on pandas dataframe as recommended in ``gemmi``.\n", + "And multiple MTZ files are expected, so we need to use ``sciline.ParamTable``.\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import sciline as sl\n", + "import scipp as sc\n", + "\n", + "from ess.nmx.mtz_io import providers as mtz_io_providers, default_parameters as mtz_io_params\n", + "from ess.nmx.mtz_io import SpaceGroupDesc\n", + "from ess.nmx.scaling import providers as scaling_providers, default_parameters as scaling_params\n", + "from ess.nmx.scaling import (\n", + " WavelengthBins,\n", + " FilteredEstimatedScaledIntensities,\n", + " ReferenceWavelength,\n", + " ScaledIntensityLeftTailThreshold,\n", + " ScaledIntensityRightTailThreshold,\n", + ")\n", + "\n", + "pl = sl.Pipeline(\n", + " providers=mtz_io_providers + scaling_providers,\n", + " params={\n", + " SpaceGroupDesc: \"C 1 2 1\",\n", + " ReferenceWavelength: sc.scalar(\n", + " 3, unit=sc.units.angstrom\n", + " ), # Remove it if you want to use the middle of the bin\n", + " ScaledIntensityLeftTailThreshold: sc.scalar(\n", + " 0.1, # Increase it to remove more outliers\n", + " ),\n", + " ScaledIntensityRightTailThreshold: sc.scalar(\n", + " 4.0, # Decrease it to remove more outliers\n", + " ),\n", + " **mtz_io_params,\n", + " **scaling_params,\n", + " WavelengthBins: 250,\n", + " },\n", + ")\n", + "pl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file_paths = pd.DataFrame({MTZFilePath: get_small_random_mtz_samples()}).rename_axis(\n", + " \"mtzfile\"\n", + ")\n", + "mapped = pl.map(file_paths)\n", + "pl[gemmi.SpaceGroup] = mapped[gemmi.SpaceGroup | None].reduce(\n", + " index='mtzfile', func=get_unique_space_group\n", + ")\n", + "pl[MtzDataFrame] = mapped[MtzDataFrame].reduce(\n", + " index='mtzfile', func=merge_mtz_dataframes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.scaling import WavelengthScaleFactors\n", + "\n", + "scaling_nmx_workflow = pl.get(WavelengthScaleFactors)\n", + "scaling_nmx_workflow.visualize(graph_attr={\"rankdir\": \"LR\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Desired Type" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.scaling import (\n", + " SelectedReferenceWavelength,\n", + " FittingResult,\n", + " WavelengthScaleFactors,\n", + ")\n", + "\n", + "results = scaling_nmx_workflow.compute(\n", + " (\n", + " FilteredEstimatedScaledIntensities,\n", + " SelectedReferenceWavelength,\n", + " FittingResult,\n", + " WavelengthScaleFactors,\n", + " )\n", + ")\n", + "\n", + "results[WavelengthScaleFactors]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plots\n", + "\n", + "Here are plotting examples of the fitting/estimation results.\n", + "\n", + "### Estimated Scaled Intensities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import scipy.stats as stats\n", + "import matplotlib.pyplot as plt\n", + "\n", + "fig, (density_ax, prob_ax) = plt.subplots(1, 2, figsize=(10, 5))\n", + "\n", + "densities = sc.values(results[FilteredEstimatedScaledIntensities].data).values\n", + "sc.values(results[FilteredEstimatedScaledIntensities].data).hist(intensity=50).plot(\n", + " title=\"Filtered Estimated Scaled Intensities\\nDensity Plot\",\n", + " grid=True,\n", + " linewidth=3,\n", + " ax=density_ax,\n", + ")\n", + "stats.probplot(densities, dist=\"norm\", plot=prob_ax)\n", + "prob_ax.set_title(\"Filtered Estimated Scaled Intensities\\nProbability Plot\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Curve Fitting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plopp as pp\n", + "import numpy as np\n", + "from ess.nmx.scaling import FittingResult\n", + "\n", + "chebyshev_func = np.polynomial.chebyshev.Chebyshev(np.array([1, -1, 1]))\n", + "scale_function = np.vectorize(\n", + " chebyshev_func / chebyshev_func(results[SelectedReferenceWavelength].value)\n", + ")\n", + "pp.plot(\n", + " {\n", + " \"Original Data\": results[FilteredEstimatedScaledIntensities],\n", + " \"Fit Result\": results[FittingResult].fit_output,\n", + " },\n", + " grid=True,\n", + " title=\"Fit Result [Intensities vs Wavelength]\",\n", + " marker={\"Chebyshev\": None, \"Fit Result\": None},\n", + " linestyle={\"Chebyshev\": \"solid\", \"Fit Result\": \"solid\"},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reference_wavelength = sc.DataArray(\n", + " data=sc.concat(\n", + " [\n", + " results[WavelengthScaleFactors].data.min(),\n", + " results[WavelengthScaleFactors].data.max(),\n", + " ],\n", + " \"wavelength\",\n", + " ),\n", + " coords={\n", + " \"wavelength\": sc.broadcast(\n", + " results[SelectedReferenceWavelength], dims=[\"wavelength\"], shape=[2]\n", + " )\n", + " },\n", + ")\n", + "wavelength_scale_factor_plot = pp.plot(\n", + " {\n", + " \"scale_factor\": results[WavelengthScaleFactors],\n", + " \"reference_wavelength\": reference_wavelength,\n", + " },\n", + " title=\"Wavelength Scale Factors\",\n", + " grid=True,\n", + " marker={\"reference_wavelength\": None},\n", + " linestyle={\"reference_wavelength\": \"solid\"},\n", + ")\n", + "wavelength_scale_factor_plot.ax.set_xlim(2.8, 3.2)\n", + "reference_wavelength = results[SelectedReferenceWavelength].value\n", + "wavelength_scale_factor_plot.ax.text(\n", + " 3.0,\n", + " 0.25,\n", + " f\"{reference_wavelength=:} [{results[SelectedReferenceWavelength].unit}]\",\n", + " fontsize=8,\n", + " color=\"black\",\n", + ")\n", + "wavelength_scale_factor_plot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Change Provider\n", + "Here is an example of how to insert different filter function.\n", + "\n", + "In this example, we will swap a provider that filters ``EstimatedScaledIntensities`` and provide ``FilteredEstimatedScaledIntensities``.\n", + "\n", + "After updating the providers, you can go back to [Compute Desired Type](#Compute-Desired-Type) and start over." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import NewType\n", + "import scipp as sc\n", + "from ess.nmx.scaling import (\n", + " EstimatedScaledIntensities,\n", + " FilteredEstimatedScaledIntensities,\n", + ")\n", + "\n", + "# Define the new types for the filtering function\n", + "NRoot = NewType(\"NRoot\", int)\n", + "\"\"\"The n-th root to be taken for the standard deviation.\"\"\"\n", + "NRootStdDevCut = NewType(\"NRootStdDevCut\", float)\n", + "\"\"\"The number of standard deviations to be cut from the n-th root data.\"\"\"\n", + "\n", + "\n", + "def _calculate_sample_standard_deviation(var: sc.Variable) -> sc.Variable:\n", + " \"\"\"Calculate the sample variation of the data.\n", + "\n", + " This helper function is a temporary solution before\n", + " we release new scipp version with the statistics helper.\n", + " \"\"\"\n", + " import numpy as np\n", + "\n", + " return sc.scalar(np.nanstd(var.values))\n", + "\n", + "\n", + "# Define the filtering function with right argument types and return type\n", + "def cut_estimated_scaled_intensities_by_n_root_std_dev(\n", + " scaled_intensities: EstimatedScaledIntensities,\n", + " n_root: NRoot,\n", + " n_root_std_dev_cut: NRootStdDevCut,\n", + ") -> FilteredEstimatedScaledIntensities:\n", + " \"\"\"Filter the mtz data array by the quad root of the sample standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " scaled_intensities:\n", + " The scaled intensities to be filtered.\n", + "\n", + " n_root:\n", + " The n-th root to be taken for the standard deviation.\n", + " Higher n-th root means cutting is more effective on the right tail.\n", + " More explanation can be found in the notes.\n", + "\n", + " n_root_std_dev_cut:\n", + " The number of standard deviations to be cut from the n-th root data.\n", + "\n", + " Returns\n", + " -------\n", + " :\n", + " The filtered scaled intensities.\n", + "\n", + " \"\"\"\n", + " # Check the range of the n-th root\n", + " if n_root < 1:\n", + " raise ValueError(\"The n-th root should be equal to or greater than 1.\")\n", + "\n", + " copied = scaled_intensities.copy(deep=False)\n", + " nth_root = copied.data ** (1 / n_root)\n", + " # Calculate the mean\n", + " nth_root_mean = nth_root.nanmean()\n", + " # Calculate the sample standard deviation\n", + " nth_root_std_dev = _calculate_sample_standard_deviation(nth_root)\n", + " # Calculate the cut value\n", + " half_window = n_root_std_dev_cut * nth_root_std_dev\n", + " keep_range = (nth_root_mean - half_window, nth_root_mean + half_window)\n", + "\n", + " # Filter the data\n", + " return FilteredEstimatedScaledIntensities(\n", + " copied[(nth_root > keep_range[0]) & (nth_root < keep_range[1])]\n", + " )\n", + "\n", + "\n", + "pl.insert(cut_estimated_scaled_intensities_by_n_root_std_dev)\n", + "pl[NRoot] = 4\n", + "pl[NRootStdDevCut] = 1.0\n", + "\n", + "pl.compute(FilteredEstimatedScaledIntensities)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/packages/essnmx/docs/user-guide/workflow.ipynb b/packages/essnmx/docs/user-guide/workflow.ipynb new file mode 100644 index 00000000..46084b32 --- /dev/null +++ b/packages/essnmx/docs/user-guide/workflow.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NMX Reduction Workflow\n", + "\n", + "> NMX does not expect users to use python interface directly.
\n", + "This documentation is mostly for instrument data scientists or instrument scientists.
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TL;DR" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.executables import reduction\n", + "from ess.nmx.data import get_small_nmx_nexus\n", + "from ess.nmx.configurations import (\n", + " ReductionConfig,\n", + " OutputConfig,\n", + " InputConfig,\n", + " WorkflowConfig,\n", + " TimeBinCoordinate,\n", + ")\n", + "\n", + "# Build Configuration\n", + "config = ReductionConfig(\n", + " inputs=InputConfig(\n", + " input_file=[get_small_nmx_nexus().as_posix()],\n", + " detector_ids=[0, 1, 2],\n", + " ),\n", + " output=OutputConfig(\n", + " output_file=\"scipp_output.hdf\", skip_file_output=False, overwrite=True\n", + " ),\n", + " workflow=WorkflowConfig(\n", + " time_bin_coordinate=TimeBinCoordinate.time_of_flight,\n", + " nbins=10,\n", + " tof_simulation_num_neutrons=1_000_000,\n", + " tof_simulation_min_wavelength=1.8,\n", + " tof_simulation_max_wavelength=3.6,\n", + " tof_simulation_seed=42,\n", + " ),\n", + ")\n", + "\n", + "# Run reduction and display the result.\n", + "result = reduction(config=config, display=display)\n", + "dg = result.to_datagroup()\n", + "dg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "`essnmx` provides a command line data reduction tool.
\n", + "The `essnmx-reduce` interface will reduce `nexus` file
\n", + "and save the results into `NXlauetof`(not exactly but very close) format for `dials`.
\n", + "\n", + "For conveniences and safety, all configuration options are wrapped in a nested pydantic model.
\n", + "Here is a python API you can use to build the configuration and turn it into command line arguments.\n", + "\n", + "**The configuration object is a pydantic model, and it thus enforces strict checks on the types of the arguments.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.configurations import (\n", + " ReductionConfig,\n", + " OutputConfig,\n", + " InputConfig,\n", + " WorkflowConfig,\n", + " TimeBinCoordinate,\n", + " Compression,\n", + " to_command_arguments,\n", + ")\n", + "\n", + "config = ReductionConfig(\n", + " inputs=InputConfig(\n", + " input_file=[\"PATH_TO_THE_NEXUS_FILE.hdf\"],\n", + " detector_ids=[0, 1, 2], # Detector index to be reduced in alphabetical order.\n", + " ),\n", + " output=OutputConfig(output_file=\"scipp_output.hdf\", skip_file_output=True),\n", + " workflow=WorkflowConfig(\n", + " time_bin_coordinate=TimeBinCoordinate.time_of_flight,\n", + " nbins=10,\n", + " tof_simulation_num_neutrons=1_000_000,\n", + " tof_simulation_min_wavelength=1.8,\n", + " tof_simulation_max_wavelength=3.6,\n", + " tof_simulation_seed=42,\n", + " ),\n", + ")\n", + "\n", + "display(config)\n", + "print(to_command_arguments(config=config, one_line=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reduce Nexus File(s)\n", + "\n", + "`OutputConfig` has an option called `skip_file_output` if you want to reduce the file and use it only on the memory.
\n", + "Then you can use `save_results` function to explicitly save the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.executables import reduction\n", + "from ess.nmx.data import get_small_nmx_nexus\n", + "\n", + "config = ReductionConfig(\n", + " inputs=InputConfig(input_file=[get_small_nmx_nexus().as_posix()]),\n", + " output=OutputConfig(skip_file_output=True),\n", + ")\n", + "results = reduction(config=config, display=display)\n", + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx.executables import save_results\n", + "\n", + "output_config = OutputConfig(\n", + " output_file=\"scipp_output.hdf\", overwrite=True, compression=Compression.GZIP\n", + ")\n", + "save_results(results=results, output_config=output_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading Reduced File\n", + "\n", + "There is a custom loader for NXlauetof file for NMX.
\n", + "It reconstructs the position coordinates from the file and adds them back to the data array.
\n", + "The data group should almost look the same as the in-memory results.
\n", + "The loaded data group will not have some coordinates compared to the in-memory results.
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ess.nmx._nxlauetof_io import load_essnmx_nxlauetof\n", + "\n", + "loaded = load_essnmx_nxlauetof('scipp_output.hdf')\n", + "loaded" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can then plot the loaded data array exactly same as the in-memory results.\n", + "\n", + "For example, you can plot the 3D instrument view:\n", + "\n", + "```python\n", + "%matplotlib widget\n", + "import scipp as sc\n", + "import scippneutron as scn\n", + "\n", + "\n", + "dims=('y_pixel_offset', 'x_pixel_offset')\n", + "merged_2d_das = sc.concat(\n", + " [\n", + " det['data'].sum('tof').flatten(dims=dims, to='detector_number')\n", + " for det in loaded['instrument']['detectors'].values()\n", + " ],\n", + " dim='detector_number',\n", + ")\n", + "\n", + "scn.instrument_view(merged_2d_das)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compression Modes\n", + "\n", + "There are multiple compression modes for `detector counts` data(other datasets are not compressed).
\n", + "The default mode is `BITSHUFFLE_LZ4`.
\n", + "\n", + "Here is the rough benchmark results with the small test dataset.
\n", + "With the result, users can decide which compression mode to use.\n", + "\n", + "| Compression Mode | Final Size [MB] | Compression Ratio | Writing Time [s] | Reading Time [s] |\n", + "| ---------------- | --------------- |------------------ | ---------------- | ---------------- |\n", + "| NONE | 1_966 | 1 | 4 | 1 |\n", + "| GZIP | 5 | 370 | 18 | 5 |\n", + "| BITSHUFFLE_LZ4 | 17 | 114 | 10 | 3 |\n", + "\n", + "> In the ESS standard VISA environment. (64 GB mem/6 VCPUs)\n", + "\n", + "`BITSHUFFLE_LZ4` showed almost twice faster speed for writing/reading the reduced file.
\n", + "`GZIP` and `BITSHUFFLE` could both compress the data more than 99% (when the histogram was very empty)
\n", + "but `GZIP` had 3 times better compression ratio than `BITSHUFFLE` for this particular dataset.
\n", + "\n", + ".. note::\n", + "Why `BITSHUFFLE` is the default compression mode? -\n", + "`Bitshuffle` is compatible with `DIALS` and other crystallography packages.\n", + "It is the primary compression mode of all data collected on `DECTRIS EIGER` detectors,\n", + "which are the primary detectors used at synchrotron X-ray MX beamlines.\n", + "Most of these packages can also read `gzip` data but the slow readout makes `gzip` less attractive than `bitshuffle`.\n", + "\n", + "\n", + ".. warning::\n", + "`Bitshuffle` may not be supported in cerntain environments, such as MacOS or Windows.\n", + "It was accepted to be default because the first step of the reduction workflow (this workflow)\n", + "is expected to be run by ESS in the specific environment that bitshuffle supports." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/packages/essnmx/pyproject.toml b/packages/essnmx/pyproject.toml new file mode 100644 index 00000000..a225a968 --- /dev/null +++ b/packages/essnmx/pyproject.toml @@ -0,0 +1,139 @@ +[build-system] +requires = [ + "setuptools>=77", + "setuptools_scm[toml]>=8.0", +] +build-backend = "setuptools.build_meta" + +[project] +name = "essnmx" +description = "Data reduction for NMX at the European Spallation Source." +authors = [{ name = "Scipp contributors" }] +license = "BSD-3-Clause" +license-files = ["../../LICENSE"] +readme = "README.md" +classifiers = [ + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering", + "Typing :: Typed", +] +requires-python = ">=3.11" + +# IMPORTANT: +# Run 'tox -e deps' after making changes here. This will update requirement files. +# Make sure to list one dependency per line. +dependencies = [ + "dask>=2022.1.0", + "essreduce>=26.2.1", + "graphviz", + "plopp>=24.7.0", + "sciline>=24.06.0", + "scipp>=25.3.0", + "scippnexus>=23.12.0", + "scippneutron>=26.02.0", + "pooch>=1.5", + "pandas>=2.1.2", + "gemmi>=0.6.6", + "defusedxml>=0.7.1", + "msgpack>=1.0.8", + "tof>=25.12.1", + "numpy>=2.0.0", +] + +dynamic = ["version"] + +[project.scripts] +essnmx_reduce_mcstas = "ess.nmx.mcstas.executables:main" +essnmx-reduce = "ess.nmx.executables:main" + +[project.optional-dependencies] +test = [ + "pytest>=8.0", + "bitshuffle>=0.5.2;os_name == 'posix'", +] + +[project.urls] +"Bug Tracker" = "https://github.com/scipp/essnmx/issues" +"Documentation" = "https://scipp.github.io/essnmx" +"Source" = "https://github.com/scipp/essnmx" + +[tool.setuptools_scm] +root = "../.." +tag_regex = "^essnmx/(?P[vV]?\\d+(?:\\.\\d+)*(?:[._-]?\\w+)*)$" +git_describe_command = [ "git", "describe", "--dirty", "--long", "--match", "essnmx/*[0-9]*"] + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = """ +--strict-config +--strict-markers +--import-mode=importlib +-ra +-v +""" +testpaths = "tests" +filterwarnings = [ + "error", +] + +[tool.ruff] +line-length = 88 +extend-include = ["*.ipynb"] +extend-exclude = [ + ".*", "__pycache__", "build", "dist", "install", +] + +[tool.ruff.lint] +# See https://docs.astral.sh/ruff/rules/ +select = ["B", "C4", "DTZ", "E", "F", "G", "I", "PERF", "PGH", "PT", "PYI", "RUF", "S", "T20", "UP", "W"] +ignore = [ + # Conflict with ruff format, see + # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + "COM812", "COM819", "D206", "D300", "E111", "E114", "E117", "ISC001", "ISC002", "Q000", "Q001", "Q002", "Q003", "W191", +] +fixable = ["B010", "I001", "PT001", "RUF022"] +isort.known-first-party = ["ess.nmx"] +pydocstyle.convention = "numpy" + +[tool.ruff.lint.per-file-ignores] +# those files have an increased risk of relying on import order +"tests/*" = [ + "S101", # asserts are fine in tests + "B018", # 'useless expressions' are ok because some tests just check for exceptions +] +"*.ipynb" = [ + "E501", # longer lines are sometimes more readable + "F403", # *-imports used with domain types + "F405", # linter may fail to find names because of *-imports + "I", # we don't collect imports at the top + "S101", # asserts are used for demonstration and are safe in notebooks + "T201", # printing is ok for demonstration purposes +] + +[tool.ruff.format] +quote-style = "preserve" + +[tool.mypy] +strict = true +ignore_missing_imports = true +enable_error_code = [ + "ignore-without-code", + "redundant-expr", + "truthy-bool", +] +warn_unreachable = true + +[tool.codespell] +ignore-words-list = [ + # Codespell wants "socioeconomic" which seems to be the standard spelling. + # But we use the word in our code of conduct which is the contributor covenant. + # Let's not modify it if we don't have to. + "socio-economic", +] diff --git a/packages/essnmx/resources/logo-text.svg b/packages/essnmx/resources/logo-text.svg new file mode 100644 index 00000000..1f21b025 --- /dev/null +++ b/packages/essnmx/resources/logo-text.svg @@ -0,0 +1,665 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + essnmx + + diff --git a/packages/essnmx/src/ess/nmx/__init__.py b/packages/essnmx/src/ess/nmx/__init__.py new file mode 100644 index 00000000..d6e82906 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +# ruff: noqa: RUF100, E402, I + +import importlib.metadata + +try: + __version__ = importlib.metadata.version("essnmx") +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" + +del importlib + +from .mcstas import NMXMcStasWorkflow + +__all__ = ["NMXMcStasWorkflow"] diff --git a/packages/essnmx/src/ess/nmx/_display_helper.py b/packages/essnmx/src/ess/nmx/_display_helper.py new file mode 100644 index 00000000..676c3b9e --- /dev/null +++ b/packages/essnmx/src/ess/nmx/_display_helper.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) +from dataclasses import fields, is_dataclass + +import scipp as sc + + +def _is_nested(obj) -> bool: + return is_dataclass(obj) or isinstance(obj, sc.DataGroup | dict) + + +def to_datagroup(obj) -> sc.DataGroup: + if is_dataclass(obj): + return sc.DataGroup( + { + field.name: to_datagroup(value) + if _is_nested(value := getattr(obj, field.name)) + else value + for field in fields(obj) + } + ) + elif isinstance(obj, sc.DataGroup | dict): + return sc.DataGroup( + { + name: to_datagroup(value) if _is_nested(value) else value + for name, value in obj.items() + } + ) + else: + return obj diff --git a/packages/essnmx/src/ess/nmx/_executable_helper.py b/packages/essnmx/src/ess/nmx/_executable_helper.py new file mode 100644 index 00000000..425c494e --- /dev/null +++ b/packages/essnmx/src/ess/nmx/_executable_helper.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import argparse +import enum +import glob +import logging +import pathlib +import sys +from functools import partial +from types import UnionType +from typing import Literal, TypeGuard, TypeVar, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +from .configurations import InputConfig, OutputConfig, ReductionConfig, WorkflowConfig + + +def _validate_annotation(annotation) -> TypeGuard[type]: + def _validate_atomic_type(annotation) -> bool: + return ( + (annotation in (int, float, str, bool)) + or (isinstance(annotation, type) and issubclass(annotation, enum.StrEnum)) + or (get_origin(annotation) is Literal) + ) + + return ( + _validate_atomic_type(annotation) + or ( + (origin := get_origin(annotation)) in (Union, UnionType) + and _validate_atomic_type(_get_no_nonetype_args(annotation)) + ) + or ( + origin in (list, tuple, set) + and len(args := get_args(annotation)) > 0 + and _validate_atomic_type(args[0]) + ) + ) + + +def _get_no_nonetype_args(annotation) -> type: + origin_type = get_origin(annotation) + if (origin_type is UnionType or origin_type is Union) and type(None) in ( + union_args := get_args(annotation) + ): + arg_types = set(union_args) - {type(None)} + if len(arg_types) > 1: + raise TypeError( + "Optional type with single non-None type is not supported: " + f"{annotation}" + ) + return next(iter(arg_types)) + return annotation + + +def _is_appendable_type(annotation) -> bool: + return get_origin(annotation) in (list, tuple, set) + + +def _retrieve_field_value( + field_name: str, field_info: FieldInfo, args: argparse.Namespace +): + if isinstance(field_info.annotation, type) and issubclass( + field_info.annotation, enum.StrEnum + ): + return field_info.annotation[getattr(args, field_name)] + return getattr(args, field_name) + + +def add_args_from_pydantic_model( + *, model_cls: type[BaseModel], parser: argparse.ArgumentParser +) -> argparse.ArgumentParser: + """Add arguments to the parser from the pydantic model class. + + Each field in the model class is added as a command line argument + with the name `--{field-name}`. + Arguments are added based on fields' information: + - type annotation (type, choices, nargs) + - description (help text) + - default value (default, required and help text) + + Supported annotation for command arguments: + - Atomic types: int, float, str, bool, enum.StrEnum, Literal + - Optional[AtomicType] + - List[AtomicType], Tuple[AtomicType, ...], Set[AtomicType] + + Parameters + ---------- + model_cls: + Pydantic model class to extract the arguments from. + parser: + Argument parser to add the arguments to. + It adds a new argument group for the model. + The group name is taken from the model's title config if available, + otherwise the model class name is used. + + """ + group = parser.add_argument_group( + model_cls.model_config.get("title", model_cls.__name__) + ) + for field_name, field_info in model_cls.model_fields.items(): + add_argument = partial(group.add_argument, f"--{field_name.replace('_', '-')}") + + if not _validate_annotation(field_info.annotation): + raise TypeError(f"Unsupported annotation type: {field_info.annotation}") + + arg_type = _get_no_nonetype_args(field_info.annotation) + if _is_appendable_type(arg_type): + nargs = '+' + arg_type = get_args(field_info.annotation)[0] + else: + nargs = None + arg_type = arg_type + + required = field_info.default is PydanticUndefined + default = ... if required else field_info.default + + if arg_type is bool: + add_argument = partial(add_argument, action='store_true') + elif isinstance(arg_type, type) and issubclass(arg_type, enum.StrEnum): + add_argument = partial( + add_argument, + type=str, + choices=[str(e) for e in arg_type], + ) + default = default.name if isinstance(default, enum.StrEnum) else default + elif get_origin(arg_type) is Literal: + add_argument = partial( + add_argument, + type=str, + choices=[str(lit) for lit in get_args(arg_type)], + ) + else: + add_argument = partial(add_argument, type=arg_type, nargs=nargs) + + help_text = ' '.join([field_info.description or '', f"(default: {default})"]) + add_argument(default=default, required=required, help=help_text) + + return parser + + +T = TypeVar('T', bound=BaseModel) + + +def from_args(cls: type[T], args: argparse.Namespace) -> T: + """Create an instance of the pydantic model from the argparse namespace. + + It ignores any extra arguments in the namespace that are not part of the model. + """ + kwargs = { + field_name: _retrieve_field_value(field_name, field_info, args) + for field_name, field_info in cls.model_fields.items() + } + return cls(**kwargs) + + +def build_reduction_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Command line arguments for the ESS NMX reduction. " + "It assumes 14 Hz pulse speed." + ) + parser = add_args_from_pydantic_model(model_cls=InputConfig, parser=parser) + parser = add_args_from_pydantic_model(model_cls=WorkflowConfig, parser=parser) + parser = add_args_from_pydantic_model(model_cls=OutputConfig, parser=parser) + return parser + + +def reduction_config_from_args(args: argparse.Namespace) -> ReductionConfig: + return ReductionConfig( + inputs=from_args(InputConfig, args), + workflow=from_args(WorkflowConfig, args), + output=from_args(OutputConfig, args), + ) + + +def build_logger(args: argparse.Namespace | OutputConfig) -> logging.Logger: + logger = logging.getLogger(__name__) + if args.verbose: + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(sys.stdout)) + return logger + + +def collect_matching_input_files(*input_file_patterns: str) -> list[pathlib.Path]: + """Helper to collect input files matching the given patterns.""" + + input_files: list[str] = [] + for pattern in input_file_patterns: + input_files.extend(glob.glob(pattern)) + + # Remove duplicates and sort + return sorted({pathlib.Path(f).resolve() for f in input_files}) diff --git a/packages/essnmx/src/ess/nmx/_nxlauetof_io.py b/packages/essnmx/src/ess/nmx/_nxlauetof_io.py new file mode 100644 index 00000000..f21dea0a --- /dev/null +++ b/packages/essnmx/src/ess/nmx/_nxlauetof_io.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) +import warnings + +import scipp as sc +import scippnexus as snx +from ess.reduce.nexus.types import FilePath, NeXusFile +from scippneutron.metadata import RadiationProbe, SourceType + +from .types import ControlMode + + +def _validate_entry(entry: snx.Group) -> None: + if str(entry.attrs['NX_class']) != 'NXlauetof': + raise ValueError("File entry is not NXlauetof.") + _MANDATORY_FIELDS = ('control', 'instrument', 'sample') + missing_fields = [field for field in _MANDATORY_FIELDS if field not in entry] + if any(missing_fields): + raise ValueError("File entry missing mandatory fields, ", missing_fields) + + +def _as_vector(var: sc.Variable) -> sc.Variable: + if var.dims == () and var.dtype == sc.DType.vector3: + return var + elif len(var.dims) == 1 and var.sizes[var.dim] == 3: + return sc.vector(value=var.values, unit=var.unit) + else: + warnings.warn( + f"Cannot convert to vector3 scalar: {var}. " + "Falling back to the original form.", + UserWarning, + stacklevel=3, + ) + return var + + +def _handle_sample(sample_dg: sc.DataGroup, sample: snx.Group) -> sc.DataGroup: + sample_dg['crystal_rotation'] = _as_vector(sample_dg['crystal_rotation']) + sample_dg['position'] = _as_vector(sample_dg['position']) + unit_cell = sample_dg.pop('unit_cell') + sample_dg['unit_cell_length'] = sc.vector( + unit_cell[:3], unit=sample['unit_cell'].attrs['length-unit'] + ) + sample_dg['unit_cell_angle'] = sc.vector( + unit_cell[3:], unit=sample['unit_cell'].attrs['angle-unit'] + ) + return sample_dg + + +def _handle_monitor(control_dg: sc.DataGroup, control: snx.Group) -> sc.DataGroup: + tof_bin_coord_key = 'tof_bin_coord' + + if tof_bin_coord_key in control.attrs: + tof_bin_coord = control.attrs['tof_bin_coord'] + control_dg['tof_bin_coord'] = tof_bin_coord + data: sc.DataArray = control_dg['data'] + data.coords[tof_bin_coord] = data.coords.pop('time_of_flight') + + control_dg['mode'] = ControlMode[control_dg['mode']] + + return control_dg + + +def _handle_source(instrument_dg: sc.DataGroup, instrument: snx.Group) -> sc.DataGroup: + source_dg = instrument_dg['source'] + distance = source_dg.pop('distance') + position = sc.vector( + instrument['source']['distance'].attrs['position'], unit=distance.unit + ) + source_dg['position'] = position + source_dg['source_type'] = SourceType(source_dg.pop('type')) + source_dg['probe'] = RadiationProbe(source_dg['probe']) + + +def _restore_positions( + *, metadatas: sc.DataGroup, fast_axis_dim: str, slow_axis_dim: str, sizes: dict +) -> sc.Variable: + fast_axis = metadatas['fast_axis'] + fast_axis_size = sizes[fast_axis_dim] + slow_axis = metadatas['slow_axis'] + slow_axis_size = sizes[slow_axis_dim] + + pixel_sizes = { + 'x_pixel_offset': metadatas['x_pixel_size'], + 'y_pixel_offset': metadatas['y_pixel_size'], + } + + fast_axis_offsets = ( + sc.arange(dim=fast_axis_dim, start=0.0, stop=fast_axis_size) + * pixel_sizes[fast_axis_dim] + * fast_axis + ) + slow_axis_offsets = ( + sc.arange(dim=slow_axis_dim, start=0.0, stop=slow_axis_size) + * pixel_sizes[slow_axis_dim] + * slow_axis + ) + # The slow axis should be the outer most dimension. + detector_sizes = {slow_axis_dim: slow_axis_size, fast_axis_dim: fast_axis_size} + + pixel_offsets = fast_axis_offsets.broadcast( + sizes=detector_sizes + ) + slow_axis_offsets.broadcast(sizes=detector_sizes) + + detetor_center = metadatas['origin'] + slow_axis_width = pixel_sizes[slow_axis_dim] * slow_axis_size + fast_axis_width = pixel_sizes[fast_axis_dim] * fast_axis_size + detector_corner = ( + detetor_center + - (slow_axis_width / 2) * slow_axis + + pixel_sizes[slow_axis_dim] * slow_axis / 2 + - (fast_axis_width / 2) * fast_axis + + pixel_sizes[fast_axis_dim] * fast_axis / 2 + ) + + return pixel_offsets + detector_corner + + +def _handle_detector_data( + instrument_dg: sc.DataGroup, instrument: snx.Group +) -> sc.DataGroup: + detectors: sc.DataGroup[sc.DataGroup] = sc.DataGroup( + { + det_name: instrument_dg.pop(det_name) + for det_name in instrument[snx.NXdetector].keys() + } + ) + instrument_dg['detectors'] = detectors + time_coord_name = next(iter({'tof', 'event_time_offset'} & set(detectors.dims))) + time_field_name = 'time_of_flight' if time_coord_name == 'tof' else time_coord_name + + for det_name, det_gr in detectors.items(): + # These fields are part of the histogram as data and coordinate. + non_meta_keys = ('data', 'time_of_flight', 'event_time_offset') + all_keys = list(filter(lambda key: key not in non_meta_keys, det_gr.keys())) + metadatas = sc.DataGroup({key: det_gr.pop(key) for key in all_keys}) + + for vector_field in ('slow_axis', 'fast_axis', 'origin'): + metadatas[vector_field] = _as_vector(metadatas[vector_field]) + + det_gr['metadata'] = metadatas + fast_axis_dim = instrument[det_name]['fast_axis'].attrs['dim'] + slow_axis_dim = instrument[det_name]['slow_axis'].attrs['dim'] + metadatas['fast_axis_dim'] = fast_axis_dim + metadatas['slow_axis_dim'] = slow_axis_dim + metadatas['detector_name'] = det_name + metadatas['first_pixel_position'] = sc.vector( + instrument[det_name]['origin'].attrs['first_pixel_position'], + unit=metadatas['origin'].unit, + ) + time_coord = metadatas.pop('original_time_edges') + mid_time = det_gr.pop(time_field_name) + if sc.any(sc.midpoints(time_coord) != mid_time): + warnings.warn( + "Time bin edges and mid point coordinates do not agree.", + UserWarning, + stacklevel=3, + ) + det_gr['data'] = sc.DataArray( + data=det_gr['data'], + coords={ + time_coord_name: time_coord, + 'position': _restore_positions( + metadatas=metadatas, + fast_axis_dim=fast_axis_dim, + slow_axis_dim=slow_axis_dim, + sizes=det_gr['data'].sizes, + ), + }, + ) + + +def load_essnmx_nxlauetof(file: str | FilePath | NeXusFile) -> sc.DataGroup: + with snx.File(file, mode='r') as f: + with warnings.catch_warnings(action='ignore'): + # Expecting warnings for loading NXdetectors. + # The data array reconstruction is handled manually later. + dg = f[()] + + _validate_entry(entry := f['entry']) + _handle_sample(dg['entry']['sample'], entry['sample']) + _handle_monitor(dg['entry']['control'], entry['control']) + _handle_source(dg['entry']['instrument'], entry['instrument']) + _handle_detector_data(dg['entry']['instrument'], entry['instrument']) + + return dg['entry'] diff --git a/packages/essnmx/src/ess/nmx/configurations.py b/packages/essnmx/src/ess/nmx/configurations.py new file mode 100644 index 00000000..08e821b6 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/configurations.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import enum + +from pydantic import BaseModel, Field + +from .types import Compression + + +class InputConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Input Configuration"} + # File IO + input_file: list[str] = Field( + title="Input File", + description="Path to the input file. If multiple file paths are given," + " the output(histogram) will be merged(summed) " + "and will not save individual outputs per input file. ", + ) + swmr: bool = Field( + title="SWMR Mode", + description="Open the input file in SWMR mode", + default=False, + ) + # Detector selection + detector_ids: list[int] = Field( + title="Detector IDs", + description="Detector indices to process", + default=[0, 1, 2], + ) + # Chunking options + iter_chunk: bool = Field( + title="Iterate in Chunks", + description="Whether to process the input file in chunks " + " based on the hdf5 dataset chunk size. " + "It is ignored if hdf5 dataset is not chunked. " + "If True, it overrides chunk-size-pulse and chunk-size-events options.", + default=False, + ) + chunk_size_pulse: int = Field( + title="Chunk Size Pulse", + description="Number of pulses to process in each chunk. " + "If 0 or negative, process all pulses at once.", + default=0, + ) + chunk_size_events: int = Field( + title="Chunk Size Events", + description="Number of events to process in each chunk. " + "If 0 or negative, process all events at once." + "If both chunk-size-pulse and chunk-size-events are set, " + "chunk-size-pulse is preferred.", + default=0, + ) + + +class TimeBinUnit(enum.StrEnum): + ms = 'ms' + us = 'us' + ns = 'ns' + + +class TimeBinCoordinate(enum.StrEnum): + event_time_offset = 'event_time_offset' + time_of_flight = 'time_of_flight' + + +class WorkflowConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Workflow Configuration"} + time_bin_coordinate: TimeBinCoordinate = Field( + title="Time Bin Coordinate", + description="Coordinate to bin the time data. " + "Selecting `event_time_offset` means " + "reduction steps are skipped, " + "i.e. calculating `time of flight(tof)` " + "and simply saves histograms of the raw data.", + default=TimeBinCoordinate.time_of_flight, + # Default is time of flight since + # DIALS should expect the time of flight. + ) + nbins: int = Field( + title="Number of Time Bins", + description="Number of Time bins", + default=50, + ) + min_time_bin: int | None = Field( + title="Minimum Time", + description="Minimum time edge of [time_bin_coordinate] in [time_bin_unit].", + default=None, + ) + max_time_bin: int | None = Field( + title="Maximum Time", + description="Maximum time edge of [time_bin_coordinate] in [time_bin_unit].", + default=None, + ) + time_bin_unit: TimeBinUnit = Field( + title="Unit of Time Bins", + description="Unit of time bins.", + default=TimeBinUnit.ms, + ) + tof_lookup_table_file_path: str | None = Field( + title="TOF Lookup Table File Path", + description="Path to the TOF lookup table file. " + "If None, the lookup table will be computed on-the-fly.", + default=None, + ) + tof_simulation_num_neutrons: int = Field( + title="Number of Neutrons for TOF Simulation", + description="Number of neutrons to simulate for TOF lookup table calculation.", + default=1_000_000, + ) + tof_simulation_min_wavelength: float = Field( + title="TOF Simulation Minimum Wavelength", + description="Minimum wavelength for TOF simulation in Angstrom.", + default=1.8, + ) + tof_simulation_max_wavelength: float = Field( + title="TOF Simulation Maximum Wavelength", + description="Maximum wavelength for TOF simulation in Angstrom.", + default=3.6, + ) + tof_simulation_min_ltotal: float = Field( + title="TOF Simulation Minimum Ltotal", + description="Minimum total flight path for TOF simulation in meters.", + default=150.0, + ) + tof_simulation_max_ltotal: float = Field( + title="TOF Simulation Maximum Ltotal", + description="Maximum total flight path for TOF simulation in meters.", + default=170.0, + ) + tof_simulation_seed: int = Field( + title="TOF Simulation Seed", + description="Random seed for TOF simulation.", + default=42, # No reason. + ) + + +class OutputConfig(BaseModel): + # Add title of the basemodel + model_config = {"title": "Output Configuration"} + # Log verbosity + verbose: bool = Field( + title="Verbose Logging", + description="Increase output verbosity.", + default=False, + ) + # File output + skip_file_output: bool = Field( + title="Skip File Output", + description="If True, the output file will not be written.", + default=False, + ) + output_file: str = Field( + title="Output File", + description="Path to the output file. " + "It will be overwritten if ``overwrite`` is True.", + default="scipp_output.h5", + ) + overwrite: bool = Field( + title="Overwrite Output File", + description="If True, overwrite the output file if ``output_file`` exists.", + default=False, + ) + compression: Compression = Field( + title="Compression", + description="Compress option of reduced output file.", + default=Compression.BITSHUFFLE_LZ4, + ) + + +class ReductionConfig(BaseModel): + """Container for all reduction configurations.""" + + inputs: InputConfig + workflow: WorkflowConfig = Field(default_factory=WorkflowConfig) + output: OutputConfig = Field(default_factory=OutputConfig) + + @property + def _children(self) -> list[BaseModel]: + return [self.inputs, self.workflow, self.output] + + +def to_command_arguments( + *, config: ReductionConfig, one_line: bool = True, separator: str = '\\\n' +) -> list[str] | str: + """Convert the config to a list of command line arguments. + + Parameters + ---------- + one_line: + If True, return a single string with all arguments joined by spaces. + If False, return a list of argument strings. + + """ + args = {} + for instance in config._children: + args.update(instance.model_dump(mode='python')) + args = {f"--{k.replace('_', '-')}": v for k, v in args.items() if v is not None} + + arg_list = [] + for k, v in args.items(): + if not isinstance(v, bool): + arg_list.append(k) + if isinstance(v, list): + arg_list.extend(str(item) for item in v) + elif isinstance(v, enum.StrEnum): + arg_list.append(v.value) + else: + arg_list.append(str(v)) + elif v is True: + arg_list.append(k) + + if one_line: + # Default separator is backslash + newline for better readability + # Users can directly copy-paste the output in a terminal or a script. + return ( + (separator + '--') + .join(" ".join(arg_list).split('--')) + .removeprefix(separator) + ) + else: + return arg_list diff --git a/packages/essnmx/src/ess/nmx/data/__init__.py b/packages/essnmx/src/ess/nmx/data/__init__.py new file mode 100644 index 00000000..ece282a9 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/data/__init__.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pathlib + +from ess.reduce.data import Entry, make_registry + +_version = "1" + +__all__ = [ + "get_path", + "get_small_mcstas", + "get_small_mtz_samples", + "get_small_nmx_nexus", + "get_small_random_mtz_samples", +] + + +_registry = make_registry( + "ess/nmx", + version=_version, + files={ + "small_mcstas_sample.h5": "md5:2afaac205d13ee857ee5364e3f1957a7", + "mtz_samples.tar.gz": Entry( + alg="md5", chk="bed1eaf604bbe8725c1f6a20ca79fcc0", extractor="untar" + ), + "mtz_random_samples.tar.gz": Entry( + alg="md5", chk="c8259ae2e605560ab88959e7109613b6", extractor="untar" + ), + "small_nmx_nexus.hdf.zip": Entry( + alg="md5", chk="96877cddc9f6392c96890069657710ca", extractor="unzip" + ), + }, +) + + +def get_small_mcstas() -> pathlib.Path: + """McStas file that contains only ``bank0(1-3)`` in the ``data`` group. + + Real McStas file should contain more dataset under ``data`` group. + McStas version >=3. + """ + return get_path("small_mcstas_sample.h5") + + +def get_path(name: str) -> pathlib.Path: + """ + Return the path to a data file bundled with ess nmx. + + This function only works with example data and cannot handle + paths to custom files. + """ + return _registry.get_path(name) + + +def get_small_mtz_samples() -> list[pathlib.Path]: + """Return a list of path to MTZ sample files randomly chosen from real dataset. + + This samples also contain optional columns. + """ + return _registry.get_paths("mtz_samples.tar.gz") + + +def get_small_random_mtz_samples() -> list[pathlib.Path]: + """Return a list of path to MTZ sample files filled with random values + + This sample only contains mandatory columns for the workflow examples. + They are made for documentation, not necessarily for testing. + Use ``get_small_mtz_samples`` for testing since they are + more representative of real data. + """ + return _registry.get_paths("mtz_random_samples.tar.gz") + + +def get_small_nmx_nexus() -> pathlib.Path: + """Return the path to a small NMX NeXus file.""" + + return get_path("small_nmx_nexus.hdf.zip") diff --git a/packages/essnmx/src/ess/nmx/dials_io.py b/packages/essnmx/src/ess/nmx/dials_io.py new file mode 100644 index 00000000..d615f5b4 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/dials_io.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import json +import pathlib +from typing import NewType + +import gemmi +import numpy as np +import pandas as pd +import scipp as sc + +from .dials_reflection_io import load + +# User defined or configurable types +DialsReflectionFilePath = NewType("DialsReflectionFilePath", pathlib.Path) +"""Path to the dials reflection file""" +DialsReflectionFile = NewType("DialsReflectionFile", dict) +"""The raw DIALS reflection file, read in as a dict""" +DialsExperimentFilePath = NewType("DialsExperimentFilePath", pathlib.Path) +"""Path to the dials experiment file""" +DialsExperiment = NewType("DialsExperiment", dict) +"""Experiment details from DIALS .expt file (JSON format)""" +SpaceGroupDesc = NewType("SpaceGroupDesc", str) +"""The space group description. e.g. 'P 21 21 21'""" +DEFAULT_SPACE_GROUP_DESC = SpaceGroupDesc("P 1") +"""The default space group description to use if not found in the input files.""" +UnitCell = NewType("UnitCell", tuple[float]) +"""The unit cell a, b, c in Angstrom, alpha, beta, gamma in degrees""" + +# Custom column names +WavelengthColumnName = NewType("WavelengthColumnName", str) +"""The name of the wavelength column in the DIALS reflection file.""" +DEFAULT_WAVELENGTH_COLUMN_NAME = WavelengthColumnName("LAMBDA") + +IntensityColumnName = NewType("IntensityColumnName", str) +"""The name of the intensity column in the DIALS reflection file.""" +DEFAULT_INTENSITY_COLUMN_NAME = IntensityColumnName("I") + +VarianceColumnName = NewType("VarianceColumnName", str) +"""The name of the variance (stdev(I)**2) of intensity column in the DIALS reflection +file.""" +DEFAULT_VARIANCE_COLUMN_NAME = VarianceColumnName("VARI") + +StdDevColumnName = NewType("StdDevColumnName", str) +"""The name of the standard deviation of intensity column in the DIALS reflection +file.""" +DEFAULT_STDEV_COLUMN_NAME = VarianceColumnName("SIGI") + + +# Computed types +DialsDataFrame = NewType("DialsDataFrame", pd.DataFrame) +"""The raw mtz dataframe.""" +NMXDialsDataFrame = NewType("NMXDialsDataFrame", pd.DataFrame) +"""The processed mtz dataframe with derived columns.""" +NMXDialsDataArray = NewType("NMXDialsDataArray", sc.DataArray) + + +def read_dials_reflection_file( + file_path: DialsReflectionFilePath, +) -> DialsReflectionFile: + """read dials reflection file""" + + return DialsReflectionFile(load(file_path.as_posix(), copy=True)) + + +def read_dials_experiment_file(file_path: DialsExperimentFilePath) -> DialsExperiment: + """Read Dials Experiment .expt file""" + + return DialsExperiment(json.load(open(file_path))) + + +def get_unit_cell(dials_expt: DialsExperiment) -> UnitCell: + """ + Get the unit cell from the expt file. + It is saved as real-space vectors so the unit cell has to be + calculated from them. + """ + crystal = dials_expt['crystal'][0] + ra, rb, rc = tuple([crystal[f'real_space_{x}'] for x in 'abc']) + a = np.linalg.norm(ra) + b = np.linalg.norm(rb) + c = np.linalg.norm(rc) + al = np.rad2deg(np.arccos(np.dot(rb, rc) / (b * c))) + be = np.rad2deg(np.arccos(np.dot(ra, rc) / (a * c))) + ga = np.rad2deg(np.arccos(np.dot(ra, rb) / (a * b))) + + return UnitCell(a, b, c, al, be, ga) + + +def get_unique_space_group(dials_expt: DialsExperiment) -> gemmi.SpaceGroup: + """ + Get space group from Dials expt file. + For cctbx/disambiguation reasons it is saved as the Hall symbol, but + the H-M notation can be back-determined with gemmi. + """ + crystal = dials_expt['crystal'][0] + sg_hall = crystal['space_group_hall_symbol'] + + return gemmi.find_spacegroup_by_ops(gemmi.symops_from_hall(sg_hall)) + + +def get_reciprocal_asu(spacegroup: gemmi.SpaceGroup) -> gemmi.ReciprocalAsu: + """Returns the reciprocal asymmetric unit from the space group.""" + + return gemmi.ReciprocalAsu(spacegroup) + + +def dials_refl_to_pandas(refls: dict) -> pd.DataFrame: + """Converts the loaded DIALS reflection file to a pandas dataframe. + + It is equivalent to the following code: + + .. code-block:: python + + import numpy as np + import pandas as pd + + data = np.array(mtz, copy=False) + columns = mtz.column_labels() + return pd.DataFrame(data, columns=columns) + + It is recommended in the gemmi documentation. + + """ + if refls.get('experiment_identifier'): # this has no relevant information + del refls['experiment_identifier'] # and it complicates loading as a df + return pd.DataFrame( + { + key: list(val) if isinstance(val, np.ndarray) and val.ndim > 1 else val + for key, val in refls.items() + } + ) + + +def process_dials_refl_list_to_dataframe( + refls: dict, +) -> DialsDataFrame: + """Select and derive columns from the original ``MtzDataFrame``. + + Parameters + ---------- + mtz: + The raw mtz dataset. + + wavelength_column_name: + The name of the wavelength column in the mtz file. + + intensity_column_name: + The name of the intensity column in the mtz file. + + intensity_sig_col_name: + The name of the standard uncertainty of intensity column in the mtz file. + + Returns + ------- + : + The new mtz dataframe with derived and renamed columns. + + The derived columns are: + + - ``SIGI``: The uncertainty of the intensity value, defined as the square root + of the measured variance. + + For consistent names of columns/coordinates, the following columns are renamed: + + - ``wavelength_column_name`` -> ``'wavelength'`` + - ``intensity_column_name`` -> ``'I'`` + - ``intensity_sig_col_name`` -> ``'SIGI'`` + + Other columns are kept as they are. + + """ + orig_df = dials_refl_to_pandas(refls) + new_df = pd.DataFrame() + + new_df['H'] = orig_df['miller_index'].map(lambda x: x[0]).astype(int) + new_df['K'] = orig_df['miller_index'].map(lambda x: x[1]).astype(int) + new_df['L'] = orig_df['miller_index'].map(lambda x: x[2]).astype(int) + + new_df['hkl'] = orig_df['miller_index'] + new_df["d"] = orig_df['d'] + new_df['wavelength'] = orig_df['wavelength_cal'] + + new_df[DEFAULT_INTENSITY_COLUMN_NAME] = orig_df['intensity.sum.value'] + new_df[DEFAULT_VARIANCE_COLUMN_NAME] = orig_df['intensity.sum.variance'] + new_df[DEFAULT_STDEV_COLUMN_NAME] = np.sqrt(orig_df['intensity.sum.variance']) + + for column in [col for col in orig_df.columns if col not in new_df]: + new_df[column] = orig_df[column] + + return DialsDataFrame(new_df) + + +def process_dials_dataframe( + *, + dials_df: DialsDataFrame, + reciprocal_asu: gemmi.ReciprocalAsu, + sg: gemmi.SpaceGroup, +) -> NMXDialsDataFrame: + """Modify/Add columns of the shallow copy of a dials dataframe. + + This method must be called after merging multiple mtz dataframe. + """ + + df = dials_df.copy(deep=False) + + def _reciprocal_asu(row: pd.Series) -> list[int]: + """Converts miller indices(HKL) to ASU indices.""" + + return reciprocal_asu.to_asu(row["hkl"], sg.operations())[0] + + df["hkl_asu"] = df.apply(_reciprocal_asu, axis=1) + # Unpack the indices for later. + df[["H_ASU", "K_ASU", "L_ASU"]] = pd.DataFrame( + df["hkl_asu"].to_list(), index=df.index + ) + + return NMXDialsDataFrame(df) + + +def nmx_dials_dataframe_to_scipp_dataarray( + nmx_mtz_df: NMXDialsDataFrame, +) -> NMXDialsDataArray: + """Converts the processed mtz dataframe to a scipp dataarray. + + The intensity, with column name :attr:`~DEFAULT_INTENSITY_COLUMN_NAME` + becomes the data and the standard uncertainty of intensity, + with column name :attr:`~DEFAULT_SIGMA_INTENSITY_COLUMN_NAME` + becomes the variances of the data. + + Parameters + ---------- + nmx_mtz_df: + The merged and processed mtz dataframe. + + Returns + ------- + : + The scipp dataarray with the intensity and variances. + The ``I`` column becomes the data and the + squared ``SIGI`` column becomes the variances. + Therefore they are not in the coordinates. + + Following coordinates are modified: + + - ``hkl``: The miller indices as a string. + It is modified to have a string dtype + since is no dtype that can represent this in scipp. + + - ``hkl_asu``: The asymmetric unit of miller indices as a string. + This coordinate will be used to derive estimated scale factors. + It is modified to have a string dtype + as the same reason as why ``hkl`` coordinate is modified. + + Zero or negative intensities are removed from the dataarray. + It can happen due to the post-processing of the data, + e.g. background subtraction. + + """ + from scipp.compat.pandas_compat import from_pandas_dataframe, parse_bracket_header + + to_scipp = nmx_mtz_df.copy(deep=False) + # Convert to scipp Dataset + nmx_mtz_ds = from_pandas_dataframe( + to_scipp, + data_columns=[ + DEFAULT_INTENSITY_COLUMN_NAME, + DEFAULT_STDEV_COLUMN_NAME, + DEFAULT_VARIANCE_COLUMN_NAME, + ], + header_parser=parse_bracket_header, + ) + # Pop the indices columns. + # TODO: We can put them back once we support tuple[int] dtype. + # See https://github.com/scipp/scipp/issues/3046 for more details. + # Temporarily, we will manually convert them to a string. + # It is done on the scipp variable instead of the dataframe + # since columns with string dtype are converted to PyObject dtype + # instead of string by `from_pandas_dataframe`. + for indices_name in ("hkl", "hkl_asu"): + nmx_mtz_ds.coords[indices_name] = sc.array( + dims=nmx_mtz_ds.coords[indices_name].dims, + values=nmx_mtz_df[indices_name].astype(str).tolist(), + # `astype`` is not enough to convert the dtype to string. + # The result of `astype` will have `PyObject` as a dtype. + ) + # Add units + nmx_mtz_ds.coords["wavelength"].unit = sc.units.angstrom + for key in nmx_mtz_ds.keys(): + nmx_mtz_ds[key].unit = sc.units.dimensionless + + # Add variances + nmx_mtz_da = nmx_mtz_ds[DEFAULT_INTENSITY_COLUMN_NAME].copy(deep=False) + nmx_mtz_da.variances = nmx_mtz_ds[DEFAULT_VARIANCE_COLUMN_NAME].values + + # Return DataArray without negative intensities + return NMXDialsDataArray(nmx_mtz_da[nmx_mtz_da.data > 0]) diff --git a/packages/essnmx/src/ess/nmx/dials_reflection_io.py b/packages/essnmx/src/ess/nmx/dials_reflection_io.py new file mode 100644 index 00000000..dd72f624 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/dials_reflection_io.py @@ -0,0 +1,251 @@ +""" +DIALS .refl file loader + +This loads msgpack-type DIALS reflection files, without having DIALS or +cctbx in the python environment. + +Note: All modern .refl files are at time of writing msgpack-based. Some +much older files might be in pickle format, which this doesn't read. + +Adapted from Nick Cavendish of the DIALS team. +""" + +import functools +import logging +import operator +import os +import struct +from collections.abc import Iterable +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import IO, cast + +import msgpack +import numpy as np + + +@dataclass +class Shoebox: + panel: int + bbox: tuple[int] + data: np.ndarray | None = None + mask: np.ndarray | None = None + background: np.ndarray | None = None + + +def _decode_raw_numpy(dtype, shape: int | Iterable = 1): + """ + Decoding a column that maps straight to a numpy array. + + Args: + dtype: The numpy dtype for the array + shape: + The shape of a single item. Either an int, or a collection + of ints, in C-array order (row major) + """ + # Convert to a shape tuple + if isinstance(shape, int): + shape = (shape,) + else: + shape = tuple(shape) + + def _decode_specific(data, copy): + num_items, raw = data + array = np.frombuffer(raw, dtype=dtype) + + if shape != (1,): + item_width = functools.reduce(operator.mul, shape) + if len(raw) % item_width != 0: + raise AssertionError( + "Raw data length %s not divisible by item width %s", + len(raw), + item_width, + ) + if num_items * item_width != len(array): + raise AssertionError( + "(Num items) %s * (item width) %s != (raw data length) %s", + num_items, + item_width, + len(raw), + ) + array = array.reshape(num_items, *shape) + if copy: + return np.copy(array) + return array + + return _decode_specific + + +def _decode_shoeboxes(data: list, copy) -> Iterable[Shoebox]: + # Shoebox is float + num_items, raw = data + shoeboxes: list[Shoebox | None] = [] + pos = 0 + while pos < len(raw): + sbox_header_fmt = "": _decode_raw_numpy(np.double, shape=3), + "cctbx::miller::index<>": _decode_raw_numpy(np.int32, shape=3), + "Shoebox<>": _decode_shoeboxes, + "vec2": _decode_raw_numpy(np.double, shape=2), + "mat3": _decode_raw_numpy(np.double, shape=(3, 3)), + # "std::string": _decode_wip, # - string writing broken; dials/dials#1858 +} + + +def decode_column(column_entry, copy): + """Decode a single column value""" + datatype, data = column_entry + + converter = _reftable_decoders.get(datatype) + if not converter: + logging.warning( + "Data type '%s' does not have a converter; cannot read", datatype + ) + return None + return converter(data, copy=copy) + + +def _get_unpacked(stream_or_path: str | IO | bytes | os.PathLike): + """Works out the logic to pass a stream/pathlike to msgpack""" + try: + logging.INFO(type(stream_or_path)) + path = os.fspath(cast(str, stream_or_path)) + is_fspathlike = True + except (TypeError, ValueError): + path = stream_or_path + is_fspathlike = isinstance(stream_or_path, str) + + if is_fspathlike: + with open(path, "rb") as f: + un = msgpack.Unpacker(f, strict_map_key=False) + return un.unpack() + else: + un = msgpack.Unpacker(stream_or_path, strict_map_key=False) + return un.unpack() + + +def loads(data: bytes, copy=False): + """ + Load a DIALS msgpack-encoded .refl file. + + Args: + data: bytes data, already read from the file. + copy: Should the data be copied into writable numpy arrays. + + Returns: See .load(stream_or_path) + """ + return load(BytesIO(data), copy) + + +def load(stream_or_path: IO | Path | os.PathLike, copy=False) -> dict: + """ + Load a DIALS msgpack-encoded .refl file + + Args: + stream_or_path: The filename or data to load + copy: + Should the data be copied. This will cause more memory usage + whilst loading the raw data. + + Returns: + + A dictionary with each column in the reflection table. If there + is an identifier mapping as part of the reflection table, then + this is returned as an extra 'experiment_identifier' column. + All columns except Shoeboxes are returned as numpy arrays, + except Shoebox columns, which are returned as Dataclass objects + which contain the portions of data from the file. + + With copy=False, all numpy arrays are pointing against the raw + memory returned by msgpack, which means they are read-only. + With copy=True, an immediate copy is done. This causes memory + usage to double while loading, but the created numpy arrays own + their own memory. + """ + root_data = _get_unpacked(stream_or_path) + + if not root_data[0] == "dials::af::reflection_table": + raise ValueError("Does not appear to be a dials reflection table file") + if not root_data[1] == 1: + raise ValueError( + f"reflection_table data is version {root_data[1]}." + "Only Version 1 is understood" + ) + refdata = root_data[2] + + rows = refdata["nrows"] + identifiers = refdata["identifiers"] + data = refdata["data"] + + decoded_data = { + name: decode_column(value, copy=copy) for name, value in data.items() + } + + # Filter out empty (unknown) columns + decoded_data = {k: v for k, v in decoded_data.items() if v is not None} + + # Cross-check the columns are the expected lengths + for name, column in decoded_data.items(): + if len(column) != rows: + logging.warning( + "Warning: Mismatch of column lengths: " + "[%s] is [%d] instead of expected [%s]", + name, + len(column), + rows, + ) + + # Make an "identifiers" column + if "id" in decoded_data and identifiers: + decoded_data["experiment_identifier"] = [ + identifiers[x] for x in decoded_data["id"] if x > 0 + ] + + return decoded_data diff --git a/packages/essnmx/src/ess/nmx/executables.py b/packages/essnmx/src/ess/nmx/executables.py new file mode 100644 index 00000000..21643caf --- /dev/null +++ b/packages/essnmx/src/ess/nmx/executables.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import logging +import pathlib +import warnings +from collections.abc import Callable + +import numpy as np +import scipp as sc +import scippnexus as snx +from ess.reduce.nexus.types import Filename, NeXusName, RawDetector, SampleRun +from ess.reduce.time_of_flight.types import TimeOfFlightLookupTable, TofDetector + +from ._executable_helper import ( + build_logger, + build_reduction_argument_parser, + collect_matching_input_files, + reduction_config_from_args, +) +from .configurations import ( + OutputConfig, + ReductionConfig, + TimeBinCoordinate, + WorkflowConfig, +) +from .nexus import ( + _check_file, + export_detector_metadata_as_nxlauetof, + export_monitor_metadata_as_nxlauetof, + export_reduced_data_as_nxlauetof, + export_static_metadata_as_nxlauetof, +) +from .types import ( + NMXDetectorMetadata, + NMXInstrument, + NMXLauetof, + NMXMonitorMetadata, + NMXReducedDetector, + NMXSampleMetadata, + NMXSourceMetadata, +) +from .workflows import initialize_nmx_workflow, select_detector_names + +_TOF_COORD_NAME = 'tof' +"""Name of the TOF coordinate used in DataArrays.""" +_ETO_COORD_NAME = 'event_time_offset' +"""Name of the Event Time Offset Coordinate used in Nexus.""" + + +def _retrieve_input_file(input_file: list[str]) -> pathlib.Path: + """Temporary helper to retrieve a single input file from the list + Until multiple input file support is implemented. + """ + from collections import Counter + + # Check duplicated pattern or paths + _counts = Counter(input_file) + duplicating_patterns = {pattern for pattern, num in _counts.items() if num > 1} + if duplicating_patterns: + raise ValueError( + f"Duplicated file paths or pattern found. {duplicating_patterns}" + ) + + if isinstance(input_file, list): + input_files = collect_matching_input_files(*input_file) + if len(input_files) == 0: + raise ValueError( + "No input files found for reduction." + "Check if the file paths are correct.", + input_file, + ) + elif len(input_files) > 1: + raise NotImplementedError( + "Currently, only a single input file is supported for reduction." + ) + input_file_path = input_files[0] + else: + input_file_path = input_file + + return input_file_path + + +def _retrieve_display( + logger: logging.Logger | None, display: Callable | None +) -> Callable: + if display is not None: + return display + elif logger is not None: + return logger.info + else: + return logging.getLogger(__name__).info + + +def _retrieve_time_bin_coordinate_name(wf_config: WorkflowConfig) -> str: + if wf_config.time_bin_coordinate == TimeBinCoordinate.time_of_flight: + return _TOF_COORD_NAME + elif wf_config.time_bin_coordinate == TimeBinCoordinate.event_time_offset: + return _ETO_COORD_NAME + + +def _warn_bin_edge_out_of_range( + *, edge: sc.Variable, coord_name: str, desc: str +) -> None: + warnings.warn( + message=f"{edge} is {desc} than all " + f"{coord_name} values.\n" + "The histogram will all have zero values.", + category=UserWarning, + stacklevel=4, + ) + + +def _match_data_unit_dtype(config_var: sc.Variable, da: sc.Variable) -> sc.Variable: + return config_var.to(unit=da.unit, dtype=da.dtype) + + +def _build_time_bin_edges( + *, + wf_config: WorkflowConfig, + result_das: sc.DataGroup, + t_coord_name: str, +) -> sc.Variable: + # Calculate the min and max of the data itself. + da_min_t = min(da.bins.coords[t_coord_name].nanmin() for da in result_das.values()) + da_max_t = max(da.bins.coords[t_coord_name].nanmax() for da in result_das.values()) + + # Use the user-set parameters if available + # and validate them according to the data. + # Lower Time Bin Edge + if wf_config.min_time_bin is not None: + min_t = sc.scalar(wf_config.min_time_bin, unit=wf_config.time_bin_unit) + min_t = _match_data_unit_dtype(min_t, da=da_min_t) + # If the user-set minimum time bin value + # is bigger than all time-bin-coordinate values. + if min_t > da_max_t: + _warn_bin_edge_out_of_range( + edge=min_t, coord_name=wf_config.time_bin_coordinate, desc='bigger' + ) + else: + min_t = da_min_t + + # Upper Time Bin Edge + if wf_config.max_time_bin is not None: + max_t = sc.scalar(wf_config.max_time_bin, unit=wf_config.time_bin_unit) + max_t = _match_data_unit_dtype(max_t, da=da_max_t) + # If the user-set maximum time bin value + # is smaller than all time-bin-coordinate values. + if max_t <= da_min_t: + _warn_bin_edge_out_of_range( + edge=max_t, coord_name=wf_config.time_bin_coordinate, desc='smaller' + ) + else: + max_t = da_max_t + + # Avoid dropping the event that has the exact same + # `event_time_offset`` or `tof` value as the upper bin edge. + max_t.value = np.nextafter(max_t.value, np.inf) + + # Validate the results. + if min_t >= max_t: + raise ValueError( + f"Minimum time bin edge, {min_t} " + "is bigger than or equal to the " + f"maximum time bin edge, {max_t}.\n" + "Cannot build a time bin edges coordinate.\n" + "Please check your configurations again." + ) + + # Build the bin-edges to histogram the results. + n_edges = wf_config.nbins + 1 + return sc.linspace(dim=t_coord_name, start=min_t, stop=max_t, num=n_edges) + + +def reduction( + *, + config: ReductionConfig, + logger: logging.Logger | None = None, + display: Callable | None = None, +) -> NMXLauetof: + """Reduce NMX data from a Nexus file and export to NXLauetof(ESS NMX specific) file. + + Parameters + ---------- + config: + Reduction configuration. + + Data reduction parameters are taken from this config + instead of passing them directly as keyword arguments. + They can be either built from command-line arguments + using `ReductionConfig.from_args()` or constructed manually. + + If the reduced data is successfully written to the output file + the configuration is also saved there for future reference. + logger: + Logger to use for logging messages. If None, a default logger is created. + display: + Callable for displaying messages, useful in Jupyter notebooks. If None, + defaults to logger.info. + + Returns + ------- + sc.DataGroup: + A DataGroup containing the reduced data for each selected detector. + + """ + # Check the file output configuration before we start heavy computation. + if not config.output.skip_file_output: + _check_file(config.output.output_file, config.output.overwrite) + + display = _retrieve_display(logger, display) + input_file_path = _retrieve_input_file(config.inputs.input_file).resolve() + display(f"Input file: {input_file_path}") + + output_file_path = pathlib.Path(config.output.output_file).resolve() + display(f"Output file: {output_file_path}") + + detector_names = select_detector_names(detector_ids=config.inputs.detector_ids) + + # Initialize workflow + base_wf = initialize_nmx_workflow(config=config.workflow) + # Insert parameters and cache intermediate results + base_wf[Filename[SampleRun]] = input_file_path + + if config.workflow.time_bin_coordinate == TimeBinCoordinate.time_of_flight: + # We cache the time of flight look up table + # only if we need to calculate time-of-flight coordinates. + # If `event_time_offset` was requested, + # we do not have to calculate the look up table at all. + base_wf[TimeOfFlightLookupTable] = base_wf.compute(TimeOfFlightLookupTable) + + metadatas = base_wf.compute((NMXSampleMetadata, NMXSourceMetadata)) + + tof_das = sc.DataGroup() + detector_metas = sc.DataGroup() + + if config.workflow.time_bin_coordinate == TimeBinCoordinate.event_time_offset: + target_type = RawDetector[SampleRun] + elif config.workflow.time_bin_coordinate == TimeBinCoordinate.time_of_flight: + target_type = TofDetector[SampleRun] + + for detector_name in detector_names: + cur_wf = base_wf.copy() + cur_wf[NeXusName[snx.NXdetector]] = detector_name + results = cur_wf.compute((target_type, NMXDetectorMetadata)) + detector_metas[detector_name] = results[NMXDetectorMetadata] + # Binning into 1 bin and getting final tof bin edges later. + tof_das[detector_name] = results[target_type] + + # Make tof bin edges covering all detectors + t_coord_name = _retrieve_time_bin_coordinate_name(wf_config=config.workflow) + t_bin_edges = _build_time_bin_edges( + wf_config=config.workflow, result_das=tof_das, t_coord_name=t_coord_name + ) + + monitor_metadata = NMXMonitorMetadata( + tof_bin_coord=t_coord_name, + # TODO: Use real monitor data + # Currently NMX simulations or experiments do not have monitors + data=sc.DataArray( + coords={t_coord_name: t_bin_edges}, + data=sc.ones_like(t_bin_edges[:-1]), + ), + ) + + # Histogram detector counts + tof_histograms = sc.DataGroup() + for detector_name, tof_da in tof_das.items(): + histogram = tof_da.hist({t_coord_name: t_bin_edges}) + tof_histograms[detector_name] = histogram + + detector_results = sc.DataGroup( + { + detector_name: NMXReducedDetector( + data=histogram, metadata=detector_metas[detector_name] + ) + for detector_name, histogram in tof_histograms.items() + } + ) + source_meta: NMXSourceMetadata = metadatas[NMXSourceMetadata] + sample_meta: NMXSampleMetadata = metadatas[NMXSampleMetadata] + + results = NMXLauetof( + control=monitor_metadata, + instrument=NMXInstrument(detectors=detector_results, source=source_meta), + sample=sample_meta, + ) + + if config.workflow.time_bin_coordinate == TimeBinCoordinate.time_of_flight: + results.lookup_table = base_wf.compute(TimeOfFlightLookupTable) + + if not config.output.skip_file_output: + save_results(results=results, output_config=config.output) + + return results + + +def save_results(*, results: NMXLauetof, output_config: OutputConfig) -> None: + # Validate if results have expected fields + + export_static_metadata_as_nxlauetof( + sample_metadata=results.sample, + source_metadata=results.instrument.source, + program=results.reducer, + output_file=output_config.output_file, + overwrite=output_config.overwrite, + ) + export_monitor_metadata_as_nxlauetof( + monitor_metadata=results.control, + output_file=output_config.output_file, + ) + for detector_name, detector_result in results.instrument.detectors.items(): + export_detector_metadata_as_nxlauetof( + detector_metadata=detector_result.metadata, + output_file=output_config.output_file, + ) + if isinstance(detector_result.data, sc.DataArray): + export_reduced_data_as_nxlauetof( + detector_name=detector_name, + da=detector_result.data, + output_file=output_config.output_file, + compress_mode=output_config.compression, + ) + else: + raise ValueError(f"Detector counts histogram missing in {detector_name}") + + +def main() -> None: + parser = build_reduction_argument_parser() + config = reduction_config_from_args(parser.parse_args()) + logger = build_logger(config.output) + + reduction(config=config, logger=logger) diff --git a/packages/essnmx/src/ess/nmx/mcstas/__init__.py b/packages/essnmx/src/ess/nmx/mcstas/__init__.py new file mode 100644 index 00000000..11700114 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +from .types import MaximumCounts + +default_parameters = {MaximumCounts: 10000} + + +def NMXMcStasWorkflow(): + import sciline as sl + + from .load import providers as loader_providers + from .reduction import ( + calculate_maximum_toa, + calculate_minimum_toa, + format_nmx_reduced_data, + proton_charge_from_event_counts, + raw_event_probability_to_counts, + reduce_raw_event_probability, + ) + from .xml import read_mcstas_geometry_xml + + return sl.Pipeline( + ( + *loader_providers, + calculate_maximum_toa, + calculate_minimum_toa, + read_mcstas_geometry_xml, + proton_charge_from_event_counts, + reduce_raw_event_probability, + raw_event_probability_to_counts, + format_nmx_reduced_data, + ), + params=default_parameters, + ) diff --git a/packages/essnmx/src/ess/nmx/mcstas/executables.py b/packages/essnmx/src/ess/nmx/mcstas/executables.py new file mode 100644 index 00000000..3bef4e6e --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/executables.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import argparse +import logging +import pathlib +from collections.abc import Callable +from functools import partial + +import sciline as sl +import scipp as sc +from ess.reduce.streaming import ( + EternalAccumulator, + MaxAccumulator, + MinAccumulator, + StreamProcessor, +) + +from ..types import Compression +from . import NMXMcStasWorkflow +from .load import ( + mcstas_weight_to_probability_scalefactor, + raw_event_data_chunk_generator, +) +from .nexus import ( + _export_detector_metadata_as_nxlauetof, + _export_reduced_data_as_nxlauetof, + _export_static_metadata_as_nxlauetof, +) +from .streaming import calculate_number_of_chunks +from .types import ( + DetectorIndex, + DetectorName, + FilePath, + MaximumCounts, + MaximumProbability, + MaximumTimeOfArrival, + McStasWeight2CountScaleFactor, + MinimumTimeOfArrival, + NMXDetectorMetadata, + NMXExperimentMetadata, + NMXRawDataMetadata, + NMXReducedCounts, + NMXReducedDataGroup, + PixelIds, + RawEventProbability, + TimeBinSteps, +) +from .xml import McStasInstrument + + +def _build_metadata_streaming_processor_helper() -> Callable[ + [sl.Pipeline], StreamProcessor +]: + return partial( + StreamProcessor, + dynamic_keys=(RawEventProbability,), + target_keys=(NMXRawDataMetadata,), + accumulators={ + MaximumProbability: MaxAccumulator, + MaximumTimeOfArrival: MaxAccumulator, + MinimumTimeOfArrival: MinAccumulator, + }, + ) + + +def _build_final_streaming_processor_helper() -> Callable[ + [sl.Pipeline], StreamProcessor +]: + return partial( + StreamProcessor, + dynamic_keys=(RawEventProbability,), + target_keys=(NMXReducedDataGroup,), + accumulators={NMXReducedCounts: EternalAccumulator}, + ) + + +def calculate_raw_data_metadata( + *detector_ids: DetectorIndex | DetectorName, + wf: sl.Pipeline, + chunk_size: int = 10_000_000, + logger: logging.Logger | None = None, +) -> NMXRawDataMetadata: + # Stream processor building helper + scalefactor_stream_processor = _build_metadata_streaming_processor_helper() + metadata_wf = wf.copy() + # Loop over the detectors + file_path = metadata_wf.compute(FilePath) + raw_data_metadatas = {} + + for detector_i in detector_ids: + temp_wf = metadata_wf.copy() + if isinstance(detector_i, str): + temp_wf[DetectorName] = detector_i + else: + temp_wf[DetectorIndex] = detector_i + + detector_name = temp_wf.compute(DetectorName) + max_chunk_id = calculate_number_of_chunks( + temp_wf.compute(FilePath), + detector_name=detector_name, + chunk_size=chunk_size, + ) + # Build the stream processor + processor = scalefactor_stream_processor(temp_wf) + for i_da, da in enumerate( + raw_event_data_chunk_generator( + file_path=file_path, detector_name=detector_name, chunk_size=chunk_size + ) + ): + if any(da.sizes.values()) == 0: + continue + else: + results = processor.add_chunk({RawEventProbability: da}) + if logger is not None: + logger.info( + "[{%s}/{%s}] Processed chunk for {%s}", + i_da + 1, + max_chunk_id, + detector_name, + ) + + raw_data_metadatas[detector_i] = results[NMXRawDataMetadata] + + # We take the min/maximum values of the scale factor + # We are doing it manually because it is not possible to update parameters + # in the workflow that stream processor uses. + min_toa = min(dg.min_toa for dg in raw_data_metadatas.values()) + max_toa = max(dg.max_toa for dg in raw_data_metadatas.values()) + max_probability = max(dg.max_probability for dg in raw_data_metadatas.values()) + + return NMXRawDataMetadata( + min_toa=min_toa, max_toa=max_toa, max_probability=max_probability + ) + + +def reduction( + *, + input_file: pathlib.Path, + output_file: pathlib.Path, + chunk_size: int = 10_000_000, + nbins: int = 50, + max_counts: int | None = None, + detector_ids: list[int | str], + compression: Compression = Compression.BITSHUFFLE_LZ4, + wf: sl.Pipeline | None = None, + logger: logging.Logger | None = None, + toa_min_max_prob: tuple[float] | None = None, +) -> None: + wf = wf.copy() if wf is not None else NMXMcStasWorkflow() + wf[FilePath] = input_file + # Set static info + wf[McStasInstrument] = wf.compute(McStasInstrument) + + if not toa_min_max_prob: + # Calculate parameters for data reduction + data_metadata = calculate_raw_data_metadata( + *detector_ids, wf=wf, logger=logger, chunk_size=chunk_size + ) + if logger is not None: + logger.info("Metadata retrieved: %s", data_metadata) + + toa_bin_edges = sc.linspace( + dim='t', + start=data_metadata.min_toa, + stop=data_metadata.max_toa, + num=nbins + 1, + ) + scale_factor = mcstas_weight_to_probability_scalefactor( + max_counts=wf.compute(MaximumCounts), + max_probability=data_metadata.max_probability, + ) + else: + if logger is not None: + logger.info("Metadata given: %s", toa_min_max_prob) + toa_min = sc.scalar(toa_min_max_prob[0], unit='s') + toa_max = sc.scalar(toa_min_max_prob[1], unit='s') + prob_max = sc.scalar(toa_min_max_prob[2]) + toa_bin_edges = sc.linspace(dim='t', start=toa_min, stop=toa_max, num=nbins + 1) + scale_factor = mcstas_weight_to_probability_scalefactor( + max_counts=wf.compute(MaximumCounts), + max_probability=prob_max, + ) + + if max_counts: + scale_factor = mcstas_weight_to_probability_scalefactor( + max_counts=MaximumCounts(max_counts), + max_probability=data_metadata.max_probability, + ) + else: + scale_factor = mcstas_weight_to_probability_scalefactor( + max_counts=wf.compute(MaximumCounts), + max_probability=data_metadata.max_probability, + ) + # Compute metadata and make the skeleton output file + experiment_metadata = wf.compute(NMXExperimentMetadata) + detector_metas = [] + for detector_i in range(3): + temp_wf = wf.copy() + temp_wf[DetectorIndex] = detector_i + detector_metas.append(temp_wf.compute(NMXDetectorMetadata)) + + if logger is not None: + logger.info("Exporting metadata into the output file %s", output_file) + + _export_static_metadata_as_nxlauetof( + experiment_metadata=experiment_metadata, + output_file=output_file, + # Arbitrary metadata falls into ``entry`` group as a variable. + mcstas_weight2count_scale_factor=scale_factor, + ) + _export_detector_metadata_as_nxlauetof(*detector_metas, output_file=output_file) + # Compute histogram + final_wf = wf.copy() + # Set the scale factor and time bin edges + final_wf[McStasWeight2CountScaleFactor] = scale_factor + final_wf[TimeBinSteps] = toa_bin_edges + + file_path = final_wf.compute(FilePath) + final_stream_processor = _build_final_streaming_processor_helper() + # Loop over the detectors + result_list = [] + for detector_i in detector_ids: + temp_wf = final_wf.copy() + if isinstance(detector_i, str): + temp_wf[DetectorName] = detector_i + else: + temp_wf[DetectorIndex] = detector_i + # Set static information as parameters + detector_name = temp_wf.compute(DetectorName) + temp_wf[PixelIds] = temp_wf.compute(PixelIds) + max_chunk_id = calculate_number_of_chunks( + file_path, detector_name=detector_name, chunk_size=chunk_size + ) + + # Build the stream processor + processor = final_stream_processor(temp_wf) + for i_da, da in enumerate( + raw_event_data_chunk_generator( + file_path=file_path, detector_name=detector_name, chunk_size=chunk_size + ) + ): + if any(da.sizes.values()) == 0: + continue + else: + results = processor.add_chunk({RawEventProbability: da}) + if logger is not None: + logger.info( + "[{%s}/{%s}] Processed chunk for {%s}", + i_da + 1, + max_chunk_id, + detector_name, + ) + + result = results[NMXReducedDataGroup] + result_list.append(result) + if logger is not None: + logger.info("Appending reduced data into the output file %s", output_file) + + _export_reduced_data_as_nxlauetof( + result, + output_file=output_file, + compress_counts=(compression == Compression.NONE), + ) + from ess.nmx.reduction import merge_panels + + return merge_panels(*result_list) + + +def _add_mcstas_args(parser: argparse.ArgumentParser) -> None: + mcstas_arg_group = parser.add_argument_group("McStas Data Reduction Options") + mcstas_arg_group.add_argument( + "--max_counts", + type=int, + default=None, + help="Maximum Counts", + ) + mcstas_arg_group.add_argument( + "--chunk_size", + type=int, + default=10_000_000, + help="Chunk size for processing (number of events per chunk)", + ) + + +def build_reduction_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Command line arguments for the NMX reduction. " + "It assumes 14 Hz pulse speed." + ) + input_arg_group = parser.add_argument_group("Input Options") + input_arg_group.add_argument( + "--input_file", type=str, help="Path to the input file", required=True + ) + input_arg_group.add_argument( + "--nbins", + type=int, + default=50, + help="Number of TOF bins", + ) + input_arg_group.add_argument( + "--detector_ids", + type=int, + nargs="+", + default=[0, 1, 2], + help="Detector indices to process", + ) + + output_arg_group = parser.add_argument_group("Output Options") + output_arg_group.add_argument( + "--output_file", + type=str, + default="scipp_output.h5", + help="Path to the output file", + ) + output_arg_group.add_argument( + "--compression", + type=str, + default=Compression.BITSHUFFLE_LZ4.name, + choices=[compression_key.name for compression_key in Compression], + help="Compress option of reduced output file. Default: BITSHUFFLE_LZ4", + ) + output_arg_group.add_argument( + "--verbose", "-v", action="store_true", help="Increase output verbosity" + ) + + return parser + + +def main() -> None: + from .._executable_helper import build_logger + + parser = build_reduction_arg_parser() + _add_mcstas_args(parser) + args = parser.parse_args() + + input_file = pathlib.Path(args.input_file).resolve() + output_file = pathlib.Path(args.output_file).resolve() + + logger = build_logger(args) + + wf = NMXMcStasWorkflow() + reduction( + input_file=input_file, + output_file=output_file, + chunk_size=args.chunk_size, + nbins=args.nbins, + max_counts=args.max_counts, + detector_ids=args.detector_ids, + compression=Compression[args.compression], + logger=logger, + wf=wf, + ) diff --git a/packages/essnmx/src/ess/nmx/mcstas/load.py b/packages/essnmx/src/ess/nmx/mcstas/load.py new file mode 100644 index 00000000..c3830d0b --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/load.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import re +from collections.abc import Generator + +import scipp as sc +import scippnexus as snx + +from .types import ( + CrystalRotation, + DetectorBankPrefix, + DetectorIndex, + DetectorName, + FilePath, + MaximumCounts, + MaximumProbability, + MaximumTimeOfArrival, + McStasWeight2CountScaleFactor, + MinimumTimeOfArrival, + NMXDetectorMetadata, + NMXExperimentMetadata, + NMXRawDataMetadata, + NMXRawEventCountsDataGroup, + PixelIds, + RawEventProbability, +) +from .xml import McStasInstrument, read_mcstas_geometry_xml + + +def detector_name_from_index(index: DetectorIndex) -> DetectorName: + return f'nD_Mantid_{getattr(index, "value", index)}' + + +def load_event_data_bank_name( + detector_name: DetectorName, file_path: FilePath +) -> DetectorBankPrefix: + '''Finds the filename associated with a detector''' + with snx.File(file_path) as file: + description = file['entry1/instrument/description'][()] + for bank_name, det_names in bank_names_to_detector_names(description).items(): + if detector_name in det_names: + return DetectorBankPrefix(bank_name.partition('.')[0]) + raise KeyError( + f"{DetectorBankPrefix.__name__} cannot be found for " + f"{DetectorName.__name__} from the file {FilePath.__name__}" + ) + + +def _exclude_zero_events(data: sc.Variable) -> sc.Variable: + """Exclude events with zero counts from the data. + + McStas can add extra event lines containing 0,0,0,0,0,0 + These lines should not be included so we skip it. + """ + data = data[(data != sc.scalar(0.0, unit=data.unit)).any(dim="dim_1")] + return data + + +def _wrap_raw_event_data(data: sc.Variable) -> RawEventProbability: + data = data.rename_dims({'dim_0': 'event'}) + data = _exclude_zero_events(data) + try: + event_da = sc.DataArray( + coords={ + 'id': sc.array( + dims=['event'], + values=data['dim_1', 4].values, + dtype='int64', + unit=None, + ), + 't': sc.array(dims=['event'], values=data['dim_1', 5].values, unit='s'), + }, + data=sc.array( + dims=['event'], values=data['dim_1', 0].values, unit='counts' + ), + ) + except IndexError: + event_da = sc.DataArray( + coords={ + 'id': sc.array( + dims=['event'], + values=data['dim_1', 1].values, + dtype='int64', + unit=None, + ), + 't': sc.array(dims=['event'], values=data['dim_1', 2].values, unit='s'), + }, + data=sc.array( + dims=['event'], values=data['dim_1', 0].values, unit='counts' + ), + ) + + return RawEventProbability(event_da) + + +def load_raw_event_data( + file_path: FilePath, *, detector_name: DetectorName, bank_prefix: DetectorBankPrefix +) -> RawEventProbability: + """Retrieve events from the nexus file. + + Parameters + ---------- + file_path: + Path to the nexus file + detector_name: + Name of the detector to load + bank_prefix: + Prefix identifying the event data array containing the events of the detector + If None, the bank name is determined automatically from the detector name. + + """ + if bank_prefix is None: + bank_prefix = load_event_data_bank_name(detector_name, file_path) + bank_name = f'{bank_prefix}_dat_list_p_x_y_n_id_t' + with snx.File(file_path, 'r') as f: + root = f["entry1/data"] + (bank_name,) = (name for name in root.keys() if bank_name in name) + data = root[bank_name]["events"][()] + return _wrap_raw_event_data(data) + + +def _check_chunk_size(chunk_size: int) -> None: + if 0 < chunk_size < 10_000_000: + import warnings + + warnings.warn( + "The chunk size may be too small < 10_000_000.\n" + "Consider increasing the chunk size for better performance.\n" + "Hint: NMX typically expect ~10^8 bins as reduced data.", + UserWarning, + stacklevel=2, + ) + + +def _check_maximum_chunk_size(d_slices: tuple[slice, ...]) -> None: + """Check the maximum size of the slices.""" + max_chunk_size = max( + (d_slice.stop - d_slice.start) / d_slice.step for d_slice in d_slices + ) + _check_chunk_size(max_chunk_size) + + +def _validate_chunk_size(chunk_size: int) -> None: + """Validate the chunk size.""" + if not isinstance(chunk_size, int): + raise TypeError("Chunk size must be an integer.") + if chunk_size < -1: + raise ValueError("Invalid chunk size. It should be -1(for all) or > 0.") + + +def raw_event_data_chunk_generator( + file_path: FilePath, + *, + detector_name: DetectorName, + bank_prefix: DetectorBankPrefix | None = None, + chunk_size: int = 0, # Number of rows to read at a time +) -> Generator[RawEventProbability, None, None]: + """Chunk events from the nexus file. + + Parameters + ---------- + file_path: + Path to the nexus file + detector_name: + Name of the detector to load + pixel_ids: + Pixel ids to generate the data array with the events + chunk_size: + Number of rows to read at a time. + If 0, chunk slice is determined automatically by the ``iter_chunks``. + Note that it only works if the dataset is already chunked. + + Yields + ------ + RawEventProbability: + Data array containing the events of the detector. + + Raises + ------ + ValueError: + If the chunk size is not valid. (>= -1) + TypeError: + If the chunk size is not an integer. + Warning + If the chunk size is too small (< 10_000_000). + + """ + _check_chunk_size(chunk_size) + _validate_chunk_size(chunk_size) + + # Find the data bank name associated with the detector + bank_prefix = load_event_data_bank_name( + detector_name=detector_name, file_path=file_path + ) + bank_name = f'{bank_prefix}_dat_list_p_x_y_n_id_t' + with snx.File(file_path, 'r') as f: + root = f["entry1/data"] + (bank_name,) = (name for name in root.keys() if bank_name in name) + + with snx.File(file_path, 'r') as f: + root = f["entry1/data"] + dset = root[bank_name]["events"] + if chunk_size == 0: + # dset.dataset.iter_chunks() yields (dim_0_slice, dim_1_slice) + dim_0_slices = tuple(dim0_sl for dim0_sl, _ in dset.dataset.iter_chunks()) + # Only checking maximum chunk size + # since the last chunk may be smaller than the rest of the chunks + _check_maximum_chunk_size(dim_0_slices) + for dim_0_slice in dim_0_slices: + da = _wrap_raw_event_data(dset["dim_0", dim_0_slice]) + yield da + elif chunk_size == -1: + yield _wrap_raw_event_data(dset[()]) + else: + num_events = dset.shape[0] + for start in range(0, num_events, chunk_size): + data = dset["dim_0", start : start + chunk_size] + yield _wrap_raw_event_data(data) + + +def load_crystal_rotation( + file_path: FilePath, instrument: McStasInstrument +) -> CrystalRotation: + """Retrieve crystal rotation from the file. + + Raises + ------ + KeyError + If the crystal rotation is not found in the file. + + """ + with snx.File(file_path, 'r') as file: + param_keys = tuple(f"entry1/simulation/Param/XtalPhi{key}" for key in "XYZ") + if not all(key in file for key in param_keys): + raise KeyError( + f"Crystal rotations [{', '.join(param_keys)}] not found in file." + ) + return CrystalRotation( + sc.vector( + value=[file[param_key][...] for param_key in param_keys], + unit=instrument.simulation_settings.angle_unit, + ) + ) + + +def maximum_probability(da: RawEventProbability) -> MaximumProbability: + """Find the maximum probability in the data.""" + return MaximumProbability(da.data.max()) + + +def mcstas_weight_to_probability_scalefactor( + max_counts: MaximumCounts, max_probability: MaximumProbability +) -> McStasWeight2CountScaleFactor: + """Calculate the scale factor to convert McStas weights to counts. + + max_counts * (probabilities / max_probability) + + Parameters + ---------- + max_counts: + The maximum number of counts after scaling the event counts. + + scale_factor: + The scale factor to convert McStas weights to counts + + """ + + return McStasWeight2CountScaleFactor( + sc.scalar(max_counts, unit="counts") / max_probability + ) + + +def bank_names_to_detector_names(description: str) -> dict[str, list[str]]: + """Associates event data names with the names of the detectors + where the events were detected""" + + detector_component_regex = ( + # Start of the detector component definition, contains the detector name. + # r'^COMPONENT (?P.*) = Monitor_nD\(\n' + r'^COMPONENT (?P.*) = (Monitor_nD|Union_abs_logger_nD)\(\n' + # Some uninteresting lines, we're looking for 'filename'. + # Make sure no new component begins. + r'(?:(?!COMPONENT)(?!filename)(?:.|\s))*' + # The line that defines the filename of the file that stores the + # events associated with the detector. + r'(?:filename = \"(?P[^\"]*)\")?' + ) + matches = re.finditer(detector_component_regex, description, re.MULTILINE) + bank_names_to_detector_names = {} + for m in matches: + bank_names_to_detector_names.setdefault( + # If filename was not set for the detector the filename for the + # event data defaults to the name of the detector. + m.group('bank_name') or m.group('detector_name'), + [], + ).append(m.group('detector_name')) + return bank_names_to_detector_names + + +def load_experiment_metadata( + instrument: McStasInstrument, crystal_rotation: CrystalRotation +) -> NMXExperimentMetadata: + """Load the experiment metadata from the McStas file.""" + return NMXExperimentMetadata( + sc.DataGroup( + crystal_rotation=crystal_rotation, **instrument.experiment_metadata() + ) + ) + + +def load_detector_metadata( + instrument: McStasInstrument, detector_name: DetectorName +) -> NMXDetectorMetadata: + """Load the detector metadata from the McStas file.""" + return NMXDetectorMetadata( + sc.DataGroup(**instrument.detector_metadata(detector_name)) + ) + + +def load_mcstas( + *, + da: RawEventProbability, + experiment_metadata: NMXExperimentMetadata, + detector_metadata: NMXDetectorMetadata, +) -> NMXRawEventCountsDataGroup: + return NMXRawEventCountsDataGroup( + sc.DataGroup(weights=da, **experiment_metadata, **detector_metadata) + ) + + +def retrieve_pixel_ids( + instrument: McStasInstrument, detector_name: DetectorName +) -> PixelIds: + """Retrieve the pixel IDs for a given detector.""" + return PixelIds(instrument.pixel_ids(detector_name)) + + +def retrieve_raw_data_metadata( + min_toa: MinimumTimeOfArrival, + max_toa: MaximumTimeOfArrival, + max_probability: MaximumProbability, +) -> NMXRawDataMetadata: + """Retrieve the metadata of the raw data.""" + return NMXRawDataMetadata( + min_toa=min_toa, max_toa=max_toa, max_probability=max_probability + ) + + +providers = ( + retrieve_raw_data_metadata, + read_mcstas_geometry_xml, + detector_name_from_index, + load_event_data_bank_name, + load_raw_event_data, + maximum_probability, + mcstas_weight_to_probability_scalefactor, + retrieve_pixel_ids, + load_crystal_rotation, + load_mcstas, + load_experiment_metadata, + load_detector_metadata, +) diff --git a/packages/essnmx/src/ess/nmx/mcstas/nexus.py b/packages/essnmx/src/ess/nmx/mcstas/nexus.py new file mode 100644 index 00000000..404772b3 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/nexus.py @@ -0,0 +1,657 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import io +import pathlib +import warnings +from collections.abc import Callable, Generator +from functools import partial, wraps +from typing import Any, TypeVar + +import h5py +import numpy as np +import sciline as sl +import scipp as sc + +from .types import ( + DetectorIndex, + DetectorName, + FilePath, + NMXDetectorMetadata, + NMXExperimentMetadata, + NMXReducedDataGroup, +) + + +def _fallback_compute_positions(dg: sc.DataGroup) -> sc.DataGroup: + import warnings + + import scippnexus as snx + + warnings.warn( + "Using fallback compute_positions due to empty log entries. " + "This may lead to incorrect results. Please check the data carefully." + "The fallback will replace empty logs with a scalar value of zero.", + UserWarning, + stacklevel=2, + ) + + empty_transformations = [ + transformation + for transformation in dg['depends_on'].transformations.values() + if 'time' in transformation.value.dims + and transformation.sizes['time'] == 0 # empty log + ] + for transformation in empty_transformations: + orig_value = transformation.value + orig_value = sc.scalar(0, unit=orig_value.unit, dtype=orig_value.dtype) + transformation.value = orig_value + return snx.compute_positions(dg, store_transform='transform_matrix') + + +def _compute_positions( + dg: sc.DataGroup, auto_fix_transformations: bool = False +) -> sc.DataGroup: + """Compute positions of the data group from transformations. + + Wraps the `scippnexus.compute_positions` function + and provides a fallback for cases where the transformations + contain empty logs. + + Parameters + ---------- + dg: + Data group containing the transformations and data. + auto_fix_transformations: + If `True`, it will attempt to fix empty transformations. + It will replace them with a scalar value of zero. + It is because adding a time dimension will make it not possible + to compute positions of children due to time-dependent transformations. + + Returns + ------- + : + Data group with computed positions. + + Warnings + -------- + If `auto_fix_transformations` is `True`, it will warn about the fallback + being used due to empty logs or scalar transformations. + This is because the fallback may lead to incorrect results. + + """ + import scippnexus as snx + + try: + return snx.compute_positions(dg, store_transform='transform_matrix') + except ValueError as e: + if auto_fix_transformations: + return _fallback_compute_positions(dg) + raise e + + +def _create_dataset_from_string(*, root_entry: h5py.Group, name: str, var: str) -> None: + root_entry.create_dataset(name, dtype=h5py.string_dtype(), data=var) + + +def _create_dataset_from_var( + *, + root_entry: h5py.Group, + var: sc.Variable, + name: str, + long_name: str | None = None, + compression: str | None = None, + compression_opts: int | tuple[int, int] | None = None, + chunks: tuple[int, ...] | int | bool | None = None, + dtype: Any = None, +) -> h5py.Dataset: + compression_options = {} + if compression is not None: + compression_options["compression"] = compression + if compression_opts is not None: + compression_options["compression_opts"] = compression_opts + + dataset = root_entry.create_dataset( + name, + data=var.values if dtype is None else var.values.astype(dtype, copy=False), + chunks=chunks, + **compression_options, + ) + if var.unit is not None: + dataset.attrs["units"] = str(var.unit) + if long_name is not None: + dataset.attrs["long_name"] = long_name + return dataset + + +@wraps(_create_dataset_from_var) +def _create_compressed_dataset(*args, **kwargs): + """Create dataset with compression options. + + It will try to use ``bitshuffle`` for compression if available. + Otherwise, it will fall back to ``gzip`` compression. + + [``Bitshuffle/LZ4``](https://github.com/kiyo-masui/bitshuffle) + is used for convenience. + Since ``Dectris`` uses it for their Nexus file compression, + it is compatible with DIALS. + ``Bitshuffle/LZ4`` tends to give similar results to + GZIP and other compression algorithms with better performance. + A naive implementation of bitshuffle/LZ4 compression, + shown in [issue #124](https://github.com/scipp/essnmx/issues/124), + led to 80% file reduction (365 MB vs 1.8 GB). + + """ + try: + import bitshuffle.h5 + + compression_filter = bitshuffle.h5.H5FILTER + default_compression_opts = (0, bitshuffle.h5.H5_COMPRESS_LZ4) + except ImportError: + warnings.warn( + UserWarning( + "Could not find the bitshuffle.h5 module from bitshuffle package. " + "The bitshuffle package is not installed or only partially installed. " + "Exporting to NeXus files with bitshuffle compression is not possible." + ), + stacklevel=2, + ) + compression_filter = "gzip" + default_compression_opts = 4 + + return _create_dataset_from_var( + *args, + **kwargs, + compression=compression_filter, + compression_opts=default_compression_opts, + ) + + +def _create_root_data_entry(file_obj: h5py.File) -> h5py.Group: + nx_entry = file_obj.create_group("NMX_data") + nx_entry.attrs["NX_class"] = "NXentry" + nx_entry.attrs["default"] = "data" + nx_entry.attrs["name"] = "NMX" + nx_entry["name"] = "NMX" + nx_entry["definition"] = "TOFRAW" + return nx_entry + + +def _create_sample_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: + nx_sample = nx_entry.create_group("NXsample") + nx_sample["name"] = data['sample_name'].value + _create_dataset_from_var( + root_entry=nx_sample, + var=data['crystal_rotation'], + name='crystal_rotation', + long_name='crystal rotation in Phi (XYZ)', + ) + return nx_sample + + +def _create_instrument_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: + nx_instrument = nx_entry.create_group("NXinstrument") + nx_instrument.create_dataset("proton_charge", data=data['proton_charge'].values) + + nx_detector_1 = nx_instrument.create_group("detector_1") + # Detector counts + _create_compressed_dataset( + root_entry=nx_detector_1, + name="counts", + var=data['counts'], + ) + # Time of arrival bin edges + _create_dataset_from_var( + root_entry=nx_detector_1, + var=data['counts'].coords['t'], + name="t_bin", + long_name="t_bin TOF (ms)", + ) + # Pixel IDs + _create_compressed_dataset( + root_entry=nx_detector_1, + name="pixel_id", + var=data['counts'].coords['id'], + long_name="pixel ID", + ) + return nx_instrument + + +def _create_detector_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: + nx_detector = nx_entry.create_group("NXdetector") + # Position of the first pixel (lowest ID) in the detector + _create_compressed_dataset( + root_entry=nx_detector, + name="origin", + var=data['origin_position'], + ) + # Fast axis, along where the pixel ID increases by 1 + _create_dataset_from_var( + root_entry=nx_detector, var=data['fast_axis'], name="fast_axis" + ) + # Slow axis, along where the pixel ID increases + # by the number of pixels in the fast axis + _create_dataset_from_var( + root_entry=nx_detector, var=data['slow_axis'], name="slow_axis" + ) + return nx_detector + + +def _create_source_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group: + nx_source = nx_entry.create_group("NXsource") + nx_source["name"] = "European Spallation Source" + nx_source["short_name"] = "ESS" + nx_source["type"] = "Spallation Neutron Source" + nx_source["distance"] = sc.norm(data['source_position']).value + nx_source["probe"] = "neutron" + nx_source["target_material"] = "W" + return nx_source + + +def export_as_nexus( + data: sc.DataGroup, output_file: str | pathlib.Path | io.BytesIO +) -> None: + """Export the reduced data to a NeXus file. + + Currently exporting step is not expected to be part of sciline pipelines. + """ + warnings.warn( + DeprecationWarning( + "Exporting to custom NeXus format will be deprecated in the near future " + ">=26.12.0. " + "Please use ``export_as_nxlauetof`` instead." + ), + stacklevel=2, + ) + with h5py.File(output_file, "w") as f: + f.attrs["default"] = "NMX_data" + nx_entry = _create_root_data_entry(f) + _create_sample_group(data, nx_entry) + _create_instrument_group(data, nx_entry) + _create_detector_group(data, nx_entry) + _create_source_group(data, nx_entry) + + +def _create_lauetof_data_entry(file_obj: h5py.File) -> h5py.Group: + nx_entry = file_obj.create_group("entry") + nx_entry.attrs["NX_class"] = "NXentry" + return nx_entry + + +def _add_lauetof_definition(nx_entry: h5py.Group) -> None: + _create_dataset_from_string(root_entry=nx_entry, name="definition", var="NXlauetof") + + +def _add_lauetof_instrument(nx_entry: h5py.Group) -> h5py.Group: + nx_instrument = nx_entry.create_group("instrument") + nx_instrument.attrs["NX_class"] = "NXinstrument" + _create_dataset_from_string(root_entry=nx_instrument, name="name", var="NMX") + return nx_instrument + + +def _add_lauetof_source_group( + dg: NMXExperimentMetadata, nx_instrument: h5py.Group +) -> None: + nx_source = nx_instrument.create_group("source") + nx_source.attrs["NX_class"] = "NXsource" + _create_dataset_from_string( + root_entry=nx_source, name="name", var="European Spallation Source" + ) + _create_dataset_from_string(root_entry=nx_source, name="short_name", var="ESS") + _create_dataset_from_string( + root_entry=nx_source, name="type", var="Spallation Neutron Source" + ) + _create_dataset_from_var( + root_entry=nx_source, name="distance", var=sc.norm(dg["source_position"]) + ) + # Legacy probe information. + _create_dataset_from_string(root_entry=nx_source, name="probe", var="neutron") + + +def _add_lauetof_detector_group(dg: sc.DataGroup, nx_instrument: h5py.Group) -> None: + nx_detector = nx_instrument.create_group(dg["detector_name"].value) # Detector name + nx_detector.attrs["NX_class"] = "NXdetector" + _create_dataset_from_var( + name="polar_angle", + root_entry=nx_detector, + var=sc.scalar(0, unit='deg'), # TODO: Add real data + ) + _create_dataset_from_var( + name="azimuthal_angle", + root_entry=nx_detector, + var=sc.scalar(0, unit='deg'), # TODO: Add real data + ) + _create_dataset_from_var( + name="x_pixel_size", root_entry=nx_detector, var=dg["x_pixel_size"] + ) + _create_dataset_from_var( + name="y_pixel_size", root_entry=nx_detector, var=dg["y_pixel_size"] + ) + _create_dataset_from_var( + name="distance", + root_entry=nx_detector, + var=sc.scalar(0, unit='m'), # TODO: Add real data + ) + # Legacy geometry information until we have a better way to store it + _create_dataset_from_var( + name="origin", root_entry=nx_detector, var=dg['origin_position'] + ) + # Fast axis, along where the pixel ID increases by 1 + _create_dataset_from_var( + root_entry=nx_detector, var=dg['fast_axis'], name="fast_axis" + ) + # Slow axis, along where the pixel ID increases + # by the number of pixels in the fast axis + _create_dataset_from_var( + root_entry=nx_detector, var=dg['slow_axis'], name="slow_axis" + ) + + +def _add_lauetof_sample_group(dg: NMXExperimentMetadata, nx_entry: h5py.Group) -> None: + nx_sample = nx_entry.create_group("sample") + nx_sample.attrs["NX_class"] = "NXsample" + _create_dataset_from_var( + root_entry=nx_sample, + var=dg['crystal_rotation'], + name='crystal_rotation', + long_name='crystal rotation in Phi (XYZ)', + ) + _create_dataset_from_string( + root_entry=nx_sample, + name='name', + var=dg['sample_name'].value, + ) + _create_dataset_from_var( + name='orientation_matrix', + root_entry=nx_sample, + var=sc.array( + dims=['i', 'j'], + values=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + unit="dimensionless", + ), # TODO: Add real data, the sample orientation matrix + ) + _create_dataset_from_var( + name='unit_cell', + root_entry=nx_sample, + var=sc.array( + dims=['i'], + values=[1.0, 1.0, 1.0, 90.0, 90.0, 90.0], + unit="dimensionless", # TODO: Add real data, + # a, b, c, alpha, beta, gamma + ), + ) + + +def _add_lauetof_monitor_group(data: sc.DataGroup, nx_entry: h5py.Group) -> None: + nx_monitor = nx_entry.create_group("control") + nx_monitor.attrs["NX_class"] = "NXmonitor" + _create_dataset_from_string(root_entry=nx_monitor, name='mode', var='monitor') + nx_monitor["preset"] = 0.0 # Check if this is the correct value + data_dset = _create_dataset_from_var( + name='data', + root_entry=nx_monitor, + var=sc.array( + dims=['tof'], values=[1, 1, 1], unit="counts" + ), # TODO: Add real data, bin values + ) + data_dset.attrs["signal"] = 1 + data_dset.attrs["primary"] = 1 + _create_dataset_from_var( + name='time_of_flight', + root_entry=nx_monitor, + var=sc.array( + dims=['tof'], values=[1, 1, 1], unit="s" + ), # TODO: Add real data, bin edges + ) + + +def _add_arbitrary_metadata( + nx_entry: h5py.Group, **arbitrary_metadata: sc.Variable +) -> None: + if not arbitrary_metadata: + return + + metadata_group = nx_entry.create_group("metadata") + for key, value in arbitrary_metadata.items(): + if not isinstance(value, sc.Variable): + import warnings + + msg = f"Skipping metadata key '{key}' as it is not a scipp.Variable." + warnings.warn(UserWarning(msg), stacklevel=2) + continue + else: + _create_dataset_from_var( + name=key, + root_entry=metadata_group, + var=value, + ) + + +def _export_static_metadata_as_nxlauetof( + experiment_metadata: NMXExperimentMetadata, + output_file: str | pathlib.Path | io.BytesIO, + **arbitrary_metadata: sc.Variable, +) -> None: + """Export the metadata to a NeXus file with the LAUE_TOF application definition. + + ``Metadata`` in this context refers to the information + that is not part of the reduced detector counts itself, + but is necessary for the interpretation of the reduced data. + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + experiment_metadata: + Experiment metadata object. + output_file: + Output file path. + arbitrary_metadata: + Arbitrary metadata that does not fit into the existing metadata objects. + + """ + with h5py.File(output_file, "w") as f: + f.attrs["NX_class"] = "NXlauetof" + nx_entry = _create_lauetof_data_entry(f) + _add_lauetof_definition(nx_entry) + _add_lauetof_sample_group(experiment_metadata, nx_entry) + nx_instrument = _add_lauetof_instrument(nx_entry) + _add_lauetof_source_group(experiment_metadata, nx_instrument) + # Placeholder for ``monitor`` group + _add_lauetof_monitor_group(experiment_metadata, nx_entry) + # Skipping ``NXdata``(name) field with data link + # Add arbitrary metadata + _add_arbitrary_metadata(nx_entry, **arbitrary_metadata) + + +def _export_detector_metadata_as_nxlauetof( + *detector_metadatas: NMXDetectorMetadata, + output_file: str | pathlib.Path | io.BytesIO, + append_mode: bool = True, +) -> None: + """Export the detector specific metadata to a NeXus file. + + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + detector_metadatas: + Detector metadata objects. + output_file: + Output file path. + + """ + + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with h5py.File(output_file, "r+") as f: + nx_entry = f["entry"] + if "instrument" not in nx_entry: + nx_instrument = _add_lauetof_instrument(f["entry"]) + else: + nx_instrument = nx_entry["instrument"] + # Add detector group metadata + for detector_metadata in detector_metadatas: + _add_lauetof_detector_group(detector_metadata, nx_instrument) + + +def _extract_counts(dg: sc.DataGroup) -> sc.Variable: + counts: sc.DataArray = dg['counts'].data + if 'id' in counts.dims: + num_x, num_y = dg["detector_shape"].value + return sc.fold(counts, dim='id', sizes={'x': num_x, 'y': num_y}) + else: + # If there is no 'id' dimension, we assume it is already in the correct shape + return counts + + +def _export_reduced_data_as_nxlauetof( + dg: NMXReducedDataGroup, + output_file: str | pathlib.Path | io.BytesIO, + *, + append_mode: bool = True, + compress_counts: bool = True, +) -> None: + """Export the reduced data to a NeXus file with the LAUE_TOF application definition. + + Even though this function only exports + reduced data(detector counts and its coordinates), + the input should contain all the necessary metadata + for minimum sanity check. + + Parameters + ---------- + dg: + Reduced data and metadata. + output_file: + Output file path. + append_mode: + If ``True``, the file is opened in append mode. + If ``False``, the file is opened in None-append mode. + > None-append mode is not supported for now. + > Only append mode is supported for now. + compress_counts: + If ``True``, the detector counts are compressed using bitshuffle. + It is because only the detector counts are expected to be large. + + """ + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with h5py.File(output_file, "r+") as f: + nx_detector: h5py.Group = f[f"entry/instrument/{dg['detector_name'].value}"] + # Data - shape: [n_x_pixels, n_y_pixels, n_tof_bins] + # The actual application definition defines it as integer, + # but we keep the original data type for now + num_x, num_y = dg["detector_shape"].value # Probably better way to do this + if compress_counts: + data_dset = _create_compressed_dataset( + name="data", + root_entry=nx_detector, + var=_extract_counts(dg), + chunks=(num_x, num_y, 1), + dtype=np.uint, + ) + else: + data_dset = _create_dataset_from_var( + name="data", + root_entry=nx_detector, + var=_extract_counts(dg), + dtype=np.uint, + ) + data_dset.attrs["signal"] = 1 + _create_dataset_from_var( + name='time_of_flight', + root_entry=nx_detector, + var=sc.midpoints(dg['counts'].coords['t'], dim='t'), + ) + + +def _check_file( + filename: str | pathlib.Path | io.BytesIO, overwrite: bool +) -> pathlib.Path | io.BytesIO: + if isinstance(filename, str | pathlib.Path): + filename = pathlib.Path(filename) + if filename.exists() and not overwrite: + raise FileExistsError( + f"File '{filename}' already exists. Use `overwrite=True` to overwrite." + ) + return filename + + +T = TypeVar("T", bound=sc.DataArray) + + +class NXLauetofWriter: + def __init__( + self, + *, + output_filename: str | pathlib.Path | io.BytesIO, + workflow: sl.Pipeline, + chunk_generator: Callable[[FilePath, DetectorName], Generator[T, None, None]], + chunk_insert_key: type[T], + extra_meta: dict[str, sc.Variable] | None = None, + compress_counts: bool = True, + overwrite: bool = False, + ) -> None: + from ess.reduce.streaming import EternalAccumulator, StreamProcessor + + from .types import FilePath, NMXReducedCounts + + self.compress_counts = compress_counts + self._chunk_generator = chunk_generator + self._chunk_insert_key = chunk_insert_key + self._workflow = workflow + self._output_filename = _check_file(output_filename, overwrite) + self._input_filename = workflow.compute(FilePath) + self._final_stream_processor = partial( + StreamProcessor, + dynamic_keys=(chunk_insert_key,), + target_keys=(NMXReducedDataGroup,), + accumulators={NMXReducedCounts: EternalAccumulator}, + ) + self._detector_metas: dict[DetectorName, NMXDetectorMetadata] = {} + self._detector_reduced: dict[DetectorName, NMXReducedDataGroup] = {} + _export_static_metadata_as_nxlauetof( + experiment_metadata=self._workflow.compute(NMXExperimentMetadata), + output_file=self._output_filename, + **(extra_meta or {}), + ) + + def add_panel( + self, *, detector_id: DetectorIndex | DetectorName + ) -> NMXReducedDataGroup: + from .types import PixelIds + + temp_wf = self._workflow.copy() + if isinstance(detector_id, int): + temp_wf[DetectorIndex] = detector_id + elif isinstance(detector_id, str): + temp_wf[DetectorName] = detector_id + else: + raise TypeError( + f"Expected detector_id to be an int or str, got {type(detector_id)}" + ) + + _export_detector_metadata_as_nxlauetof( + temp_wf.compute(NMXDetectorMetadata), + output_file=self._output_filename, + ) + # First compute static information + detector_name = temp_wf.compute(DetectorName) + temp_wf[PixelIds] = temp_wf.compute(PixelIds) + processor = self._final_stream_processor(temp_wf) + # Then iterate over the chunks + for da in self._chunk_generator(self._input_filename, detector_name): + if any(da.sizes.values()) == 0: + continue + else: + results = processor.add_chunk({self._chunk_insert_key: da}) + + _export_reduced_data_as_nxlauetof( + results[NMXReducedDataGroup], + self._output_filename, + compress_counts=self.compress_counts, + ) + return results[NMXReducedDataGroup] diff --git a/packages/essnmx/src/ess/nmx/mcstas/reduction.py b/packages/essnmx/src/ess/nmx/mcstas/reduction.py new file mode 100644 index 00000000..19e70535 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/reduction.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import scipp as sc + +from .types import ( + MaximumTimeOfArrival, + McStasWeight2CountScaleFactor, + MinimumTimeOfArrival, + NMXDetectorMetadata, + NMXExperimentMetadata, + NMXReducedCounts, + NMXReducedDataGroup, + NMXReducedProbability, + PixelIds, + ProtonCharge, + RawEventProbability, + TimeBinSteps, +) + + +def calculate_minimum_toa(da: RawEventProbability) -> MinimumTimeOfArrival: + """Calculate the minimum time of arrival from the data.""" + return MinimumTimeOfArrival(da.coords['t'].min()) + + +def calculate_maximum_toa(da: RawEventProbability) -> MaximumTimeOfArrival: + """Calculate the maximum time of arrival from the data.""" + return MaximumTimeOfArrival(da.coords['t'].max()) + + +def proton_charge_from_event_counts(da: NMXReducedCounts) -> ProtonCharge: + """Make up the proton charge from the event counts. + + Proton charge is proportional to the number of neutrons, + which is proportional to the number of events. + The scale factor is manually chosen based on previous results + to be convenient for data manipulation in the next steps. + It is derived this way since + the protons are not part of McStas simulation, + and the number of neutrons is not included in the result. + + Parameters + ---------- + event_da: + The event data + + """ + # Arbitrary number to scale the proton charge + return ProtonCharge(sc.scalar(1 / 10_000, unit='dimensionless') * da.data.sum()) + + +def reduce_raw_event_probability( + da: RawEventProbability, pixel_ids: PixelIds, time_bin_step: TimeBinSteps +) -> NMXReducedProbability: + return NMXReducedProbability(da.group(pixel_ids).hist(t=time_bin_step)) + + +def raw_event_probability_to_counts( + da: NMXReducedProbability, + scale_factor: McStasWeight2CountScaleFactor, +) -> NMXReducedCounts: + return NMXReducedCounts(da * scale_factor) + + +def format_nmx_reduced_data( + da: NMXReducedCounts, + proton_charge: ProtonCharge, + experiment_metadata: NMXExperimentMetadata, + detector_metadata: NMXDetectorMetadata, +) -> NMXReducedDataGroup: + """Bin time of arrival data into ``time_bin_step`` bins.""" + + return NMXReducedDataGroup( + sc.DataGroup( + counts=da, + proton_charge=proton_charge, + **experiment_metadata, + **detector_metadata, + ) + ) + + +def _concat_or_same( + obj: list[sc.Variable | sc.DataArray], dim: str +) -> sc.Variable | sc.DataArray: + first = obj[0] + # instrument.to_coords in bin_time_of_arrival adds a panel coord to some fields, + # even if it has only length 1. If this is the case we concat, even if identical. + # Maybe McStasInstrument.to_coords should be changed to only handle a single + # panel, and not perform concatenation? + if all(dim not in o.dims and sc.identical(first, o) for o in obj): + return first + return sc.concat(obj, dim) + + +def merge_panels(*panel: NMXReducedDataGroup) -> NMXReducedDataGroup: + """Merge a list of panels by concatenating along the 'panel' dimension.""" + keys = panel[0].keys() + if not all(p.keys() == keys for p in panel): + raise ValueError("All panels must have the same keys.") + return NMXReducedDataGroup( + sc.DataGroup( + {key: _concat_or_same([p[key] for p in panel], 'panel') for key in keys} + ) + ) diff --git a/packages/essnmx/src/ess/nmx/mcstas/streaming.py b/packages/essnmx/src/ess/nmx/mcstas/streaming.py new file mode 100644 index 00000000..dbf3da51 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/streaming.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from typing import Any + +import scipp as sc +import scippnexus as snx +from ess.reduce.streaming import Accumulator + +from .load import _validate_chunk_size, load_event_data_bank_name +from .types import DetectorBankPrefix, DetectorName, FilePath + + +class MinAccumulator(Accumulator): + """Accumulator that keeps track of the maximum value seen so far.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._cur_min: sc.Variable | None = None + + @property + def value(self) -> sc.Variable | None: + return self._cur_min + + def _do_push(self, value: sc.Variable) -> None: + new_min = value.min() + if self._cur_min is None: + self._cur_min = new_min + else: + self._cur_min = min(self._cur_min, new_min) + + +class MaxAccumulator(Accumulator): + """Accumulator that keeps track of the maximum value seen so far.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._cur_max: sc.Variable | None = None + + @property + def value(self) -> sc.Variable | None: + return self._cur_max + + def _do_push(self, value: sc.Variable) -> None: + new_max = value.max() + if self._cur_max is None: + self._cur_max = new_max + else: + self._cur_max = max(self._cur_max, new_max) + + +def calculate_number_of_chunks( + file_path: FilePath, + *, + detector_name: DetectorName, + bank_prefix: DetectorBankPrefix | None = None, + chunk_size: int = 0, # Number of rows to read at a time +) -> int: + """Calculate number of chunks in the event data. + + Parameters + ---------- + file_path: + Path to the nexus file + detector_name: + Name of the detector to load + pixel_ids: + Pixel ids to generate the data array with the events + chunk_size: + Number of rows to read at a time. + If 0, chunk slice is determined automatically by the ``iter_chunks``. + Note that it only works if the dataset is already chunked. + + Returns + ------- + : + Number of chunks in the event data. + + Raises + ------ + ValueError: + If the chunk size is not valid. (>= -1) + TypeError: + If the chunk size is not an integer. + + """ + _validate_chunk_size(chunk_size) + # Find the data bank name associated with the detector + bank_prefix = load_event_data_bank_name( + detector_name=detector_name, file_path=file_path + ) + bank_name = f'{bank_prefix}_dat_list_p_x_y_n_id_t' + with snx.File(file_path, 'r') as f: + root = f["entry1/data"] + (bank_name,) = (name for name in root.keys() if bank_name in name) + with snx.File(file_path, 'r') as f: + root = f["entry1/data"] + dset: snx.Field = root[bank_name]["events"] + if chunk_size == 0: + return len(list(dset.dataset.iter_chunks())) + elif chunk_size == -1: + return 1 # Read all at once + else: + return dset.shape[0] // chunk_size + int(dset.shape[0] % chunk_size != 0) diff --git a/packages/essnmx/src/ess/nmx/mcstas/types.py b/packages/essnmx/src/ess/nmx/mcstas/types.py new file mode 100644 index 00000000..0d629021 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/types.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import Any, NewType + +import scipp as sc + +FilePath = NewType("FilePath", str) +"""File name of a file containing the results of a McStas run""" + +DetectorIndex = NewType("DetectorIndex", int | sc.Variable | sc.DataArray) +"""Index of the detector to load. Index ordered by the id:s of the pixels""" + +DetectorName = NewType("DetectorName", str) +"""Name of the detector to load""" + +DetectorBankPrefix = NewType("DetectorBankPrefix", str) +"""Prefix identifying the event data array containing +the events from the selected detector""" + +MaximumCounts = NewType("MaximumCounts", int) +"""Maximum number of counts after scaling the event counts""" + +MaximumProbability = NewType("MaximumProbability", sc.Variable) +"""Maximum probability to scale the McStas event counts""" + +McStasWeight2CountScaleFactor = NewType("McStasWeight2CountScaleFactor", sc.Variable) +"""Scale factor to convert McStas weights to counts""" + +NMXExperimentMetadata = NewType("NMXExperimentMetadata", sc.DataGroup) +"""Metadata of the experiment""" + +NMXDetectorMetadata = NewType("NMXDetectorMetadata", sc.DataGroup) +"""Metadata of the detector""" + +RawEventProbability = NewType("RawEventProbability", sc.DataArray) +"""DataArray containing the event probabilities read from the McStas file, +has coordinates 'id' and 't' """ + +NMXRawEventCountsDataGroup = NewType("NMXRawEventCountsDataGroup", sc.DataGroup) +"""DataGroup containing the RawEventData, experiment metadata and detector metadata""" + +ProtonCharge = NewType("ProtonCharge", sc.Variable) +"""The proton charge signal""" + +CrystalRotation = NewType("CrystalRotation", sc.Variable) +"""Rotation of the crystal""" + +DetectorGeometry = NewType("DetectorGeometry", Any) +"""Description of the geometry of the detector banks""" + +TimeBinSteps = NewType("TimeBinSteps", int) +"""Number of bins in the binning of the time coordinate""" + +PixelIds = NewType("PixelIds", sc.Variable) +"""The pixel ids of the detector""" + +NMXReducedProbability = NewType("NMXReducedProbability", sc.DataArray) +"""Histogram of time-of-arrival and pixel-id.""" + +NMXReducedCounts = NewType("NMXReducedCounts", sc.DataArray) +"""Histogram of time-of-arrival and pixel-id.""" + +NMXReducedDataGroup = NewType("NMXReducedDataGroup", sc.DataGroup) +"""Datagroup containing Histogram(id, t), experiment metadata and detector metadata""" + +MinimumTimeOfArrival = NewType("MinimumTimeOfArrival", sc.Variable) +"""Minimum time of arrival of the raw data""" + +MaximumTimeOfArrival = NewType("MaximumTimeOfArrival", sc.Variable) +"""Maximum time of arrival of the raw data""" + + +@dataclass +class NMXRawDataMetadata: + """Metadata of the raw data, i.e. maximum weight and min/max time of arrival""" + + max_probability: MaximumProbability + min_toa: MinimumTimeOfArrival + max_toa: MaximumTimeOfArrival diff --git a/packages/essnmx/src/ess/nmx/mcstas/xml.py b/packages/essnmx/src/ess/nmx/mcstas/xml.py new file mode 100644 index 00000000..36ad0ab7 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mcstas/xml.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# McStas instrument geometry xml description related functions. +from collections.abc import Iterable +from dataclasses import dataclass +from types import MappingProxyType +from typing import Protocol, TypeVar + +import h5py +import scipp as sc +from defusedxml.ElementTree import fromstring + +from ..rotation import axis_angle_to_quaternion, quaternion_to_matrix +from .types import FilePath + +T = TypeVar('T') + + +_AXISNAME_TO_UNIT_VECTOR = MappingProxyType( + { + 'x': sc.vector([1.0, 0.0, 0.0]), + 'y': sc.vector([0.0, 1.0, 0.0]), + 'z': sc.vector([0.0, 0.0, 1.0]), + } +) + + +class _XML(Protocol): + """XML element or tree type. + + Temporarily used for type hinting. + Builtin XML type is blocked by bandit security check.""" + + tag: str + attrib: dict[str, str] + + def find(self, name: str) -> '_XML | None': ... + + def __iter__(self) -> '_XML': ... + + def __next__(self) -> '_XML': ... + + +def _check_and_unpack_if_only_one(xml_items: list[_XML], name: str) -> _XML: + """Check if there is only one element with ``name``.""" + if len(xml_items) > 1: + raise ValueError(f"Multiple {name}s found.") + elif len(xml_items) == 0: + raise ValueError(f"No {name} found.") + + return xml_items.pop() + + +def select_by_tag(xml_items: _XML, tag: str) -> _XML: + """Select element with ``tag`` if there is only one.""" + + return _check_and_unpack_if_only_one(list(filter_by_tag(xml_items, tag)), tag) + + +def filter_by_tag(xml_items: Iterable[_XML], tag: str) -> Iterable[_XML]: + """Filter xml items by tag.""" + return (item for item in xml_items if item.tag == tag) + + +def filter_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> Iterable[_XML]: + """Filter xml items by type prefix.""" + return ( + item for item in xml_items if item.attrib.get('type', '').startswith(prefix) + ) + + +def select_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> _XML: + """Select xml item by type prefix.""" + + cands = list(filter_by_type_prefix(xml_items, prefix)) + return _check_and_unpack_if_only_one(cands, prefix) + + +def find_attributes(component: _XML, *args: str) -> dict[str, float]: + """Retrieve ``args`` as float from xml.""" + + return {key: float(component.attrib[key]) for key in args} + + +@dataclass +class SimulationSettings: + """Simulation settings extracted from McStas instrument xml description.""" + + # From + length_unit: str # 'unit' of + angle_unit: str # 'unit' of + # From + beam_axis: str # 'axis' of + handedness: str # 'val' of + + @classmethod + def from_xml(cls, tree: _XML) -> 'SimulationSettings': + """Create simulation settings from xml.""" + defaults = select_by_tag(tree, 'defaults') + length_desc = select_by_tag(defaults, 'length') + angle_desc = select_by_tag(defaults, 'angle') + reference_frame = select_by_tag(defaults, 'reference-frame') + along_beam = select_by_tag(reference_frame, 'along-beam') + handedness = select_by_tag(reference_frame, 'handedness') + + return cls( + length_unit=length_desc.attrib['unit'], + angle_unit=angle_desc.attrib['unit'], + beam_axis=along_beam.attrib['axis'], + handedness=handedness.attrib['val'], + ) + + +def _position_from_location(location: _XML, unit: str = 'm') -> sc.Variable: + """Retrieve position from location.""" + x, y, z = find_attributes(location, 'x', 'y', 'z').values() + return sc.vector([x, y, z], unit=unit) + + +def _rotation_matrix_from_location( + location: _XML, angle_unit: str = 'degree' +) -> sc.Variable: + """Retrieve rotation matrix from location.""" + + attribs = find_attributes(location, 'axis-x', 'axis-y', 'axis-z', 'rot') + x, y, z, w = axis_angle_to_quaternion( + x=attribs['axis-x'], + y=attribs['axis-y'], + z=attribs['axis-z'], + theta=sc.scalar(-attribs['rot'], unit=angle_unit), + ) + return quaternion_to_matrix(x=x, y=y, z=z, w=w) + + +@dataclass +class DetectorDesc: + """Detector information extracted from McStas instrument xml description.""" + + # From + component_type: str # 'type' + name: str + id_start: int # 'idstart' + fast_axis_name: str # 'idfillbyfirst' + # From + num_x: int # 'xpixels' + num_y: int # 'ypixels' + step_x: sc.Variable # 'xstep' + step_y: sc.Variable # 'ystep' + start_x: float # 'xstart' + start_y: float # 'ystart' + # From under + position: sc.Variable # 'x', 'y', 'z' + # Calculated fields + rotation_matrix: sc.Variable + slow_axis_name: str + fast_axis: sc.Variable + slow_axis: sc.Variable + + @classmethod + def from_xml( + cls, + *, + component: _XML, + type_desc: _XML, + simulation_settings: SimulationSettings, + ) -> 'DetectorDesc': + """Create detector description from xml component and type.""" + + location = select_by_tag(component, 'location') + rotation_matrix = _rotation_matrix_from_location( + location, simulation_settings.angle_unit + ) + fast_axis_name = component.attrib['idfillbyfirst'] + slow_axis_name = 'xy'.replace(fast_axis_name, '') + + length_unit = simulation_settings.length_unit + + # Type casting from str to float and then int to allow *e* notation + # For example, '1e4' -> 10000.0 -> 10_000 + return cls( + component_type=type_desc.attrib['name'], + name=component.attrib['name'], + id_start=int(float(component.attrib['idstart'])), + fast_axis_name=fast_axis_name, + slow_axis_name=slow_axis_name, + num_x=int(float(type_desc.attrib['xpixels'])), + num_y=int(float(type_desc.attrib['ypixels'])), + step_x=sc.scalar(float(type_desc.attrib['xstep']), unit=length_unit), + step_y=sc.scalar(float(type_desc.attrib['ystep']), unit=length_unit), + start_x=float(type_desc.attrib['xstart']), + start_y=float(type_desc.attrib['ystart']), + position=_position_from_location(location, simulation_settings.length_unit), + rotation_matrix=rotation_matrix, + fast_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[fast_axis_name], + slow_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[slow_axis_name], + ) + + @property + def total_pixels(self) -> int: + return self.num_x * self.num_y + + @property + def slow_step(self) -> sc.Variable: + return self.step_y if self.fast_axis_name == 'x' else self.step_x + + @property + def fast_step(self) -> sc.Variable: + return self.step_x if self.fast_axis_name == 'x' else self.step_y + + @property + def num_fast_pixels_per_row(self) -> int: + """Number of pixels in each row of the detector along the fast axis.""" + return self.num_x if self.fast_axis_name == 'x' else self.num_y + + @property + def detector_shape(self) -> tuple: + """Shape of the detector panel. (num_x, num_y)""" + return (self.num_x, self.num_y) + + +def _collect_detector_descriptions(tree: _XML) -> tuple[DetectorDesc, ...]: + """Retrieve detector geometry descriptions from mcstas file.""" + type_list = list(filter_by_tag(tree, 'type')) + simulation_settings = SimulationSettings.from_xml(tree) + + def _find_type_desc(det: _XML) -> _XML: + for type_ in type_list: + if type_.attrib['name'] == det.attrib['type']: + return type_ + + raise ValueError( + f"Cannot find type {det.attrib['type']} for {det.attrib['name']}." + ) + + detector_components = [ + DetectorDesc.from_xml( + component=det, + type_desc=_find_type_desc(det), + simulation_settings=simulation_settings, + ) + for det in filter_by_type_prefix(filter_by_tag(tree, 'component'), 'MonNDtype') + ] + + return tuple(sorted(detector_components, key=lambda x: x.id_start)) + + +@dataclass +class SampleDesc: + """Sample description extracted from McStas instrument xml description.""" + + # From + component_type: str + name: str + # From under + position: sc.Variable + rotation_matrix: sc.Variable | None + + @classmethod + def from_xml( + cls, *, tree: _XML, simulation_settings: SimulationSettings + ) -> 'SampleDesc': + """Create sample description from xml component.""" + source_xml = select_by_type_prefix(tree, 'sampleMantid-type') + location = select_by_tag(source_xml, 'location') + try: + rotation_matrix = _rotation_matrix_from_location( + location, simulation_settings.angle_unit + ) + except KeyError: + rotation_matrix = None + + return cls( + component_type=source_xml.attrib['type'], + name=source_xml.attrib['name'], + position=_position_from_location(location, simulation_settings.length_unit), + rotation_matrix=rotation_matrix, + ) + + def position_from_sample(self, other: sc.Variable) -> sc.Variable: + """Position of ``other`` relative to the sample. + + All positions and distance are stored relative to the sample position. + + Parameters + ---------- + other: + Position of the other object in 3D vector. + + """ + return other - self.position + + +@dataclass +class SourceDesc: + """Source description extracted from McStas instrument xml description.""" + + # From + component_type: str + name: str + # From under + position: sc.Variable + + @classmethod + def from_xml( + cls, *, tree: _XML, simulation_settings: SimulationSettings + ) -> 'SourceDesc': + """Create source description from xml component.""" + source_xml = select_by_type_prefix(tree, 'sourceMantid-type') + location = select_by_tag(source_xml, 'location') + + return cls( + component_type=source_xml.attrib['type'], + name=source_xml.attrib['name'], + position=_position_from_location(location, simulation_settings.length_unit), + ) + + +def _construct_pixel_id(detector_desc: DetectorDesc) -> sc.Variable: + """Pixel IDs for single detector.""" + start, stop = ( + detector_desc.id_start, + detector_desc.id_start + detector_desc.total_pixels, + ) + return sc.arange('id', start, stop, unit=None) + + +def _construct_pixel_ids(detector_descs: tuple[DetectorDesc, ...]) -> sc.Variable: + """Pixel IDs for all detectors.""" + ids = [_construct_pixel_id(det) for det in detector_descs] + return sc.concat(ids, 'id') + + +def _pixel_positions( + detector: DetectorDesc, position_offset: sc.Variable +) -> sc.Variable: + """Position of pixels of the ``detector``. + + Position of each pixel is relative to the position_offset. + """ + pixel_idx = sc.arange('id', detector.total_pixels) + n_col = sc.scalar(detector.num_fast_pixels_per_row) + + pixel_n_slow = pixel_idx // n_col + pixel_n_fast = pixel_idx % n_col + + fast_axis_steps = detector.fast_axis * detector.fast_step + slow_axis_steps = detector.slow_axis * detector.slow_step + + return ( + (pixel_n_slow * slow_axis_steps) + + (pixel_n_fast * fast_axis_steps) + + detector.rotation_matrix + * sc.vector( + [detector.start_x, detector.start_y, 0.0], unit=position_offset.unit + ) # Detector pixel offset should also be rotated first. + ) + position_offset + + +@dataclass +class McStasInstrument: + simulation_settings: SimulationSettings + detectors: tuple[DetectorDesc, ...] + source: SourceDesc + sample: SampleDesc + + @classmethod + def from_xml(cls, tree: _XML) -> 'McStasInstrument': + """Create McStas instrument from xml.""" + simulation_settings = SimulationSettings.from_xml(tree) + + return cls( + simulation_settings=simulation_settings, + detectors=_collect_detector_descriptions(tree), + source=SourceDesc.from_xml( + tree=tree, simulation_settings=simulation_settings + ), + sample=SampleDesc.from_xml( + tree=tree, simulation_settings=simulation_settings + ), + ) + + def pixel_ids(self, *det_names: str) -> sc.Variable: + """Pixel IDs for the detectors. + + If multiple detectors are requested, all pixel IDs will be concatenated along + the 'id' dimension. + + Parameters + ---------- + det_names: + Names of the detectors to extract pixel IDs for. + + """ + detectors = tuple(det for det in self.detectors if det.name in det_names) + return _construct_pixel_ids(detectors) + + def experiment_metadata(self) -> dict[str, sc.Variable]: + """Extract experiment metadata from the McStas instrument description.""" + return { + 'sample_position': self.sample.position_from_sample(self.sample.position), + 'source_position': self.sample.position_from_sample(self.source.position), + 'sample_name': sc.scalar(self.sample.name), + } + + def _detector_metadata(self, det_name: str) -> dict[str, sc.Variable]: + try: + detector = next(det for det in self.detectors if det.name == det_name) + except StopIteration as e: + raise KeyError(f"Detector {det_name} not found.") from e + return { + 'fast_axis': detector.fast_axis, + 'slow_axis': detector.slow_axis, + 'origin_position': self.sample.position_from_sample(detector.position), + 'position': _pixel_positions( + detector, self.sample.position_from_sample(detector.position) + ), + 'detector_shape': sc.scalar(detector.detector_shape), + 'x_pixel_size': detector.step_x, + 'y_pixel_size': detector.step_y, + 'detector_name': sc.scalar(detector.name), + } + + def detector_metadata(self, *det_names: str) -> dict[str, sc.Variable]: + """Extract detector metadata from the McStas instrument description. + + If multiple detector is requested, all metadata will be concatenated along the + 'panel' dimension. + + Parameters + ---------- + det_names: + Names of the detectors to extract metadata for. + + """ + if len(det_names) == 1: + return self._detector_metadata(det_names[0]) + detector_metadatas = { + det_name: self._detector_metadata(det_name) for det_name in det_names + } + # Concat all metadata into panel dimension + metadata_keys: set[str] = set().union( + set(detector_metadatas[det_name].keys()) for det_name in det_names + ) + return { + key: sc.concat( + [metadata[key] for metadata in detector_metadatas.values()], 'panel' + ) + for key in metadata_keys + } + + +def read_mcstas_geometry_xml(file_path: FilePath) -> McStasInstrument: + """Retrieve geometry parameters from mcstas file""" + instrument_xml_path = 'entry1/instrument/instrument_xml/data' + with h5py.File(file_path) as file: + tree = fromstring(file[instrument_xml_path][...][0]) + return McStasInstrument.from_xml(tree) diff --git a/packages/essnmx/src/ess/nmx/mtz_io.py b/packages/essnmx/src/ess/nmx/mtz_io.py new file mode 100644 index 00000000..87ee79ec --- /dev/null +++ b/packages/essnmx/src/ess/nmx/mtz_io.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import pathlib +from typing import NewType + +import gemmi +import numpy as np +import pandas as pd +import scipp as sc + +# Index types for param table. +MTZFileIndex = NewType("MTZFileIndex", int) +"""The index of the mtz file when iterating over multiple mtz files.""" + +# User defined or configurable types +MTZFilePath = NewType("MTZFilePath", pathlib.Path) +"""Path to the mtz file""" +SpaceGroupDesc = NewType("SpaceGroupDesc", str) +"""The space group description. e.g. 'P 21 21 21'""" +DEFAULT_SPACE_GROUP_DESC = SpaceGroupDesc("P 1") +"""The default space group description to use if not found in the mtz files.""" + +# Custom column names +WavelengthColumnName = NewType("WavelengthColumnName", str) +"""The name of the wavelength column in the mtz file.""" +DEFAULT_WAVELENGTH_COLUMN_NAME = WavelengthColumnName("LAMBDA") + +IntensityColumnName = NewType("IntensityColumnName", str) +"""The name of the intensity column in the mtz file.""" +DEFAULT_INTENSITY_COLUMN_NAME = IntensityColumnName("I") + +StdDevColumnName = NewType("StdDevColumnName", str) +"""The name of the standard uncertainty of intensity column in the mtz file.""" +DEFAULT_STD_DEV_COLUMN_NAME = StdDevColumnName("SIGI") + +# Computed types +MtzDataFrame = NewType("MtzDataFrame", pd.DataFrame) +"""The raw mtz dataframe.""" +NMXMtzDataFrame = NewType("NMXMtzDataFrame", pd.DataFrame) +"""The processed mtz dataframe with derived columns.""" +NMXMtzDataArray = NewType("NMXMtzDataArray", sc.DataArray) + + +def read_mtz_file(file_path: MTZFilePath) -> gemmi.Mtz: + """read mtz file""" + + return gemmi.read_mtz_file(file_path.as_posix()) + + +def mtz_to_pandas(mtz: gemmi.Mtz) -> pd.DataFrame: + """Converts the mtz file to a pandas dataframe. + + It is equivalent to the following code: + + .. code-block:: python + + import numpy as np + import pandas as pd + + data = np.array(mtz, copy=False) + columns = mtz.column_labels() + return pd.DataFrame(data, columns=columns) + + It is recommended in the gemmi documentation. + + """ + + return pd.DataFrame( # Recommended in the gemmi documentation. + data=np.array(mtz, copy=False), columns=mtz.column_labels() + ) + + +def process_single_mtz_to_dataframe( + mtz: gemmi.Mtz, + wavelength_column_name: WavelengthColumnName = DEFAULT_WAVELENGTH_COLUMN_NAME, + intensity_column_name: IntensityColumnName = DEFAULT_INTENSITY_COLUMN_NAME, + intensity_sig_col_name: StdDevColumnName = DEFAULT_STD_DEV_COLUMN_NAME, +) -> MtzDataFrame: + """Select and derive columns from the original ``MtzDataFrame``. + + Parameters + ---------- + mtz: + The raw mtz dataset. + + wavelength_column_name: + The name of the wavelength column in the mtz file. + + intensity_column_name: + The name of the intensity column in the mtz file. + + intensity_sig_col_name: + The name of the standard uncertainty of intensity column in the mtz file. + + Returns + ------- + : + The new mtz dataframe with derived and renamed columns. + + The derived columns are: + + - ``hkl``: The miller indices as a list of integers. + - ``d``: The d-spacing calculated from the miller indices. + :math:``\\dfrac{2}{d^{2}} = \\dfrac{\\sin^2(\\theta)}{\\lambda^2}`` + - ``resolution``: The resolution calculated from the d-spacing. + + For consistent names of columns/coordinates, the following columns are renamed: + + - ``wavelength_column_name`` -> ``'wavelength'`` + - ``intensity_column_name`` -> ``'I'`` + - ``intensity_sig_col_name`` -> ``'SIGI'`` + + Other columns are kept as they are. + + Notes + ----- + :class:`pandas.DataFrame` is used from loading to merging, + but :class:`gemmi.Mtz` has :func:`gemmi.Mtz:calculate_d` + that can derive the ``d`` using ``HKL``. + This part of the method must be called on each mtz file separately. + + """ + from .mtz_io import mtz_to_pandas + + orig_df = mtz_to_pandas(mtz) + mtz_df = pd.DataFrame() + + # HKL should always be integer. + mtz_df[["H", "K", "L"]] = orig_df[["H", "K", "L"]].astype(int) + mtz_df["hkl"] = mtz_df[["H", "K", "L"]].values.tolist() + + def _calculate_d(row: pd.Series) -> float: + return mtz.get_cell().calculate_d(row["hkl"]) + + mtz_df["d"] = mtz_df.apply(_calculate_d, axis=1) + mtz_df["resolution"] = (1 / mtz_df["d"]) ** 2 / 4 + mtz_df["wavelength"] = orig_df[wavelength_column_name] + mtz_df[DEFAULT_INTENSITY_COLUMN_NAME] = orig_df[intensity_column_name] + mtz_df[DEFAULT_STD_DEV_COLUMN_NAME] = orig_df[intensity_sig_col_name] + # Keep other columns + for column in [col for col in orig_df.columns if col not in mtz_df]: + mtz_df[column] = orig_df[column] + + return MtzDataFrame(mtz_df) + + +def get_space_group_from_description(desc: SpaceGroupDesc) -> gemmi.SpaceGroup: + """Retrieves spacegroup from parameter. + + Parameters + ---------- + desc: + The space group description to use if not found in the mtz files. + + Returns + ------- + : + The space group. + """ + return gemmi.SpaceGroup(desc) + + +def get_space_group_from_mtz(mtz: gemmi.Mtz) -> gemmi.SpaceGroup | None: + """Retrieves spacegroup from file. + + Spacegroup is always expected in any MTZ files, but it may be missing. + + Parameters + ---------- + mtz: + Raw mtz dataset. + + Returns + ------- + : + The space group, or None if not found. + """ + return mtz.spacegroup + + +def get_unique_space_group(*spacegroups: gemmi.SpaceGroup | None) -> gemmi.SpaceGroup: + """Retrieves the unique space group from multiple space groups. + + Parameters + ---------- + spacegroups: + The space groups to check. + + Returns + ------- + : + The unique space group. + + Raises + ------ + ValueError: + If there are multiple space groups. + """ + spacegroups = [sgrp for sgrp in spacegroups if sgrp is not None] + if len(spacegroups) == 0: + raise ValueError("No space group found.") + first = spacegroups[0] + if all(sgrp == first for sgrp in spacegroups): + return first + raise ValueError(f"Multiple space groups found: {spacegroups}") + + +def get_reciprocal_asu(spacegroup: gemmi.SpaceGroup) -> gemmi.ReciprocalAsu: + """Returns the reciprocal asymmetric unit from the space group.""" + + return gemmi.ReciprocalAsu(spacegroup) + + +def merge_mtz_dataframes(*mtz_dfs: MtzDataFrame) -> MtzDataFrame: + """Merge multiple mtz dataframes into one.""" + + return MtzDataFrame(pd.concat(mtz_dfs, ignore_index=True)) + + +def process_mtz_dataframe( + *, + mtz_df: MtzDataFrame, + reciprocal_asu: gemmi.ReciprocalAsu, + sg: gemmi.SpaceGroup, +) -> NMXMtzDataFrame: + """Modify/Add columns of the shallow copy of a mtz dataframe. + + This method must be called after merging multiple mtz dataframe. + """ + df = mtz_df.copy(deep=False) + + def _reciprocal_asu(row: pd.Series) -> list[int]: + """Converts miller indices(HKL) to ASU indices.""" + + return reciprocal_asu.to_asu(row["hkl"], sg.operations())[0] + + df["hkl_asu"] = df.apply(_reciprocal_asu, axis=1) + # Unpack the indices for later. + df[["H_ASU", "K_ASU", "L_ASU"]] = pd.DataFrame( + df["hkl_asu"].to_list(), index=df.index + ) + + return NMXMtzDataFrame(df) + + +def nmx_mtz_dataframe_to_scipp_dataarray( + nmx_mtz_df: NMXMtzDataFrame, +) -> NMXMtzDataArray: + """Converts the processed mtz dataframe to a scipp dataarray. + + The intensity, with column name :attr:`~DEFAULT_INTENSITY_COLUMN_NAME` + becomes the data and the standard uncertainty of intensity, + with column name :attr:`~DEFAULT_SIGMA_INTENSITY_COLUMN_NAME` + becomes the variances of the data. + + Parameters + ---------- + nmx_mtz_df: + The merged and processed mtz dataframe. + + Returns + ------- + : + The scipp dataarray with the intensity and variances. + The ``I`` column becomes the data and the + squared ``SIGI`` column becomes the variances. + Therefore they are not in the coordinates. + + Following coordinates are modified: + + - ``hkl``: The miller indices as a string. + It is modified to have a string dtype + since is no dtype that can represent this in scipp. + + - ``hkl_asu``: The asymmetric unit of miller indices as a string. + This coordinate will be used to derive estimated scale factors. + It is modified to have a string dtype + as the same reason as why ``hkl`` coordinate is modified. + + Zero or negative intensities are removed from the dataarray. + It can happen due to the post-processing of the data, + e.g. background subtraction. + + """ + from scipp.compat.pandas_compat import from_pandas_dataframe, parse_bracket_header + + to_scipp = nmx_mtz_df.copy(deep=False) + # Convert to scipp Dataset + nmx_mtz_ds = from_pandas_dataframe( + to_scipp, + data_columns=[ + DEFAULT_INTENSITY_COLUMN_NAME, + DEFAULT_STD_DEV_COLUMN_NAME, + ], + header_parser=parse_bracket_header, + ) + # Pop the indices columns. + # TODO: We can put them back once we support tuple[int] dtype. + # See https://github.com/scipp/scipp/issues/3046 for more details. + # Temporarily, we will manually convert them to a string. + # It is done on the scipp variable instead of the dataframe + # since columns with string dtype are converted to PyObject dtype + # instead of string by `from_pandas_dataframe`. + for indices_name in ("hkl", "hkl_asu"): + nmx_mtz_ds.coords[indices_name] = sc.array( + dims=nmx_mtz_ds.coords[indices_name].dims, + values=nmx_mtz_df[indices_name].astype(str).tolist(), + # `astype`` is not enough to convert the dtype to string. + # The result of `astype` will have `PyObject` as a dtype. + ) + # Add units + nmx_mtz_ds.coords["wavelength"].unit = sc.units.angstrom + for key in nmx_mtz_ds.keys(): + nmx_mtz_ds[key].unit = sc.units.dimensionless + + # Add variances + nmx_mtz_da = nmx_mtz_ds[DEFAULT_INTENSITY_COLUMN_NAME].copy(deep=False) + nmx_mtz_da.variances = (nmx_mtz_ds[DEFAULT_STD_DEV_COLUMN_NAME].data ** 2).values + + # Return DataArray without negative intensities + return NMXMtzDataArray(nmx_mtz_da[nmx_mtz_da.data > 0]) + + +providers = ( + read_mtz_file, + process_single_mtz_to_dataframe, + # get_space_group_from_description, + get_space_group_from_mtz, + get_reciprocal_asu, + process_mtz_dataframe, + nmx_mtz_dataframe_to_scipp_dataarray, +) +"""The providers related to the MTZ IO.""" + +default_parameters = { + WavelengthColumnName: DEFAULT_WAVELENGTH_COLUMN_NAME, + IntensityColumnName: DEFAULT_INTENSITY_COLUMN_NAME, + StdDevColumnName: DEFAULT_STD_DEV_COLUMN_NAME, +} +"""The parameters related to the MTZ IO.""" diff --git a/packages/essnmx/src/ess/nmx/nexus.py b/packages/essnmx/src/ess/nmx/nexus.py new file mode 100644 index 00000000..20a71d3e --- /dev/null +++ b/packages/essnmx/src/ess/nmx/nexus.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import io +import pathlib +import warnings +from typing import Any + +import h5py +import numpy as np +import scipp as sc +import scippnexus as snx + +from .configurations import Compression +from .types import ( + NMXDetectorMetadata, + NMXMonitorMetadata, + NMXProgram, + NMXSampleMetadata, + NMXSourceMetadata, +) + + +def _check_file( + filename: str | pathlib.Path | io.BytesIO, overwrite: bool +) -> pathlib.Path | io.BytesIO: + if isinstance(filename, str | pathlib.Path): + filename = pathlib.Path(filename) + if filename.exists() and not overwrite: + raise FileExistsError( + f"File '{filename}' already exists. Use `overwrite=True` to overwrite." + ) + return filename + + +def _create_dataset_from_var( + *, + root_entry: h5py.Group, + var: sc.Variable, + name: str, + long_name: str | None = None, + compression: str | None = None, + compression_opts: int | tuple[int, int] | None = None, + chunks: tuple[int, ...] | int | bool | None = None, + dtype: Any = None, +) -> h5py.Dataset: + compression_options = {} + if compression is not None: + compression_options["compression"] = compression + if compression_opts is not None: + compression_options["compression_opts"] = compression_opts + + dataset = root_entry.create_dataset( + name, + data=var.values if dtype is None else var.values.astype(dtype, copy=False), + chunks=chunks, + **compression_options, + ) + if var.unit is not None: + dataset.attrs["units"] = str(var.unit) + if long_name is not None: + dataset.attrs["long_name"] = long_name + return dataset + + +def _retrieve_compression_arguments(compress_mode: Compression) -> dict: + """Returns compression filter and opts arguments for the ``compress_mode``. + + Returns an empty dictionary if an unimplemented compression mode + or `NONE` compression mode is selected. + + """ + if compress_mode == Compression.BITSHUFFLE_LZ4: + try: + import bitshuffle.h5 + + compression_filter = bitshuffle.h5.H5FILTER + compression_opts = (0, bitshuffle.h5.H5_COMPRESS_LZ4) + except ImportError: + warnings.warn( + UserWarning( + "Could not find the bitshuffle.h5 module from bitshuffle package. " + "The bitshuffle package is not installed properly. " + "Trying with gzip compression instead..." + ), + stacklevel=2, + ) + compression_filter = "gzip" + compression_opts = 4 + elif compress_mode == Compression.GZIP: + compression_filter = "gzip" + compression_opts = 4 + elif compress_mode == Compression.NONE: + return {} + else: + warnings.warn( + UserWarning( + f"Compression Mode {compress_mode} is not implemented yet. " + "Not Compressing the dataset... " + "Try `GZIP` or `BITSHUFFLE_LZ4` if compression is needed." + ), + stacklevel=2, + ) + return {} + + return {"compression": compression_filter, "compression_opts": compression_opts} + + +def _add_arbitrary_metadata( + nx_entry: h5py.Group, **arbitrary_metadata: sc.Variable +) -> None: + if not arbitrary_metadata: + return + + metadata_group = nx_entry.create_group("metadata") + for key, value in arbitrary_metadata.items(): + if not isinstance(value, sc.Variable): + import warnings + + msg = f"Skipping metadata key '{key}' as it is not a scipp.Variable." + warnings.warn(UserWarning(msg), stacklevel=2) + continue + else: + _create_dataset_from_var( + name=key, + root_entry=metadata_group, + var=value, + ) + + +def _set_default_instrument(nx_entry: snx.Group) -> snx.Group: + """Return NXinstrument group. + + If 'instrument' exists in the NXentry group, it returns the existing one. + Otherwise, new NXinstrument group is created and returned. + The default NXinstrument group has a field 'name' with the instrument name, 'NMX'. + """ + if "instrument" not in nx_entry: + nx_instrument = nx_entry.create_class("instrument", 'NXinstrument') + nx_instrument.create_field(key='name', value='NMX') + else: + nx_instrument = nx_entry["instrument"] + + return nx_instrument + + +def export_static_metadata_as_nxlauetof( + *, + sample_metadata: NMXSampleMetadata, + source_metadata: NMXSourceMetadata, + program: NMXProgram, + output_file: str | pathlib.Path | io.BytesIO, + overwrite: bool = False, + **arbitrary_metadata: sc.Variable, +) -> None: + """Export the metadata to a NeXus file with the LAUE_TOF application definition. + + ``Metadata`` in this context refers to the information + that is not part of the reduced detector counts itself, + but is necessary for the interpretation of the reduced data. + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + sample_metadata: + Sample metadata object. + source_metadata: + Source metadata object. + monitor_metadata: + Monitor metadata object. + output_file: + Output file path. + arbitrary_metadata: + Arbitrary metadata that does not fit into the existing metadata objects. + + """ + _check_file(output_file, overwrite=overwrite) + with snx.File(output_file, "w") as f: + f._group.attrs["NX_class"] = "NXlauetof" + nx_entry = f.create_class(name='entry', class_name='NXlauetof') + nx_entry.create_field('definitions', value='NXlauetof') + nx_entry['sample'] = sample_metadata + nx_entry['reducer'] = program + + nx_instrument = _set_default_instrument(nx_entry) + nx_instrument['source'] = source_metadata + _add_arbitrary_metadata(nx_entry._group, **arbitrary_metadata) + + +def export_monitor_metadata_as_nxlauetof( + monitor_metadata: NMXMonitorMetadata, + output_file: str | pathlib.Path | io.BytesIO, + append_mode: bool = True, +) -> None: + """Export the detector specific metadata to a NeXus file. + + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + monitor_metadata: + Monitor metadata object. + output_file: + Output file path. + + """ + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with snx.File(output_file, "r+") as f: + nx_entry = f["entry"] + nx_entry["control"] = monitor_metadata + + +def export_detector_metadata_as_nxlauetof( + detector_metadata: NMXDetectorMetadata, + output_file: str | pathlib.Path | io.BytesIO, + append_mode: bool = True, +) -> None: + """Export the detector specific metadata to a NeXus file. + + Since NMX can have arbitrary number of detectors, + this function can take multiple detector metadata objects. + + Parameters + ---------- + detector_metadatas: + Detector metadata objects. + output_file: + Output file path. + + """ + + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with snx.File(output_file, "r+") as f: + nx_entry: snx.Group = f["entry"] + nx_instrument = _set_default_instrument(nx_entry) + nx_instrument[detector_metadata.detector_name] = detector_metadata + + +def export_reduced_data_as_nxlauetof( + detector_name: str, + da: sc.DataArray, + output_file: str | pathlib.Path | io.BytesIO, + *, + append_mode: bool = True, + compress_mode: Compression = Compression.BITSHUFFLE_LZ4, +) -> None: + """Export the reduced data to a NeXus file with the LAUE_TOF application definition. + + Even though this function only exports + reduced data(detector counts and its coordinates), + the input should contain all the necessary metadata + for minimum sanity check. + + Parameters + ---------- + dg: + Reduced data and metadata. + output_file: + Output file path. + append_mode: + If ``True``, the file is opened in append mode. + If ``False``, the file is opened in None-append mode. + > None-append mode is not supported for now. + > Only append mode is supported for now. + compress_mode: + The detector counts are compressed using the ``compress_mode``. + It is because only the detector counts are expected to be large. + If ``Compression.BITSHUFFLE_LZ4`` is selected + but the bitshuffle is not supported for the environment, + it will fall back to ``Compression.GZIP``. + Select ``Compression.NONE`` if compression is not needed. + + """ + if not append_mode: + raise NotImplementedError("Only append mode is supported for now.") + + with h5py.File(output_file, "r+") as f: + nx_detector: h5py.Group = f[f"entry/instrument/{detector_name}"] + # Data - shape: [n_x_pixels, n_y_pixels, n_tof_bins] + # The actual application definition defines it as integer, + # so we overwrite the dtype here. + + compression_args = _retrieve_compression_arguments(compress_mode) + if compress_mode != Compression.NONE: # Calculate the chunk sizes + num_x, num_y = da.sizes['x_pixel_offset'], da.sizes['y_pixel_offset'] + compression_args['chunks'] = (num_x, num_y, 1) # Chunk along tof axis + + data_dset = _create_dataset_from_var( + name="data", + root_entry=nx_detector, + var=da.data, + dtype=np.uint, + **compression_args, + ) + + data_dset.attrs["signal"] = 1 + data_dset.attrs["axes"] = list(da.dims) + + if 'tof' in da.coords: + time_field_name = "time_of_flight" + time_coord_name = "tof" + time_dim = "tof" + elif 'event_time_offset' in da.coords: + time_field_name = "event_time_offset" + time_coord_name = "event_time_offset" + time_dim = "event_time_offset" + else: + raise ValueError("Could not find time-related bin edges to store.") + + _create_dataset_from_var( + name=time_field_name, + root_entry=nx_detector, + var=sc.midpoints(da.coords[time_coord_name], dim=time_dim), + ) + + _create_dataset_from_var( + name='original_time_edges', + root_entry=nx_detector, + var=da.coords[time_coord_name], + ) diff --git a/packages/essnmx/src/ess/nmx/py.typed b/packages/essnmx/src/ess/nmx/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/packages/essnmx/src/ess/nmx/rotation.py b/packages/essnmx/src/ess/nmx/rotation.py new file mode 100644 index 00000000..4ec91c84 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/rotation.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# Rotation related functions for NMX +import numpy as np +import scipp as sc +from numpy.typing import NDArray + + +def axis_angle_to_quaternion( + *, x: float, y: float, z: float, theta: sc.Variable +) -> NDArray: + """Convert axis-angle to queternions, [x, y, z, w]. + + Parameters + ---------- + x: + X component of axis of rotation. + y: + Y component of axis of rotation. + z: + Z component of axis of rotation. + theta: + Angle of rotation, with unit of ``rad`` or ``deg``. + + Returns + ------- + : + A list of (normalized) quaternions, [x, y, z, w]. + + Notes + ----- + Axis of rotation (x, y, z) does not need to be normalized, + but it returns a unit quaternion (x, y, z, w). + + """ + + w: sc.Variable = sc.cos(theta.to(unit='rad') / 2) + xyz: sc.Variable = -sc.sin(theta.to(unit='rad') / 2) * sc.vector([x, y, z]) + q = np.array([*xyz.values, w.value]) + return q / np.linalg.norm(q) + + +def quaternion_to_matrix(*, x: float, y: float, z: float, w: float) -> sc.Variable: + """Convert quaternion to rotation matrix. + + Parameters + ---------- + x: + x(a) component of quaternion. + y: + y(b) component of quaternion. + z: + z(c) component of quaternion. + w: + w component of quaternion. + + Returns + ------- + : + A 3x3 rotation matrix. + + """ + from scipy.spatial.transform import Rotation + + return sc.spatial.rotations_from_rotvecs( + rotation_vectors=sc.vector( + Rotation.from_quat([x, y, z, w]).as_rotvec(), + unit='rad', + ) + ) diff --git a/packages/essnmx/src/ess/nmx/scaling.py b/packages/essnmx/src/ess/nmx/scaling.py new file mode 100644 index 00000000..97ae4174 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/scaling.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import NewType, TypeVar + +import scipp as sc + +from .mtz_io import NMXMtzDataArray + +# User defined or configurable types +WavelengthBins = NewType("WavelengthBins", sc.Variable | int) +"""User configurable wavelength binning""" +ReferenceWavelength = NewType("ReferenceWavelength", sc.Variable | None) +"""The wavelength to select reference intensities.""" + +# Computed types +"""Filtered mtz dataframe by the quad root of the sample standard deviation.""" +WavelengthBinned = NewType("WavelengthBinned", sc.DataArray) +"""Binned mtz dataframe by wavelength(LAMBDA) with derived columns.""" +SelectedReferenceWavelength = NewType("SelectedReferenceWavelength", sc.Variable) +"""The wavelength to select reference intensities.""" +ReferenceIntensities = NewType("ReferenceIntensities", sc.DataArray) +"""Reference intensities selected by the wavelength.""" +EstimatedScaleFactor = NewType("EstimatedScaleFactor", sc.DataArray) +"""The estimated scale factor from the reference intensities per ``hkl_asu``.""" +EstimatedScaledIntensities = NewType("EstimatedScaledIntensities", sc.DataArray) +"""Scaled intensities by the estimated scale factor.""" +FilteredEstimatedScaledIntensities = NewType( + "FilteredEstimatedScaledIntensities", sc.DataArray +) + +T = TypeVar("T") + + +def get_wavelength_binned( + mtz_da: NMXMtzDataArray, + wavelength_bins: WavelengthBins, +) -> WavelengthBinned: + """Bin the whole dataset by wavelength(LAMBDA). + + Parameters + ---------- + mtz_da: + The merged dataset. + + wavelength_bins: + The wavelength(LAMBDA) bins. + + Notes + ----- + Wavelength(LAMBDA) binning should always be done on the merged dataset. + + """ + return WavelengthBinned(mtz_da.bin({"wavelength": wavelength_bins})) + + +def _is_bin_empty(binned: sc.DataArray, idx: int) -> bool: + """Check if the bin is empty.""" + return binned[idx].values.size == 0 + + +def _get_middle_bin_idx(binned: sc.DataArray) -> int: + """Find the middle bin index. + + If the middle one is empty, the function will search for the nearest. + """ + middle_number, offset = len(binned) // 2, 0 + + while 0 < (cur_idx := middle_number + offset) < len(binned) and _is_bin_empty( + binned, cur_idx + ): + offset = -offset + 1 if offset <= 0 else -offset + + if _is_bin_empty(binned, cur_idx): + raise ValueError("No reference group found.") + + return cur_idx + + +def get_reference_wavelength( + binned: WavelengthBinned, + reference_wavelength: ReferenceWavelength, +) -> SelectedReferenceWavelength: + """Select the reference wavelength. + + Parameters + ---------- + binned: + The wavelength binned data. + + reference_wavelength: + The reference wavelength to select the intensities. + If ``None``, the middle group is selected. + It should be a scalar variable as it is selecting one of bins. + + """ + if reference_wavelength is None: + ref_idx = _get_middle_bin_idx(binned) + return SelectedReferenceWavelength(binned.coords["wavelength"][ref_idx]) + else: + return SelectedReferenceWavelength(reference_wavelength) + + +def get_reference_intensities( + binned: WavelengthBinned, + reference_wavelength: SelectedReferenceWavelength, +) -> ReferenceIntensities: + """Find the reference intensities by the wavelength. + + Parameters + ---------- + binned: + The wavelength binned data. + + reference_wavelength: + The reference wavelength to select the intensities. + + Raises + ------ + ValueError: + If no reference group is found. + + """ + if reference_wavelength is None: + ref_idx = _get_middle_bin_idx(binned) + return binned[ref_idx].values.copy(deep=False) + else: + if reference_wavelength.dims: + raise ValueError("Reference wavelength should be a scalar.") + try: + return binned["wavelength", reference_wavelength].values.copy(deep=False) + except IndexError as err: + raise IndexError(f"{reference_wavelength} out of range.") from err + + +def estimate_scale_factor_per_hkl_asu_from_reference( + reference_intensities: ReferenceIntensities, +) -> EstimatedScaleFactor: + """Calculate the estimated scale factor per ``hkl_asu``. + + The estimated scale factor is calculated as the average + of the inverse of the non-empty reference intensities. + + It is part of the calculation of estimated scaled intensities + for fitting the scaling model. + + .. math:: + + EstimatedScaleFactor_{(hkl)} = \\dfrac{ + \\sum_{i=1}^{N_{(hkl)}} \\dfrac{1}{I_{i}} + }{ + N_{(hkl)} + } + = average( \\dfrac{1}{I_{(hkl)}} ) + + Estimated scale factor is calculated per ``hkl_asu``. + This is part of the calculation of roughly-scaled-intensities + for fitting the scaling model. + The whole procedure is described in :func:`average_roughly_scaled_intensities`. + + Parameters + ---------- + reference_intensities: + The reference intensities selected by wavelength. + + Returns + ------- + : + The estimated scale factor per ``hkl_asu``. + The result should have a dimension of ``hkl_asu``. + + It does not have a dimension of ``wavelength`` since + it is calculated from the reference intensities, + which is selected by one ``wavelength``. + + """ + # Workaround for https://github.com/scipp/scipp/issues/3046 + # and https://github.com/scipp/scipp/issues/3425 + import numpy as np + + unique_hkl = np.unique(reference_intensities.coords["hkl_asu"].values) + group_var = sc.array(dims=["hkl_asu"], values=unique_hkl) + grouped = reference_intensities.group(group_var) + + return EstimatedScaleFactor((1 / grouped).bins.mean()) + + +def average_roughly_scaled_intensities( + binned: WavelengthBinned, + scale_factor: EstimatedScaleFactor, +) -> EstimatedScaledIntensities: + """Scale the intensities by the estimated scale factor. + + Parameters + ---------- + binned: + Intensities binned in the wavelength dimension. + It will be grouped by reflection (hkl) in the process. + + scale_factor: + The estimated scale factor per reflection(hkl) of the reference wavelength bin. + See :func:`estimate_scale_factor_per_hkl_asu_from_reference` + for the calculation of the estimated scale factor. + + .. math:: + + EstimatedScaleFactor_{(hkl)} = + average( \\dfrac{1}{I_{\\lambda=reference, (hkl)}} ) + + Returns + ------- + : + Average scaled intensities on ``hkl(asu)`` indices per wavelength. + + Notes + ----- + The average of roughly scaled intensities are calculated by the following formula: + + .. math:: + + EstimatedScaledI_{\\lambda} + = \\dfrac{ + \\sum_{i=1}^{N_{\\lambda, (hkl)}} + EstimatedScaledI_{\\lambda, (hkl)} + }{ + N_{\\lambda, (hkl)} + } + + And scaled intensities on each ``hkl(asu)`` indices per wavelength + are calculated by the following formula: + + .. math:: + :nowrap: + + \\begin{eqnarray} + EstimatedScaledI_{\\lambda, (hkl)} \\\\ + = \\dfrac{ + \\sum_{i=1}^{N_{\\lambda=reference, (hkl)}} + \\sum_{j=1}^{N_{\\lambda, (hkl)}} + \\dfrac{I_{j}}{I_{i}} + }{ + N_{\\lambda=reference, (hkl)}*N_{\\lambda, (hkl)} + } \\\\ + = \\dfrac{ + \\sum_{i=1}^{N_{\\lambda=reference, (hkl)}} \\dfrac{1}{I_{i}} + }{ + N_{\\lambda=reference, (hkl)} + } * \\dfrac{ + \\sum_{j=1}^{N_{\\lambda, (hkl)}} I_{j} + }{ + N_{\\lambda, (hkl)} + } \\\\ + = average( \\dfrac{1}{I_{\\lambda=reference, (hkl)}} ) + * average( I_{\\lambda, (hkl)} ) + \\end{eqnarray} + + Therefore the ``binned(wavelength dimension)`` should be + grouped along the ``hkl(asu)`` coordinate in the calculation. + + """ + # Group by HKL_EQ of the estimated scale factor from reference intensities + grouped = binned.group(scale_factor.coords["hkl_asu"]) + + # Drop variances of the scale factor + # Scale each group each bin by the scale factor + intensities = sc.nanmean( + grouped.bins.nanmean() * sc.values(scale_factor), dim="hkl_asu" + ) + # Take the midpoints of the wavelength bin coordinates + # to represent the average wavelength of the bin + # It is because the bin-edges are dropped while flattening the data + # and the data is expected to be filtered after this step. + intensities.coords["wavelength"] = sc.midpoints( + intensities.coords["wavelength"], + ) + return EstimatedScaledIntensities(intensities) + + +ScaledIntensityLeftTailThreshold = NewType( + "ScaledIntensityLeftTailThreshold", sc.Variable +) +"""The threshold to cut the left tail of the estimated scaled intensities.""" +DEFAULT_LEFT_TAIL_THRESHOLD = ScaledIntensityLeftTailThreshold(sc.scalar(0.1)) +ScaledIntensityRightTailThreshold = NewType( + "ScaledIntensityRightTailThreshold", sc.Variable +) +"""The threshold to cut the right tail of the estimated scaled intensities.""" +DEFAULT_RIGHT_TAIL_THRESHOLD = ScaledIntensityRightTailThreshold(sc.scalar(2.0)) + + +def cut_tails( + scaled_intensities: EstimatedScaledIntensities, + left_threashold: ScaledIntensityLeftTailThreshold = DEFAULT_LEFT_TAIL_THRESHOLD, + right_threshold: ScaledIntensityRightTailThreshold = DEFAULT_RIGHT_TAIL_THRESHOLD, +) -> FilteredEstimatedScaledIntensities: + """Cut the right tail of the estimated scaled intensities by the threshold. + + Parameters + ---------- + scaled_intensities: + The scaled intensities to be filtered. + + left_threashold: + The threshold to be cut from the left tail. + + right_threshold: + The threshold to be cut from the right tail. + + Returns + ------- + : + The filtered scaled intensities with the tails cut. + + """ + return FilteredEstimatedScaledIntensities( + scaled_intensities[ + (scaled_intensities.data > left_threashold) + & (scaled_intensities.data < right_threshold) + ].copy(deep=False) + ) + + +@dataclass +class FittingResult: + """Result of the fitting process.""" + + fitting_func: Callable[..., sc.DataArray] + """The fitting function to be used for fitting.""" + params: Mapping + """Parameters of the fitting function.""" + covariance: Mapping + """Covariance of the :attr:`~FittingParams`.""" + fit_output: sc.DataArray + """The final output of the fitting function.""" + + +def polyval_wavelength( + wavelength: sc.Variable, *, out_unit: str, **kwargs +) -> sc.DataArray: + """Polynomial helper for fitting. + + The coefficients are adjusted to make the fitting result + have ``out_unit`` as unit. + + Parameters + ---------- + wavelength: + The wavelength coordinate. + out_unit: + The unit of the output. + **kwargs: + The polynomial coefficients. + + Returns + ------- + : + The polynomial calculated at the wavelength. + + + """ + # We need to use float64 precision because + # the curve fit routine depends on finite difference + # estimates of the derivative. + wavelength = wavelength.to(dtype='float64') + out = sc.zeros_like(wavelength) + out.unit = out_unit + xk = sc.ones_like(wavelength) + for _, arg_value in enumerate(kwargs.values()): + out += sc.values(arg_value) * xk * sc.scalar(1.0, unit=out.unit / xk.unit) + xk *= wavelength + return out + + +WavelengthFittingPolynomialDegree = NewType("WavelengthFittingPolynomialDegree", int) +DEFAULT_WAVELENGTH_FITTING_POLYNOMIAL_DEGREE = WavelengthFittingPolynomialDegree(7) + + +def fit_wavelength_scale_factor_polynomial( + estimated_intensities: FilteredEstimatedScaledIntensities, + *, + n_degree: WavelengthFittingPolynomialDegree, +) -> FittingResult: + """Fit the wavelength scale factor polynomial. + + It uses :func:`polyval_wavelength` as the fitting function + and :func:`scipp.optimize.curve_fit` for the fitting process. + The initial guess for the polynomial coefficients is set to 1 + for all degrees. + The unit of the coefficients is adjusted to make the fitting result + dimensionless. + + Parameters + ---------- + estimated_intensities: + The estimated scaled intensities to be fitted. + n_degree: + The degree of the polynomial to be fitted. + + Returns + ------- + : + The fitting result. + + """ + + from functools import partial + + fitting_func = partial(polyval_wavelength, out_unit="dimensionless") + p_result, cov_result = sc.curve_fit( + coords=["wavelength"], + f=fitting_func, + da=estimated_intensities, + p0={f"arg{i}": sc.scalar(1) for i in range(n_degree)}, + ) + data = fitting_func(estimated_intensities.coords["wavelength"], **p_result) + return FittingResult( + fitting_func=fitting_func, + params=p_result, + covariance=cov_result, + fit_output=sc.DataArray( + data=data.data, + coords={"wavelength": estimated_intensities.coords["wavelength"]}, + ), + ) + + +WavelengthScaleFactors = NewType("WavelengthScaleFactors", sc.DataArray) +"""The scale factors of `"wavelength"`.""" + + +def calculate_wavelength_scale_factor( + fitting_result: FittingResult, + reference_wavelength: SelectedReferenceWavelength, +) -> WavelengthScaleFactors: + """Calculate the scale factors along the `"wavelength"`.""" + + scaled_reference = fitting_result.fitting_func( + reference_wavelength, **fitting_result.params + ) + scale_factor = fitting_result.fit_output / scaled_reference + return WavelengthScaleFactors(scale_factor) + + +providers = ( + cut_tails, + get_wavelength_binned, + get_reference_wavelength, + get_reference_intensities, + estimate_scale_factor_per_hkl_asu_from_reference, + average_roughly_scaled_intensities, + fit_wavelength_scale_factor_polynomial, + calculate_wavelength_scale_factor, +) +"""Providers for scaling data.""" + +default_parameters = { + WavelengthBins: sc.linspace("wavelength", 1.8, 3.5, 250, unit="angstrom"), + ScaledIntensityLeftTailThreshold: DEFAULT_LEFT_TAIL_THRESHOLD, + ScaledIntensityRightTailThreshold: DEFAULT_RIGHT_TAIL_THRESHOLD, + WavelengthFittingPolynomialDegree: WavelengthFittingPolynomialDegree(7), +} +"""Default parameters for scaling data.""" diff --git a/packages/essnmx/src/ess/nmx/types.py b/packages/essnmx/src/ess/nmx/types.py new file mode 100644 index 00000000..cfc4301c --- /dev/null +++ b/packages/essnmx/src/ess/nmx/types.py @@ -0,0 +1,277 @@ +import enum +from dataclasses import dataclass, field +from typing import Literal, NewType + +import h5py +import numpy as np +import scipp as sc +import scippnexus as snx +from ess.reduce.time_of_flight.types import TofLookupTable +from scippneutron.metadata import RadiationProbe, SourceType + +from ._display_helper import to_datagroup + + +class Compression(enum.StrEnum): + """Compression type of the output file. + + These options are written as enum for future extensibility. + """ + + NONE = 'NONE' + GZIP = 'GZIP' + BITSHUFFLE_LZ4 = 'BITSHUFFLE_LZ4' + + +TofSimulationMinWavelength = NewType("TofSimulationMinWavelength", sc.Variable) +"""Minimum wavelength for tof simulation to calculate look up table.""" + +TofSimulationMaxWavelength = NewType("TofSimulationMaxWavelength", sc.Variable) +"""Maximum wavelength for tof simulation to calculate look up table.""" + + +class ControlMode(enum.StrEnum): + """Control mode of counting. + + Based on the NXlauetof definition of ``control`` (NXmonitor) field. + """ + + monitor = 'monitor' + """Count to a preset value based on received monitor counts.""" + timer = 'timer' + """Count to a preset value based on clock time""" + + +def _unit_matrix() -> sc.Variable: + return sc.spatial.linear_transform( + value=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + unit="dimensionless", + ) + + +def _uniform_unit_cell_length() -> sc.Variable: + return sc.vector([1.0, 1.0, 1.0], unit='dimensionless') + + +def _cube_unit_cell_angle() -> sc.Variable: + return sc.vector([90.0, 90.0, 90.0], unit='deg') + + +@dataclass(kw_only=True) +class NMXSampleMetadata: + nx_class = snx.NXsample + + crystal_rotation: sc.Variable + name: str + position: sc.Variable + # Temporarily hardcoding some values + # TODO: Remove hardcoded values + orientation_matrix: sc.Variable = field(default_factory=_unit_matrix) + unit_cell_length: sc.Variable = field(default_factory=_uniform_unit_cell_length) + unit_cell_angle: sc.Variable = field(default_factory=_cube_unit_cell_angle) + + @property + def unit_cell(self) -> sc.Variable: + """a, b, c, alpha, beta, gamma.""" + + return np.concat([self.unit_cell_length.values, self.unit_cell_angle.values]) + + def __write_to_nexus_group__(self, group: h5py.Group): + cr_field = snx.create_field(group, 'crystal_rotation', self.crystal_rotation) + cr_field.attrs['long_name'] = 'crystal rotation in Phi (XYZ)' + snx.create_field(group, 'name', self.name) + snx.create_field(group, 'position', self.position) + snx.create_field(group, 'orientation_matrix', self.orientation_matrix) + unit_cell = snx.create_field(group, 'unit_cell', self.unit_cell) + unit_cell.attrs['length-unit'] = str(self.unit_cell_length.unit) + unit_cell.attrs['angle-unit'] = str(self.unit_cell_angle.unit) + + +@dataclass(kw_only=True) +class NMXSourceMetadata: + nx_class = snx.NXsource + + position: sc.Variable + """Position of the source (from the sample).""" + + # These three fields are matching fields as ``scippneutron.metadata.Source``. + # However, NMX needs to store `position` as a vector, + # not only the name, type and probe + # essnmx cannot use ``scippneutron.metadata.Source`` as it is. + # We will need to implement unpacking function for vector scalar value. + # Therefore we decided not to use the ``scippneutron.metadata.Source`` for now + # but the ``NMXSourceMetadata`` 's ``source_type`` and ``probe`` fields + # have the same Enum types as ``scippneutron.metadata.Source``. + name: Literal['European Spallation Source'] = "European Spallation Source" + source_type: SourceType = SourceType.SpallationNeutronSource + probe: RadiationProbe = RadiationProbe.Neutron + + @property + def distance(self) -> sc.Variable: + return sc.norm(self.position) + + def __write_to_nexus_group__(self, group: h5py.Group): + snx.create_field(group, 'name', self.name) + snx.create_field(group, 'type', self.source_type.value) + distance = snx.create_field(group, 'distance', self.distance) + distance.attrs['position'] = self.position.values + snx.create_field(group, 'probe', self.probe.value) + + +def _zero_float_count() -> sc.Variable: + return sc.scalar(0.0, unit='count') + + +@dataclass(kw_only=True) +class NMXMonitorMetadata: + nx_class = snx.NXmonitor + data: sc.DataArray + """Monitor counts.""" + + @property + def time_of_flight(self) -> sc.Variable: + return self.data.coords[self.tof_bin_coord] + + tof_bin_coord: str = field( + default='tof', + metadata={ + "description": "Name of the time-of-flight coordinate " + "in the monitor histogram." + }, + ) + mode: ControlMode = field( + default=ControlMode.monitor, + metadata={"description": "Mode of counting. One of `monitor` or `timer`."}, + ) + preset: sc.Variable = field( + default_factory=_zero_float_count, + metadata={"description": "Preset value of counting for the `mode`."}, + ) + + def __write_to_nexus_group__(self, group: h5py.Group): + group.attrs['axes'] = self.data.dims + group.attrs['tof_bin_coord'] = self.tof_bin_coord + snx.create_field(group, 'mode', str(self.mode)) + snx.create_field(group, 'preset', self.preset) + data_field = snx.create_field(group, 'data', self.data.data) + data_field.attrs['signal'] = 1 + data_field.attrs['primary'] = 1 + snx.create_field(group, 'time_of_flight', self.time_of_flight) + + +@dataclass(kw_only=True) +class NMXDetectorMetadata: + nx_class = snx.NXdetector + + detector_name: str + x_pixel_size: sc.Variable + y_pixel_size: sc.Variable + origin: sc.Variable + """Center of the detector panel.""" + fast_axis: sc.Variable + """Inner most dimension if the data is sorted by detector number. + + The index of the fast axis changes fast along the detector number. + + i.e. When detector numbers grows: ``0, 1, 2, 3, 4, 5, 6, ...`` + and the size of the fast axis is ``3``, + the fast axis index will be: ``0, 1, 2, 0, 1, 2, 0 ...`` + for each detector number. + + """ + fast_axis_dim: str + slow_axis: sc.Variable + """Outer most dimension if the data is sorted by detector number. + + The index of the slow axis changes slowly along the detector number. + + i.e. When detector numbers grows: ``0, 1, 2, 3, 4, 5, 6, ...`` + and the size of the fast axis is ``3``, + the slow axis index will be: ``0, 0, 0, 1, 1, 1, 2, ...`` + for each detector number. + + """ + slow_axis_dim: str + distance: sc.Variable + first_pixel_position: sc.Variable + """First pixel position with respect to the sample. + + Additional field for DIALS. It should be a 3D vector. + """ + # TODO: Remove hardcoded values + polar_angle: sc.Variable = field(default_factory=lambda: sc.scalar(0, unit='deg')) + azimuthal_angle: sc.Variable = field( + default_factory=lambda: sc.scalar(0, unit='deg') + ) + + def __write_to_nexus_group__(self, group: h5py.Group): + snx.create_field(group, 'x_pixel_size', self.x_pixel_size) + snx.create_field(group, 'y_pixel_size', self.y_pixel_size) + origin = snx.create_field(group, 'origin', self.origin) + origin.attrs['first_pixel_position'] = self.first_pixel_position.values + fast_axis = snx.create_field(group, 'fast_axis', self.fast_axis) + fast_axis.attrs['dim'] = self.fast_axis_dim + slow_axis = snx.create_field(group, 'slow_axis', self.slow_axis) + slow_axis.attrs['dim'] = self.slow_axis_dim + snx.create_field(group, 'distance', self.distance) + snx.create_field(group, 'polar_angle', self.polar_angle) + snx.create_field(group, 'azimuthal_angle', self.azimuthal_angle) + + +@dataclass(kw_only=True) +class NMXReducedDetector: + """Reduced Detector data and metadata container. + + In an output file, all metadata fields are stored on the same level as the `data`. + However, in this reduced detector data container, the `data` and `metadata` are + separated with an extra hierarchy. + It is because the `data` needs more control how to be stored, + i.e. compression option. + Also, the histogram may need chunk-wise processing + and therefore metadata may need to be written in advance so that + the `data` can be appended to the existing `NXdetector` HDF5 Group. + + """ + + data: sc.DataArray | None = None + """3D Histogram of the detector counts or its place holder.""" + metadata: NMXDetectorMetadata + """NMX Detector metadata.""" + + +@dataclass(kw_only=True) +class NMXInstrument: + nx_class = snx.NXinstrument + + detectors: sc.DataGroup[NMXReducedDetector] + name: str = "NMX" + source: NMXSourceMetadata + + +@dataclass(kw_only=True) +class NMXProgram: + nx_class = 'NXprogram' + + program: str = 'essnmx' + + def __write_to_nexus_group__(self, group: h5py.Group): + from ess.nmx import __version__ as essnmxversion + + prog = snx.create_field(group, 'program', self.program) + prog.attrs['version'] = essnmxversion + + +@dataclass(kw_only=True) +class NMXLauetof: + nx_class = "NXlauetof" + + control: NMXMonitorMetadata + definitions: Literal['NXlauetof'] = 'NXlauetof' + instrument: NMXInstrument + sample: NMXSampleMetadata + lookup_table: TofLookupTable | None = None + reducer: NMXProgram = field(default_factory=NMXProgram) + "Information of the reduction software." + + def to_datagroup(self) -> sc.DataGroup: + return to_datagroup(self) diff --git a/packages/essnmx/src/ess/nmx/workflows.py b/packages/essnmx/src/ess/nmx/workflows.py new file mode 100644 index 00000000..2004bf05 --- /dev/null +++ b/packages/essnmx/src/ess/nmx/workflows.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from collections.abc import Iterable + +import sciline +import scipp as sc +import scippnexus as snx +import tof +from ess.reduce.nexus.types import ( + EmptyDetector, + Filename, + NeXusComponent, + NeXusTransformation, + Position, + SampleRun, +) +from ess.reduce.time_of_flight import ( + GenericTofWorkflow, + LtotalRange, + NumberOfSimulatedNeutrons, + SimulationResults, + SimulationSeed, + TofLookupTableWorkflow, +) +from ess.reduce.time_of_flight.lut import BeamlineComponentReading +from ess.reduce.time_of_flight.types import TimeOfFlightLookupTableFilename +from ess.reduce.workflow import register_workflow + +from .configurations import WorkflowConfig +from .types import ( + NMXDetectorMetadata, + NMXSampleMetadata, + NMXSourceMetadata, + TofSimulationMaxWavelength, + TofSimulationMinWavelength, +) + +default_parameters = { + TofSimulationMaxWavelength: sc.scalar(3.6, unit='angstrom'), + TofSimulationMinWavelength: sc.scalar(1.8, unit='angstrom'), +} + + +def _simulate_fixed_wavelength_tof( + wmin: TofSimulationMinWavelength, + wmax: TofSimulationMaxWavelength, + neutrons: NumberOfSimulatedNeutrons, + seed: SimulationSeed, +) -> SimulationResults: + """ + Simulate a pulse of neutrons propagating through the instrument using the + ``tof`` package (https://tof.readthedocs.io). + This runs a simulation assuming there are no choppers in the instrument. + + Parameters + ---------- + wmin: + Minimum wavelength of the simulated neutrons. + wmax: + Maximum wavelength of the simulated neutrons. + neutrons: + Number of neutrons to simulate. + seed: + Random seed for the simulation. + """ + source = tof.Source( + facility="ess", + neutrons=neutrons, + pulses=1, + seed=seed, + wmax=wmax, + wmin=wmin, + ) + events = source.data.squeeze().flatten(to="event") + + return SimulationResults( + readings={ + "source": BeamlineComponentReading( + time_of_arrival=events.coords["birth_time"], + wavelength=events.coords["wavelength"], + weight=events.data, + distance=source.distance, + ) + }, + choppers=None, + ) + + +def _merge_panels(*da: sc.DataArray) -> sc.DataArray: + """Merge multiple DataArrays representing different panels into one.""" + merged = sc.concat(da, dim='panel') + return merged + + +def select_detector_names(*, detector_ids: Iterable[int] = (0, 1, 2)): + import os + + # Users can override detector names via environment variable + # It is a comma-separated list of detector names + # e.g., NMX_DETECTOR_NAMES=detector_panel_0,detector_panel_1,detector_panel_2 + # The detector names are not expected to be changed from the default ones, + # but this option is provided for minimum flexibility. + DETECTOR_NAME_VAR = os.environ.get("NMX_DETECTOR_NAMES", None) + if DETECTOR_NAME_VAR is not None: + return tuple( + name + for i_name, name in enumerate(DETECTOR_NAME_VAR.split(',')) + if i_name in detector_ids + ) + else: + return tuple(f'detector_panel_{i}' for i in detector_ids) + + +def assemble_sample_metadata( + crystal_rotation: Position[snx.NXcrystal, SampleRun], + sample_position: Position[snx.NXsample, SampleRun], + sample_component: NeXusComponent[snx.NXsample, SampleRun], +) -> NMXSampleMetadata: + """Assemble sample metadata for NMX reduction workflow.""" + name = sample_component['name'] + if isinstance(name, sc.Variable) and name.dtype == str: + sample_name = name.value + elif isinstance(name, str): + sample_name = name + else: + raise TypeError(f'Sample name {name}is in a wrong type: ', type(name)) + + return NMXSampleMetadata( + name=sample_name, + crystal_rotation=crystal_rotation, + position=sample_position, + ) + + +def assemble_source_metadata( + source_position: Position[snx.NXsource, SampleRun], +) -> NMXSourceMetadata: + """Assemble source metadata for NMX reduction workflow.""" + return NMXSourceMetadata(position=source_position) + + +def _decide_fast_axis(da: sc.DataArray) -> str: + x_slice = da['x_pixel_offset', 0].coords['detector_number'] + y_slice = da['y_pixel_offset', 0].coords['detector_number'] + + if (x_slice.max() < y_slice.max()).value: + return 'y' + elif (x_slice.max() > y_slice.max()).value: + return 'x' + else: + raise ValueError( + "Cannot decide fast axis based on pixel offsets. " + "Please specify the fast axis explicitly." + ) + + +def _decide_step(offsets: sc.Variable) -> sc.Variable: + """Decide the step size based on the offsets assuming at least 2 values.""" + sorted_offsets = sc.sort(offsets, key=offsets.dim, order='ascending') + return sorted_offsets[1] - sorted_offsets[0] + + +def _normalize_vector(vec: sc.Variable) -> sc.Variable: + return vec / sc.norm(vec) + + +def _retrieve_crystal_rotation( + file_path: Filename[SampleRun], +) -> Position[snx.NXcrystal, SampleRun]: + """Temporary provider to retrieve crystal rotation from Nexus file.""" + from ess.reduce.nexus._nexus_loader import load_from_path + from ess.reduce.nexus.types import NeXusLocationSpec + + spec = NeXusLocationSpec( + filename=file_path, + component_name='sample/crystal_rotation', + ) + try: + rotation: snx.nxtransformations.Transform = load_from_path(location=spec) + except KeyError: + import warnings + + warnings.warn( + "No crystal rotation found in the Nexus file under " + f"'entry/{spec.component_name}'. Returning zero rotation.", + RuntimeWarning, + stacklevel=1, + ) + zero_rotation = sc.vector([0, 0, 0], unit='deg') + return Position[snx.NXcrystal, SampleRun](zero_rotation) + else: + # TODO: Make sure if retrieving rotation vector is enough here. + return Position[snx.NXcrystal, SampleRun](rotation.vector) + + +def assemble_detector_metadata( + detector_component: NeXusComponent[snx.NXdetector, SampleRun], + transformation: NeXusTransformation[snx.NXdetector, SampleRun], + sample_position: Position[snx.NXsample, SampleRun], + source_position: Position[snx.NXsource, SampleRun], + empty_detector: EmptyDetector[SampleRun], +) -> NMXDetectorMetadata: + """Assemble detector metadata for NMX reduction workflow.""" + positions = empty_detector.coords['position'] + # Origin should be the center of the detector. + origin = positions.mean() + _fast_axis = _decide_fast_axis(empty_detector) + _slow_axis = 'y' if _fast_axis == 'x' else 'x' + t_unit = transformation.value.unit + + axis_vectors = { + 'x': positions['x_pixel_offset', 1]['y_pixel_offset', 0] + - positions['x_pixel_offset', 0]['y_pixel_offset', 0], + 'y': positions['y_pixel_offset', 1]['x_pixel_offset', 0] + - positions['y_pixel_offset', 0]['x_pixel_offset', 0], + } + + fast_axis_vector = axis_vectors[_fast_axis].to(unit=t_unit) + slow_axis_vector = axis_vectors[_slow_axis].to(unit=t_unit) + x_pixel_size = _decide_step(empty_detector.coords['x_pixel_offset']) + y_pixel_size = _decide_step(empty_detector.coords['y_pixel_offset']) + distance = sc.norm(origin - source_position.to(unit=origin.unit)) + + # We save the first pixel position so that DIALS can read use it. + flattened = empty_detector.flatten(to='detector_number') + first_pixel_number = flattened.coords['detector_number'].min() + first_pixel_position = flattened['detector_number', first_pixel_number].coords[ + 'position' + ] + first_pixel_position_from_sample = first_pixel_position - sample_position + + return NMXDetectorMetadata( + detector_name=detector_component['nexus_component_name'], + x_pixel_size=x_pixel_size, + y_pixel_size=y_pixel_size, + origin=origin, + fast_axis=_normalize_vector(fast_axis_vector), + fast_axis_dim=_fast_axis + '_pixel_offset', + slow_axis=_normalize_vector(slow_axis_vector), + slow_axis_dim=_slow_axis + '_pixel_offset', + distance=distance, + first_pixel_position=first_pixel_position_from_sample, + ) + + +@register_workflow +def NMXWorkflow() -> sciline.Pipeline: + generic_wf = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[]) + + generic_wf.insert(_retrieve_crystal_rotation) + generic_wf.insert(assemble_sample_metadata) + generic_wf.insert(assemble_source_metadata) + generic_wf.insert(assemble_detector_metadata) + for key, value in default_parameters.items(): + generic_wf[key] = value + + return generic_wf + + +def _validate_mergable_workflow(wf: sciline.Pipeline): + if wf.indices: + raise NotImplementedError("Only flat workflow can be merged.") + + +def _merge_workflows( + base_wf: sciline.Pipeline, merged_wf: sciline.Pipeline +) -> sciline.Pipeline: + _validate_mergable_workflow(base_wf) + _validate_mergable_workflow(merged_wf) + + for key, spec in merged_wf.underlying_graph.nodes.items(): + if 'value' in spec: + base_wf[key] = spec['value'] + elif (provider_spec := spec.get('provider')) is not None: + base_wf.insert(provider_spec.func) + + return base_wf + + +def initialize_nmx_workflow(*, config: WorkflowConfig) -> sciline.Pipeline: + """Initialize NMX workflow according to the workflow configuration. + + If a TOF lookup table file path is provided in the configuration, + it is used directly. Otherwise, a TOF simulation workflow is added to + the NMX workflow to compute the lookup table on-the-fly. + + All other parameters required for TOF simulation are also set + as parameters in the workflow. + + Parameters + ---------- + config: + Workflow configuration for NMX reduction. + params: + Additional parameters to set in the workflow. + + """ + wf = NMXWorkflow() + if config.tof_lookup_table_file_path is not None: + wf[TimeOfFlightLookupTableFilename] = config.tof_lookup_table_file_path + else: + wf = _merge_workflows(base_wf=wf, merged_wf=TofLookupTableWorkflow()) + wf.insert(_simulate_fixed_wavelength_tof) + wmax = sc.scalar(config.tof_simulation_max_wavelength, unit='angstrom') + wmin = sc.scalar(config.tof_simulation_min_wavelength, unit='angstrom') + wf[TofSimulationMaxWavelength] = wmax + wf[TofSimulationMinWavelength] = wmin + wf[SimulationSeed] = config.tof_simulation_seed + ltotal_min = sc.scalar(value=config.tof_simulation_min_ltotal, unit='m') + ltotal_max = sc.scalar(value=config.tof_simulation_max_ltotal, unit='m') + wf[LtotalRange] = LtotalRange((ltotal_min, ltotal_max)) + + return wf + + +__all__ = ['NMXWorkflow'] diff --git a/packages/essnmx/tests/conftest.py b/packages/essnmx/tests/conftest.py new file mode 100644 index 00000000..8710000b --- /dev/null +++ b/packages/essnmx/tests/conftest.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +# These fixtures cannot be found by pytest, +# if they are not defined in `conftest.py` under `tests` directory. diff --git a/packages/essnmx/tests/executable_test.py b/packages/essnmx/tests/executable_test.py new file mode 100644 index 00000000..5edf8b02 --- /dev/null +++ b/packages/essnmx/tests/executable_test.py @@ -0,0 +1,529 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +import pathlib +import subprocess +import time +from contextlib import contextmanager +from enum import Enum + +import h5py +import pydantic +import pytest +import scipp as sc +import scippnexus as snx +from scipp.testing import assert_identical + +from ess.nmx._executable_helper import ( + InputConfig, + OutputConfig, + ReductionConfig, + WorkflowConfig, + build_reduction_argument_parser, + reduction_config_from_args, +) +from ess.nmx.configurations import TimeBinCoordinate, TimeBinUnit, to_command_arguments +from ess.nmx.executables import reduction +from ess.nmx.types import Compression, NMXLauetof + + +def _build_arg_list_from_pydantic_instance(*instances: pydantic.BaseModel) -> list[str]: + args = {} + for instance in instances: + args.update(instance.model_dump(mode='python')) + args = {f"--{k.replace('_', '-')}": v for k, v in args.items() if v is not None} + + arg_list = [] + for k, v in args.items(): + if not isinstance(v, bool): + arg_list.append(k) + if isinstance(v, list): + arg_list.extend(str(item) for item in v) + elif isinstance(v, Enum): + arg_list.append(v.name) + else: + arg_list.append(str(v)) + elif v is True: + arg_list.append(k) + + return arg_list + + +def _default_config() -> ReductionConfig: + """Helper to create a default ReductionConfig instance.""" + return ReductionConfig( + inputs=InputConfig(input_file=['']), + workflow=WorkflowConfig(), + output=OutputConfig(), + ) + + +def _check_non_default_config(testing_config: ReductionConfig) -> None: + """Helper to check that all values in the config are non-default.""" + default_config = _default_config() + testing_children = testing_config._children + default_children = default_config._children + for testing_child, default_child in zip( + testing_children, default_children, strict=True + ): + testing_model = testing_child.model_dump(mode='python') + default_model = default_child.model_dump(mode='python') + for key, testing_value in testing_model.items(): + if key == 'tof_lookup_table_file_path': + # This value may be None or default, so we skip the check. + continue + default_value = default_model[key] + assert testing_value != default_value, ( + f"Value for '{key}' is default: {testing_value}" + ) + + +def test_reduction_config() -> None: + """Test ReductionConfig argument parsing.""" + # Build config instances with non-default values. + input_options = InputConfig( + input_file=['test-input.h5'], + swmr=True, + detector_ids=[0, 1, 2, 3], + iter_chunk=True, + chunk_size_pulse=10, + chunk_size_events=100000, + ) + workflow_options = WorkflowConfig( + nbins=100, + min_time_bin=10, + max_time_bin=100_000, + time_bin_coordinate=TimeBinCoordinate.event_time_offset, + time_bin_unit=TimeBinUnit.us, + tof_simulation_num_neutrons=700_000, + tof_simulation_max_wavelength=5.0, + tof_simulation_min_wavelength=1.0, + tof_simulation_min_ltotal=140.0, + tof_simulation_max_ltotal=200.0, + tof_simulation_seed=12345, + ) + output_options = OutputConfig( + output_file='test-output.h5', + compression=Compression.NONE, + verbose=True, + skip_file_output=True, + overwrite=True, + ) + expected_config = ReductionConfig( + inputs=input_options, workflow=workflow_options, output=output_options + ) + # Check if all values are non-default. + _check_non_default_config(expected_config) + + # Build argument list manually, not using `to_command_arguments` to test it. + arg_list = _build_arg_list_from_pydantic_instance( + input_options, workflow_options, output_options + ) + assert arg_list == to_command_arguments(config=expected_config, one_line=False) + + # Parse arguments and build config from them. + parser = build_reduction_argument_parser() + args = parser.parse_args(arg_list) + config = reduction_config_from_args(args) + assert expected_config == config + + +@pytest.fixture(scope="session") +def small_nmx_nexus_path(): + """Fixture to provide the path to the small NMX NeXus file.""" + from ess.nmx.data import get_small_nmx_nexus + + return get_small_nmx_nexus() + + +def _check_output_file(output_file_path: pathlib.Path, nbins: int): + detector_names = [f'detector_panel_{i}' for i in range(3)] + mandatory_fields = ( + 'data', + 'distance', + 'fast_axis', + 'slow_axis', + 'origin', + 'x_pixel_size', + 'y_pixel_size', + 'origin', + ) + with snx.File(output_file_path, 'r') as f: + # Test + assert f['entry/instrument/name'][()] == 'NMX' + for name in detector_names: + det_gr = f[f'entry/instrument/{name}'] + assert det_gr is not None + toa_edges = det_gr['time_of_flight'][()] + assert len(toa_edges) == nbins + assert all(field_name in det_gr for field_name in mandatory_fields) + + +def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path): + """Test that the executable runs and returns the expected output.""" + output_file = tmp_path / "output.h5" + assert not output_file.exists() + + nbins = 20 # Small number of bins for testing. + # The output has 1280x1280 pixels per detector per time bin. + commands = ( + 'essnmx-reduce', + '--input-file', + small_nmx_nexus_path, + '--nbins', + str(nbins), + '--output-file', + output_file.as_posix(), + ) + # Validate that all commands are strings and contain no unsafe characters + result = subprocess.run( # noqa: S603 - We are not accepting arbitrary input here. + commands, text=True, capture_output=True, check=False + ) + assert result.returncode == 0 + assert output_file.exists() + _check_output_file(output_file, nbins=nbins) + + +@contextmanager +def known_warnings(): + with pytest.warns(RuntimeWarning, match="No crystal rotation*"): + yield + + +@pytest.fixture +def temp_output_file(tmp_path: pathlib.Path): + output_file_path = tmp_path / "scipp_output.h5" + yield output_file_path + if output_file_path.exists(): + output_file_path.unlink() + + +@pytest.fixture +def reduction_config( + small_nmx_nexus_path: pathlib.Path, temp_output_file: pathlib.Path +) -> ReductionConfig: + input_config = InputConfig(input_file=[small_nmx_nexus_path.as_posix()]) + # Compression option is not default (NONE) but + # the actual default compression option, BITSHUFFLE_LZ4, + # only properly works in linux so we set it to NONE here + # for convenience of testing on all platforms. + output_config = OutputConfig( + output_file=temp_output_file.as_posix(), + compression=Compression.NONE, + skip_file_output=True, # No need to write output file for most tests. + ) + return ReductionConfig(inputs=input_config, output=output_config) + + +def _retrieve_one_hist(results: NMXLauetof) -> sc.DataArray: + """Helper to retrieve the first DataArray from the results dictionary.""" + da = results.instrument.detectors['detector_panel_0'].data + if not isinstance(da, sc.DataArray): + raise TypeError("Histogram is not a DataArray.") + return da + + +def test_reduction_default_settings(reduction_config: ReductionConfig) -> None: + # Only check if reduction runs without errors with default settings. + with known_warnings(): + reduction(config=reduction_config) + + +def test_reduction_only_number_of_time_bins(reduction_config: ReductionConfig) -> None: + reduction_config.workflow.nbins = 20 + with known_warnings(): + hist = _retrieve_one_hist(reduction(config=reduction_config)) + + # Check that the number of time bins is as expected. + assert len(hist.coords['tof']) == 21 # nbins + 1 edges + + +def test_histogram_event_time_offset(reduction_config: ReductionConfig) -> None: + reduction_config.workflow.nbins = 20 + reduction_config.workflow.time_bin_coordinate = TimeBinCoordinate.event_time_offset + with known_warnings(): + hist = _retrieve_one_hist(reduction(config=reduction_config)) + + # Check that the number of time bins is as expected. + assert len(hist.coords['event_time_offset']) == 21 # nbins + 1 edges + # Check if the histogram result is reasonable + zero = sc.scalar(0.0, unit='counts', dtype='float32', variance=0.0) + assert bool(hist.data.sum() > zero) + + +def test_histogram_invalid_min_max_raises(reduction_config: ReductionConfig) -> None: + reduction_config.workflow.min_time_bin = 120 + reduction_config.workflow.max_time_bin = 100 + with pytest.raises(ValueError, match='Cannot build a time bin edges coordinate'): + with known_warnings(): + reduction(config=reduction_config) + + +def test_histogram_invalid_min_max_raises_eto( + reduction_config: ReductionConfig, +) -> None: + reduction_config.workflow.time_bin_coordinate = TimeBinCoordinate.event_time_offset + reduction_config.workflow.min_time_bin = 50 + reduction_config.workflow.max_time_bin = 40 + with pytest.raises(ValueError, match='Cannot build a time bin edges coordinate'): + with known_warnings(): + reduction(config=reduction_config) + + +@pytest.mark.parametrize( + argnames="t_coord", + argvalues=[TimeBinCoordinate.time_of_flight, TimeBinCoordinate.event_time_offset], +) +def test_histogram_out_of_range_min_warns( + reduction_config: ReductionConfig, t_coord: TimeBinCoordinate +) -> None: + reduction_config.workflow.time_bin_coordinate = t_coord + reduction_config.workflow.nbins = 20 + reduction_config.workflow.min_time_bin = 1_000 + reduction_config.workflow.max_time_bin = 2_000 + with pytest.warns(UserWarning, match='is bigger than all'): + with known_warnings(): + results = reduction(config=reduction_config) + + for histogram in results.instrument.detectors.values(): + assert isinstance(histogram.data, sc.DataArray) + da = histogram.data + assert_identical( + da.data.sum(), sc.scalar(0.0, unit='counts', dtype='float32', variance=0.0) + ) + + +@pytest.mark.parametrize( + argnames="t_coord", + argvalues=[TimeBinCoordinate.time_of_flight, TimeBinCoordinate.event_time_offset], +) +def test_histogram_out_of_range_max_warns( + reduction_config: ReductionConfig, t_coord: TimeBinCoordinate +) -> None: + reduction_config.workflow.time_bin_coordinate = t_coord + reduction_config.workflow.nbins = 10 + reduction_config.workflow.min_time_bin = -1 + reduction_config.workflow.max_time_bin = 0 + with pytest.warns(UserWarning, match='is smaller than all'): + with known_warnings(): + results = reduction(config=reduction_config) + + for det in results.instrument.detectors.values(): + da = det.data + assert isinstance(da, sc.DataArray) + assert_identical( + da.data.sum(), sc.scalar(0.0, unit='counts', dtype='float32', variance=0.0) + ) + + +@pytest.fixture +def tof_lut_file_path(tmp_path: pathlib.Path): + """Fixture to provide the path to the small NMX NeXus file.""" + from dataclasses import is_dataclass + + from ess.reduce.time_of_flight import TimeOfFlightLookupTable + + from ess.nmx.workflows import initialize_nmx_workflow + + # Simply use the default workflow for testing. + workflow = initialize_nmx_workflow(config=WorkflowConfig()) + tof_lut: TimeOfFlightLookupTable = workflow.compute(TimeOfFlightLookupTable) + + # Change the tof range a bit for testing. + if isinstance(tof_lut, sc.DataArray): + tof_lut *= 2 + elif is_dataclass(tof_lut): + tof_lut.array *= 2 + else: + raise TypeError("Unexpected type for TOF lookup table.") + + lut_file_path = tmp_path / "nmx_tof_lookup_table.h5" + tof_lut.save_hdf5(lut_file_path.as_posix()) + yield lut_file_path + if lut_file_path.exists(): + lut_file_path.unlink() + + +def test_reduction_with_tof_lut_file( + reduction_config: ReductionConfig, tof_lut_file_path: pathlib.Path +) -> None: + # Make sure the config uses no TOF lookup table file initially. + assert reduction_config.workflow.tof_lookup_table_file_path is None + with known_warnings(): + default_results = reduction(config=reduction_config) + + # Update config to use the TOF lookup table file. + reduction_config.workflow.tof_lookup_table_file_path = tof_lut_file_path.as_posix() + with known_warnings(): + results = reduction(config=reduction_config) + + default_hists = [det.data for det in default_results.instrument.detectors.values()] + hists = [det.data for det in results.instrument.detectors.values()] + + for default_hist, hist in zip(default_hists, hists, strict=True): + assert isinstance(default_hist, sc.DataArray) + assert isinstance(hist, sc.DataArray) + tof_edges_default = default_hist.coords['tof'] + tof_edges = hist.coords['tof'] + assert_identical(default_hist.data, hist.data) + assert_identical(tof_edges_default * 2, tof_edges) + + +def test_reduction_succeed_when_skipping_evenif_output_file_exists( + reduction_config: ReductionConfig, temp_output_file: pathlib.Path +) -> None: + # Make sure the file exists + temp_output_file.touch(exist_ok=True) + # Make sure the file output is skipped. + reduction_config.output.skip_file_output = True + + # Adjust workflow setting to finish fast. + reduction_config.workflow.nbins = 2 + reduction_config.workflow.time_bin_coordinate = TimeBinCoordinate.event_time_offset + with known_warnings(): + reduction(config=reduction_config) + + +def test_reduction_fails_fast_if_output_file_exists( + reduction_config: ReductionConfig, temp_output_file: pathlib.Path +) -> None: + # Make sure the file exists + temp_output_file.touch() + # Make sure file output is NOT skipped. + reduction_config.output.skip_file_output = False + + start = time.time() + with pytest.raises(FileExistsError): + reduction(config=reduction_config) + finish = time.time() + + # Check if the `reduction` call fails within 1 second. + # There is no special reason why it is 1 second. + # It should just fail as fast as possible. + assert finish - start < 1 + + +def test_reduction_compression_gzip( + reduction_config: ReductionConfig, tmp_path: pathlib.Path +) -> None: + reduction_config.output.skip_file_output = False + reduction_config.workflow.nbins = 5 # For faster test + file_paths: dict[Compression, pathlib.Path] = {} + + for compress_mode in (Compression.NONE, Compression.GZIP): + reduction_config.output.compression = compress_mode + cur_file_path = tmp_path / f'compress_{compress_mode}_output.hdf' + file_paths[compress_mode] = cur_file_path + assert not cur_file_path.exists() + reduction_config.output.output_file = cur_file_path.as_posix() + # Running the whole reduction instead of only saving the file on purpose. + with known_warnings(): + reduction(config=reduction_config) + assert cur_file_path.exists() + + assert ( + file_paths[Compression.NONE].stat().st_size + > file_paths[Compression.GZIP].stat().st_size + ) + with h5py.File(file_paths[Compression.NONE]) as file: + for i in range(3): + assert file[f'entry/instrument/detector_panel_{i}/data'].chunks is None + + with h5py.File(file_paths[Compression.GZIP]) as file: + for i in range(3): + data_path = f'entry/instrument/detector_panel_{i}/data' + assert file[data_path].chunks == (1280, 1280, 1) + assert file[data_path].compression == 'gzip' + assert file[data_path].compression_opts == 4 + + +try: + # Just checking availability + import bitshuffle.h5 # noqa: F401 +except ImportError: + BITSHUFFLE_AVAILABLE = False +else: + BITSHUFFLE_AVAILABLE = True + + +@pytest.mark.skipif( + not BITSHUFFLE_AVAILABLE, + reason="Bitshuffle is not available in this environment.", +) +def test_reduction_compression_bitshuffle_smaller_than_gzip( + reduction_config: ReductionConfig, tmp_path: pathlib.Path +) -> None: + reduction_config.output.skip_file_output = False + reduction_config.workflow.nbins = 5 # For faster test + file_paths: dict[Compression, pathlib.Path] = {} + total_times: dict[Compression, pathlib.Path] = {} + + for compress_mode in (Compression.GZIP, Compression.BITSHUFFLE_LZ4): + reduction_config.output.compression = compress_mode + cur_file_path = tmp_path / f'compress_{compress_mode}_output.hdf' + file_paths[compress_mode] = cur_file_path + assert not cur_file_path.exists() + reduction_config.output.output_file = cur_file_path.as_posix() + # Running the whole reduction instead of only saving the file on purpose. + with known_warnings(): + start = time.time() + reduction(config=reduction_config) + end = time.time() + + assert cur_file_path.exists() + total_times[compress_mode] = end - start + + # GZIP is expected to have better compression ratio than BITSHUFFLE + assert ( + file_paths[Compression.BITSHUFFLE_LZ4].stat().st_size + > file_paths[Compression.GZIP].stat().st_size + ) + # BITSHUFFLE is expected to be faster than GZIP + assert total_times[Compression.BITSHUFFLE_LZ4] < total_times[Compression.GZIP] + + with h5py.File(file_paths[Compression.GZIP]) as file: + for i in range(3): + data_path = f'entry/instrument/detector_panel_{i}/data' + assert file[data_path].chunks == (1280, 1280, 1) + assert file[data_path].compression == 'gzip' + + with h5py.File(file_paths[Compression.BITSHUFFLE_LZ4]) as file: + for i in range(3): + data_path = f'entry/instrument/detector_panel_{i}/data' + assert file[data_path].chunks == (1280, 1280, 1) + # For some reason it doesn't write the compression. + # so we check the filter instead. + # assert file[data_path].compression == 'bitshuffle' + assert '32008' in file[data_path]._filters + + +@pytest.mark.skipif( + BITSHUFFLE_AVAILABLE, + reason="Bitshuffle is available in this environment so it won't fall back.", +) +def test_reduction_compression_bitshuffle_fall_back_to_gzip( + reduction_config: ReductionConfig, temp_output_file: pathlib.Path +) -> None: + reduction_config.output.skip_file_output = False + reduction_config.workflow.nbins = 5 # For faster test + reduction_config.output.compression = Compression.BITSHUFFLE_LZ4 + reduction_config.output.output_file = temp_output_file.as_posix() + + with known_warnings(): + with pytest.warns(UserWarning, match='bitshuffle.h5'): + reduction(config=reduction_config) + + with h5py.File(temp_output_file) as file: + for i in range(3): + data_path = f'entry/instrument/detector_panel_{i}/data' + assert file[data_path].chunks == (1280, 1280, 1) + assert file[data_path].compression == 'gzip' + + +def test_reduction_duplicated_path_raises(reduction_config: ReductionConfig) -> None: + # Run with two files with same names. + reduction_config.inputs.input_file = reduction_config.inputs.input_file * 2 + with pytest.raises( + ValueError, match=r'Duplicated file paths or pattern found.*small_nmx_nexus.hdf' + ): + reduction(config=reduction_config) diff --git a/packages/essnmx/tests/mcstas/exporter_test.py b/packages/essnmx/tests/mcstas/exporter_test.py new file mode 100644 index 00000000..576c5f23 --- /dev/null +++ b/packages/essnmx/tests/mcstas/exporter_test.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import io + +import numpy as np +import pytest +import scipp as sc + +from ess.nmx.mcstas.nexus import export_as_nexus +from ess.nmx.mcstas.types import NMXReducedDataGroup + + +@pytest.fixture +def reduced_data() -> NMXReducedDataGroup: + rng = np.random.default_rng(42) + id_list = sc.array(dims=['event'], values=rng.integers(0, 12, size=100)) + t_list = sc.array(dims=['event'], values=rng.random(size=100, dtype=float)) + counts = ( + sc.DataArray( + data=sc.ones(dims=['event'], shape=[100]), + coords={'id': id_list, 't': t_list}, + ) + .group('id') + .hist(t=10) + ) + + return NMXReducedDataGroup( + sc.DataGroup( + dict( # noqa: C408 + counts=counts, + proton_charge=sc.scalar(1.0, unit='counts'), + crystal_rotation=sc.vector(value=[0.0, 20.0, 0.0], unit='deg'), + fast_axis=sc.vectors( + dims=['panel'], + values=[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + unit='m', + ), + slow_axis=sc.vectors( + dims=['panel'], + values=[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], + unit='m', + ), + origin_position=sc.vectors( + dims=['panel'], + values=[[-0.2, 0.0, 0.0], [0.0, 0.0, 0.0], [0.2, 0.0, 0.0]], + unit='m', + ), + sample_position=sc.vector(value=[0.0, 0.0, 0.0], unit='m'), + source_position=sc.vector(value=[-3, 0.0, -4], unit='m'), + sample_name=sc.scalar('Unit Test Sample'), + position=sc.zeros(dims=['panel', 'id'], shape=[3, 4], unit='m'), + ) + ) + ) + + +def _is_bitshuffle_available() -> bool: + import platform + + return not ( + platform.machine().startswith("arm") or platform.platform().startswith('win') + ) + + +def test_mcstas_reduction_export_to_bytestream( + reduced_data: NMXReducedDataGroup, +) -> None: + """Test export method.""" + import h5py + import numpy as np + import scipp as sc + + data_fields = [ + 'NXdetector', + 'NXsample', + 'NXsource', + 'NXinstrument', + 'definition', + 'name', + ] + + with io.BytesIO() as bio: + if not _is_bitshuffle_available(): + # bitshuffle does not build correctly on Windows and ARM machines + # We are keeping this test here to catch when it builds correctly + # in the future. + with pytest.warns( + DeprecationWarning, match='Please use ``export_as_nxlauetof`` instead.' + ): + with pytest.warns(UserWarning, match='bitshuffle.h5'): + export_as_nexus(reduced_data, bio) + else: + with pytest.warns( + DeprecationWarning, match='Please use ``export_as_nxlauetof`` instead.' + ): + export_as_nexus(reduced_data, bio) + + with h5py.File(bio, 'r') as f: + assert 'NMX_data' in f + nmx_data: h5py.Group = f.require_group('NMX_data') + for field in data_fields: + assert field in nmx_data + + nx_detector = nmx_data.require_group('NXdetector') + assert np.all( + nx_detector['fast_axis'][()] == reduced_data['fast_axis'].values + ) + assert np.all( + nx_detector['slow_axis'][()] == reduced_data['slow_axis'].values + ) + assert np.all( + nx_detector['origin'][()] == reduced_data['origin_position'].values + ) + + instrument_data = nmx_data.require_group('NXinstrument') + assert ( + instrument_data['proton_charge'][()] + == reduced_data['proton_charge'].value + ) + + det1_data = instrument_data.require_group('detector_1') + assert np.all(det1_data['counts'][()] == reduced_data['counts'].values) + assert np.all( + det1_data['pixel_id'][()] == reduced_data['counts'].coords['id'].values + ) + assert np.all( + det1_data['t_bin'][()] == reduced_data['counts'].coords['t'].values + ) + + nx_sample = nmx_data.require_group('NXsample') + sample_name: bytes = nx_sample['name'][()] + assert sample_name.decode() == reduced_data['sample_name'].value + + nx_source = nmx_data.require_group('NXsource') + assert ( + nx_source['distance'][()] + == sc.norm(reduced_data['source_position']).value + ) diff --git a/packages/essnmx/tests/mcstas/loader_test.py b/packages/essnmx/tests/mcstas/loader_test.py new file mode 100644 index 00000000..c9d090d3 --- /dev/null +++ b/packages/essnmx/tests/mcstas/loader_test.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import pathlib +import sys +from collections.abc import Generator + +import pytest +import scipp as sc +import scippnexus as snx +from scipp.testing import assert_allclose, assert_identical + +from ess.nmx import NMXMcStasWorkflow +from ess.nmx.data import get_small_mcstas +from ess.nmx.mcstas.load import bank_names_to_detector_names, load_crystal_rotation +from ess.nmx.mcstas.types import ( + DetectorBankPrefix, + DetectorIndex, + FilePath, + NMXRawEventCountsDataGroup, +) + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) +from mcstas_description_examples import ( + no_detectors, + one_detector_no_filename, + two_detectors_same_filename, + two_detectors_two_filenames, +) + + +def check_nmxdata_properties( + dg: NMXRawEventCountsDataGroup, fast_axis, slow_axis +) -> None: + assert isinstance(dg, sc.DataGroup) + assert_allclose(dg['fast_axis'], fast_axis, atol=sc.scalar(0.005)) + assert_identical(dg['slow_axis'], slow_axis) + + +def check_scalar_properties_mcstas_3(dg: NMXRawEventCountsDataGroup): + """Test helper for NMXData loaded from McStas 3. + + Expected numbers are hard-coded based on the sample file. + """ + assert_identical(dg['crystal_rotation'], sc.vector([0, 0, 0], unit='deg')) + assert_identical(dg['sample_position'], sc.vector(value=[0, 0, 0], unit='m')) + assert_identical( + dg['source_position'], sc.vector(value=[-0.53123, 0.0, -157.405], unit='m') + ) + assert dg['sample_name'] == sc.scalar("sampleMantid") + + +@pytest.mark.parametrize( + ('detector_index', 'fast_axis', 'slow_axis'), + [ + # Expected values are provided by the IDS + # based on the simulation settings of the sample file. + (0, (1.0, 0.0, -0.01), (0.0, 1.0, 0.0)), + (1, (-0.01, 0.0, -1.0), (0.0, 1.0, 0.0)), + (2, (0.01, 0.0, 1.0), (0.0, 1.0, 0.0)), + ], +) +def test_file_reader_mcstas3(detector_index, fast_axis, slow_axis) -> None: + file_path = get_small_mcstas() + + pl = NMXMcStasWorkflow() + pl[FilePath] = file_path + pl[DetectorIndex] = detector_index + dg, bank = pl.compute((NMXRawEventCountsDataGroup, DetectorBankPrefix)).values() + + entry_path = f"entry1/data/{bank}_dat_list_p_x_y_n_id_t" + with snx.File(file_path) as file: + raw_data = file[entry_path]["events"][()] + data_length = raw_data.sizes['dim_0'] + + check_scalar_properties_mcstas_3(dg) + assert dg['weights'].sizes['event'] == data_length + check_nmxdata_properties(dg, sc.vector(fast_axis), sc.vector(slow_axis)) + + +@pytest.fixture(params=[get_small_mcstas]) +def tmp_mcstas_file( + tmp_path: pathlib.Path, + request: pytest.FixtureRequest, +) -> Generator[pathlib.Path, None, None]: + import os + import shutil + + original_file_path = request.param() + + tmp_file = tmp_path / pathlib.Path('file.h5') + shutil.copy(original_file_path, tmp_file) + yield tmp_file + os.remove(tmp_file) + + +def test_file_reader_mcstas_additional_fields(tmp_mcstas_file: pathlib.Path) -> None: + """Check if additional fields names do not break the loader.""" + import h5py + + entry_path = "entry1/data/bank01_events_dat_list_p_x_y_n_id_t" + new_entry_path = entry_path + '_L' + + with h5py.File(tmp_mcstas_file, 'r+') as file: + dataset = file[entry_path] + del file[entry_path] + file[new_entry_path] = dataset + + pl = NMXMcStasWorkflow() + pl[FilePath] = str(tmp_mcstas_file) + pl[DetectorIndex] = 0 + dg = pl.compute(NMXRawEventCountsDataGroup) + + assert isinstance(dg, sc.DataGroup) + + +@pytest.fixture +def rotation_mission_tmp_file(tmp_mcstas_file: pathlib.Path) -> pathlib.Path: + import h5py + + param_keys = tuple(f"entry1/simulation/Param/XtalPhi{key}" for key in "XYZ") + + # Remove the rotation parameters from the file. + with h5py.File(tmp_mcstas_file, 'a') as file: + for key in param_keys: + del file[key] + + return tmp_mcstas_file + + +def test_missing_rotation(rotation_mission_tmp_file: FilePath) -> None: + with pytest.raises(KeyError, match="XtalPhiX"): + load_crystal_rotation(rotation_mission_tmp_file, None) + # McStasInstrument is not used due to error in the file. + + +def test_bank_names_to_detector_names_two_detectors(): + res = bank_names_to_detector_names(two_detectors_two_filenames) + assert len(res) == 2 + assert all(len(v) == 1 for v in res.values()) + + +def test_bank_names_to_detector_names_same_filename(): + res = bank_names_to_detector_names(two_detectors_same_filename) + assert len(res) == 1 + assert all(len(v) == 2 for v in res.values()) + + +def test_bank_names_to_detector_names_no_detectors(): + res = bank_names_to_detector_names(no_detectors) + assert len(res) == 0 + + +def test_bank_names_to_detector_names_no_filename(): + res = bank_names_to_detector_names(one_detector_no_filename) + assert len(res) == 1 + ((bank, (detector,)),) = res.items() + assert bank == detector diff --git a/packages/essnmx/tests/mcstas/mcstas_description_examples.py b/packages/essnmx/tests/mcstas/mcstas_description_examples.py new file mode 100644 index 00000000..40315ead --- /dev/null +++ b/packages/essnmx/tests/mcstas/mcstas_description_examples.py @@ -0,0 +1,84 @@ +# flake8: noqa: E501 + +no_detectors = """ +SPLIT 999 COMPONENT Xtal = Single_crystal( + order = 1, + p_transmit=0.001, + reflections = "Rubredoxin.lau", + xwidth = XtalSize_width, + yheight = XtalSize_height, + zdepth = XtalSize_depth, + mosaic = XtalMosaicity, + delta_d_d=1e-4) + AT (0, 0, deltaz) RELATIVE PREVIOUS + ROTATED (XtalPhiX,XtalPhiY, XtalPhiZ) RELATIVE armSample + EXTEND %{ + if (!SCATTERED) {ABSORB;} + %} + +COMPONENT Sphere = PSD_monitor_4PI( + nx = 360, ny = 360, filename = "4pi", radius = 0.2, + restore_neutron = 1) + +AT (0, 0, deltaz) RELATIVE armSample +""" + +two_detectors_two_filenames = """ +COMPONENT nD_Mantid_0 = Monitor_nD( + options ="mantid square x limits=[0 0.512] bins=1280 y limits=[0 0.512] bins=1280, neutron pixel min=1 t, list all neutrons", + xmin = 0, + xmax = 0.512, + ymin = 0, + ymax = 0.512, + restore_neutron = 1, + filename = "bank01_events.dat") + AT (-0.25, -0.25, 0.29) RELATIVE armSample + ROTATED (0, 0, 0) RELATIVE armSample + +COMPONENT nD_Mantid_1 = Monitor_nD( + options ="mantid square x limits=[0 0.512] bins=1280 y limits=[0 0.512] bins=1280, neutron pixel min=2000000 t, list all neutrons", + xmin = 0, + xmax = 0.512, + ymin = 0, + ymax = 0.512, + restore_neutron = 1, + filename = "bank02_events.dat") + AT (-0.29, -0.25, 0.25) RELATIVE armSample + ROTATED (0, 90, 0) RELATIVE armSample +""" + +one_detector_no_filename = """ +COMPONENT nD_Mantid_2 = Monitor_nD( + options ="mantid square x limits=[0 0.512] bins=1280 y limits=[0 0.512] bins=1280, neutron pixel min=2000000 t, list all neutrons", + xmin = 0, + xmax = 0.512, + ymin = 0, + ymax = 0.512, + restore_neutron = 1, + AT (-0.29, -0.25, 0.25) RELATIVE armSample + ROTATED (0, 90, 0) RELATIVE armSample +""" + +two_detectors_same_filename = """ +COMPONENT nD_Mantid_0 = Monitor_nD( + options ="mantid square x limits=[0 0.512] bins=1280 y limits=[0 0.512] bins=1280, neutron pixel min=1 t, list all neutrons", + xmin = 0, + xmax = 0.512, + ymin = 0, + ymax = 0.512, + restore_neutron = 1, + filename = "bank01_events.dat") + AT (-0.25, -0.25, 0.29) RELATIVE armSample + ROTATED (0, 0, 0) RELATIVE armSample + +COMPONENT nD_Mantid_1 = Monitor_nD( + options ="mantid square x limits=[0 0.512] bins=1280 y limits=[0 0.512] bins=1280, neutron pixel min=2000000 t, list all neutrons", + xmin = 0, + xmax = 0.512, + ymin = 0, + ymax = 0.512, + restore_neutron = 1, + filename = "bank01_events.dat") + AT (-0.29, -0.25, 0.25) RELATIVE armSample + ROTATED (0, 90, 0) RELATIVE armSample +""" diff --git a/packages/essnmx/tests/mcstas/mcstas_io_test.py b/packages/essnmx/tests/mcstas/mcstas_io_test.py new file mode 100644 index 00000000..39eb8426 --- /dev/null +++ b/packages/essnmx/tests/mcstas/mcstas_io_test.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import pathlib + +import pytest +import scipp as sc + +from ess.nmx.data import get_small_mcstas +from ess.nmx.mcstas.load import load_raw_event_data, raw_event_data_chunk_generator + + +@pytest.fixture(params=[get_small_mcstas]) +def mcstas_file_path(request: pytest.FixtureRequest) -> pathlib.Path: + return request.param() + + +def test_generator_loading_at_once(mcstas_file_path) -> None: + from ess.nmx.mcstas.load import detector_name_from_index + + detector_name = detector_name_from_index(0) + whole_chunk = next( + raw_event_data_chunk_generator( + mcstas_file_path, detector_name=detector_name, chunk_size=-1 + ) + ) + loaded_data = load_raw_event_data( + mcstas_file_path, detector_name=detector_name, bank_prefix=None + ) + assert sc.identical(whole_chunk, loaded_data) + + +def test_generator_loading_warns_if_too_small(mcstas_file_path) -> None: + from ess.nmx.mcstas.load import detector_name_from_index + + detector_name = detector_name_from_index(0) + with pytest.warns(UserWarning, match="The chunk size may be too small"): + next( + raw_event_data_chunk_generator( + mcstas_file_path, detector_name=detector_name, chunk_size=1 + ) + ) diff --git a/packages/essnmx/tests/mcstas/workflow_test.py b/packages/essnmx/tests/mcstas/workflow_test.py new file mode 100644 index 00000000..6f78bb76 --- /dev/null +++ b/packages/essnmx/tests/mcstas/workflow_test.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pathlib + +import pandas as pd +import pytest +import sciline as sl +import scipp as sc + +from ess.nmx import NMXMcStasWorkflow +from ess.nmx.data import get_small_mcstas +from ess.nmx.mcstas.reduction import merge_panels +from ess.nmx.mcstas.types import ( + DetectorIndex, + FilePath, + MaximumCounts, + NMXRawEventCountsDataGroup, + NMXReducedDataGroup, + TimeBinSteps, +) + + +@pytest.fixture(params=[get_small_mcstas]) +def mcstas_file_path(request: pytest.FixtureRequest) -> pathlib.Path: + return request.param() + + +@pytest.fixture +def mcstas_workflow(mcstas_file_path: pathlib.Path) -> sl.Pipeline: + wf = NMXMcStasWorkflow() + wf[FilePath] = mcstas_file_path + wf[TimeBinSteps] = 50 + return wf + + +@pytest.fixture +def multi_bank_mcstas_workflow(mcstas_workflow: sl.Pipeline) -> sl.Pipeline: + pl = mcstas_workflow.copy() + pl[NMXReducedDataGroup] = ( + pl[NMXReducedDataGroup] + .map(pd.DataFrame({DetectorIndex: range(3)}).rename_axis('panel')) + .reduce(index='panel', func=merge_panels) + ) + return pl + + +def test_pipeline_builder( + mcstas_workflow: sl.Pipeline, mcstas_file_path: pathlib.Path +) -> None: + assert mcstas_workflow.get(FilePath).compute() == mcstas_file_path + + +def test_pipeline_mcstas_loader(mcstas_workflow: sl.Pipeline) -> None: + """Test if the loader graph is complete.""" + mcstas_workflow[DetectorIndex] = 0 + nmx_data = mcstas_workflow.compute(NMXRawEventCountsDataGroup) + assert isinstance(nmx_data, sc.DataGroup) + assert nmx_data.sizes['id'] == 1280 * 1280 + + +def test_pipeline_mcstas_reduction(multi_bank_mcstas_workflow: sl.Pipeline) -> None: + """Test if the loader graph is complete.""" + from scipp.testing import assert_allclose, assert_identical + + from ess.nmx.mcstas import default_parameters + + nmx_reduced_data = multi_bank_mcstas_workflow.compute(NMXReducedDataGroup) + assert nmx_reduced_data.shape == (3, (1280, 1280)[0] * (1280, 1280)[1], 50) + # Panel, Pixels, Time bins + assert isinstance(nmx_reduced_data, sc.DataGroup) + + # Check maximum value of weights. + assert_allclose( + nmx_reduced_data['counts'].max().data, + sc.scalar(default_parameters[MaximumCounts], unit='counts', dtype=float), + atol=sc.scalar(1e-10, unit='counts'), + rtol=sc.scalar(1e-8), + ) + assert_identical( + nmx_reduced_data['proton_charge'], + sc.scalar(1e-4, unit='dimensionless') + * nmx_reduced_data['counts'].data.sum('id').sum('t'), + ) + assert nmx_reduced_data.sizes['t'] == 50 diff --git a/packages/essnmx/tests/mtz_io_test.py b/packages/essnmx/tests/mtz_io_test.py new file mode 100644 index 00000000..e585cbb6 --- /dev/null +++ b/packages/essnmx/tests/mtz_io_test.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import pathlib + +import gemmi +import pytest +import scipp as sc + +from ess.nmx import mtz_io +from ess.nmx.data import get_small_mtz_samples +from ess.nmx.mtz_io import ( + DEFAULT_SPACE_GROUP_DESC, # P 1 + MtzDataFrame, + MTZFileIndex, + MTZFilePath, + NMXMtzDataArray, + NMXMtzDataFrame, + get_reciprocal_asu, + mtz_to_pandas, + nmx_mtz_dataframe_to_scipp_dataarray, + process_mtz_dataframe, + process_single_mtz_to_dataframe, + read_mtz_file, +) + + +@pytest.fixture(params=get_small_mtz_samples()) +def file_path(request) -> pathlib.Path: + return request.param + + +def test_gemmi_mtz(file_path: pathlib.Path) -> None: + mtz = read_mtz_file(MTZFilePath(file_path)) + assert mtz.spacegroup == gemmi.SpaceGroup("C 1 2 1") # Hard-coded value + assert len(mtz.columns[0]) == 100 # Number of samples, hard-coded value + + +@pytest.fixture +def gemmi_mtz_object(file_path: pathlib.Path) -> gemmi.Mtz: + return read_mtz_file(MTZFilePath(file_path)) + + +def test_mtz_to_pandas_dataframe(gemmi_mtz_object: gemmi.Mtz) -> None: + df = mtz_to_pandas(gemmi_mtz_object) + assert set(df.columns) == set(gemmi_mtz_object.column_labels()) + # Check if the test data are not all-same + first_column_name, second_column_name = df.columns[0:2] + assert not all(df[first_column_name] == df[second_column_name]) + + # Check if the data are the same + for column in gemmi_mtz_object.columns: + assert column.label in df.columns + assert all(df[column.label] == column.array) + + +def test_mtz_to_process_pandas_dataframe(gemmi_mtz_object: gemmi.Mtz) -> None: + df = process_single_mtz_to_dataframe(gemmi_mtz_object) + for expected_colum in ["hkl", "d", "resolution", *"HKL", "wavelength", "I", "SIGI"]: + assert expected_colum in df.columns + + for hkl_column in "HKL": + assert hkl_column in df.columns + assert df[hkl_column].dtype == int + + assert "hkl_asu" not in df.columns # It should be done on merged dataframes + + +@pytest.fixture +def mtz_list() -> list[gemmi.Mtz]: + return [ + read_mtz_file(MTZFilePath(file_path)) for file_path in get_small_mtz_samples() + ] + + +def test_get_space_group_with_spacegroup_desc() -> None: + assert ( + mtz_io.get_space_group_from_description(DEFAULT_SPACE_GROUP_DESC).short_name() + == "P1" + ) + + +@pytest.fixture +def conflicting_mtz_series( + mtz_list: list[gemmi.Mtz], +) -> list[gemmi.Mtz]: + mtz_list[MTZFileIndex(0)].spacegroup = gemmi.SpaceGroup(DEFAULT_SPACE_GROUP_DESC) + # Make sure the space groups are different + assert ( + mtz_list[MTZFileIndex(0)].spacegroup.short_name() + != mtz_list[MTZFileIndex(1)].spacegroup.short_name() + ) + + return mtz_list + + +def test_get_unique_space_group_raises_on_conflict( + conflicting_mtz_series: list[gemmi.Mtz], +) -> None: + reg = r"Multiple space groups found:.+P 1.+C 1 2 1" + space_groups = [ + mtz_io.get_space_group_from_mtz(mtz) for mtz in conflicting_mtz_series + ] + with pytest.raises(ValueError, match=reg): + mtz_io.get_unique_space_group(*space_groups) + + +@pytest.fixture +def merged_mtz_dataframe(mtz_list: list[gemmi.Mtz]) -> MtzDataFrame: + """Tests if the merged data frame has the expected columns.""" + reduced_mtz = [process_single_mtz_to_dataframe(mtz) for mtz in mtz_list] + return mtz_io.merge_mtz_dataframes(*reduced_mtz) + + +@pytest.fixture +def nmx_data_frame( + mtz_list: list[gemmi.Mtz], + merged_mtz_dataframe: MtzDataFrame, +) -> NMXMtzDataFrame: + space_grs = [mtz_io.get_space_group_from_mtz(mtz) for mtz in mtz_list] + space_gr = mtz_io.get_unique_space_group(*space_grs) + reciprocal_asu = get_reciprocal_asu(space_gr) + + return process_mtz_dataframe( + mtz_df=merged_mtz_dataframe, + reciprocal_asu=reciprocal_asu, + sg=space_gr, + ) + + +def test_process_merged_mtz_dataframe( + merged_mtz_dataframe: MtzDataFrame, + nmx_data_frame: NMXMtzDataFrame, +) -> None: + assert "hkl_asu" not in merged_mtz_dataframe.columns + assert "hkl_asu" in nmx_data_frame.columns + + +@pytest.fixture +def nmx_data_array(nmx_data_frame: NMXMtzDataFrame) -> NMXMtzDataArray: + return nmx_mtz_dataframe_to_scipp_dataarray(nmx_data_frame) + + +def test_to_scipp_dataarray( + nmx_data_array: NMXMtzDataArray, +) -> None: + # Check the intended modification + for indices_coord_name in ("hkl", "hkl_asu"): + assert nmx_data_array.coords[indices_coord_name].dtype == str + + assert sc.all(nmx_data_array.data > 0) diff --git a/packages/essnmx/tests/nxlauetof_io_helper_test.py b/packages/essnmx/tests/nxlauetof_io_helper_test.py new file mode 100644 index 00000000..46264f24 --- /dev/null +++ b/packages/essnmx/tests/nxlauetof_io_helper_test.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) +import pathlib +from contextlib import contextmanager + +import pytest +from scipp.testing.assertions import assert_allclose, assert_identical + +from ess.nmx._nxlauetof_io import load_essnmx_nxlauetof +from ess.nmx.configurations import InputConfig, OutputConfig, ReductionConfig +from ess.nmx.executables import reduction +from ess.nmx.types import Compression + + +@pytest.fixture +def temp_output_file(tmp_path: pathlib.Path): + output_file_path = tmp_path / "scipp_output.h5" + yield output_file_path + if output_file_path.exists(): + output_file_path.unlink() + + +@pytest.fixture +def reduction_config(temp_output_file: pathlib.Path) -> ReductionConfig: + from ess.nmx.data import get_small_nmx_nexus + + input_config = InputConfig(input_file=[get_small_nmx_nexus().as_posix()]) + output_config = OutputConfig( + output_file=temp_output_file.as_posix(), + compression=Compression.NONE, + skip_file_output=False, + ) + return ReductionConfig(inputs=input_config, output=output_config) + + +@contextmanager +def known_warnings(): + with pytest.warns(RuntimeWarning, match="No crystal rotation*"): + yield + + +def test_loaded_data_same_as_in_memory_result( + reduction_config: ReductionConfig, +) -> None: + with known_warnings(): + result = reduction(config=reduction_config) + original_result_dg = result.to_datagroup() + + # Adjust original result to be same as expected loaded data group. + original_result_dg.pop('lookup_table') + original_positions = {} + detectors = original_result_dg['instrument']['detectors'] + for det_name, det in detectors.items(): + # Removing coordinates that are not kept in the file or reconstructed. + det['data'].coords.pop('Ltotal') + det['data'].coords.pop('detector_number') + det['data'].coords.pop('x_pixel_offset') + det['data'].coords.pop('y_pixel_offset') + # Saving position coordinate to compare them by allclose instead of eq. + original_positions[det_name] = det['data'].coords.pop('position') + + loaded_dg = load_essnmx_nxlauetof(reduction_config.output.output_file) + + loaded_detector_positions = {} + for det_name, loaded_det in loaded_dg['instrument']['detectors'].items(): + loaded_detector_positions[det_name] = loaded_det['data'].coords.pop('position') + + assert_identical(loaded_dg, original_result_dg) + # Using the x_pixel_size of the first panel to get absolute tolerance. + pixel_size = next(iter(detectors.values()))['metadata']['x_pixel_size'] + atol = pixel_size / 10.0 + for det_name, original_position in original_positions.items(): + loaded_position = loaded_detector_positions[det_name] + assert_allclose(original_position, loaded_position, atol=atol) diff --git a/packages/essnmx/tests/package_test.py b/packages/essnmx/tests/package_test.py new file mode 100644 index 00000000..8e2cb7de --- /dev/null +++ b/packages/essnmx/tests/package_test.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +"""Tests of package integrity. + +Note that additional imports need to be added for repositories that +contain multiple packages. +""" + +from ess import nmx as pkg + + +def test_has_version(): + assert hasattr(pkg, '__version__') + + +# This is for CI package tests. They need to run tests with minimal dependencies, +# that is, without installing pytest. This code does not affect pytest. +if __name__ == '__main__': + test_has_version() diff --git a/packages/essnmx/tests/scaling_test.py b/packages/essnmx/tests/scaling_test.py new file mode 100644 index 00000000..219603bc --- /dev/null +++ b/packages/essnmx/tests/scaling_test.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import pytest +import scipp as sc + +from ess.nmx.scaling import ( + ReferenceIntensities, + estimate_scale_factor_per_hkl_asu_from_reference, + get_reference_intensities, + get_reference_wavelength, +) + + +@pytest.fixture +def nmx_data_array() -> sc.DataArray: + da = sc.DataArray( + data=sc.array(dims=["row"], values=[1, 2, 3, 4, 5, 3.1, 3.2]), + coords={ + "wavelength": sc.Variable(dims=["row"], values=[1, 2, 3, 4, 5, 3, 3]), + "hkl_asu": sc.array( + dims=["row"], + values=[ + "[1, 2, 3]", + "[4, 5, 6]", + "[7, 8, 9]", + "[10, 11, 12]", + "[13, 14, 15]", + "[7, 8, 9]", + "[9, 8, 7]", + ], + ), + }, + ) + da.variances = ( + sc.array(dims=["row"], values=[0.1, 0.2, 0.3, 0.4, 0.5, 0.31, 0.32]) ** 2 + ) + return da + + +def test_get_reference_bin_middle(nmx_data_array: sc.DataArray) -> None: + """Test the middle bin.""" + + binned = nmx_data_array.bin({"wavelength": 6}) + reference_wavelength = get_reference_wavelength(binned, reference_wavelength=None) + + ref_bin = get_reference_intensities( + nmx_data_array.bin({"wavelength": 6}), + reference_wavelength, + ) + selected_idx = (2, 5, 6) + assert all( + ref_bin.data.values == [nmx_data_array.data.values[idx] for idx in selected_idx] + ) + + +@pytest.fixture +def reference_bin(nmx_data_array: sc.DataArray) -> ReferenceIntensities: + binned = nmx_data_array.bin({"wavelength": 6}) + reference_wavelength = get_reference_wavelength(binned, reference_wavelength=None) + + return get_reference_intensities( + binned, + reference_wavelength, + ) + + +def test_reference_bin_scale_factor(reference_bin: ReferenceIntensities) -> None: + """Test the scale factor for I.""" + scale_factor = estimate_scale_factor_per_hkl_asu_from_reference(reference_bin) + expected_groups = [[7, 8, 9], [9, 8, 7]] + + assert len(scale_factor) == len(expected_groups) + assert scale_factor.dim == "hkl_asu" + for idx, group in enumerate(expected_groups): + assert scale_factor.coords['hkl_asu'][idx].value == str(group) diff --git a/pixi.toml b/pixi.toml index ec626813..558e0e53 100644 --- a/pixi.toml +++ b/pixi.toml @@ -27,6 +27,10 @@ essreduce = { path = "packages/essreduce", editable = true, extras = ["test"] } [feature.essimaging.pypi-dependencies] essimaging = { path = "packages/essimaging", editable = true, extras = ["test"] } +# essnmx (depends on essreduce) +[feature.essnmx.pypi-dependencies] +essnmx = { path = "packages/essnmx", editable = true, extras = ["test"] } + # ==================== Lint feature ==================== [feature.lint.pypi-dependencies] @@ -50,23 +54,29 @@ essreduce = { path = "packages/essreduce", editable = true, extras = ["test", "d [feature.docs-essimaging.pypi-dependencies] essimaging = { path = "packages/essimaging", editable = true, extras = ["test", "docs"] } +[feature.docs-essnmx.pypi-dependencies] +essnmx = { path = "packages/essnmx", editable = true, extras = ["test", "docs"] } + # ==================== Environments ==================== [environments] # Default: all packages (for full dev setup) -default = { features = ["essreduce", "essimaging"], solve-group = "default" } +default = { features = ["essreduce", "essimaging", "essnmx"], solve-group = "default" } # Per-package test environments (include workspace dep features) essreduce = { features = ["essreduce"], solve-group = "default" } essimaging = { features = ["essimaging", "essreduce"], solve-group = "default" } +essnmx = { features = ["essnmx", "essreduce"], solve-group = "default" } # Lower-bound test environments (separate resolution) lb-essreduce = { features = ["essreduce"], solve-group = "lower-bound" } lb-essimaging = { features = ["essimaging", "essreduce"], solve-group = "lower-bound" } +lb-essnmx = { features = ["essnmx", "essreduce"], solve-group = "lower-bound" } # Docs environments (package with docs extra + pandoc) docs-essreduce = { features = ["docs-essreduce", "docs"], solve-group = "default" } docs-essimaging = { features = ["docs-essimaging", "docs-essreduce", "docs"], solve-group = "default" } +docs-essnmx = { features = ["docs-essnmx", "docs-essreduce", "docs"], solve-group = "default" } # Lint environment (standalone, no package deps) lint = { features = ["lint"], solve-group = "lint" }