diff --git a/packages/essreduce/docs/user-guide/tof/dream.ipynb b/packages/essreduce/docs/user-guide/tof/dream.ipynb index 2703772c..bea0db22 100644 --- a/packages/essreduce/docs/user-guide/tof/dream.ipynb +++ b/packages/essreduce/docs/user-guide/tof/dream.ipynb @@ -26,7 +26,7 @@ "import scippnexus as snx\n", "from scippneutron.chopper import DiskChopper\n", "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName\n", - "from ess.reduce.time_of_flight import *" + "from ess.reduce.kinematics import *" ] }, { @@ -201,7 +201,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ess.reduce.time_of_flight.fakes import FakeBeamline\n", + "from ess.reduce.kinematics.fakes import FakeBeamline\n", "\n", "ess_beamline = FakeBeamline(\n", " choppers=disk_choppers,\n", @@ -299,14 +299,14 @@ "metadata": {}, "outputs": [], "source": [ - "wf = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[])\n", + "wf = GenericWavelengthWorkflow(run_types=[SampleRun], monitor_types=[])\n", "\n", "wf[RawDetector[SampleRun]] = raw_data\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "wf[NeXusDetectorName] = 'dream_detector'\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': float(\"inf\")}\n", "\n", - "wf.visualize(TofDetector[SampleRun])" + "wf.visualize(WavelengthDetector[SampleRun])" ] }, { @@ -342,14 +342,14 @@ "metadata": {}, "outputs": [], "source": [ - "lut_wf = TofLookupTableWorkflow()\n", + "lut_wf = LookupTableWorkflow()\n", "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", "lut_wf[SourcePosition] = source_position\n", "lut_wf[LtotalRange] = (\n", " sc.scalar(5.0, unit=\"m\"),\n", " sc.scalar(80.0, unit=\"m\"),\n", ")\n", - "lut_wf.visualize(TofLookupTable)" + "lut_wf.visualize(LookupTable)" ] }, { @@ -381,20 +381,19 @@ "def to_event_time_offset(sim):\n", " # Compute event_time_offset at the detector\n", " eto = (\n", - " sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", + " sim.time_of_arrival + ((lut_wf.compute(LtotalRange)[1] - sim.distance) / sim.speed).to(unit=\"us\")\n", " ) % sc.scalar(1e6 / 14.0, unit=\"us\")\n", - " # Compute time-of-flight at the detector\n", - " tof = (Ltotal / sim.speed).to(unit=\"us\")\n", + " # # Compute time-of-flight at the detector\n", + " # tof = (Ltotal / sim.speed).to(unit=\"us\")\n", " return sc.DataArray(\n", " data=sim.weight,\n", - " coords={\"wavelength\": sim.wavelength, \"event_time_offset\": eto, \"tof\": tof},\n", + " coords={\"wavelength\": sim.wavelength, \"event_time_offset\": eto},\n", " )\n", "\n", "\n", "events = to_event_time_offset(sim.readings[\"t0\"])\n", - "fig1 = events.hist(wavelength=300, event_time_offset=300).plot(norm=\"log\")\n", - "fig2 = events.hist(tof=300, event_time_offset=300).plot(norm=\"log\")\n", - "fig1 + fig2" + "fig = events.hist(wavelength=300, event_time_offset=300).plot(norm=\"log\")\n", + "fig" ] }, { @@ -414,10 +413,10 @@ "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(TofLookupTable)\n", + "table = lut_wf.compute(LookupTable)\n", "\n", "# Overlay mean on the figure above\n", - "table.array[\"distance\", -1].plot(ax=fig2.ax, color=\"C1\", ls=\"-\", marker=None)" + "table.array[\"distance\", -1].plot(ax=fig.ax, color=\"C1\", ls=\"-\", marker=None)" ] }, { @@ -456,11 +455,11 @@ "outputs": [], "source": [ "# Set the computed lookup table onto the original workflow\n", - "wf[TofLookupTable] = table\n", + "wf[LookupTable] = table\n", "\n", "# Compute time-of-flight of neutron events\n", - "tofs = wf.compute(TofDetector[SampleRun])\n", - "tofs" + "wavs = wf.compute(WavelengthDetector[SampleRun])\n", + "wavs" ] }, { @@ -478,7 +477,7 @@ "metadata": {}, "outputs": [], "source": [ - "tofs.bins.concat().hist(tof=300).plot()" + "# tofs.bins.concat().hist(tof=300).plot()" ] }, { @@ -498,17 +497,17 @@ "metadata": {}, "outputs": [], "source": [ - "from scippneutron.conversion.graph.beamline import beamline\n", - "from scippneutron.conversion.graph.tof import elastic\n", + "# from scippneutron.conversion.graph.beamline import beamline\n", + "# from scippneutron.conversion.graph.tof import elastic\n", "\n", - "# Perform coordinate transformation\n", - "graph = {**beamline(scatter=False), **elastic(\"tof\")}\n", - "wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", + "# # Perform coordinate transformation\n", + "# graph = {**beamline(scatter=False), **elastic(\"tof\")}\n", + "# wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", "\n", "# Define wavelength bin edges\n", - "wavs = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", + "edges = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", "\n", - "histogrammed = wav_wfm.hist(wavelength=wavs).squeeze()\n", + "histogrammed = wavs.hist(wavelength=edges).squeeze()\n", "histogrammed.plot()" ] }, @@ -536,7 +535,7 @@ "pp.plot(\n", " {\n", " \"wfm\": histogrammed,\n", - " \"ground_truth\": ground_truth.hist(wavelength=wavs),\n", + " \"ground_truth\": ground_truth.hist(wavelength=edges),\n", " }\n", ")" ] @@ -622,8 +621,8 @@ "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", "# Compute tofs and wavelengths\n", - "tofs = wf.compute(TofDetector[SampleRun])\n", - "wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", + "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", + "# wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", "\n", "# Compare in plot\n", "ground_truth = []\n", @@ -634,8 +633,8 @@ "figs = [\n", " pp.plot(\n", " {\n", - " \"wfm\": wav_wfm[\"detector_number\", i].bins.concat().hist(wavelength=wavs),\n", - " \"ground_truth\": ground_truth[i].hist(wavelength=wavs),\n", + " \"wfm\": wav_wfm[\"detector_number\", i].bins.concat().hist(wavelength=edges),\n", + " \"ground_truth\": ground_truth[i].hist(wavelength=edges),\n", " },\n", " title=f\"Pixel {i+1}\",\n", " )\n", @@ -747,7 +746,7 @@ "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(TofLookupTable)\n", + "table = lut_wf.compute(LookupTable)\n", "table.plot(ymin=65) / (sc.stddevs(table.array) / sc.values(table.array)).plot(norm=\"linear\", ymin=55, vmax=0.05)" ] }, @@ -771,11 +770,11 @@ "metadata": {}, "outputs": [], "source": [ - "wf[TofLookupTable] = table\n", + "wf[LookupTable] = table\n", "\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': 0.01}\n", "\n", - "masked_table = wf.compute(ErrorLimitedTofLookupTable[snx.NXdetector])\n", + "masked_table = wf.compute(ErrorLimitedLookupTable[snx.NXdetector])\n", "masked_table.plot(ymin=65)" ] }, @@ -804,9 +803,9 @@ "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", "# Compute time-of-flight\n", - "tofs = wf.compute(TofDetector[SampleRun])\n", + "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", "# Compute wavelength\n", - "wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", + "# wav_wfm = tofs.transform_coords(\"wavelength\", graph=graph)\n", "\n", "# Compare to the true wavelengths\n", "ground_truth = ess_beamline.model_result[\"detector\"].data.flatten(to=\"event\")\n", @@ -814,8 +813,8 @@ "\n", "pp.plot(\n", " {\n", - " \"wfm\": wav_wfm.hist(wavelength=wavs).squeeze(),\n", - " \"ground_truth\": ground_truth.hist(wavelength=wavs),\n", + " \"wfm\": wav_wfm.hist(wavelength=edges).squeeze(),\n", + " \"ground_truth\": ground_truth.hist(wavelength=edges),\n", " }\n", ")" ] @@ -837,7 +836,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/packages/essreduce/src/ess/reduce/__init__.py b/packages/essreduce/src/ess/reduce/__init__.py index 44e84029..1ce5144f 100644 --- a/packages/essreduce/src/ess/reduce/__init__.py +++ b/packages/essreduce/src/ess/reduce/__init__.py @@ -3,7 +3,7 @@ import importlib.metadata -from . import nexus, normalization, time_of_flight, uncertainty +from . import nexus, normalization, uncertainty, unwrap try: __version__ = importlib.metadata.version("essreduce") @@ -12,4 +12,4 @@ del importlib -__all__ = ["nexus", "normalization", "time_of_flight", "uncertainty"] +__all__ = ["nexus", "normalization", "uncertainty", "unwrap"] diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/__init__.py b/packages/essreduce/src/ess/reduce/unwrap/__init__.py similarity index 63% rename from packages/essreduce/src/ess/reduce/time_of_flight/__init__.py rename to packages/essreduce/src/ess/reduce/unwrap/__init__.py index dbb9fc49..ad6b78d3 100644 --- a/packages/essreduce/src/ess/reduce/time_of_flight/__init__.py +++ b/packages/essreduce/src/ess/reduce/unwrap/__init__.py @@ -2,15 +2,15 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) """ -Utilities for computing real neutron time-of-flight from chopper settings and +Utilities for computing neutron wavelength from chopper settings and neutron time-of-arrival at the detectors. """ from ..nexus.types import DiskChoppers -from .eto_to_tof import providers from .lut import ( BeamlineComponentReading, DistanceResolution, + LookupTableWorkflow, LtotalRange, NumberOfSimulatedNeutrons, PulsePeriod, @@ -19,35 +19,38 @@ SimulationSeed, SourcePosition, TimeResolution, - TofLookupTableWorkflow, simulate_chopper_cascade_using_tof, ) +from .to_wavelength import providers from .types import ( DetectorLtotal, - ErrorLimitedTofLookupTable, + ErrorLimitedLookupTable, + LookupTable, + LookupTableFilename, LookupTableRelativeErrorThreshold, MonitorLtotal, PulseStrideOffset, - TimeOfFlightLookupTable, - TimeOfFlightLookupTableFilename, - ToaDetector, - TofDetector, - TofLookupTable, - TofLookupTableFilename, - TofMonitor, + # ToaDetector, + # TofDetector, + # TofLookupTable, + # TofLookupTableFilename, + # TofMonitor, WavelengthDetector, WavelengthMonitor, ) -from .workflow import GenericTofWorkflow +from .workflow import GenericUnwrapWorkflow __all__ = [ "BeamlineComponentReading", "DetectorLtotal", "DiskChoppers", "DistanceResolution", - "ErrorLimitedTofLookupTable", - "GenericTofWorkflow", + "ErrorLimitedLookupTable", + "GenericUnwrapWorkflow", + "LookupTable", + "LookupTableFilename", "LookupTableRelativeErrorThreshold", + "LookupTableWorkflow", "LtotalRange", "MonitorLtotal", "NumberOfSimulatedNeutrons", @@ -57,15 +60,7 @@ "SimulationResults", "SimulationSeed", "SourcePosition", - "TimeOfFlightLookupTable", - "TimeOfFlightLookupTableFilename", "TimeResolution", - "ToaDetector", - "TofDetector", - "TofLookupTable", - "TofLookupTableFilename", - "TofLookupTableWorkflow", - "TofMonitor", "WavelengthDetector", "WavelengthMonitor", "providers", diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/fakes.py b/packages/essreduce/src/ess/reduce/unwrap/fakes.py similarity index 100% rename from packages/essreduce/src/ess/reduce/time_of_flight/fakes.py rename to packages/essreduce/src/ess/reduce/unwrap/fakes.py diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/interpolator_numba.py b/packages/essreduce/src/ess/reduce/unwrap/interpolator_numba.py similarity index 100% rename from packages/essreduce/src/ess/reduce/time_of_flight/interpolator_numba.py rename to packages/essreduce/src/ess/reduce/unwrap/interpolator_numba.py diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/interpolator_scipy.py b/packages/essreduce/src/ess/reduce/unwrap/interpolator_scipy.py similarity index 100% rename from packages/essreduce/src/ess/reduce/time_of_flight/interpolator_scipy.py rename to packages/essreduce/src/ess/reduce/unwrap/interpolator_scipy.py diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py similarity index 94% rename from packages/essreduce/src/ess/reduce/time_of_flight/lut.py rename to packages/essreduce/src/ess/reduce/unwrap/lut.py index 92078feb..dab74663 100644 --- a/packages/essreduce/src/ess/reduce/time_of_flight/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -11,7 +11,7 @@ import scipp as sc from ..nexus.types import AnyRun, DiskChoppers -from .types import TofLookupTable +from .types import LookupTable @dataclass @@ -136,7 +136,7 @@ class SimulationResults: """ -def _compute_mean_tof( +def _compute_mean_wavelength( simulation: BeamlineComponentReading, distance: sc.Variable, time_bins: sc.Variable, @@ -168,11 +168,11 @@ def _compute_mean_tof( toas = simulation.time_of_arrival + (travel_length / simulation.speed).to( unit=time_unit, copy=False ) - tofs = distance / simulation.speed + # tofs = distance / simulation.speed data = sc.DataArray( data=simulation.weight, - coords={"toa": toas, "tof": tofs.to(unit=time_unit, copy=False)}, + coords={"toa": toas, "wavelength": simulation.wavelength}, ) # Add the event_time_offset coordinate, wrapped to the frame_period @@ -189,27 +189,29 @@ def _compute_mean_tof( binned = data.bin(event_time_offset=time_bins) binned_sum = binned.bins.sum() - # Weighted mean of tof inside each bin - mean_tof = (binned.bins.data * binned.bins.coords["tof"]).bins.sum() / binned_sum - # Compute the variance of the tofs to track regions with large uncertainty + # Weighted mean of wavelength inside each bin + mean_wavelength = ( + binned.bins.data * binned.bins.coords["wavelength"] + ).bins.sum() / binned_sum + # Compute the variance of the wavelengths to track regions with large uncertainty variance = ( - binned.bins.data * (binned.bins.coords["tof"] - mean_tof) ** 2 + binned.bins.data * (binned.bins.coords["wavelength"] - mean_wavelength) ** 2 ).bins.sum() / binned_sum - mean_tof.variances = variance.values - return mean_tof + mean_wavelength.variances = variance.values + return mean_wavelength -def make_tof_lookup_table( +def make_wavelength_lookup_table( simulation: SimulationResults, ltotal_range: LtotalRange, distance_resolution: DistanceResolution, time_resolution: TimeResolution, pulse_period: PulsePeriod, pulse_stride: PulseStride, -) -> TofLookupTable: +) -> LookupTable: """ - Compute a lookup table for time-of-flight as a function of distance and + Compute a lookup table for wavelength as a function of distance and time-of-arrival. Parameters @@ -321,14 +323,14 @@ def make_tof_lookup_table( if simulation_reading is None: closest = sorted_simulation_results[-1] raise ValueError( - "Building the Tof lookup table failed: the requested position " + "Building the lookup table failed: the requested position " f"{dist.value} {dist.unit} is before the component with the lowest " "distance in the simulation. The first component in the beamline " f"has distance {closest.distance.value} {closest.distance.unit}." ) pieces.append( - _compute_mean_tof( + _compute_mean_wavelength( simulation=simulation_reading, distance=dist, time_bins=time_bins, @@ -355,7 +357,7 @@ def make_tof_lookup_table( }, ) - return TofLookupTable( + return LookupTable( array=table, pulse_period=pulse_period, pulse_stride=pulse_stride, @@ -442,13 +444,13 @@ def simulate_chopper_cascade_using_tof( return SimulationResults(readings=sim_readings, choppers=choppers) -def TofLookupTableWorkflow(): +def LookupTableWorkflow(): """ - Create a workflow for computing a time-of-flight lookup table from a + Create a workflow for computing a wavelength lookup table from a simulation of neutrons propagating through a chopper cascade. """ wf = sl.Pipeline( - (make_tof_lookup_table, simulate_chopper_cascade_using_tof), + (make_wavelength_lookup_table, simulate_chopper_cascade_using_tof), params={ PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), PulseStride: 1, diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/resample.py b/packages/essreduce/src/ess/reduce/unwrap/resample.py similarity index 100% rename from packages/essreduce/src/ess/reduce/time_of_flight/resample.py rename to packages/essreduce/src/ess/reduce/unwrap/resample.py diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/eto_to_tof.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py similarity index 69% rename from packages/essreduce/src/ess/reduce/time_of_flight/eto_to_tof.py rename to packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index 7982680b..6b228e5d 100644 --- a/packages/essreduce/src/ess/reduce/time_of_flight/eto_to_tof.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -15,7 +15,6 @@ import scippneutron as scn import scippnexus as snx from scippneutron._utils import elem_unit -from scippneutron.conversion.tof import wavelength_from_tof try: from .interpolator_numba import Interpolator as InterpolatorImpl @@ -38,23 +37,27 @@ from .resample import rebin_strictly_increasing from .types import ( DetectorLtotal, - ErrorLimitedTofLookupTable, + ErrorLimitedLookupTable, + LookupTable, LookupTableRelativeErrorThreshold, MonitorLtotal, PulseStrideOffset, - ToaDetector, - TofDetector, - TofLookupTable, - TofMonitor, WavelengthDetector, WavelengthMonitor, ) -class TofInterpolator: - def __init__(self, lookup: sc.DataArray, distance_unit: str, time_unit: str): +class WavelengthInterpolator: + def __init__( + self, + lookup: sc.DataArray, + distance_unit: str, + time_unit: str, + wavelength_unit: str = 'angstrom', + ): self._distance_unit = distance_unit self._time_unit = time_unit + self._wavelength_unit = wavelength_unit self._time_edges = ( lookup.coords["event_time_offset"] @@ -68,7 +71,7 @@ def __init__(self, lookup: sc.DataArray, distance_unit: str, time_unit: str): self._interpolator = InterpolatorImpl( time_edges=self._time_edges, distance_edges=self._distance_edges, - values=lookup.data.to(unit=self._time_unit, copy=False).values, + values=lookup.data.to(unit=self._wavelength_unit, copy=False).values, ) def __call__( @@ -100,16 +103,17 @@ def __call__( pulse_index=pulse_index.values if pulse_index is not None else None, pulse_period=pulse_period.value, ), - unit=self._time_unit, + unit=self._wavelength_unit, ) -def _time_of_flight_data_histogram( - da: sc.DataArray, lookup: ErrorLimitedTofLookupTable, ltotal: sc.Variable +def _compute_wavelength_histogram( + da: sc.DataArray, lookup: ErrorLimitedLookupTable, ltotal: sc.Variable ) -> sc.DataArray: # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, # it may be called 'tof' or 'frame_time'. - key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof", "frame_time"})) + possible_names = {"time_of_flight", "tof", "frame_time"} + key = next(iter(set(da.coords.keys()) & possible_names)) raw_eto = da.coords[key].to(dtype=float, copy=False) eto_unit = raw_eto.unit pulse_period = lookup.pulse_period.to(unit=eto_unit) @@ -126,19 +130,19 @@ def _time_of_flight_data_histogram( etos = rebinned.coords[key] # Create linear interpolator - interp = TofInterpolator( + interp = WavelengthInterpolator( lookup.array, distance_unit=ltotal.unit, time_unit=eto_unit ) - # Compute time-of-flight of the bin edges using the interpolator - tofs = interp( + # Compute wavelengths of the bin edges using the interpolator + wavs = interp( ltotal=ltotal.broadcast(sizes=etos.sizes), event_time_offset=etos, pulse_period=pulse_period, ) - return rebinned.assign_coords(tof=tofs).drop_coords( - list({key} & {"time_of_flight", "frame_time"}) + return rebinned.assign_coords(wavelength=wavs).drop_coords( + list({key} & possible_names) ) @@ -148,11 +152,11 @@ def _guess_pulse_stride_offset( event_time_offset: sc.Variable, pulse_period: sc.Variable, pulse_stride: int, - interp: TofInterpolator, + interp: WavelengthInterpolator, ) -> int: """ Using the minimum ``event_time_zero`` to calculate a reference time when computing - the time-of-flight for the neutron events makes the workflow depend on when the + the wavelength for the neutron events makes the workflow depend on when the first event was recorded. There is no straightforward way to know if we started recording at the beginning of a frame, or half-way through a frame, without looking at the chopper logs. This can be manually corrected using the pulse_stride_offset @@ -161,9 +165,9 @@ def _guess_pulse_stride_offset( Here, we perform a simple guess for the ``pulse_stride_offset`` if it is not provided. - We choose a few random events, compute the time-of-flight for every possible value + We choose a few random events, compute the wavelength for every possible value of pulse_stride_offset, and return the value that yields the least number of NaNs - in the computed time-of-flight. + in the computed wavelength. Parameters ---------- @@ -180,8 +184,8 @@ def _guess_pulse_stride_offset( interp: Interpolator for the lookup table. """ - tofs = {} - # Choose a few random events to compute the time-of-flight + wavs = {} + # Choose a few random events for which to compute the wavelength inds = np.random.choice( len(event_time_offset), min(5000, len(event_time_offset)), replace=False ) @@ -198,25 +202,25 @@ def _guess_pulse_stride_offset( ) for i in range(pulse_stride): pulse_inds = (pulse_index + i) % pulse_stride - tofs[i] = interp( + wavs[i] = interp( ltotal=ltotal, event_time_offset=etos, pulse_index=pulse_inds, pulse_period=pulse_period, ) # Find the entry in the list with the least number of nan values - return sorted(tofs, key=lambda x: sc.isnan(tofs[x]).sum())[0] + return sorted(wavs, key=lambda x: sc.isnan(wavs[x]).sum())[0] -def _prepare_tof_interpolation_inputs( +def _prepare_wavelength_interpolation_inputs( da: sc.DataArray, - lookup: ErrorLimitedTofLookupTable, + lookup: ErrorLimitedLookupTable, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> dict: """ - Prepare the inputs required for the time-of-flight interpolation. - This function is used when computing the time-of-flight for event data, and for + Prepare the inputs required for the wavelength interpolation. + This function is used when computing the wavelength for event data, and for computing the time-of-arrival for event data (as they both require guessing the pulse_stride_offset if not provided). @@ -225,8 +229,7 @@ def _prepare_tof_interpolation_inputs( da: Data array with event data. lookup: - Lookup table giving time-of-flight as a function of distance and time of - arrival. + Lookup table giving wavelength as a function of distance and time of arrival. ltotal: Total length of the flight path from the source to the detector. pulse_stride_offset: @@ -238,7 +241,7 @@ def _prepare_tof_interpolation_inputs( eto_unit = elem_unit(etos) # Create linear interpolator - interp = TofInterpolator( + interp = WavelengthInterpolator( lookup.array, distance_unit=ltotal.unit, time_unit=eto_unit ) @@ -302,21 +305,21 @@ def _prepare_tof_interpolation_inputs( } -def _time_of_flight_data_events( +def _compute_wavelength_events( da: sc.DataArray, - lookup: ErrorLimitedTofLookupTable, + lookup: ErrorLimitedLookupTable, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> sc.DataArray: - inputs = _prepare_tof_interpolation_inputs( + inputs = _prepare_wavelength_interpolation_inputs( da=da, lookup=lookup, ltotal=ltotal, pulse_stride_offset=pulse_stride_offset, ) - # Compute time-of-flight for all neutrons using the interpolator - tofs = inputs["interp"]( + # Compute wavelength for all neutrons using the interpolator + wavs = inputs["interp"]( ltotal=inputs["ltotal"], event_time_offset=inputs["eto"], pulse_index=inputs["pulse_index"], @@ -324,8 +327,8 @@ def _time_of_flight_data_events( ) parts = da.bins.constituents - parts["data"] = tofs - result = da.bins.assign_coords(tof=sc.bins(**parts, validate_indices=False)) + parts["data"] = wavs + result = da.bins.assign_coords(wavelength=sc.bins(**parts, validate_indices=False)) out = result.bins.drop_coords("event_time_offset") # The result may still have an 'event_time_zero' dimension (in the case of an @@ -363,6 +366,7 @@ def detector_ltotal_from_straight_line_approximation( gravity: Gravity vector. """ + # TODO: scatter=True should not be hard-coded here graph = { **scn.conversion.graph.beamline.beamline(scatter=True), 'source_position': lambda: source_position, @@ -403,10 +407,10 @@ def monitor_ltotal_from_straight_line_approximation( def _mask_large_uncertainty_in_lut( - table: TofLookupTable, error_threshold: float -) -> TofLookupTable: + table: LookupTable, error_threshold: float +) -> LookupTable: """ - Mask regions in the time-of-flight lookup table with large uncertainty using NaNs. + Mask regions in the lookup table with large uncertainty using NaNs. Parameters ---------- @@ -416,12 +420,10 @@ def _mask_large_uncertainty_in_lut( Threshold for the relative standard deviation (coefficient of variation) of the projected time-of-flight above which values are masked. """ - # TODO: The error threshold could be made dependent on the time-of-flight or - # distance, instead of being a single value for the whole table. da = table.array relative_error = sc.stddevs(da.data) / sc.values(da.data) mask = relative_error > sc.scalar(error_threshold) - return TofLookupTable( + return LookupTable( **{ **asdict(table), "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), @@ -430,25 +432,25 @@ def _mask_large_uncertainty_in_lut( def mask_large_uncertainty_in_lut_detector( - table: TofLookupTable, + table: LookupTable, error_threshold: LookupTableRelativeErrorThreshold, detector_name: NeXusDetectorName, -) -> ErrorLimitedTofLookupTable[snx.NXdetector]: +) -> ErrorLimitedLookupTable[snx.NXdetector]: """ - Mask regions in the time-of-flight lookup table with large uncertainty using NaNs. + Mask regions in the wavelength lookup table with large uncertainty using NaNs. Parameters ---------- table: - Lookup table with time-of-flight as a function of distance and time-of-arrival. + Lookup table with wavelength as a function of distance and time-of-arrival. error_threshold: Threshold for the relative standard deviation (coefficient of variation) of the - projected time-of-flight above which values are masked. + projected wavelength above which values are masked. detector_name: Name of the detector for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedTofLookupTable[snx.NXdetector]( + return ErrorLimitedLookupTable[snx.NXdetector]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[detector_name] ) @@ -456,42 +458,42 @@ def mask_large_uncertainty_in_lut_detector( def mask_large_uncertainty_in_lut_monitor( - table: TofLookupTable, + table: LookupTable, error_threshold: LookupTableRelativeErrorThreshold, monitor_name: NeXusName[MonitorType], -) -> ErrorLimitedTofLookupTable[MonitorType]: +) -> ErrorLimitedLookupTable[MonitorType]: """ - Mask regions in the time-of-flight lookup table with large uncertainty using NaNs. + Mask regions in the wavelength lookup table with large uncertainty using NaNs. Parameters ---------- table: - Lookup table with time-of-flight as a function of distance and time-of-arrival. + Lookup table with wavelength as a function of distance and time-of-arrival. error_threshold: Threshold for the relative standard deviation (coefficient of variation) of the - projected time-of-flight above which values are masked. + projected wavelength above which values are masked. monitor_name: Name of the monitor for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedTofLookupTable[MonitorType]( + return ErrorLimitedLookupTable[MonitorType]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[monitor_name] ) ) -def _compute_tof_data( +def _compute_wavelength_data( da: sc.DataArray, - lookup: ErrorLimitedTofLookupTable[Component], + lookup: ErrorLimitedLookupTable[Component], ltotal: sc.Variable, pulse_stride_offset: int, ) -> sc.DataArray: if da.bins is None: - data = _time_of_flight_data_histogram(da=da, lookup=lookup, ltotal=ltotal) - out = rebin_strictly_increasing(data, dim='tof') + data = _compute_wavelength_histogram(da=da, lookup=lookup, ltotal=ltotal) + out = rebin_strictly_increasing(data, dim='wavelength') else: - out = _time_of_flight_data_events( + out = _compute_wavelength_events( da=da, lookup=lookup, ltotal=ltotal, @@ -500,16 +502,16 @@ def _compute_tof_data( return out.assign_coords(Ltotal=ltotal) -def detector_time_of_flight_data( +def detector_wavelength_data( detector_data: RawDetector[RunType], - lookup: ErrorLimitedTofLookupTable[snx.NXdetector], + lookup: ErrorLimitedLookupTable[snx.NXdetector], ltotal: DetectorLtotal[RunType], pulse_stride_offset: PulseStrideOffset, -) -> TofDetector[RunType]: +) -> WavelengthDetector[RunType]: """ - Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a + Convert the time-of-arrival (event_time_offset) data to wavelength data using a lookup table. - The output data will have two new coordinates: time-of-flight and Ltotal. + The output data will have two new coordinates: wavelength and Ltotal. Parameters ---------- @@ -517,7 +519,7 @@ def detector_time_of_flight_data( Raw detector data loaded from a NeXus file, e.g., NXdetector containing NXevent_data. lookup: - Lookup table giving time-of-flight as a function of distance and time of + Lookup table giving wavelength as a function of distance and time of arrival. ltotal: Total length of the flight path from the source to the detector. @@ -525,8 +527,8 @@ def detector_time_of_flight_data( When pulse-skipping, the offset of the first pulse in the stride. This is typically zero but can be a small integer < pulse_stride. """ - return TofDetector[RunType]( - _compute_tof_data( + return WavelengthDetector[RunType]( + _compute_wavelength_data( da=detector_data, lookup=lookup, ltotal=ltotal, @@ -535,16 +537,16 @@ def detector_time_of_flight_data( ) -def monitor_time_of_flight_data( +def monitor_wavelength_data( monitor_data: RawMonitor[RunType, MonitorType], - lookup: ErrorLimitedTofLookupTable[MonitorType], + lookup: ErrorLimitedLookupTable[MonitorType], ltotal: MonitorLtotal[RunType, MonitorType], pulse_stride_offset: PulseStrideOffset, -) -> TofMonitor[RunType, MonitorType]: +) -> WavelengthMonitor[RunType, MonitorType]: """ - Convert the time-of-arrival (event_time_offset) data to time-of-flight data using a + Convert the time-of-arrival (event_time_offset) data to wavelength data using a lookup table. - The output data will have two new coordinates: time-of-flight and Ltotal. + The output data will have two new coordinates: wavelength and Ltotal. Parameters ---------- @@ -552,7 +554,7 @@ def monitor_time_of_flight_data( Raw monitor data loaded from a NeXus file, e.g., NXmonitor containing NXevent_data. lookup: - Lookup table giving time-of-flight as a function of distance and time of + Lookup table giving wavelength as a function of distance and time of arrival. ltotal: Total length of the flight path from the source to the monitor. @@ -560,8 +562,8 @@ def monitor_time_of_flight_data( When pulse-skipping, the offset of the first pulse in the stride. This is typically zero but can be a small integer < pulse_stride. """ - return TofMonitor[RunType, MonitorType]( - _compute_tof_data( + return WavelengthMonitor[RunType, MonitorType]( + _compute_wavelength_data( da=monitor_data, lookup=lookup, ltotal=ltotal, @@ -570,109 +572,13 @@ def monitor_time_of_flight_data( ) -def detector_time_of_arrival_data( - detector_data: RawDetector[RunType], - lookup: ErrorLimitedTofLookupTable[snx.NXdetector], - ltotal: DetectorLtotal[RunType], - pulse_stride_offset: PulseStrideOffset, -) -> ToaDetector[RunType]: - """ - Convert the time-of-flight data to time-of-arrival data using a lookup table. - The output data will have a time-of-arrival coordinate. - The time-of-arrival is the time since the neutron was emitted from the source. - It is basically equal to event_time_offset + pulse_index * pulse_period. - - TODO: This is not actually the 'time-of-arrival' in the strict sense, as it is - still wrapped over the frame period. We should consider unwrapping it in the future - to get the true time-of-arrival. - Or give it a different name to avoid confusion. - - Parameters - ---------- - da: - Raw detector data loaded from a NeXus file, e.g., NXdetector containing - NXevent_data. - lookup: - Lookup table giving time-of-flight as a function of distance and time of - arrival. - ltotal: - Total length of the flight path from the source to the detector. - pulse_stride_offset: - When pulse-skipping, the offset of the first pulse in the stride. This is - typically zero but can be a small integer < pulse_stride. - """ - if detector_data.bins is None: - raise NotImplementedError( - "Computing time-of-arrival in histogram mode is not implemented yet." - ) - inputs = _prepare_tof_interpolation_inputs( - da=detector_data, - lookup=lookup, - ltotal=ltotal, - pulse_stride_offset=pulse_stride_offset, - ) - parts = detector_data.bins.constituents - parts["data"] = inputs["eto"] - # The pulse index is None if pulse_stride == 1 (i.e., no pulse skipping) - if inputs["pulse_index"] is not None: - parts["data"] = parts["data"] + inputs["pulse_index"] * inputs["pulse_period"] - result = detector_data.bins.assign_coords( - toa=sc.bins(**parts, validate_indices=False) - ) - return ToaDetector[RunType](result) - - -def _tof_to_wavelength(da: sc.DataArray) -> sc.DataArray: - """ - Convert time-of-flight data to wavelength data. - - Here we assume that the input data contains a Ltotal coordinate, which is required - for the conversion. - This coordinate is assigned in the ``_compute_tof_data`` function. - """ - return da.transform_coords( - 'wavelength', graph={"wavelength": wavelength_from_tof}, keep_intermediate=False - ) - - -def detector_wavelength_data( - detector_data: TofDetector[RunType], -) -> WavelengthDetector[RunType]: - """ - Convert time-of-flight coordinate of the detector data to wavelength. - - Parameters - ---------- - da: - Detector data with time-of-flight coordinate. - """ - return WavelengthDetector[RunType](_tof_to_wavelength(detector_data)) - - -def monitor_wavelength_data( - monitor_data: TofMonitor[RunType, MonitorType], -) -> WavelengthMonitor[RunType, MonitorType]: - """ - Convert time-of-flight coordinate of the monitor data to wavelength. - - Parameters - ---------- - da: - Monitor data with time-of-flight coordinate. - """ - return WavelengthMonitor[RunType, MonitorType](_tof_to_wavelength(monitor_data)) - - def providers() -> tuple[Callable]: """ Providers of the time-of-flight workflow. """ return ( - detector_time_of_flight_data, - monitor_time_of_flight_data, detector_ltotal_from_straight_line_approximation, monitor_ltotal_from_straight_line_approximation, - detector_time_of_arrival_data, detector_wavelength_data, monitor_wavelength_data, mask_large_uncertainty_in_lut_detector, diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/types.py b/packages/essreduce/src/ess/reduce/unwrap/types.py similarity index 61% rename from packages/essreduce/src/ess/reduce/time_of_flight/types.py rename to packages/essreduce/src/ess/reduce/unwrap/types.py index 1972fa1d..86dd086c 100644 --- a/packages/essreduce/src/ess/reduce/time_of_flight/types.py +++ b/packages/essreduce/src/ess/reduce/unwrap/types.py @@ -10,22 +10,19 @@ from ..nexus.types import Component, MonitorType, RunType -TofLookupTableFilename = NewType("TofLookupTableFilename", str) -"""Filename of the time-of-flight lookup table.""" - -TimeOfFlightLookupTableFilename = TofLookupTableFilename -"""Filename of the time-of-flight lookup table (alias).""" +LookupTableFilename = NewType("LookupTableFilename", str) +"""Filename of the wavelength lookup table.""" @dataclass -class TofLookupTable: +class LookupTable: """ - Lookup table giving time-of-flight as a function of distance and time of arrival. + Lookup table giving wavelength as a function of distance and ``event_time_offset``. """ array: sc.DataArray - """The lookup table data array that maps (distance, time_of_arrival) to - time_of_flight.""" + """The lookup table data array that maps (distance, event_time_offset) to + wavelength.""" pulse_period: sc.Variable """Pulse period of the neutron source.""" pulse_stride: int @@ -33,7 +30,7 @@ class TofLookupTable: distance_resolution: sc.Variable """Resolution of the distance coordinate in the lookup table.""" time_resolution: sc.Variable - """Resolution of the time_of_arrival coordinate in the lookup table.""" + """Resolution of the event_time_offset coordinate in the lookup table.""" choppers: sc.DataGroup | None = None """Chopper parameters used when generating the lookup table, if any. This is made optional so we can still support old lookup tables without chopper info.""" @@ -47,14 +44,9 @@ def plot(self, *args, **kwargs) -> Any: return self.array.plot(*args, **kwargs) -TimeOfFlightLookupTable = TofLookupTable -"""Lookup table giving time-of-flight as a function of distance and time of arrival -(alias).""" - - -class ErrorLimitedTofLookupTable(sl.Scope[Component, TofLookupTable], TofLookupTable): +class ErrorLimitedLookupTable(sl.Scope[Component, LookupTable], LookupTable): """Lookup table that is masked with NaNs in regions where the standard deviation of - the time-of-flight is above a certain threshold.""" + the wavelength is above a certain threshold.""" PulseStrideOffset = NewType("PulseStrideOffset", int | None) @@ -66,7 +58,7 @@ class ErrorLimitedTofLookupTable(sl.Scope[Component, TofLookupTable], TofLookupT LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", dict) """ Threshold for the relative standard deviation (coefficient of variation) of the -projected time-of-flight above which values are masked. +projected wavelength above which values are masked. The threshold can be different for different beamline components (monitors, detector banks, etc.). The dictionary should have the component names as keys and the corresponding thresholds as values. @@ -91,26 +83,6 @@ class MonitorLtotal(sl.Scope[RunType, MonitorType, sc.Variable], sc.Variable): """Total path length of neutrons from source to monitor.""" -class TofDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray): - """Detector data with time-of-flight coordinate.""" - - -class ToaDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray): - """Detector data with time-of-arrival coordinate. - - When the pulse stride is 1 (i.e., no pulse skipping), the time-of-arrival is the - same as the event_time_offset. When pulse skipping is used, the time-of-arrival is - the event_time_offset + pulse_offset * pulse_period. - This means that the time-of-arrival is basically the event_time_offset wrapped - over the frame period instead of the pulse period - (where frame_period = pulse_stride * pulse_period). - """ - - -class TofMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray): - """Monitor data with time-of-flight coordinate.""" - - class WavelengthDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray): """Detector data with wavelength coordinate.""" diff --git a/packages/essreduce/src/ess/reduce/time_of_flight/workflow.py b/packages/essreduce/src/ess/reduce/unwrap/workflow.py similarity index 84% rename from packages/essreduce/src/ess/reduce/time_of_flight/workflow.py rename to packages/essreduce/src/ess/reduce/unwrap/workflow.py index 1cbd9fba..b493162f 100644 --- a/packages/essreduce/src/ess/reduce/time_of_flight/workflow.py +++ b/packages/essreduce/src/ess/reduce/unwrap/workflow.py @@ -6,12 +6,12 @@ import scipp as sc from ..nexus import GenericNeXusWorkflow -from . import eto_to_tof -from .types import PulseStrideOffset, TofLookupTable, TofLookupTableFilename +from . import to_wavelength +from .types import LookupTable, LookupTableFilename, PulseStrideOffset -def load_tof_lookup_table(filename: TofLookupTableFilename) -> TofLookupTable: - """Load a time-of-flight lookup table from an HDF5 file.""" +def load_lookup_table(filename: LookupTableFilename) -> LookupTable: + """Load a wavelength lookup table from an HDF5 file.""" table = sc.io.load_hdf5(filename) # Support old format where the metadata were stored as coordinates of the DataArray. @@ -38,19 +38,19 @@ def load_tof_lookup_table(filename: TofLookupTableFilename) -> TofLookupTable: if "error_threshold" in table: del table["error_threshold"] - return TofLookupTable(**table) + return LookupTable(**table) -def GenericTofWorkflow( +def GenericUnwrapWorkflow( *, run_types: Iterable[sciline.typing.Key], monitor_types: Iterable[sciline.typing.Key], ) -> sciline.Pipeline: """ - Generic workflow for computing the neutron time-of-flight for detector and monitor + Generic workflow for computing the neutron wavelength for detector and monitor data. - This workflow builds on the ``GenericNeXusWorkflow`` and computes time-of-flight + This workflow builds on the ``GenericNeXusWorkflow`` and computes wavelength from a lookup table that is created from the chopper settings, detector Ltotal and the neutron time-of-arrival. @@ -82,10 +82,10 @@ def GenericTofWorkflow( """ wf = GenericNeXusWorkflow(run_types=run_types, monitor_types=monitor_types) - for provider in eto_to_tof.providers(): + for provider in to_wavelength.providers(): wf.insert(provider) - wf.insert(load_tof_lookup_table) + wf.insert(load_lookup_table) # Default parameters wf[PulseStrideOffset] = None diff --git a/packages/essreduce/tests/time_of_flight/workflow_test.py b/packages/essreduce/tests/time_of_flight/workflow_test.py deleted file mode 100644 index a9848517..00000000 --- a/packages/essreduce/tests/time_of_flight/workflow_test.py +++ /dev/null @@ -1,269 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) -import numpy as np -import pytest -import sciline -import scipp as sc -import scippnexus as snx -from scipp.testing import assert_identical - -from ess.reduce import time_of_flight -from ess.reduce.nexus.types import ( - AnyRun, - DiskChoppers, - EmptyDetector, - NeXusData, - NeXusDetectorName, - Position, - RawDetector, - SampleRun, -) -from ess.reduce.time_of_flight import ( - GenericTofWorkflow, - TofLookupTableWorkflow, - fakes, -) - -sl = pytest.importorskip("sciline") - - -@pytest.fixture -def workflow() -> GenericTofWorkflow: - sizes = {'detector_number': 10} - calibrated_beamline = sc.DataArray( - data=sc.ones(sizes=sizes), - coords={ - "position": sc.spatial.as_vectors( - sc.zeros(sizes=sizes, unit='m'), - sc.zeros(sizes=sizes, unit='m'), - sc.linspace("detector_number", 79, 81, 10, unit='m'), - ), - "detector_number": sc.array( - dims=["detector_number"], values=np.arange(10), unit=None - ), - }, - ) - - events = sc.DataArray( - data=sc.ones(dims=["event"], shape=[1000]), - coords={ - "event_time_offset": sc.linspace( - "event", 0.0, 1000.0 / 14, num=1000, unit="ms" - ).to(unit="ns"), - "event_id": sc.array( - dims=["event"], values=np.arange(1000) % 10, unit=None - ), - }, - ) - nexus_data = sc.DataArray( - sc.bins( - begin=sc.array(dims=["pulse"], values=[0], unit=None), - data=events, - dim="event", - ) - ) - - wf = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[]) - wf[NeXusDetectorName] = "detector" - wf[time_of_flight.LookupTableRelativeErrorThreshold] = {'detector': np.inf} - wf[EmptyDetector[SampleRun]] = calibrated_beamline - wf[NeXusData[snx.NXdetector, SampleRun]] = nexus_data - wf[Position[snx.NXsample, SampleRun]] = sc.vector([0, 0, 77], unit='m') - wf[Position[snx.NXsource, SampleRun]] = sc.vector([0, 0, 0], unit='m') - - return wf - - -def test_TofLookupTableWorkflow_can_compute_tof_lut(): - wf = TofLookupTableWorkflow() - wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - wf[time_of_flight.SourcePosition] = fakes.source_position() - lut = wf.compute(time_of_flight.TofLookupTable) - assert lut.array is not None - assert lut.distance_resolution is not None - assert lut.time_resolution is not None - assert lut.pulse_stride is not None - assert lut.pulse_period is not None - assert lut.choppers is not None - - -@pytest.mark.parametrize("coord", ["tof", "wavelength"]) -def test_GenericTofWorkflow_with_tof_lut_from_tof_simulation(workflow, coord: str): - # Should be able to compute DetectorData without chopper and simulation params - # This contains event_time_offset (time-of-arrival). - _ = workflow.compute(RawDetector[SampleRun]) - # By default, the workflow tries to load the LUT from file - with pytest.raises(sciline.UnsatisfiedRequirement): - _ = workflow.compute(time_of_flight.TofLookupTable) - with pytest.raises(sciline.UnsatisfiedRequirement): - _ = workflow.compute(time_of_flight.TofDetector[SampleRun]) - - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - table = lut_wf.compute(time_of_flight.TofLookupTable) - - workflow[time_of_flight.TofLookupTable] = table - - if coord == "tof": - detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) - assert 'tof' in detector.bins.coords - else: - detector = workflow.compute(time_of_flight.WavelengthDetector[SampleRun]) - assert 'wavelength' in detector.bins.coords - - -@pytest.mark.parametrize("coord", ["tof", "wavelength"]) -def test_GenericTofWorkflow_with_tof_lut_from_file( - workflow, tmp_path: pytest.TempPathFactory, coord: str -): - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - lut = lut_wf.compute(time_of_flight.TofLookupTable) - lut.save_hdf5(filename=tmp_path / "lut.h5") - - workflow[time_of_flight.TofLookupTableFilename] = (tmp_path / "lut.h5").as_posix() - - loaded_lut = workflow.compute(time_of_flight.TofLookupTable) - assert_identical(lut.array, loaded_lut.array) - assert_identical(lut.pulse_period, loaded_lut.pulse_period) - assert lut.pulse_stride == loaded_lut.pulse_stride - assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) - assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert_identical(lut.choppers, loaded_lut.choppers) - - if coord == "tof": - detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) - assert 'tof' in detector.bins.coords - else: - detector = workflow.compute(time_of_flight.WavelengthDetector[SampleRun]) - assert 'wavelength' in detector.bins.coords - - -def test_GenericTofWorkflow_with_tof_lut_from_file_old_format( - workflow, tmp_path: pytest.TempPathFactory -): - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - lut = lut_wf.compute(time_of_flight.TofLookupTable) - old_lut = sc.DataArray( - data=lut.array.data, - coords={ - "distance": lut.array.coords["distance"], - "event_time_offset": lut.array.coords["event_time_offset"], - "pulse_period": lut.pulse_period, - "pulse_stride": sc.scalar(lut.pulse_stride, unit=None), - "distance_resolution": lut.distance_resolution, - "time_resolution": lut.time_resolution, - }, - ) - old_lut.save_hdf5(filename=tmp_path / "lut.h5") - - workflow[time_of_flight.TofLookupTableFilename] = (tmp_path / "lut.h5").as_posix() - loaded_lut = workflow.compute(time_of_flight.TofLookupTable) - assert_identical(lut.array, loaded_lut.array) - assert_identical(lut.pulse_period, loaded_lut.pulse_period) - assert lut.pulse_stride == loaded_lut.pulse_stride - assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) - assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert loaded_lut.choppers is None # No chopper info in old format - - detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) - assert 'tof' in detector.bins.coords - - -def test_GenericTofWorkflow_with_tof_lut_from_tof_simulation_using_alias(workflow): - # Should be able to compute DetectorData without chopper and simulation params - # This contains event_time_offset (time-of-arrival). - _ = workflow.compute(RawDetector[SampleRun]) - - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - table = lut_wf.compute(time_of_flight.TimeOfFlightLookupTable) - - workflow[time_of_flight.TimeOfFlightLookupTable] = table - # Should now be able to compute DetectorData with chopper and simulation params - detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) - assert 'tof' in detector.bins.coords - - -def test_GenericTofWorkflow_with_tof_lut_from_file_using_alias( - workflow, tmp_path: pytest.TempPathFactory -): - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(75.0, unit="m"), - sc.scalar(85.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - lut = lut_wf.compute(time_of_flight.TimeOfFlightLookupTable) - lut.save_hdf5(filename=tmp_path / "lut.h5") - - workflow[time_of_flight.TimeOfFlightLookupTableFilename] = ( - tmp_path / "lut.h5" - ).as_posix() - loaded_lut = workflow.compute(time_of_flight.TimeOfFlightLookupTable) - assert_identical(lut.array, loaded_lut.array) - assert_identical(lut.pulse_period, loaded_lut.pulse_period) - assert lut.pulse_stride == loaded_lut.pulse_stride - assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) - assert_identical(lut.time_resolution, loaded_lut.time_resolution) - assert_identical(lut.choppers, loaded_lut.choppers) - - detector = workflow.compute(time_of_flight.TofDetector[SampleRun]) - assert 'tof' in detector.bins.coords - - -@pytest.mark.parametrize("coord", ["tof", "wavelength"]) -def test_GenericTofWorkflow_assigns_Ltotal_coordinate(workflow, coord): - raw = workflow.compute(RawDetector[SampleRun]) - - assert "Ltotal" not in raw.coords - - lut_wf = TofLookupTableWorkflow() - lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 10_000 - lut_wf[time_of_flight.LtotalRange] = ( - sc.scalar(20.0, unit="m"), - sc.scalar(100.0, unit="m"), - ) - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - table = lut_wf.compute(time_of_flight.TofLookupTable) - workflow[time_of_flight.TofLookupTable] = table - - if coord == "tof": - result = workflow.compute(time_of_flight.TofDetector[SampleRun]) - else: - result = workflow.compute(time_of_flight.WavelengthDetector[SampleRun]) - - assert "Ltotal" in result.coords diff --git a/packages/essreduce/tests/time_of_flight/interpolator_test.py b/packages/essreduce/tests/unwrap/interpolator_test.py similarity index 96% rename from packages/essreduce/tests/time_of_flight/interpolator_test.py rename to packages/essreduce/tests/unwrap/interpolator_test.py index 5e1c013d..b7f14426 100644 --- a/packages/essreduce/tests/time_of_flight/interpolator_test.py +++ b/packages/essreduce/tests/unwrap/interpolator_test.py @@ -3,10 +3,10 @@ import numpy as np -from ess.reduce.time_of_flight.interpolator_numba import ( +from ess.reduce.unwrap.interpolator_numba import ( Interpolator as InterpolatorNumba, ) -from ess.reduce.time_of_flight.interpolator_scipy import ( +from ess.reduce.unwrap.interpolator_scipy import ( Interpolator as InterpolatorScipy, ) diff --git a/packages/essreduce/tests/time_of_flight/lut_test.py b/packages/essreduce/tests/unwrap/lut_test.py similarity index 69% rename from packages/essreduce/tests/time_of_flight/lut_test.py rename to packages/essreduce/tests/unwrap/lut_test.py index ee118dcb..495e57dd 100644 --- a/packages/essreduce/tests/time_of_flight/lut_test.py +++ b/packages/essreduce/tests/unwrap/lut_test.py @@ -4,30 +4,30 @@ import scipp as sc from scippneutron.chopper import DiskChopper -from ess.reduce import time_of_flight +from ess.reduce import unwrap from ess.reduce.nexus.types import AnyRun -from ess.reduce.time_of_flight import TofLookupTableWorkflow +from ess.reduce.unwrap import LookupTableWorkflow sl = pytest.importorskip("sciline") def test_lut_workflow_computes_table(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = {} - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 60 - wf[time_of_flight.PulseStride] = 1 + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = {} + wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 60 + wf[unwrap.PulseStride] = 1 lmin, lmax = sc.scalar(25.0, unit='m'), sc.scalar(35.0, unit='m') dres = sc.scalar(0.1, unit='m') tres = sc.scalar(333.0, unit='us') - wf[time_of_flight.LtotalRange] = lmin, lmax - wf[time_of_flight.DistanceResolution] = dres - wf[time_of_flight.TimeResolution] = tres + wf[unwrap.LtotalRange] = lmin, lmax + wf[unwrap.DistanceResolution] = dres + wf[unwrap.TimeResolution] = tres - table = wf.compute(time_of_flight.TofLookupTable) + table = wf.compute(unwrap.LookupTable) assert table.array.coords['distance'].min() < lmin assert table.array.coords['distance'].max() > lmax @@ -41,22 +41,22 @@ def test_lut_workflow_computes_table(): def test_lut_workflow_pulse_skipping(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = {} - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 62 - wf[time_of_flight.PulseStride] = 2 + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = {} + wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 62 + wf[unwrap.PulseStride] = 2 lmin, lmax = sc.scalar(55.0, unit='m'), sc.scalar(65.0, unit='m') dres = sc.scalar(0.1, unit='m') tres = sc.scalar(250.0, unit='us') - wf[time_of_flight.LtotalRange] = lmin, lmax - wf[time_of_flight.DistanceResolution] = dres - wf[time_of_flight.TimeResolution] = tres + wf[unwrap.LtotalRange] = lmin, lmax + wf[unwrap.DistanceResolution] = dres + wf[unwrap.TimeResolution] = tres - table = wf.compute(time_of_flight.TofLookupTable) + table = wf.compute(unwrap.LookupTable) assert table.array.coords['event_time_offset'].max() == 2 * sc.scalar( 1 / 14, unit='s' @@ -64,22 +64,22 @@ def test_lut_workflow_pulse_skipping(): def test_lut_workflow_non_exact_distance_range(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = {} - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 63 - wf[time_of_flight.PulseStride] = 1 + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = {} + wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 63 + wf[unwrap.PulseStride] = 1 lmin, lmax = sc.scalar(25.0, unit='m'), sc.scalar(35.0, unit='m') dres = sc.scalar(0.33, unit='m') tres = sc.scalar(250.0, unit='us') - wf[time_of_flight.LtotalRange] = lmin, lmax - wf[time_of_flight.DistanceResolution] = dres - wf[time_of_flight.TimeResolution] = tres + wf[unwrap.LtotalRange] = lmin, lmax + wf[unwrap.DistanceResolution] = dres + wf[unwrap.TimeResolution] = tres - table = wf.compute(time_of_flight.TofLookupTable) + table = wf.compute(unwrap.LookupTable) assert table.array.coords['distance'].min() < lmin assert table.array.coords['distance'].max() > lmax @@ -146,21 +146,21 @@ def _make_choppers(): def test_lut_workflow_computes_table_with_choppers(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = _make_choppers() - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 64 - wf[time_of_flight.PulseStride] = 1 - - wf[time_of_flight.LtotalRange] = ( + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = _make_choppers() + wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 64 + wf[unwrap.PulseStride] = 1 + + wf[unwrap.LtotalRange] = ( sc.scalar(35.0, unit='m'), sc.scalar(65.0, unit='m'), ) - wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m') - wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us') + wf[unwrap.DistanceResolution] = sc.scalar(0.1, unit='m') + wf[unwrap.TimeResolution] = sc.scalar(250.0, unit='us') - table = wf.compute(time_of_flight.TofLookupTable) + table = wf.compute(unwrap.LookupTable) # At low distance, the rays are more focussed low_dist = table.array['distance', 2] @@ -180,21 +180,21 @@ def test_lut_workflow_computes_table_with_choppers(): def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = _make_choppers() - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 64 - wf[time_of_flight.PulseStride] = 1 - - wf[time_of_flight.LtotalRange] = ( + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = _make_choppers() + wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 64 + wf[unwrap.PulseStride] = 1 + + wf[unwrap.LtotalRange] = ( sc.scalar(5.0, unit='m'), sc.scalar(65.0, unit='m'), ) - wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m') - wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us') + wf[unwrap.DistanceResolution] = sc.scalar(0.1, unit='m') + wf[unwrap.TimeResolution] = sc.scalar(250.0, unit='us') - table = wf.compute(time_of_flight.TofLookupTable) + table = wf.compute(unwrap.LookupTable) # Close to source: early times and large spread da = table.array['distance', 2] @@ -230,21 +230,21 @@ def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): def test_lut_workflow_raises_for_distance_before_source(): - wf = TofLookupTableWorkflow() - wf[time_of_flight.DiskChoppers[AnyRun]] = {} - wf[time_of_flight.SourcePosition] = sc.vector([0, 0, 10], unit='m') - wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - wf[time_of_flight.SimulationSeed] = 65 - wf[time_of_flight.PulseStride] = 1 + wf = LookupTableWorkflow() + wf[unwrap.DiskChoppers[AnyRun]] = {} + wf[unwrap.SourcePosition] = sc.vector([0, 0, 10], unit='m') + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 65 + wf[unwrap.PulseStride] = 1 # Setting the starting point at zero will make a table that would cover a range # from -0.2m to 65.0m - wf[time_of_flight.LtotalRange] = ( + wf[unwrap.LtotalRange] = ( sc.scalar(0.0, unit='m'), sc.scalar(65.0, unit='m'), ) - wf[time_of_flight.DistanceResolution] = sc.scalar(0.1, unit='m') - wf[time_of_flight.TimeResolution] = sc.scalar(250.0, unit='us') + wf[unwrap.DistanceResolution] = sc.scalar(0.1, unit='m') + wf[unwrap.TimeResolution] = sc.scalar(250.0, unit='us') - with pytest.raises(ValueError, match="Building the Tof lookup table failed"): - _ = wf.compute(time_of_flight.TofLookupTable) + with pytest.raises(ValueError, match="Building the lookup table failed"): + _ = wf.compute(unwrap.LookupTable) diff --git a/packages/essreduce/tests/time_of_flight/resample_tests.py b/packages/essreduce/tests/unwrap/resample_tests.py similarity index 99% rename from packages/essreduce/tests/time_of_flight/resample_tests.py rename to packages/essreduce/tests/unwrap/resample_tests.py index 3984d166..79aab8b3 100644 --- a/packages/essreduce/tests/time_of_flight/resample_tests.py +++ b/packages/essreduce/tests/unwrap/resample_tests.py @@ -6,7 +6,7 @@ import scipp as sc from scipp.testing import assert_identical -from ess.reduce.time_of_flight import resample +from ess.reduce.unwrap import resample class TestFindStrictlyIncreasingSections: diff --git a/packages/essreduce/tests/time_of_flight/unwrap_test.py b/packages/essreduce/tests/unwrap/unwrap_test.py similarity index 67% rename from packages/essreduce/tests/time_of_flight/unwrap_test.py rename to packages/essreduce/tests/unwrap/unwrap_test.py index 3d486ed9..708d3dc3 100644 --- a/packages/essreduce/tests/time_of_flight/unwrap_test.py +++ b/packages/essreduce/tests/unwrap/unwrap_test.py @@ -4,10 +4,8 @@ import pytest import scipp as sc from scippneutron.chopper import DiskChopper -from scippneutron.conversion.graph.beamline import beamline as beamline_graph -from scippneutron.conversion.graph.tof import elastic as elastic_graph -from ess.reduce import time_of_flight +from ess.reduce import unwrap from ess.reduce.nexus.types import ( AnyRun, FrameMonitor0, @@ -17,26 +15,19 @@ RawMonitor, SampleRun, ) -from ess.reduce.time_of_flight import ( - GenericTofWorkflow, - PulsePeriod, - TofLookupTableWorkflow, - fakes, -) +from ess.reduce.unwrap import GenericUnwrapWorkflow, LookupTableWorkflow, fakes sl = pytest.importorskip("sciline") def make_lut_workflow(choppers, neutrons, seed, pulse_stride): - lut_wf = TofLookupTableWorkflow() - lut_wf[time_of_flight.DiskChoppers[AnyRun]] = choppers - lut_wf[time_of_flight.SourcePosition] = fakes.source_position() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = neutrons - lut_wf[time_of_flight.SimulationSeed] = seed - lut_wf[time_of_flight.PulseStride] = pulse_stride - lut_wf[time_of_flight.SimulationResults] = lut_wf.compute( - time_of_flight.SimulationResults - ) + lut_wf = LookupTableWorkflow() + lut_wf[unwrap.DiskChoppers[AnyRun]] = choppers + lut_wf[unwrap.SourcePosition] = fakes.source_position() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = neutrons + lut_wf[unwrap.SimulationSeed] = seed + lut_wf[unwrap.PulseStride] = pulse_stride + lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) return lut_wf @@ -75,26 +66,26 @@ def _make_workflow_event_mode( ) mon, ref = beamline.get_monitor("detector") - pl = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) + pl = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) if detector_or_monitor == "detector": pl[NeXusDetectorName] = "detector" pl[RawDetector[SampleRun]] = mon - pl[time_of_flight.DetectorLtotal[SampleRun]] = distance + pl[unwrap.DetectorLtotal[SampleRun]] = distance else: pl[NeXusName[FrameMonitor0]] = "monitor" pl[RawMonitor[SampleRun, FrameMonitor0]] = mon - pl[time_of_flight.MonitorLtotal[SampleRun, FrameMonitor0]] = distance + pl[unwrap.MonitorLtotal[SampleRun, FrameMonitor0]] = distance - pl[time_of_flight.LookupTableRelativeErrorThreshold] = { + pl[unwrap.LookupTableRelativeErrorThreshold] = { 'detector': error_threshold, 'monitor': error_threshold, } - pl[time_of_flight.PulseStrideOffset] = pulse_stride_offset + pl[unwrap.PulseStrideOffset] = pulse_stride_offset lut_wf = lut_workflow.copy() - lut_wf[time_of_flight.LtotalRange] = distance, distance + lut_wf[unwrap.LtotalRange] = distance, distance - pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) + pl[unwrap.LookupTable] = lut_wf.compute(unwrap.LookupTable) return pl, ref @@ -116,35 +107,34 @@ def _make_workflow_histogram_mode( ).to(unit=mon.bins.coords["event_time_offset"].bins.unit) ).rename(event_time_offset=dim) - pl = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) + pl = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) if detector_or_monitor == "detector": pl[NeXusDetectorName] = "detector" pl[RawDetector[SampleRun]] = mon - pl[time_of_flight.DetectorLtotal[SampleRun]] = distance + pl[unwrap.DetectorLtotal[SampleRun]] = distance else: pl[NeXusName[FrameMonitor0]] = "monitor" pl[RawMonitor[SampleRun, FrameMonitor0]] = mon - pl[time_of_flight.MonitorLtotal[SampleRun, FrameMonitor0]] = distance + pl[unwrap.MonitorLtotal[SampleRun, FrameMonitor0]] = distance - pl[time_of_flight.LookupTableRelativeErrorThreshold] = { + pl[unwrap.LookupTableRelativeErrorThreshold] = { 'detector': error_threshold, 'monitor': error_threshold, } lut_wf = lut_workflow.copy() - lut_wf[time_of_flight.LtotalRange] = distance, distance + lut_wf[unwrap.LtotalRange] = distance, distance - pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) + pl[unwrap.LookupTable] = lut_wf.compute(unwrap.LookupTable) return pl, ref -def _validate_result_events(tofs, ref, percentile, diff_threshold, rtol): - assert "event_time_offset" not in tofs.coords +def _validate_result_events(wavs, ref, percentile, diff_threshold, rtol): + assert "event_time_offset" not in wavs.coords + assert "tof" not in wavs.coords - # Convert to wavelength - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + wavs = wavs.bins.concat().value diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -159,12 +149,13 @@ def _validate_result_events(tofs, ref, percentile, diff_threshold, rtol): assert sc.isclose(ref.data.sum(), nevents, rtol=sc.scalar(rtol)) -def _validate_result_histogram_mode(tofs, ref, percentile, diff_threshold, rtol): - assert "time_of_flight" not in tofs.coords - assert "frame_time" not in tofs.coords +def _validate_result_histogram_mode(wavs, ref, percentile, diff_threshold, rtol): + assert "tof" not in wavs.coords + assert "time_of_flight" not in wavs.coords + assert "frame_time" not in wavs.coords - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph) + # graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + # wavs = tofs.transform_coords("wavelength", graph=graph) ref = ref.hist(wavelength=wavs.coords["wavelength"]) # We divide by the maximum to avoid large relative differences at the edges of the # frames where the counts are low. @@ -172,7 +163,7 @@ def _validate_result_histogram_mode(tofs, ref, percentile, diff_threshold, rtol) assert np.nanpercentile(diff.values, percentile) < diff_threshold # Make sure that we have not lost too many events (we lose some because they may be # given a NaN tof from the lookup). - assert sc.isclose(ref.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(rtol)) + assert sc.isclose(ref.data.nansum(), wavs.data.nansum(), rtol=sc.scalar(rtol)) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) @@ -197,12 +188,12 @@ def test_unwrap_with_no_choppers(detector_or_monitor) -> None: ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=96, diff_threshold=1.0, rtol=0.02 + wavs=wavs, ref=ref, percentile=96, diff_threshold=1.0, rtol=0.02 ) @@ -225,12 +216,12 @@ def test_standard_unwrap(dist, detector_or_monitor, lut_workflow_psc_choppers) - ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 ) @@ -255,12 +246,12 @@ def test_standard_unwrap_histogram_mode( ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_histogram_mode( - tofs=tofs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 + wavs=wavs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 ) @@ -281,12 +272,12 @@ def test_pulse_skipping_unwrap( ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 ) @@ -310,12 +301,12 @@ def test_pulse_skipping_unwrap_180_phase_shift(detector_or_monitor) -> None: ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 ) @@ -335,12 +326,12 @@ def test_pulse_skipping_stride_offset_guess_gives_expected_result( ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 ) @@ -375,12 +366,12 @@ def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse( ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 ) @@ -403,18 +394,18 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( lut_wf = make_lut_workflow( choppers=choppers, neutrons=300_000, seed=1234, pulse_stride=2 ) - lut_wf[time_of_flight.LtotalRange] = distance, distance + lut_wf[unwrap.LtotalRange] = distance, distance - pl = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) + pl = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0]) # Skip first pulse = half of the first frame a = mon.group('event_time_zero')['event_time_zero', 1:] a.bins.coords['event_time_zero'] = sc.bins_like(a, a.coords['event_time_zero']) concatenated = a.bins.concat('event_time_zero') - pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) - pl[time_of_flight.PulseStrideOffset] = 1 # Start the stride at the second pulse - pl[time_of_flight.LookupTableRelativeErrorThreshold] = { + pl[unwrap.LookupTable] = lut_wf.compute(unwrap.LookupTable) + pl[unwrap.PulseStrideOffset] = 1 # Start the stride at the second pulse + pl[unwrap.LookupTableRelativeErrorThreshold] = { 'detector': np.inf, 'monitor': np.inf, } @@ -422,17 +413,17 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( if detector_or_monitor == "detector": pl[NeXusDetectorName] = "detector" pl[RawDetector[SampleRun]] = concatenated - pl[time_of_flight.DetectorLtotal[SampleRun]] = distance - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + pl[unwrap.DetectorLtotal[SampleRun]] = distance + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: pl[NeXusName[FrameMonitor0]] = "monitor" pl[RawMonitor[SampleRun, FrameMonitor0]] = concatenated - pl[time_of_flight.MonitorLtotal[SampleRun, FrameMonitor0]] = distance - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + pl[unwrap.MonitorLtotal[SampleRun, FrameMonitor0]] = distance + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) # Convert to wavelength - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + # graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = wavs.bins.concat().value # Bin the events in toa starting from the pulse period to skip the first pulse. ref = ( ref.bin( @@ -459,14 +450,14 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( # All errors should be small assert np.nanpercentile(diff.values, 100) < 0.05 # Make sure that we have not lost too many events (we lose some because they may be - # given a NaN tof from the lookup). + # given a NaN wavelength from the lookup). if detector_or_monitor == "detector": target = RawDetector[SampleRun] else: target = RawMonitor[SampleRun, FrameMonitor0] assert sc.isclose( pl.compute(target).data.nansum(), - tofs.data.nansum(), + wavs.data.nansum(), rtol=sc.scalar(1.0e-3), ) @@ -491,12 +482,12 @@ def test_pulse_skipping_stride_3(detector_or_monitor) -> None: ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 ) @@ -515,12 +506,12 @@ def test_pulse_skipping_unwrap_histogram_mode( ) if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_histogram_mode( - tofs=tofs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 + wavs=wavs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 ) @@ -548,72 +539,10 @@ def test_unwrap_int(dtype, detector_or_monitor, lut_workflow_psc_choppers) -> No pl[target] = mon if detector_or_monitor == "detector": - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) else: - tofs = pl.compute(time_of_flight.TofMonitor[SampleRun, FrameMonitor0]) + wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - tofs=tofs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 - ) - - -def test_compute_toa(): - distance = sc.scalar(80.0, unit="m") - choppers = fakes.psc_choppers() - - lut_wf = make_lut_workflow( - choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=1 - ) - - pl, _ = _make_workflow_event_mode( - distance=distance, - choppers=choppers, - lut_workflow=lut_wf, - seed=2, - pulse_stride_offset=0, - error_threshold=0.1, - detector_or_monitor="detector", - ) - - toas = pl.compute(time_of_flight.ToaDetector[SampleRun]) - - assert "toa" in toas.bins.coords - raw = pl.compute(RawDetector[SampleRun]) - assert sc.allclose(toas.bins.coords["toa"], raw.bins.coords["event_time_offset"]) - - -def test_compute_toa_pulse_skipping(): - distance = sc.scalar(100.0, unit="m") - choppers = fakes.pulse_skipping_choppers() - - lut_wf = make_lut_workflow( - choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=2 - ) - - pl, _ = _make_workflow_event_mode( - distance=distance, - choppers=choppers, - lut_workflow=lut_wf, - seed=2, - pulse_stride_offset=1, - error_threshold=0.1, - detector_or_monitor="detector", - ) - - raw = pl.compute(RawDetector[SampleRun]) - - toas = pl.compute(time_of_flight.ToaDetector[SampleRun]) - - assert "toa" in toas.bins.coords - pulse_period = lut_wf.compute(PulsePeriod) - hist = toas.bins.concat().hist( - toa=sc.array( - dims=["toa"], - values=[0, pulse_period.value, pulse_period.value * 2], - unit=pulse_period.unit, - ).to(unit=toas.bins.coords["toa"].unit) + wavs=wavs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 ) - # There should be counts in both bins - n = raw.sum().value - assert hist.data[0].value > n / 5 - assert hist.data[1].value > n / 5 diff --git a/packages/essreduce/tests/time_of_flight/wfm_test.py b/packages/essreduce/tests/unwrap/wfm_test.py similarity index 84% rename from packages/essreduce/tests/time_of_flight/wfm_test.py rename to packages/essreduce/tests/unwrap/wfm_test.py index 67432966..d150640f 100644 --- a/packages/essreduce/tests/time_of_flight/wfm_test.py +++ b/packages/essreduce/tests/unwrap/wfm_test.py @@ -5,12 +5,10 @@ import pytest import scipp as sc from scippneutron.chopper import DiskChopper -from scippneutron.conversion.graph.beamline import beamline as beamline_graph -from scippneutron.conversion.graph.tof import elastic as elastic_graph -from ess.reduce import time_of_flight +from ess.reduce import unwrap from ess.reduce.nexus.types import AnyRun, NeXusDetectorName, RawDetector, SampleRun -from ess.reduce.time_of_flight import GenericTofWorkflow, TofLookupTableWorkflow, fakes +from ess.reduce.unwrap import GenericUnwrapWorkflow, LookupTableWorkflow, fakes sl = pytest.importorskip("sciline") @@ -111,15 +109,13 @@ def dream_source_position() -> sc.Variable: @pytest.fixture(scope="module") def lut_workflow_dream_choppers() -> sl.Pipeline: - lut_wf = TofLookupTableWorkflow() - lut_wf[time_of_flight.DiskChoppers[AnyRun]] = dream_choppers() - lut_wf[time_of_flight.SourcePosition] = dream_source_position() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - lut_wf[time_of_flight.SimulationSeed] = 432 - lut_wf[time_of_flight.PulseStride] = 1 - lut_wf[time_of_flight.SimulationResults] = lut_wf.compute( - time_of_flight.SimulationResults - ) + lut_wf = LookupTableWorkflow() + lut_wf[unwrap.DiskChoppers[AnyRun]] = dream_choppers() + lut_wf[unwrap.SourcePosition] = dream_source_position() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + lut_wf[unwrap.SimulationSeed] = 432 + lut_wf[unwrap.PulseStride] = 1 + lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) return lut_wf @@ -129,16 +125,16 @@ def setup_workflow( lut_workflow: sl.Pipeline, error_threshold: float = 0.1, ) -> sl.Pipeline: - pl = GenericTofWorkflow(run_types=[SampleRun], monitor_types=[]) + pl = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[]) pl[RawDetector[SampleRun]] = raw_data - pl[time_of_flight.DetectorLtotal[SampleRun]] = ltotal + pl[unwrap.DetectorLtotal[SampleRun]] = ltotal pl[NeXusDetectorName] = "detector" - pl[time_of_flight.LookupTableRelativeErrorThreshold] = {"detector": error_threshold} + pl[unwrap.LookupTableRelativeErrorThreshold] = {"detector": error_threshold} lut_wf = lut_workflow.copy() - lut_wf[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() + lut_wf[unwrap.LtotalRange] = ltotal.min(), ltotal.max() - pl[time_of_flight.TofLookupTable] = lut_wf.compute(time_of_flight.TofLookupTable) + pl[unwrap.LookupTable] = lut_wf.compute(unwrap.LookupTable) return pl @@ -193,11 +189,7 @@ def test_dream_wfm( raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_dream_choppers ) - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) - - # Convert to wavelength - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) for da in wavs.flatten(to='pixel'): x = sc.sort(da.value, key='id') @@ -211,15 +203,13 @@ def test_dream_wfm( @pytest.fixture(scope="module") def lut_workflow_dream_choppers_time_overlap(): - lut_wf = TofLookupTableWorkflow() - lut_wf[time_of_flight.DiskChoppers[AnyRun]] = dream_choppers_with_frame_overlap() - lut_wf[time_of_flight.SourcePosition] = dream_source_position() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 100_000 - lut_wf[time_of_flight.SimulationSeed] = 432 - lut_wf[time_of_flight.PulseStride] = 1 - lut_wf[time_of_flight.SimulationResults] = lut_wf.compute( - time_of_flight.SimulationResults - ) + lut_wf = LookupTableWorkflow() + lut_wf[unwrap.DiskChoppers[AnyRun]] = dream_choppers_with_frame_overlap() + lut_wf[unwrap.SourcePosition] = dream_source_position() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + lut_wf[unwrap.SimulationSeed] = 432 + lut_wf[unwrap.PulseStride] = 1 + lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) return lut_wf @@ -280,11 +270,7 @@ def test_dream_wfm_with_subframe_time_overlap( error_threshold=0.01, ) - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) - - # Convert to wavelength - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) for da in wavs.flatten(to='pixel'): x = sc.sort(da.value, key='id') @@ -403,15 +389,13 @@ def v20_source_position(): @pytest.fixture(scope="module") def lut_workflow_v20_choppers(): - lut_wf = TofLookupTableWorkflow() - lut_wf[time_of_flight.DiskChoppers[AnyRun]] = v20_choppers() - lut_wf[time_of_flight.SourcePosition] = v20_source_position() - lut_wf[time_of_flight.NumberOfSimulatedNeutrons] = 300_000 - lut_wf[time_of_flight.SimulationSeed] = 431 - lut_wf[time_of_flight.PulseStride] = 1 - lut_wf[time_of_flight.SimulationResults] = lut_wf.compute( - time_of_flight.SimulationResults - ) + lut_wf = LookupTableWorkflow() + lut_wf[unwrap.DiskChoppers[AnyRun]] = v20_choppers() + lut_wf[unwrap.SourcePosition] = v20_source_position() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 300_000 + lut_wf[unwrap.SimulationSeed] = 431 + lut_wf[unwrap.PulseStride] = 1 + lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) return lut_wf @@ -463,11 +447,7 @@ def test_v20_compute_wavelengths_from_wfm( raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_v20_choppers ) - tofs = pl.compute(time_of_flight.TofDetector[SampleRun]) - - # Convert to wavelength - graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - wavs = tofs.transform_coords("wavelength", graph=graph) + wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) for da in wavs.flatten(to='pixel'): x = sc.sort(da.value, key='id') diff --git a/packages/essreduce/tests/unwrap/workflow_test.py b/packages/essreduce/tests/unwrap/workflow_test.py new file mode 100644 index 00000000..890166a2 --- /dev/null +++ b/packages/essreduce/tests/unwrap/workflow_test.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import numpy as np +import pytest +import sciline +import scipp as sc +import scippnexus as snx +from scipp.testing import assert_identical + +from ess.reduce import unwrap +from ess.reduce.nexus.types import ( + AnyRun, + DiskChoppers, + EmptyDetector, + NeXusData, + NeXusDetectorName, + Position, + RawDetector, + SampleRun, +) +from ess.reduce.unwrap import ( + GenericUnwrapWorkflow, + LookupTableWorkflow, + fakes, +) + +sl = pytest.importorskip("sciline") + + +@pytest.fixture +def workflow() -> GenericUnwrapWorkflow: + sizes = {'detector_number': 10} + calibrated_beamline = sc.DataArray( + data=sc.ones(sizes=sizes), + coords={ + "position": sc.spatial.as_vectors( + sc.zeros(sizes=sizes, unit='m'), + sc.zeros(sizes=sizes, unit='m'), + sc.linspace("detector_number", 79, 81, 10, unit='m'), + ), + "detector_number": sc.array( + dims=["detector_number"], values=np.arange(10), unit=None + ), + }, + ) + + events = sc.DataArray( + data=sc.ones(dims=["event"], shape=[1000]), + coords={ + "event_time_offset": sc.linspace( + "event", 0.0, 1000.0 / 14, num=1000, unit="ms" + ).to(unit="ns"), + "event_id": sc.array( + dims=["event"], values=np.arange(1000) % 10, unit=None + ), + }, + ) + nexus_data = sc.DataArray( + sc.bins( + begin=sc.array(dims=["pulse"], values=[0], unit=None), + data=events, + dim="event", + ) + ) + + wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[]) + wf[NeXusDetectorName] = "detector" + wf[unwrap.LookupTableRelativeErrorThreshold] = {'detector': np.inf} + wf[EmptyDetector[SampleRun]] = calibrated_beamline + wf[NeXusData[snx.NXdetector, SampleRun]] = nexus_data + wf[Position[snx.NXsample, SampleRun]] = sc.vector([0, 0, 77], unit='m') + wf[Position[snx.NXsource, SampleRun]] = sc.vector([0, 0, 0], unit='m') + + return wf + + +def test_LookupTableWorkflow_can_compute_lut(): + wf = LookupTableWorkflow() + wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() + wf[unwrap.NumberOfSimulatedNeutrons] = 10_000 + wf[unwrap.LtotalRange] = ( + sc.scalar(75.0, unit="m"), + sc.scalar(85.0, unit="m"), + ) + wf[unwrap.SourcePosition] = fakes.source_position() + lut = wf.compute(unwrap.LookupTable) + assert lut.array is not None + assert lut.distance_resolution is not None + assert lut.time_resolution is not None + assert lut.pulse_stride is not None + assert lut.pulse_period is not None + assert lut.choppers is not None + + +def test_GenericUnwrapWorkflow_with_lut_from_tof_simulation(workflow): + # Should be able to compute DetectorData without chopper and simulation params + # This contains event_time_offset (time-of-arrival). + _ = workflow.compute(RawDetector[SampleRun]) + # By default, the workflow tries to load the LUT from file + with pytest.raises(sciline.UnsatisfiedRequirement): + _ = workflow.compute(unwrap.LookupTable) + with pytest.raises(sciline.UnsatisfiedRequirement): + _ = workflow.compute(unwrap.WavelengthDetector[SampleRun]) + + lut_wf = LookupTableWorkflow() + lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 10_000 + lut_wf[unwrap.LtotalRange] = ( + sc.scalar(75.0, unit="m"), + sc.scalar(85.0, unit="m"), + ) + lut_wf[unwrap.SourcePosition] = fakes.source_position() + table = lut_wf.compute(unwrap.LookupTable) + + workflow[unwrap.LookupTable] = table + detector = workflow.compute(unwrap.WavelengthDetector[SampleRun]) + assert 'wavelength' in detector.bins.coords + + +def test_GenericUnwrapWorkflow_with_lut_from_file( + workflow, tmp_path: pytest.TempPathFactory +): + lut_wf = LookupTableWorkflow() + lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 10_000 + lut_wf[unwrap.LtotalRange] = ( + sc.scalar(75.0, unit="m"), + sc.scalar(85.0, unit="m"), + ) + lut_wf[unwrap.SourcePosition] = fakes.source_position() + lut = lut_wf.compute(unwrap.LookupTable) + lut.save_hdf5(filename=tmp_path / "lut.h5") + + workflow[unwrap.LookupTableFilename] = (tmp_path / "lut.h5").as_posix() + + loaded_lut = workflow.compute(unwrap.LookupTable) + assert_identical(lut.array, loaded_lut.array) + assert_identical(lut.pulse_period, loaded_lut.pulse_period) + assert lut.pulse_stride == loaded_lut.pulse_stride + assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) + assert_identical(lut.time_resolution, loaded_lut.time_resolution) + assert_identical(lut.choppers, loaded_lut.choppers) + + detector = workflow.compute(unwrap.WavelengthDetector[SampleRun]) + assert 'wavelength' in detector.bins.coords + + +def test_GenericUnwrapWorkflow_with_lut_from_file_old_format( + workflow, tmp_path: pytest.TempPathFactory +): + lut_wf = LookupTableWorkflow() + lut_wf[DiskChoppers[AnyRun]] = fakes.psc_choppers() + lut_wf[unwrap.NumberOfSimulatedNeutrons] = 10_000 + lut_wf[unwrap.LtotalRange] = ( + sc.scalar(75.0, unit="m"), + sc.scalar(85.0, unit="m"), + ) + lut_wf[unwrap.SourcePosition] = fakes.source_position() + lut = lut_wf.compute(unwrap.LookupTable) + old_lut = sc.DataArray( + data=lut.array.data, + coords={ + "distance": lut.array.coords["distance"], + "event_time_offset": lut.array.coords["event_time_offset"], + "pulse_period": lut.pulse_period, + "pulse_stride": sc.scalar(lut.pulse_stride, unit=None), + "distance_resolution": lut.distance_resolution, + "time_resolution": lut.time_resolution, + }, + ) + old_lut.save_hdf5(filename=tmp_path / "lut.h5") + + workflow[unwrap.LookupTableFilename] = (tmp_path / "lut.h5").as_posix() + loaded_lut = workflow.compute(unwrap.LookupTable) + assert_identical(lut.array, loaded_lut.array) + assert_identical(lut.pulse_period, loaded_lut.pulse_period) + assert lut.pulse_stride == loaded_lut.pulse_stride + assert_identical(lut.distance_resolution, loaded_lut.distance_resolution) + assert_identical(lut.time_resolution, loaded_lut.time_resolution) + assert loaded_lut.choppers is None # No chopper info in old format + + detector = workflow.compute(unwrap.WavelengthDetector[SampleRun]) + assert 'wavelength' in detector.bins.coords