From 9f147638edc9e168f3f1aec3fd68884c3a502c92 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 6 Jun 2025 16:15:24 -0700 Subject: [PATCH 1/4] feature: Add hdf5 reader for offline plotter --- requirements.txt | 1 + synapse/cli/offline_hdf5_plotter.py | 330 ++++++++++++++++++++++++++++ synapse/cli/offline_plot.py | 34 +-- 3 files changed, 349 insertions(+), 16 deletions(-) create mode 100644 synapse/cli/offline_hdf5_plotter.py diff --git a/requirements.txt b/requirements.txt index 33460f19..4cc5feff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ pandas>=2.2.0 protobuf>=5.29 numpy >=2.0.0 crcmod +h5py \ No newline at end of file diff --git a/synapse/cli/offline_hdf5_plotter.py b/synapse/cli/offline_hdf5_plotter.py new file mode 100644 index 00000000..ab8f5855 --- /dev/null +++ b/synapse/cli/offline_hdf5_plotter.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 + +import sys +import signal +import numpy as np +import pandas as pd +import pyqtgraph as pg +from pyqtgraph.Qt import QtWidgets, QtCore +from dataclasses import dataclass +from typing import List + +import h5py +from rich.console import Console +from rich.table import Table +from rich.progress import Progress + +BACKGROUND_COLOR = "#253252250" + + +@dataclass +class PlotData: + data: pd.DataFrame # DataFrame with samples x channels + sample_rate: float + channel_ids: List[int] + + @property + def num_samples(self) -> int: + return len(self.data) + + @property + def num_channels(self) -> int: + return len(self.data.columns) + + @property + def duration_seconds(self) -> float: + return self.num_samples / self.sample_rate + + @property + def time_array(self) -> np.ndarray: + return np.arange(self.num_samples) / self.sample_rate + + def filter_channels(self, channel_ids: List[int]) -> "PlotData": + # THey come in as a list of strings deliminated by commas + channel_ids = [int(ch) for ch in channel_ids.split(",")] + return PlotData( + data=self.data.loc[:, channel_ids], + sample_rate=self.sample_rate, + channel_ids=channel_ids, + ) + + +def compute_fft(data, sample_rate): + # Apply window function to reduce spectral leakage + window = np.hanning(len(data)) + windowed_data = data * window + + # Compute FFT + fft_values = np.fft.rfft(windowed_data) # Using rfft for real input + fft_freq = np.fft.rfftfreq(len(data), d=1 / sample_rate) + + # Convert to magnitude in dB + # Add small number to avoid log(0) + fft_magnitude_db = 20 * np.log10(np.abs(fft_values) + 1e-10) + + return fft_freq, fft_magnitude_db + + +def load_h5_data(data_file, console, time_range=None): + """Load HDF5 data and return PlotData object""" + console.print(f"Loading h5 data from {data_file}") + with h5py.File(data_file, "r") as f: + # Display file info + attributes = f.attrs + if attributes: + table = Table(title="Attributes") + table.add_column("Key") + table.add_column("Value") + for key, value in attributes.items(): + table.add_row(key, str(value)) + console.print(table) + + # Get channel information + channels = f["channels"] + channel_ids = channels["id"][:].tolist() + number_of_channels = len(channel_ids) + + sample_rate = float(attributes["sample_rate_hz"]) + console.print(f"Sample rate: {sample_rate} Hz") + console.print(f"Found {number_of_channels} channels") + + # Get frame data info + frame_data = f["frame_data"] + total_samples = len(frame_data) + samples_per_channel = total_samples // number_of_channels + + console.print( + f"Total duration: {samples_per_channel / sample_rate:.2f} seconds" + ) + + # Determine time range to load + start_index = 0 + end_index = total_samples + + if time_range: + if ":" in time_range: + start_time, end_time = map(float, time_range.split(":")) + else: + start_time, end_time = 0, float(time_range) + + start_index = int(start_time * sample_rate * number_of_channels) + end_index = int(end_time * sample_rate * number_of_channels) + console.print(f"Loading time range {start_time}s to {end_time}s") + else: + # Default: load first 10 seconds + console.print("[yellow]Loading first 10 seconds[/yellow]") + end_index = min(int(10 * sample_rate * number_of_channels), total_samples) + + # Load data subset + with console.status("Loading data...", spinner="dots"): + subset_length = end_index - start_index + actual_samples_per_channel = subset_length // number_of_channels + + data_slice = frame_data[ + start_index : start_index + + (actual_samples_per_channel * number_of_channels) + ] + reshaped_data = data_slice.reshape( + actual_samples_per_channel, number_of_channels + ) + + # Create DataFrame + df = pd.DataFrame(reshaped_data, columns=range(number_of_channels)) + + return PlotData(data=df, sample_rate=sample_rate, channel_ids=channel_ids) + + +def plot(plot_data, console): + """Create the plotting GUI for HDF5 data""" + app = QtWidgets.QApplication.instance() + if not app: + app = QtWidgets.QApplication(sys.argv) + + # Setup the window for the plot + pg.setConfigOption("background", BACKGROUND_COLOR) + + # To allow for resizing, we need to add a splitter + main_splitter = QtWidgets.QSplitter() + main_splitter.setOrientation(QtCore.Qt.Orientation.Horizontal) + + left_splitter = QtWidgets.QSplitter() + left_splitter.setOrientation(QtCore.Qt.Orientation.Vertical) + + # Add widgets so we can resize + time_plot_widget = pg.GraphicsLayoutWidget() + single_channel_plot_widget = pg.GraphicsLayoutWidget() + fft_plot_widget = pg.GraphicsLayoutWidget() + + # Main plot is all the channels + plot_all = time_plot_widget.addPlot(row=0, col=0, title="All Channels") + plot_all.setLabel("bottom", "Time (s)") + plot_all.setLabel("left", "Amplitude (counts)") + plot_all.addLegend() + plot_all.showGrid(x=True, y=True) + + # Create a list to hold the curves + curves = [] + + # Offset in counts for each channel + offset = 500 + + # Create time array + time_arr = plot_data.time_array + + # Create a curve for each channel + if len(plot_data.channel_ids) > 32: + console.print( + "[yellow] Creating curves for large datasets might take a while [/yellow]" + ) + console.print( + "[yellow] Consider using the --channels flag to limit the number of channels [/yellow]" + ) + with Progress(console=console) as progress: + task = progress.add_task("Creating curves...", total=len(plot_data.channel_ids)) + + for i, channel_id in enumerate(plot_data.channel_ids): + if i >= plot_data.num_channels: + break + + final_data = ( + plot_data.data.iloc[:, i].to_numpy().astype(np.float32) - offset * i + ) + + curve = plot_all.plot( + time_arr, + final_data, + pen=pg.intColor(i, hues=plot_data.num_channels), + name=f"Ch {channel_id}", + ) + curve.setDownsampling(auto=True) + curve.setClipToView(True) + curves.append(curve) + + progress.update(task, advance=1) + + # Create a single plot for a single channel + plot_single = single_channel_plot_widget.addPlot( + row=1, col=0, title="Single Channel" + ) + plot_single.setLabel("bottom", "Time (s)") + plot_single.setLabel("left", "Amplitude (counts)") + plot_single.showGrid(x=True, y=True) + + # Create a curve for the single channel + curve_single = plot_single.plot( + time_arr, + plot_data.data.iloc[:, 0].to_numpy(), + pen=pg.intColor(0, hues=plot_data.num_channels), + name=f"Ch {plot_data.channel_ids[0]}", + ) + curve_single.setDownsampling(auto=True) + curve_single.setClipToView(True) + + # Create an fft plot of the selected channel + fft_plot = fft_plot_widget.addPlot( + row=0, col=1, rowspan=2, title="FFT of Selected Channel" + ) + fft_plot.setLabel("bottom", "Frequency (Hz)") + fft_plot.setLabel("left", "Amplitude (dB)") + fft_plot.showGrid(x=True, y=True) + + # Splitters for the widgets + left_splitter.addWidget(time_plot_widget) + left_splitter.addWidget(single_channel_plot_widget) + main_splitter.addWidget(left_splitter) + main_splitter.addWidget(fft_plot_widget) + + # Log scale for frequency axis + fft_plot.setLogMode(x=True, y=False) + + # Enable auto-range on double click + fft_plot.autoBtn.clicked.connect(lambda: fft_plot.enableAutoRange()) + + # Function to update single channel display + def update_single_channel(channel_id): + # Get the index of the channel_id in channel_ids list + try: + channel_index = plot_data.channel_ids.index(int(channel_id)) + except ValueError: + return + + # Update time domain plot + curve_single.setData(time_arr, plot_data.data.iloc[:, channel_index].to_numpy()) + curve_single.setPen(pg.intColor(channel_index, hues=plot_data.num_channels)) + + # Update FFT plot + fft_plot.clear() + fft_freq, fft_magnitude = compute_fft( + plot_data.data.iloc[:, channel_index].to_numpy(), plot_data.sample_rate + ) + + # Plot FFT with improved visibility + curve_fft = fft_plot.plot( + fft_freq, + fft_magnitude, + pen=dict(color="w", width=2), + name=f"FFT of Ch {channel_id}", + ) + curve_fft.setClipToView(True) + + # Add grid lines + fft_plot.showGrid(x=True, y=True, alpha=0.3) + + # Auto-range on channel change + fft_plot.autoRange() + + # Initialize with first channel + update_single_channel(plot_data.channel_ids[0]) + + # Create a dropdown for channel selection + combo = QtWidgets.QComboBox() + combo.addItems([str(ch) for ch in plot_data.channel_ids]) + combo.currentIndexChanged.connect( + lambda: update_single_channel(int(combo.currentText())) + ) + combo.setFixedWidth(100) + + # Create a layout for our plot, fft, and controls + main_layout = QtWidgets.QVBoxLayout() + main_layout.addWidget(combo) + main_layout.addWidget(main_splitter) + + # And finally our main widget to show + main_widget = QtWidgets.QWidget() + main_widget.setLayout(main_layout) + main_widget.setWindowTitle("Synapsectl Data Viewer") + main_widget.resize(1280, 720) + main_widget.show() + + # Handle the case of Ctrl+C + def signal_handler(sig, frame): + print("Ctrl+C pressed. Exiting...") + QtWidgets.QApplication.quit() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + app.exec() + + +def plot_h5(args): + """Main entry point for HDF5 plotting""" + console = Console() + + # Load the data + plot_data = load_h5_data(args.data, console, args.time) + + if plot_data is None: + console.print("[red]Failed to load data[/red]") + return + + console.print( + f"[green]Loaded {plot_data.num_samples:,} samples from {plot_data.num_channels} channels[/green]" + ) + console.print(f"[green]Duration: {plot_data.duration_seconds:.2f} seconds[/green]") + + # If the user has requested specific channels, filter the data + if args.channels: + plot_data = plot_data.filter_channels(args.channels) + + # Create the plot + plot(plot_data, console) diff --git a/synapse/cli/offline_plot.py b/synapse/cli/offline_plot.py index 4d554292..dda96803 100644 --- a/synapse/cli/offline_plot.py +++ b/synapse/cli/offline_plot.py @@ -9,6 +9,10 @@ import logging import signal +from offline_hdf5_plotter import plot_h5 + +from rich.console import Console + BACKGROUND_COLOR = "#253252250" @@ -138,19 +142,6 @@ def load_config(json_path): raise ValueError("Invalid JSON: No 'kElectricalBroadband' node found") -# Function to compute FFT -# NOTE(gilbert): This is the previous implementation of the FFT -# def compute_fft(data, sample_rate): -# fft_values = np.fft.fft(data) -# fft_freq = np.fft.fftfreq(len(data), d=1 / sample_rate) -# fft_values = np.abs(fft_values)[: len(fft_values) // 2] -# fft_freq = fft_freq[: len(fft_freq) // 2] -# fft_values /= max(fft_values) -# fft_values[1:] *= 2 -# fft_values[0] = 0 -# return fft_freq, fft_values - - def compute_fft(data, sample_rate): # Apply window function to reduce spectral leakage window = np.hanning(len(data)) @@ -170,6 +161,17 @@ def compute_fft(data, sample_rate): def plot(args): logger = setup_logging() + # NOTE(gilbert): we want to support the previous plotting code but we are moving to hdf5 saving and plotting + # Short circuit for now and just use the hdf5 plotting code + _, file_extension = os.path.splitext(args.data) + if file_extension == ".h5": + return plot_h5(args) + + console = Console() + console.print( + "[yellow bold]Legacy plotting is deprecated, please use the hdf5 files going forward[/yellow bold]" + ) + app = QtWidgets.QApplication.instance() if not app: app = QtWidgets.QApplication(sys.argv) @@ -264,11 +266,11 @@ def plot(args): sys.exit(1) full_time_arr = np.arange(len(data)) / sampling_freq - if end_time is not None: # FIX: Ensure end_time=0 is valid + if end_time is not None: mask = (full_time_arr >= start_time) & (full_time_arr <= end_time) - if np.any(mask): # FIX: Ensure non-empty selection - data = data.loc[mask] # FIX: Use .loc instead of .iloc + if np.any(mask): + data = data.loc[mask] time_arr = full_time_arr[mask] logger.info( f"Plotting {len(data)} samples from {time_arr[0]:.2f}s to {time_arr[-1]:.2f}s" From feb4cd7c9c075a5ec9226e16157c073e2f39b15a Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 6 Jun 2025 16:21:14 -0700 Subject: [PATCH 2/4] Added some timing metrics --- synapse/cli/offline_hdf5_plotter.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/synapse/cli/offline_hdf5_plotter.py b/synapse/cli/offline_hdf5_plotter.py index ab8f5855..4bbfd644 100644 --- a/synapse/cli/offline_hdf5_plotter.py +++ b/synapse/cli/offline_hdf5_plotter.py @@ -8,7 +8,7 @@ from pyqtgraph.Qt import QtWidgets, QtCore from dataclasses import dataclass from typing import List - +import time import h5py from rich.console import Console from rich.table import Table @@ -40,7 +40,7 @@ def time_array(self) -> np.ndarray: return np.arange(self.num_samples) / self.sample_rate def filter_channels(self, channel_ids: List[int]) -> "PlotData": - # THey come in as a list of strings deliminated by commas + # They come in as a list of strings delimited by commas channel_ids = [int(ch) for ch in channel_ids.split(",")] return PlotData( data=self.data.loc[:, channel_ids], @@ -179,6 +179,7 @@ def plot(plot_data, console): console.print( "[yellow] Consider using the --channels flag to limit the number of channels [/yellow]" ) + start_time = time.time() with Progress(console=console) as progress: task = progress.add_task("Creating curves...", total=len(plot_data.channel_ids)) @@ -201,7 +202,10 @@ def plot(plot_data, console): curves.append(curve) progress.update(task, advance=1) - + end_time = time.time() + console.print( + f"Plotted {plot_data.num_channels} channels ({plot_data.num_samples:,} samples each channel) in {end_time - start_time:.2f} seconds" + ) # Create a single plot for a single channel plot_single = single_channel_plot_widget.addPlot( row=1, col=0, title="Single Channel" From 630bdbbeaa1f5e243ad68edb477df97f085159e6 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 6 Jun 2025 16:25:17 -0700 Subject: [PATCH 3/4] clearer plot language --- synapse/cli/offline_hdf5_plotter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/cli/offline_hdf5_plotter.py b/synapse/cli/offline_hdf5_plotter.py index 4bbfd644..4dfb6bd7 100644 --- a/synapse/cli/offline_hdf5_plotter.py +++ b/synapse/cli/offline_hdf5_plotter.py @@ -203,8 +203,9 @@ def plot(plot_data, console): progress.update(task, advance=1) end_time = time.time() + total_samples = plot_data.num_samples * plot_data.num_channels console.print( - f"Plotted {plot_data.num_channels} channels ({plot_data.num_samples:,} samples each channel) in {end_time - start_time:.2f} seconds" + f"Plotted {plot_data.num_channels} channels ({total_samples:,} total samples) in {end_time - start_time:.2f} seconds" ) # Create a single plot for a single channel plot_single = single_channel_plot_widget.addPlot( From 0c7629092ebd90356011882abae44d90bb725963 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Mon, 9 Jun 2025 15:02:54 -0700 Subject: [PATCH 4/4] Tweaks from plotting --- requirements.txt | 2 +- synapse/cli/offline_plot.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4cc5feff..e3dc3538 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,4 @@ pandas>=2.2.0 protobuf>=5.29 numpy >=2.0.0 crcmod -h5py \ No newline at end of file +h5py diff --git a/synapse/cli/offline_plot.py b/synapse/cli/offline_plot.py index dda96803..32dc6998 100644 --- a/synapse/cli/offline_plot.py +++ b/synapse/cli/offline_plot.py @@ -163,9 +163,10 @@ def plot(args): # NOTE(gilbert): we want to support the previous plotting code but we are moving to hdf5 saving and plotting # Short circuit for now and just use the hdf5 plotting code - _, file_extension = os.path.splitext(args.data) - if file_extension == ".h5": - return plot_h5(args) + if args.data is not None: + _, file_extension = os.path.splitext(args.data) + if file_extension == ".h5": + return plot_h5(args) console = Console() console.print(