Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/tof/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _add_rays(
cax: plt.Axes | None = None,
zorder: int = 1,
):
x, y = (a.reshape((-1, 2)) for a in (x, y))
coll = LineCollection(np.stack((x, y), axis=2), zorder=zorder)
if isinstance(color, str):
coll.set_color(color)
Expand Down Expand Up @@ -183,6 +182,7 @@ def plot(
rng = np.random.default_rng(seed)
# Make ids for neutrons per pulse, instead of using their id coord
ids = np.arange(self.source.neutrons)
rays = {"x": [], "y": [], "color": []}

for i in range(self._source.data.sizes["pulse"]):
component_data = furthest_component.data["pulse", i]
Expand All @@ -197,17 +197,9 @@ def plot(
replace=False,
)
x, y, c = _get_rays(components, pulse=i, inds=inds)
_add_rays(
ax=ax,
x=x,
y=y,
color=c,
cbar=cbar and (i == 0),
cmap=cmap,
vmin=wmin.value if vmin is None else vmin,
vmax=wmax.value if vmax is None else vmax,
cax=cax,
)
rays["x"].append(x)
rays["y"].append(y)
rays["color"].append(c)

# Plot blocked rays
inds = rng.choice(
Expand All @@ -226,10 +218,33 @@ def plot(
)
x[line_selection] = np.nan
y[line_selection] = np.nan
_add_rays(ax=ax, x=x, y=y, color="lightgray", zorder=-1)
_add_rays(
ax=ax,
x=x.reshape((-1, 2)),
y=y.reshape((-1, 2)),
color="lightgray",
zorder=-1,
)

# Plot pulse
self.source.plot_on_time_distance_diagram(ax, pulse=i)

# Add coloured rays in one go so that they share the same colorbar, thus
# enabling using zoom on the colorbar to select a wavelength range across all
# pulses.
if len(rays["x"]) > 0:
_add_rays(
ax=ax,
x=np.concatenate([r.reshape((-1, 2)) for r in rays["x"]], axis=0),
y=np.concatenate([r.reshape((-1, 2)) for r in rays["y"]], axis=0),
color=np.concatenate([r.ravel() for r in rays["color"]], axis=0),
cbar=cbar,
cmap=cmap,
vmin=wmin.value if vmin is None else vmin,
vmax=wmax.value if vmax is None else vmax,
cax=cax,
)

if furthest_component.toa.data.sum().value > 0:
toa_max = furthest_component.toa.max().value
else:
Expand Down
Loading