From c037084d5e4de1b2b4a527e56d724e908760e0a6 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Mon, 9 Jun 2025 15:04:06 -0700 Subject: [PATCH 1/8] feature: Better stream plotter --- synapse/cli/updated_stream_plot.py | 442 +++++++++++++++++++++++++++++ 1 file changed, 442 insertions(+) create mode 100644 synapse/cli/updated_stream_plot.py diff --git a/synapse/cli/updated_stream_plot.py b/synapse/cli/updated_stream_plot.py new file mode 100644 index 0000000..b19bef4 --- /dev/null +++ b/synapse/cli/updated_stream_plot.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 +""" +High-performance neural data visualization using raster/sweep display +Much faster than scrolling plots - similar to oscilloscope displays +""" + +import numpy as np +import pyqtgraph as pg +from pyqtgraph.Qt import QtCore, QtWidgets +import threading +from collections import deque +from synapse.client.taps import Tap +from synapse.api.datatype_pb2 import BroadbandFrame as BroadbandFrameProto + +# Assuming you have your protobuf compiled +# from your_proto_pb2 import BroadbandFrame + + +class NeuralRasterDisplay(QtWidgets.QWidget): + """Fast raster-style display for neural data""" + + def __init__(self, n_channels=5, sample_rate=32000, display_seconds=5): + super().__init__() + + self.n_channels = n_channels + self.sample_rate = sample_rate + self.display_seconds = display_seconds + + # Display parameters + self.width = 1200 # pixels + self.height = 800 # pixels + self.channel_height = self.height // n_channels + + # Calculate downsampling + total_samples = sample_rate * display_seconds # 160,000 + self.downsample = total_samples // self.width # ~133 samples per pixel + self.samples_per_pixel = self.downsample + + # Data buffer for incoming samples - larger buffer to prevent overflow + buffer_size = self.downsample * n_channels * 5 # 5x larger buffer + self.data_buffer = deque(maxlen=buffer_size) + + # Image array for display (channels x width) - black background + self.image_data = np.zeros((self.height, self.width), dtype=np.float32) + + # Current write position + self.write_pos = 0 + + # Debug counter + self.pixels_written = 0 + + print( + f"Setup: {n_channels} channels, {self.downsample} samples/pixel, buffer size: {buffer_size}" + ) + + # Setup UI + self.setup_ui() + + # Update timer - only update the sweep line + self.timer = QtCore.QTimer() + self.timer.timeout.connect(self.update_display) + self.timer.start(16) # 60 FPS + + self.setWindowTitle("Neural Data Raster Display") + self.resize(self.width + 100, self.height + 100) + + def setup_ui(self): + """Create the display widget""" + layout = QtWidgets.QVBoxLayout() + self.setLayout(layout) + + # Graphics widget + self.graphics_widget = pg.GraphicsLayoutWidget() + layout.addWidget(self.graphics_widget) + + # Single plot with image item + self.plot = self.graphics_widget.addPlot() + self.plot.setLabel("left", "Channel") + self.plot.setLabel("bottom", "Time", units="s") + + # Image item for raster display + self.img_item = pg.ImageItem() + self.plot.addItem(self.img_item) + + # Set image dimensions + self.img_item.setImage(self.image_data) # No transpose needed + + # Sweep line + self.sweep_line = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen("r", width=2)) + self.plot.addItem(self.sweep_line) + + # Channel divider lines and baselines + for i in range(1, self.n_channels): + y = i * self.channel_height + line = pg.InfiniteLine( + pos=y, + angle=0, + pen=pg.mkPen("w", width=1, style=QtCore.Qt.PenStyle.DashLine), + ) + self.plot.addItem(line) + + # Add baseline (center) line for each channel + for i in range(self.n_channels): + y = i * self.channel_height + self.channel_height // 2 + baseline = pg.InfiniteLine( + pos=y, + angle=0, + pen=pg.mkPen("gray", width=1, style=QtCore.Qt.PenStyle.DotLine), + ) + self.plot.addItem(baseline) + + # Channel labels - adjusted for inverted Y + self.plot.getAxis("left").setTicks( + [ + [ + ( + self.height + - (i * self.channel_height + self.channel_height / 2), + f"Ch {i + 1}", + ) + for i in range(self.n_channels) + ] + ] + ) + + # Time axis + time_ticks = [(i * self.width / 5, f"{i}s") for i in range(6)] + self.plot.getAxis("bottom").setTicks([time_ticks]) + + # Colormap - different color for each channel + # Each channel gets its own color range: ch0=0.0-0.2, ch1=0.2-0.4, etc. + channel_colors = [ + (255, 100, 100), # Red/pink for channel 0 + (100, 255, 100), # Green for channel 1 + (100, 150, 255), # Blue for channel 2 + (255, 255, 100), # Yellow for channel 3 + (255, 100, 255), # Magenta for channel 4 + ] + + # Create color positions and colors list + colors = [(0, 0, 0)] # Black background + positions = [0.0] + + for i in range(self.n_channels): + base_pos = 0.2 + ( + i * 0.16 + ) # Each channel gets 0.16 range (0.8 total / 5 channels) + # Add darker and brighter versions of each channel color + r, g, b = channel_colors[i % len(channel_colors)] + + # Darker version (low intensity) + colors.append((r // 3, g // 3, b // 3)) + positions.append(base_pos) + + # Medium version + colors.append((r // 2, g // 2, b // 2)) + positions.append(base_pos + 0.05) + + # Bright version (high intensity) + colors.append((r, g, b)) + positions.append(base_pos + 0.15) + + cmap = pg.ColorMap(pos=positions, color=colors) + self.img_item.setLookupTable(cmap.getLookupTable(alpha=True)) + + # Set initial black image with correct orientation + self.img_item.setImage(self.image_data.T) # Transpose for ImageItem + + # Scale the image to fit the plot + self.plot.setXRange(0, self.width) + self.plot.setYRange(0, self.height) + + # Invert Y axis so channel 1 is at top + self.plot.invertY(True) + + # Controls + control_layout = QtWidgets.QHBoxLayout() + + # Gain control + self.gain_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.gain_slider.setRange(1, 100) + self.gain_slider.setValue(50) + self.gain_label = QtWidgets.QLabel("Gain: 50") + self.gain_slider.valueChanged.connect(self.update_gain) + + control_layout.addWidget(QtWidgets.QLabel("Gain:")) + control_layout.addWidget(self.gain_slider) + control_layout.addWidget(self.gain_label) + control_layout.addStretch() + + layout.addLayout(control_layout) + + def update_gain(self, value): + """Update display gain""" + self.gain_label.setText(f"Gain: {value}") + # The gain will be applied when processing new data + + def add_samples(self, channel_data): + """Add new samples from protobuf frame + + Args: + channel_data: numpy array of shape (n_channels, n_samples) + """ + # Process each sample + for sample_idx in range(channel_data.shape[1]): + # Add sample from each channel to buffer + for ch in range(self.n_channels): + self.data_buffer.append(channel_data[ch, sample_idx]) + + # Process buffer when we have enough samples + while len(self.data_buffer) >= self.downsample * self.n_channels: + # Extract samples for one pixel column + pixel_samples = [] + for _ in range(self.downsample * self.n_channels): + pixel_samples.append(self.data_buffer.popleft()) + pixel_samples = np.array(pixel_samples) + + # Reshape to channels x samples + pixel_samples = pixel_samples.reshape(self.downsample, self.n_channels).T + + # Apply gain and center around channel middle + gain = self.gain_slider.value() / 50.0 # More sensitive gain + + # Clear the column first (black background) + self.image_data[:, self.write_pos] = 0 + + # Draw continuous waveform for each channel with different colors + for ch in range(self.n_channels): + channel_center = ch * self.channel_height + self.channel_height // 2 + channel_top = ch * self.channel_height + 5 # Leave some margin + channel_bottom = (ch + 1) * self.channel_height - 5 + + # Calculate color value range for this channel + # Channel 0: 0.2-0.35, Channel 1: 0.36-0.51, etc. + color_base = 0.2 + (ch * 0.16) + color_range = 0.15 # Range within each channel's color space + + # Scale the signal data appropriately + channel_data = pixel_samples[ch] * gain / 1500.0 # Scale for visibility + + # Use more of the channel height for better signal visibility + max_amplitude = ( + self.channel_height - 10 + ) // 2 # Use most of channel height + + # Downsample the data to create a smooth waveform representation + # Take several evenly spaced samples across the pixel column time + n_points = min(20, len(channel_data)) # Use up to 20 points per pixel + if len(channel_data) > n_points: + indices = np.linspace(0, len(channel_data) - 1, n_points, dtype=int) + channel_data = channel_data[indices] + + # Convert samples to pixel positions + y_positions = [] + for sample in channel_data: + y_offset = int(sample * max_amplitude) + y_pos = channel_center + y_offset + # Clamp to channel bounds + y_pos = max(channel_top, min(channel_bottom, y_pos)) + y_positions.append(y_pos) + + # Draw the waveform with channel-specific colors + if len(y_positions) > 1: + # Method 1: Fill between min and max (thick line effect) + min_y = min(y_positions) + max_y = max(y_positions) + + # Draw the main signal band + for y in range(min_y, max_y + 1): + # Intensity based on distance from the mean + mean_y = np.mean(y_positions) + distance_from_mean = abs(y - mean_y) + relative_intensity = max( + 0.2, 1.0 - distance_from_mean / max(1, max_y - min_y) + ) + + # Map to this channel's color range + color_value = color_base + (relative_intensity * color_range) + self.image_data[y, self.write_pos] = color_value + + # Method 2: Also draw connecting lines between consecutive points + for i in range(len(y_positions) - 1): + y1, y2 = y_positions[i], y_positions[i + 1] + # Draw line between consecutive points + start_y, end_y = min(y1, y2), max(y1, y2) + for y in range(start_y, end_y + 1): + # Higher intensity for connecting lines + color_value = color_base + (0.8 * color_range) + self.image_data[y, self.write_pos] = max( + color_value, self.image_data[y, self.write_pos] + ) + + # Highlight the actual sample points + for y_pos in y_positions[::3]: # Every 3rd point + if channel_top <= y_pos <= channel_bottom: + # Maximum intensity for sample points + color_value = color_base + color_range + self.image_data[y_pos, self.write_pos] = color_value + + elif len(y_positions) == 1: + # Single point - draw it with some thickness + y_pos = y_positions[0] + color_value = color_base + (0.6 * color_range) + for dy in range(-1, 2): + if channel_top <= y_pos + dy <= channel_bottom: + self.image_data[y_pos + dy, self.write_pos] = color_value + + # Debug - print first few pixels + if self.pixels_written < 5: + print( + f"Pixel {self.pixels_written}: processing column {self.write_pos}" + ) + + # Clear a few pixels ahead of the sweep line for visibility + clear_pos = (self.write_pos + 2) % self.width + for i in range(3): + self.image_data[:, (clear_pos + i) % self.width] = 0 + + # Move write position + self.write_pos = (self.write_pos + 1) % self.width + self.pixels_written += 1 + + # Progress update + if self.pixels_written % 200 == 0: + print( + f"Written {self.pixels_written} pixels, buffer size: {len(self.data_buffer)}" + ) + + def update_display(self): + """Update only the image and sweep line position""" + # Update image - transpose because ImageItem expects (X, Y) not (Y, X) + self.img_item.setImage(self.image_data.T, autoLevels=True) + + # Update sweep line position + self.sweep_line.setPos(self.write_pos) + + +class NeuralDataReceiver: + """Handles ZMQ reception and protobuf parsing""" + + def __init__(self, display, zmq_address="tcp://localhost:5555"): + self.display = display + self.tap = Tap("10.40.61.119", verbose=True) + self.tap.connect("broadband_source_2") + self.running = False + self.thread = None + + def start(self): + """Start receiving data""" + self.running = True + self.thread = threading.Thread(target=self._receive_loop, daemon=True) + self.thread.start() + + def stop(self): + """Stop receiving data""" + self.running = False + if self.thread: + self.thread.join(timeout=1.0) + + def _receive_loop(self): + """Main reception loop""" + # Setup ZMQ + + while self.running: + try: + # Receive raw message + for message in self.tap.stream(): + frame = BroadbandFrameProto() + frame.ParseFromString(message) + # Get the first five channels + frame_data = np.array(frame.frame_data[:5]).reshape(5, -1) + self.display.add_samples(frame_data) + + # # Parse protobuf + # # frame = BroadbandFrame() + # # frame.ParseFromString(raw_message) + + # # For testing without protobuf - simulate data + # frame_data = np.random.randn(5 * 32).astype(np.float32) * 100 + # frame_data = frame_data + np.sin(np.arange(5 * 32) * 0.1) * 200 + + # # Reshape to channels x samples + # n_samples = len(frame_data) // self.display.n_channels + # channel_data = frame_data.reshape(n_samples, self.display.n_channels).T + + # # Add to display + # self.display.add_samples(channel_data) + + except Exception as e: + print(f"Error receiving data: {e}") + + +def main(): + """Main application""" + import sys + + app = QtWidgets.QApplication.instance() + if not app: + app = QtWidgets.QApplication(sys.argv) + + # Create display + display = NeuralRasterDisplay(n_channels=5, sample_rate=32000, display_seconds=5) + display.show() + + # Create receiver (comment out if you want to test without ZMQ) + receiver = NeuralDataReceiver(display) + receiver.start() + + # For testing - simulate data + # def simulate_data(): + # while True: + # # Simulate 1ms of data (32 samples per channel) + # n_samples = 32 + # channel_data = np.zeros((5, n_samples)) + + # for ch in range(5): + # # Different frequency for each channel + # t = np.arange(n_samples) / 32000 + # channel_data[ch] = ( + # 100 * np.sin(2 * np.pi * (5 + ch * 2) * t) + + # 50 * np.random.randn(n_samples) + # ) + + # # Random spikes + # if np.random.rand() < 0.05: + # channel_data[ch, np.random.randint(n_samples)] = np.random.choice([-500, 500]) + + # display.add_samples(channel_data) + # time.sleep(0.001) # 1ms + + # sim_thread = threading.Thread(target=simulate_data, daemon=True) + # sim_thread.start() + + try: + sys.exit(app.exec()) + except KeyboardInterrupt: + # receiver.stop() + pass + + +if __name__ == "__main__": + main() From 0a5b5fb3e4fd5a4ba23c3d392097f3a55e66e12b Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Thu, 12 Jun 2025 17:27:47 -0700 Subject: [PATCH 2/8] working on this example --- synapse/cli/streaming.py | 614 +++++++++++++++------------------- synapse/cli/streaming_bkup.py | 387 +++++++++++++++++++++ 2 files changed, 660 insertions(+), 341 deletions(-) create mode 100644 synapse/cli/streaming_bkup.py diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index ae5c225..c0746c6 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -1,387 +1,319 @@ -import json +import os import queue import threading import time -import traceback -import os -import logging -from typing import Optional -from operator import itemgetter -import copy - -from google.protobuf.json_format import Parse, MessageToJson +import h5py +from datetime import datetime -from synapse.api.node_pb2 import NodeType -from synapse.api.status_pb2 import DeviceState, StatusCode -from synapse.api.device_pb2 import DeviceConfiguration import synapse as syn -import synapse.client.channel as channel -import synapse.utils.ndtp_types as ndtp_types -import synapse.cli.synapse_plotter as plotter -from synapse.utils.packet_monitor import PacketMonitor +from synapse.api.status_pb2 import DeviceState, StatusCode +from synapse.client.taps import Tap +from synapse.utils.proto import load_device_config +from synapse.api.datatype_pb2 import BroadbandFrame from rich.console import Console -from rich.live import Live -from rich.pretty import pprint -def add_commands(subparsers): - a = subparsers.add_parser("read", help="Read from a device's StreamOut node") - a.add_argument( - "--config", - type=str, - help="Configuration file", - ) - a.add_argument( - "--num_ch", type=int, help="Number of channels to read from, overrides config" - ) - a.add_argument("--bin", type=bool, help="Output binary format instead of JSON") - a.add_argument("--duration", type=int, help="Duration to read for in seconds") - a.add_argument("--node_id", type=int, help="ID of the StreamOut node to read from") - a.add_argument("--plot", action="store_true", help="Plot the data in real-time") - a.add_argument("--output", type=str, help="Name of the output directory and files") - a.add_argument( - "--overwrite", - action="store_true", - default=False, - help="Overwrite existing files", - ) - a.set_defaults(func=read) +class DiskWriter: + def __init__(self, output_dir: str, buffer_size: int = 1024 * 1024): + self.output_dir = output_dir + self.buffer_size = buffer_size + self.data_queue = queue.Queue(maxsize=1000) # Prevent unbounded memory growth + self.stop_event = threading.Event() + self.writer_thread = None + def start(self): + """Start the writer thread""" + self.writer_thread = threading.Thread(target=self._write_loop) + self.writer_thread.start() -def load_config_from_file(path): - with open(path, "r") as f: - data = f.read() - proto = Parse(data, DeviceConfiguration()) - return syn.Config.from_proto(proto) + def stop(self): + """Stop the writer thread and wait for it to finish""" + self.stop_event.set() + if self.writer_thread: + self.writer_thread.join() + def put(self, data: BroadbandFrame): + """Add data to the write queue""" + try: + self.data_queue.put(data, block=False) + except queue.Full: + # If queue is full, we'll drop the oldest data + try: + self.data_queue.get_nowait() + self.data_queue.put(data, block=False) + except queue.Empty: + pass + + def _write_loop(self): + """Main writing loop that consumes data from the queue""" + filename = os.path.join(self.output_dir, f"data_{int(time.time())}.dat") + with open(filename, "wb", buffering=self.buffer_size) as f: + while not self.stop_event.is_set() or not self.data_queue.empty(): + try: + data = self.data_queue.get(timeout=1) + # Write binary data directly + print(data) + except queue.Empty: + continue + except Exception as e: + print(f"Error writing data: {e}") + continue + + +class BroadbandFrameWriter: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.data_queue = queue.Queue(maxsize=1000) + self.stop_event = threading.Event() + self.writer_thread = None + + # Create HDF5 file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.filename = os.path.join(output_dir, f"broadband_data_{timestamp}.h5") + self.file = h5py.File(self.filename, "w") + + # Create datasets + self.timestamp_dataset = self.file.create_dataset( + "/acquisition/timestamp", shape=(0,), maxshape=(None,), dtype="uint64" + ) + self.sequence_dataset = self.file.create_dataset( + "/acquisition/sequence_number", shape=(0,), maxshape=(None,), dtype="uint64" + ) + # Create frame data dataset with correct shape for 256 samples + self.frame_data_dataset = self.file.create_dataset( + "/acquisition/ElectricalSeries", + shape=(0, 256), # Each frame has 256 samples + maxshape=(None, 256), + dtype="int32", + ) -def read(args): - console = Console() - if not args.config and not args.node_id: - console.print("[bold red]Either `--config` or `--node_id` must be specified.") - return + def set_attributes( + self, sample_rate_hz: float, channels: list, session_description: str = "" + ): + """Set HDF5 attributes similar to C++ implementation""" + # Set basic attributes + self.file.attrs["sample_rate_hz"] = sample_rate_hz + if session_description: + self.file.attrs["session_description"] = session_description + + # Set session start time + self.file.attrs["session_start_time"] = datetime.now().isoformat() + + # Set device type + device_group = self.file.create_group("general/device") + device_group.attrs["device_type"] = "SciFi" + + # Create electrodes group and write channel IDs + electrodes_group = self.file.create_group( + "general/extracellular_ephys/electrodes" + ) + channel_ids = channels + electrodes_group.create_dataset("id", data=channel_ids, dtype="uint32") + + def start(self): + """Start the writer thread""" + self.writer_thread = threading.Thread(target=self._write_loop) + self.writer_thread.start() + + def stop(self): + """Stop the writer thread and wait for it to finish""" + self.stop_event.set() + if self.writer_thread: + self.writer_thread.join() + self.flush() + self.file.close() + + def put(self, frame: BroadbandFrame): + """Add frame to the write queue""" + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, we'll drop the oldest data + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + + def _write_loop(self): + """Main writing loop that consumes data from the queue""" + while not self.stop_event.is_set() or not self.data_queue.empty(): + try: + frame = self.data_queue.get(timeout=1) + + # Resize datasets + current_size = self.timestamp_dataset.shape[0] + new_size = current_size + 1 + + self.timestamp_dataset.resize(new_size, axis=0) + self.sequence_dataset.resize(new_size, axis=0) + self.frame_data_dataset.resize(new_size, axis=0) + + # Write data + self.timestamp_dataset[current_size] = frame.timestamp_ns + self.sequence_dataset[current_size] = frame.sequence_number + # Write frame data as a row in the 2D dataset + self.frame_data_dataset[current_size, :] = frame.frame_data + + # Flush periodically + if new_size % 1000 == 0: + print(f"Flushed {new_size} frames") + self.flush() + + except queue.Empty: + continue + except Exception as e: + print(f"Error writing data: {e}") + continue - output_base = args.output - timestamp = time.strftime("%Y%m%d-%H%M%S") - if not output_base: - output_base = f"synapse_data_{timestamp}" - else: - output_base = f"{output_base}_{timestamp}" + def flush(self): + """Flush all datasets to disk""" + self.timestamp_dataset.flush() + self.sequence_dataset.flush() + self.frame_data_dataset.flush() + self.file.flush() - # Check if the output directory exists, we will make the directory later after we know the config - if os.path.exists(output_base): - if not args.overwrite: - console.print( - f"[bold red]Output directory {output_base} already exists, please specify a different output directory or use `--overwrite` to overwrite existing files" - ) - return - else: - console.print(f"[bold yellow]Overwriting existing files in {output_base}") - device = syn.Device(args.uri, args.verbose) - with console.status( - "Reading from Synapse Device", spinner="bouncingBall", spinner_style="green" - ) as status: - status.update("Requesting device info") - info = device.info() - if not info: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return - console.log(f"Got info from: {info.name}") - if args.verbose: - pprint(info) - console.print("\n") +def add_commands(subparsers): + read_parser = subparsers.add_parser( + "read", help="Read from a device's Broadband Tap" + ) - status.update("Loading recording configuration") + read_parser.add_argument( + "config", type=str, help="Device configuration or manifest file" + ) - if args.config: - config = load_config_from_file(args.config) - if not config: - console.print(f"[bold red]Failed to load config from {args.config}") - return - stream_out = next( - (n for n in config.nodes if n.type == NodeType.kStreamOut), None - ) - if not stream_out: - console.print("[bold red]No StreamOut node found in config") - return - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - if not broadband: - console.print("[bold red]No BroadbandSource node found in config") - return - signal = broadband.signal - if not signal: - console.print("[bold red]No signal configured for BroadbandSource node") - return + # Output options, we will save as HDF5 + read_parser.add_argument("--output", type=str, help="Output directory") + read_parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing files" + ) - if not signal.electrode: - console.print( - "[bold red]No electrode signal configured for BroadbandSource node" + read_parser.set_defaults(func=read) + + +def configure_device(device, config, console): + with console.status("Configuring device...", spinner="bouncingBall"): + # check if we are running + info = device.info() + if info.status.state == DeviceState.kRunning: + console.log( + "[bold yellow]Device is already running, reading from existing tap[/bold yellow]" ) - return + return True - num_ch = len(signal.electrode.channels) - if args.num_ch: - num_ch = args.num_ch - offset = 0 - channels = [] - for ch in range(offset, offset + num_ch): - channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) - - broadband.signal.electrode.channels = channels - - with console.status( - "Configuring device", spinner="bouncingBall", spinner_style="green" - ) as status: - configure_status = device.configure_with_status(config) - if configure_status is None: - console.print( - "[bold red]Failed to configure device. Run with `--verbose` for more information." - ) - return - if configure_status.code == StatusCode.kInvalidConfiguration: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print("[yellow]Is there a peripheral connected to the device?") - return - elif configure_status.code == StatusCode.kFailedPrecondition: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print( - f"[yellow]If the device is already running, run `synapsectl stop {args.uri}` to stop the device and try again." - ) - return - console.print("[bold green]Device configured successfully") - - if info.status.state != DeviceState.kRunning: - print("Starting device...") - if not device.start(): - raise ValueError("Failed to start device") - - # Get the sample rate from the device - # We need to look at the node configuration with type kBroadbandSource for the sample rate - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - assert broadband is not None, "No BroadbandSource node found in config" - - else: - # TODO(gilbert): Get rid of this giant if-else block - node = next( - ( - n - for n in info.configuration.nodes - if n.type == NodeType.kStreamOut and n.id == args.node_id - ), - None, - ) - if node is None: + # Apply the configuration to the device + configure_status = device.configure_with_status(config) + if configure_status.code != StatusCode.kOk: console.print( - "[bold red]No StreamOut node found in device configuration; please configure the device with a StreamOut node." + f"[bold red]Failed to configure device: {configure_status.message}[/bold red]" ) - return + return False + console.log("[green]Configured device[/green]") - stream_out = syn.StreamOut.from_proto(node) - stream_out.device = device + return True - # We are ready to start streaming, make the output directory - os.makedirs(output_base, exist_ok=True) - # Copy our config that was taken from the device to the output directory - device_info_after_config = device.info() - if not device_info_after_config: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return - runtime_config = device_info_after_config.configuration - runtime_config_json = MessageToJson( - runtime_config, - always_print_fields_with_no_presence=True, - preserving_proto_field_name=True, - ) - output_config_path = os.path.join(output_base, "runtime_config.json") - with open(output_config_path, "w") as f: - f.write(runtime_config_json) +def start_device(device, console): + info = device.info() + if info.status.state == DeviceState.kRunning: + return True - console.print(f"[bold green]Streaming data to {output_base}") + with console.status("Starting device...", spinner="bouncingBall"): + start_status = device.start_with_status() + if start_status.code != StatusCode.kOk: + console.print( + f"[bold red]Failed to start device: {start_status.message}[/bold red]" + ) + return False + return True - status_title = ( - f"Streaming data for {args.duration} seconds" - if args.duration - else "Streaming data indefinitely" - ) - console.print(status_title) - q = queue.Queue() - plot_q = queue.Queue() if args.plot else None +def setup_output(args, console): + if not args.output: + console.print("[bold red]No output directory specified[/bold red]") + return False - threads = [] - stop = threading.Event() - if args.bin: - threads.append( - threading.Thread(target=_binary_writer, args=(stop, q, num_ch, output_base)) - ) - else: - threads.append( - threading.Thread(target=_data_writer, args=(stop, q, output_base)) - ) + # Create the output directory if it doesn't exist + os.makedirs(args.output, exist_ok=True) + return True - if args.plot: - threads.append( - threading.Thread(target=_plot_data, args=(stop, plot_q, runtime_config)) - ) - for thread in threads: - thread.start() +def get_broadband_tap(args, device, console): + read_tap = Tap(args.uri, args.verbose) + taps = read_tap.list_taps() - try: - read_packets(stream_out, q, plot_q, stop, args.duration) - except KeyboardInterrupt: - pass - finally: - console.print("Stopping read...") - stop.set() - for thread in threads: - thread.join() - - if args.config: - console.print("Stopping device...") - if not device.stop(): - console.print("[red]Failed to stop device") - console.print("Stopped") - - console.print("[bold green]Streaming complete") - console.print("[cyan]================") - console.print(f"[cyan]Output directory: {output_base}/") - console.print(f"[cyan]Run `synapsectl plot --dir {output_base}/` to plot the data") - console.print("[cyan]================") - - -def read_packets( - node: syn.StreamOut, - q: queue.Queue, - plot_q: queue.Queue, - stop: threading.Event, - duration: Optional[int] = None, - num_ch: int = 32, -): - start = time.time() - - # Keep track of our statistics - monitor = PacketMonitor() - monitor.start_monitoring() - - with Live(monitor.generate_stat_table(), refresh_per_second=4) as live: - while not stop.is_set(): - read_ret = node.read() - if read_ret is None: - logging.error("Could not get a valid read from the node") - continue + # Just get the first tap that has BroadbandFrame as the type + for t in taps: + if "BroadbandFrame" in t.message_type: + read_tap.connect(t.name) + return read_tap - synapse_data, bytes_read = read_ret - if synapse_data is None or bytes_read == 0: - logging.error("Could not read data from node") - continue - header, data = synapse_data - monitor.process_packet(header, data, bytes_read) - live.update(monitor.generate_stat_table()) + console.print("[bold red]No BroadbandFrame tap found[/bold red]") + return None - # Always add the data to the writer queues - q.put(data) - if plot_q: - plot_q.put(copy.deepcopy(data)) - if duration and (time.time() - start) > duration: - break +def read(args): + console = Console() + # Make sure we can actually get to this device + try: + config = load_device_config(args.config, console) + except Exception as e: + console.print(f"[bold red]Failed to load device configuration: {e}[/bold red]") + return -def _binary_writer(stop, q: queue.Queue, num_ch, output_base): - filename = f"{output_base}.dat" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") + # Create the device object + device = syn.Device(args.uri, args.verbose) + device_name = device.get_name() + console.log(f"[green]Connected to {device_name}[/green]") - channel_data = [] - while not stop.is_set() or not q.empty(): - try: - data: ndtp_types.ElectricalBroadbandData = q.get(True, 1) - except queue.Empty: - continue + # Apply the configuration to the device + if not configure_device(device, config, console): + console.print("[bold red]Failed to configure device[/bold red]") + return - try: - for ch_id, samples in data.samples: - channel_data.append([ch_id, samples]) - if len(channel_data) == num_ch: - channel_data.sort(key=itemgetter(0)) - channel_samples = [ch_data[1] for ch_data in channel_data] - frames = list(zip(*channel_samples)) - channel_data = [] - - for frame in frames: - for sample in frame: - fd.write( - int(sample).to_bytes(2, byteorder="little", signed=True) - ) - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _data_writer(stop, q, output_base): - filename = f"{output_base}.jsonl" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") - - while not stop.is_set() or not q.empty(): - try: - data = q.get(True, 1) - except queue.Empty: - continue + # If we got this far and they want to save things, we need to make sure they have a place to save + if args.output: + if not setup_output(args, console): + console.print("[bold red]Failed to setup output[/bold red]") + return - try: - fd.write(json.dumps(data.to_list()).encode("utf-8")) - fd.write(b"\n") - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _plot_data(stop, q, runtime_config): - """Plot streaming data from the synapse device.""" - # TODO(gilbert): Make these configurable - window_size_seconds = 3 - - # Find the first broadband source node - broadband_nodes = [ - node for node in runtime_config.nodes if node.type == NodeType.kBroadbandSource - ] - - # Make sure we have a broadband node - # TODO(gilbert): We should be able to support binned spikes here too - # but will need a refactor - if not broadband_nodes: - print("Could not find broadband source config. Cannot plot") + # Start the device + if not start_device(device, console): + console.print("[bold red]Failed to start device[/bold red]") return - broadband_source = broadband_nodes[0].broadband_source - electrode_config = broadband_source.signal.electrode - - if not electrode_config: - print( - "Could not find an electrode configuration for broadband node. Cannot plot" - ) + # With the device running, get the tap for us to connect to + broadband_tap = get_broadband_tap(args, device, console) + if not broadband_tap: + console.print("[bold red]Failed to get broadband tap[/bold red]") return - # Get configuration parameters - sample_rate_hz = broadband_source.sample_rate_hz - channel_ids = [ch.id for ch in electrode_config.channels] + # Setup our HDF5 writer if output is requested + writer = None + if args.output: + writer = BroadbandFrameWriter(args.output) + # Get sample rate and channels from config + # broadband_node = next((n for n in config.nodes if n.type == NodeType.kBroadbandSource), None) + writer.set_attributes(sample_rate_hz=32000, channels=list(range(32))) - # Start the plotter - plotter.plot_synapse_data(stop, q, sample_rate_hz, window_size_seconds, channel_ids) + writer.start() + + try: + # Now we need to start the streaming + frame = BroadbandFrame() + with console.status("Streaming data...", spinner="bouncingBall"): + for message in broadband_tap.stream(): + frame.ParseFromString(message) + if writer: + writer.put(frame) + else: + print(frame) + except KeyboardInterrupt: + console.print("\n[yellow]Stopping data collection...[/yellow]") + finally: + if writer: + writer.stop() + console.print(f"[green]Data saved to {args.output}[/green]") diff --git a/synapse/cli/streaming_bkup.py b/synapse/cli/streaming_bkup.py new file mode 100644 index 0000000..ae5c225 --- /dev/null +++ b/synapse/cli/streaming_bkup.py @@ -0,0 +1,387 @@ +import json +import queue +import threading +import time +import traceback +import os +import logging +from typing import Optional +from operator import itemgetter +import copy + +from google.protobuf.json_format import Parse, MessageToJson + +from synapse.api.node_pb2 import NodeType +from synapse.api.status_pb2 import DeviceState, StatusCode +from synapse.api.device_pb2 import DeviceConfiguration +import synapse as syn +import synapse.client.channel as channel +import synapse.utils.ndtp_types as ndtp_types +import synapse.cli.synapse_plotter as plotter +from synapse.utils.packet_monitor import PacketMonitor + +from rich.console import Console +from rich.live import Live +from rich.pretty import pprint + + +def add_commands(subparsers): + a = subparsers.add_parser("read", help="Read from a device's StreamOut node") + a.add_argument( + "--config", + type=str, + help="Configuration file", + ) + a.add_argument( + "--num_ch", type=int, help="Number of channels to read from, overrides config" + ) + a.add_argument("--bin", type=bool, help="Output binary format instead of JSON") + a.add_argument("--duration", type=int, help="Duration to read for in seconds") + a.add_argument("--node_id", type=int, help="ID of the StreamOut node to read from") + a.add_argument("--plot", action="store_true", help="Plot the data in real-time") + a.add_argument("--output", type=str, help="Name of the output directory and files") + a.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing files", + ) + a.set_defaults(func=read) + + +def load_config_from_file(path): + with open(path, "r") as f: + data = f.read() + proto = Parse(data, DeviceConfiguration()) + return syn.Config.from_proto(proto) + + +def read(args): + console = Console() + if not args.config and not args.node_id: + console.print("[bold red]Either `--config` or `--node_id` must be specified.") + return + + output_base = args.output + timestamp = time.strftime("%Y%m%d-%H%M%S") + if not output_base: + output_base = f"synapse_data_{timestamp}" + else: + output_base = f"{output_base}_{timestamp}" + + # Check if the output directory exists, we will make the directory later after we know the config + if os.path.exists(output_base): + if not args.overwrite: + console.print( + f"[bold red]Output directory {output_base} already exists, please specify a different output directory or use `--overwrite` to overwrite existing files" + ) + return + else: + console.print(f"[bold yellow]Overwriting existing files in {output_base}") + device = syn.Device(args.uri, args.verbose) + with console.status( + "Reading from Synapse Device", spinner="bouncingBall", spinner_style="green" + ) as status: + status.update("Requesting device info") + info = device.info() + if not info: + console.print(f"[bold red]Failed to get device info from {args.uri}") + return + + console.log(f"Got info from: {info.name}") + if args.verbose: + pprint(info) + console.print("\n") + + status.update("Loading recording configuration") + + if args.config: + config = load_config_from_file(args.config) + if not config: + console.print(f"[bold red]Failed to load config from {args.config}") + return + stream_out = next( + (n for n in config.nodes if n.type == NodeType.kStreamOut), None + ) + if not stream_out: + console.print("[bold red]No StreamOut node found in config") + return + broadband = next( + (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None + ) + if not broadband: + console.print("[bold red]No BroadbandSource node found in config") + return + signal = broadband.signal + if not signal: + console.print("[bold red]No signal configured for BroadbandSource node") + return + + if not signal.electrode: + console.print( + "[bold red]No electrode signal configured for BroadbandSource node" + ) + return + + num_ch = len(signal.electrode.channels) + if args.num_ch: + num_ch = args.num_ch + offset = 0 + channels = [] + for ch in range(offset, offset + num_ch): + channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) + + broadband.signal.electrode.channels = channels + + with console.status( + "Configuring device", spinner="bouncingBall", spinner_style="green" + ) as status: + configure_status = device.configure_with_status(config) + if configure_status is None: + console.print( + "[bold red]Failed to configure device. Run with `--verbose` for more information." + ) + return + if configure_status.code == StatusCode.kInvalidConfiguration: + console.print("[bold red]Failed to configure device.") + console.print(f"[italic red]Why: {configure_status.message}") + console.print("[yellow]Is there a peripheral connected to the device?") + return + elif configure_status.code == StatusCode.kFailedPrecondition: + console.print("[bold red]Failed to configure device.") + console.print(f"[italic red]Why: {configure_status.message}") + console.print( + f"[yellow]If the device is already running, run `synapsectl stop {args.uri}` to stop the device and try again." + ) + return + console.print("[bold green]Device configured successfully") + + if info.status.state != DeviceState.kRunning: + print("Starting device...") + if not device.start(): + raise ValueError("Failed to start device") + + # Get the sample rate from the device + # We need to look at the node configuration with type kBroadbandSource for the sample rate + broadband = next( + (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None + ) + assert broadband is not None, "No BroadbandSource node found in config" + + else: + # TODO(gilbert): Get rid of this giant if-else block + node = next( + ( + n + for n in info.configuration.nodes + if n.type == NodeType.kStreamOut and n.id == args.node_id + ), + None, + ) + if node is None: + console.print( + "[bold red]No StreamOut node found in device configuration; please configure the device with a StreamOut node." + ) + return + + stream_out = syn.StreamOut.from_proto(node) + stream_out.device = device + + # We are ready to start streaming, make the output directory + os.makedirs(output_base, exist_ok=True) + + # Copy our config that was taken from the device to the output directory + device_info_after_config = device.info() + if not device_info_after_config: + console.print(f"[bold red]Failed to get device info from {args.uri}") + return + runtime_config = device_info_after_config.configuration + runtime_config_json = MessageToJson( + runtime_config, + always_print_fields_with_no_presence=True, + preserving_proto_field_name=True, + ) + output_config_path = os.path.join(output_base, "runtime_config.json") + with open(output_config_path, "w") as f: + f.write(runtime_config_json) + + console.print(f"[bold green]Streaming data to {output_base}") + + status_title = ( + f"Streaming data for {args.duration} seconds" + if args.duration + else "Streaming data indefinitely" + ) + console.print(status_title) + + q = queue.Queue() + plot_q = queue.Queue() if args.plot else None + + threads = [] + stop = threading.Event() + if args.bin: + threads.append( + threading.Thread(target=_binary_writer, args=(stop, q, num_ch, output_base)) + ) + else: + threads.append( + threading.Thread(target=_data_writer, args=(stop, q, output_base)) + ) + + if args.plot: + threads.append( + threading.Thread(target=_plot_data, args=(stop, plot_q, runtime_config)) + ) + + for thread in threads: + thread.start() + + try: + read_packets(stream_out, q, plot_q, stop, args.duration) + except KeyboardInterrupt: + pass + finally: + console.print("Stopping read...") + stop.set() + for thread in threads: + thread.join() + + if args.config: + console.print("Stopping device...") + if not device.stop(): + console.print("[red]Failed to stop device") + console.print("Stopped") + + console.print("[bold green]Streaming complete") + console.print("[cyan]================") + console.print(f"[cyan]Output directory: {output_base}/") + console.print(f"[cyan]Run `synapsectl plot --dir {output_base}/` to plot the data") + console.print("[cyan]================") + + +def read_packets( + node: syn.StreamOut, + q: queue.Queue, + plot_q: queue.Queue, + stop: threading.Event, + duration: Optional[int] = None, + num_ch: int = 32, +): + start = time.time() + + # Keep track of our statistics + monitor = PacketMonitor() + monitor.start_monitoring() + + with Live(monitor.generate_stat_table(), refresh_per_second=4) as live: + while not stop.is_set(): + read_ret = node.read() + if read_ret is None: + logging.error("Could not get a valid read from the node") + continue + + synapse_data, bytes_read = read_ret + if synapse_data is None or bytes_read == 0: + logging.error("Could not read data from node") + continue + header, data = synapse_data + monitor.process_packet(header, data, bytes_read) + live.update(monitor.generate_stat_table()) + + # Always add the data to the writer queues + q.put(data) + if plot_q: + plot_q.put(copy.deepcopy(data)) + + if duration and (time.time() - start) > duration: + break + + +def _binary_writer(stop, q: queue.Queue, num_ch, output_base): + filename = f"{output_base}.dat" + full_path = os.path.join(output_base, filename) + if filename: + fd = open(full_path, "wb") + + channel_data = [] + while not stop.is_set() or not q.empty(): + try: + data: ndtp_types.ElectricalBroadbandData = q.get(True, 1) + except queue.Empty: + continue + + try: + for ch_id, samples in data.samples: + channel_data.append([ch_id, samples]) + if len(channel_data) == num_ch: + channel_data.sort(key=itemgetter(0)) + channel_samples = [ch_data[1] for ch_data in channel_data] + frames = list(zip(*channel_samples)) + channel_data = [] + + for frame in frames: + for sample in frame: + fd.write( + int(sample).to_bytes(2, byteorder="little", signed=True) + ) + + except Exception as e: + print(f"Error processing data: {e}") + traceback.print_exc() + continue + + +def _data_writer(stop, q, output_base): + filename = f"{output_base}.jsonl" + full_path = os.path.join(output_base, filename) + if filename: + fd = open(full_path, "wb") + + while not stop.is_set() or not q.empty(): + try: + data = q.get(True, 1) + except queue.Empty: + continue + + try: + fd.write(json.dumps(data.to_list()).encode("utf-8")) + fd.write(b"\n") + + except Exception as e: + print(f"Error processing data: {e}") + traceback.print_exc() + continue + + +def _plot_data(stop, q, runtime_config): + """Plot streaming data from the synapse device.""" + # TODO(gilbert): Make these configurable + window_size_seconds = 3 + + # Find the first broadband source node + broadband_nodes = [ + node for node in runtime_config.nodes if node.type == NodeType.kBroadbandSource + ] + + # Make sure we have a broadband node + # TODO(gilbert): We should be able to support binned spikes here too + # but will need a refactor + if not broadband_nodes: + print("Could not find broadband source config. Cannot plot") + return + + broadband_source = broadband_nodes[0].broadband_source + electrode_config = broadband_source.signal.electrode + + if not electrode_config: + print( + "Could not find an electrode configuration for broadband node. Cannot plot" + ) + return + + # Get configuration parameters + sample_rate_hz = broadband_source.sample_rate_hz + channel_ids = [ch.id for ch in electrode_config.channels] + + # Start the plotter + plotter.plot_synapse_data(stop, q, sample_rate_hz, window_size_seconds, channel_ids) From df3d27de0bce86d6a93ab837950a5a822d2338d4 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Thu, 12 Jun 2025 17:31:01 -0700 Subject: [PATCH 3/8] Updated to use streaming --- synapse/cli/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index c0746c6..cbb3b17 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -297,7 +297,7 @@ def read(args): writer = BroadbandFrameWriter(args.output) # Get sample rate and channels from config # broadband_node = next((n for n in config.nodes if n.type == NodeType.kBroadbandSource), None) - writer.set_attributes(sample_rate_hz=32000, channels=list(range(32))) + writer.set_attributes(sample_rate_hz=32000, channels=list(range(256))) writer.start() From 5fa66d83ada6f72c3b9ae3cf45fda0c32bc615c1 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Thu, 12 Jun 2025 17:34:03 -0700 Subject: [PATCH 4/8] Updated to use streaming --- synapse/cli/streaming.py | 45 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index cbb3b17..9c46d2a 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -4,6 +4,7 @@ import time import h5py from datetime import datetime +import numpy as np import synapse as syn from synapse.api.status_pb2 import DeviceState, StatusCode @@ -67,12 +68,12 @@ def __init__(self, output_dir: str): self.data_queue = queue.Queue(maxsize=1000) self.stop_event = threading.Event() self.writer_thread = None - + # Create HDF5 file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.filename = os.path.join(output_dir, f"broadband_data_{timestamp}.h5") self.file = h5py.File(self.filename, "w") - + # Create datasets self.timestamp_dataset = self.file.create_dataset( "/acquisition/timestamp", shape=(0,), maxshape=(None,), dtype="uint64" @@ -80,14 +81,14 @@ def __init__(self, output_dir: str): self.sequence_dataset = self.file.create_dataset( "/acquisition/sequence_number", shape=(0,), maxshape=(None,), dtype="uint64" ) - # Create frame data dataset with correct shape for 256 samples + # Create frame data dataset as a 1D array of variable length arrays self.frame_data_dataset = self.file.create_dataset( "/acquisition/ElectricalSeries", - shape=(0, 256), # Each frame has 256 samples - maxshape=(None, 256), - dtype="int32", + shape=(0,), + maxshape=(None,), + dtype=h5py.vlen_dtype(np.dtype('int32')) ) - + def set_attributes( self, sample_rate_hz: float, channels: list, session_description: str = "" ): @@ -96,26 +97,26 @@ def set_attributes( self.file.attrs["sample_rate_hz"] = sample_rate_hz if session_description: self.file.attrs["session_description"] = session_description - + # Set session start time self.file.attrs["session_start_time"] = datetime.now().isoformat() - + # Set device type device_group = self.file.create_group("general/device") device_group.attrs["device_type"] = "SciFi" - + # Create electrodes group and write channel IDs electrodes_group = self.file.create_group( "general/extracellular_ephys/electrodes" ) channel_ids = channels electrodes_group.create_dataset("id", data=channel_ids, dtype="uint32") - + def start(self): """Start the writer thread""" self.writer_thread = threading.Thread(target=self._write_loop) self.writer_thread.start() - + def stop(self): """Stop the writer thread and wait for it to finish""" self.stop_event.set() @@ -123,7 +124,7 @@ def stop(self): self.writer_thread.join() self.flush() self.file.close() - + def put(self, frame: BroadbandFrame): """Add frame to the write queue""" try: @@ -135,38 +136,38 @@ def put(self, frame: BroadbandFrame): self.data_queue.put(frame, block=False) except queue.Empty: pass - + def _write_loop(self): """Main writing loop that consumes data from the queue""" while not self.stop_event.is_set() or not self.data_queue.empty(): try: frame = self.data_queue.get(timeout=1) - + # Resize datasets current_size = self.timestamp_dataset.shape[0] new_size = current_size + 1 - + self.timestamp_dataset.resize(new_size, axis=0) self.sequence_dataset.resize(new_size, axis=0) self.frame_data_dataset.resize(new_size, axis=0) - + # Write data self.timestamp_dataset[current_size] = frame.timestamp_ns self.sequence_dataset[current_size] = frame.sequence_number - # Write frame data as a row in the 2D dataset - self.frame_data_dataset[current_size, :] = frame.frame_data - + # Write frame data as a variable length array + self.frame_data_dataset[current_size] = frame.frame_data + # Flush periodically if new_size % 1000 == 0: print(f"Flushed {new_size} frames") self.flush() - + except queue.Empty: continue except Exception as e: print(f"Error writing data: {e}") continue - + def flush(self): """Flush all datasets to disk""" self.timestamp_dataset.flush() From 83370f06600fcdd62fbbe522c38b7c000e2b339b Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Thu, 12 Jun 2025 17:36:47 -0700 Subject: [PATCH 5/8] Updated to use streaming --- synapse/cli/streaming.py | 68 ++++++++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index 9c46d2a..739ad20 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -81,14 +81,18 @@ def __init__(self, output_dir: str): self.sequence_dataset = self.file.create_dataset( "/acquisition/sequence_number", shape=(0,), maxshape=(None,), dtype="uint64" ) - # Create frame data dataset as a 1D array of variable length arrays + # Create frame data dataset as a flat array of samples self.frame_data_dataset = self.file.create_dataset( "/acquisition/ElectricalSeries", shape=(0,), maxshape=(None,), - dtype=h5py.vlen_dtype(np.dtype('int32')) + dtype="int32" ) + # Buffer for collecting frames before writing + self.frame_buffer = [] + self.buffer_size = 1000 # Number of frames to collect before writing + def set_attributes( self, sample_rate_hz: float, channels: list, session_description: str = "" ): @@ -142,25 +146,11 @@ def _write_loop(self): while not self.stop_event.is_set() or not self.data_queue.empty(): try: frame = self.data_queue.get(timeout=1) + self.frame_buffer.append(frame) - # Resize datasets - current_size = self.timestamp_dataset.shape[0] - new_size = current_size + 1 - - self.timestamp_dataset.resize(new_size, axis=0) - self.sequence_dataset.resize(new_size, axis=0) - self.frame_data_dataset.resize(new_size, axis=0) - - # Write data - self.timestamp_dataset[current_size] = frame.timestamp_ns - self.sequence_dataset[current_size] = frame.sequence_number - # Write frame data as a variable length array - self.frame_data_dataset[current_size] = frame.frame_data - - # Flush periodically - if new_size % 1000 == 0: - print(f"Flushed {new_size} frames") - self.flush() + # Write when buffer is full + if len(self.frame_buffer) >= self.buffer_size: + self._write_buffer() except queue.Empty: continue @@ -168,8 +158,46 @@ def _write_loop(self): print(f"Error writing data: {e}") continue + def _write_buffer(self): + """Write the buffered frames to disk""" + if not self.frame_buffer: + return + + # Get current sizes + current_timestamp_size = self.timestamp_dataset.shape[0] + current_frame_size = self.frame_data_dataset.shape[0] + num_frames = len(self.frame_buffer) + + # Resize datasets + new_timestamp_size = current_timestamp_size + num_frames + new_frame_size = current_frame_size + (num_frames * len(self.frame_buffer[0].frame_data)) + + self.timestamp_dataset.resize(new_timestamp_size, axis=0) + self.sequence_dataset.resize(new_timestamp_size, axis=0) + self.frame_data_dataset.resize(new_frame_size, axis=0) + + # Write data + for i, frame in enumerate(self.frame_buffer): + idx = current_timestamp_size + i + self.timestamp_dataset[idx] = frame.timestamp_ns + self.sequence_dataset[idx] = frame.sequence_number + + # Write frame data + frame_start = current_frame_size + (i * len(frame.frame_data)) + frame_end = frame_start + len(frame.frame_data) + self.frame_data_dataset[frame_start:frame_end] = frame.frame_data + + # Clear buffer + self.frame_buffer = [] + + # Flush to disk + self.flush() + print(f"Wrote {num_frames} frames") + def flush(self): """Flush all datasets to disk""" + if self.frame_buffer: + self._write_buffer() self.timestamp_dataset.flush() self.sequence_dataset.flush() self.frame_data_dataset.flush() From eb3ee31573415866d4ea09fbdafa3f170c1a1f3f Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 13 Jun 2025 14:54:55 -0700 Subject: [PATCH 6/8] Updated to use streaming --- synapse/cli/streaming.py | 392 +++++++++++++++++++------ synapse/cli/streaming_bkup.py | 387 ------------------------- synapse/cli/synapse_plotter.py | 408 ++++++++++++++++++-------- synapse/cli/updated_stream_plot.py | 442 ----------------------------- synapse/client/taps.py | 76 ++++- 5 files changed, 662 insertions(+), 1043 deletions(-) delete mode 100644 synapse/cli/streaming_bkup.py delete mode 100644 synapse/cli/updated_stream_plot.py diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index 739ad20..4b270e3 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -4,7 +4,10 @@ import time import h5py from datetime import datetime -import numpy as np +from rich.live import Live +from rich.table import Table +from rich.console import Console +from rich.text import Text import synapse as syn from synapse.api.status_pb2 import DeviceState, StatusCode @@ -12,68 +15,125 @@ from synapse.utils.proto import load_device_config from synapse.api.datatype_pb2 import BroadbandFrame -from rich.console import Console - -class DiskWriter: - def __init__(self, output_dir: str, buffer_size: int = 1024 * 1024): - self.output_dir = output_dir - self.buffer_size = buffer_size - self.data_queue = queue.Queue(maxsize=1000) # Prevent unbounded memory growth +class StreamMonitor: + def __init__(self, console: Console): + self.console = console + self.start_time = time.time() + self.message_count = 0 + self.last_update = time.time() + self.last_count = 0 + self.last_sequence = 0 + self.total_dropped = 0 + self.queue = queue.Queue(maxsize=100) self.stop_event = threading.Event() - self.writer_thread = None + self.monitor_thread = None def start(self): - """Start the writer thread""" - self.writer_thread = threading.Thread(target=self._write_loop) - self.writer_thread.start() + """Start monitoring in separate thread""" + self.start_time = time.time() + self.last_update = self.start_time + self.message_count = 0 + self.last_count = 0 + self.last_sequence = 0 + self.total_dropped = 0 + self.stop_event.clear() + self.monitor_thread = threading.Thread(target=self._monitor_loop) + self.monitor_thread.start() def stop(self): - """Stop the writer thread and wait for it to finish""" + """Stop monitoring thread""" self.stop_event.set() - if self.writer_thread: - self.writer_thread.join() + if self.monitor_thread: + self.monitor_thread.join() - def put(self, data: BroadbandFrame): - """Add data to the write queue""" + def put(self, frame: BroadbandFrame): + """Add frame to monitoring queue (non-blocking)""" try: - self.data_queue.put(data, block=False) + self.queue.put(frame, block=False) except queue.Full: - # If queue is full, we'll drop the oldest data + # Drop frame if queue is full to prevent blocking + pass + + def _monitor_loop(self): + """Process frames for monitoring in separate thread""" + while not self.stop_event.is_set(): try: - self.data_queue.get_nowait() - self.data_queue.put(data, block=False) + frame = self.queue.get(timeout=0.1) + self._update_stats(frame) except queue.Empty: - pass + continue - def _write_loop(self): - """Main writing loop that consumes data from the queue""" - filename = os.path.join(self.output_dir, f"data_{int(time.time())}.dat") - with open(filename, "wb", buffering=self.buffer_size) as f: - while not self.stop_event.is_set() or not self.data_queue.empty(): - try: - data = self.data_queue.get(timeout=1) - # Write binary data directly - print(data) - except queue.Empty: - continue - except Exception as e: - print(f"Error writing data: {e}") - continue + def _update_stats(self, frame: BroadbandFrame): + """Update statistics from frame""" + self.message_count += 1 + + # Check for dropped packets + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.total_dropped += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + + def get_current_stats(self) -> Text: + """Get current statistics as formatted text""" + current_time = time.time() + + # Calculate message rate + elapsed = current_time - self.last_update + if elapsed >= 1.0: # Update rate every second + rate = (self.message_count - self.last_count) / elapsed + self.last_count = self.message_count + self.last_update = current_time + else: + rate = ( + (self.message_count - self.last_count) + / (current_time - self.last_update) + if elapsed > 0 + else 0 + ) + + # Calculate packet loss percentage + total_expected = self.message_count + self.total_dropped + loss_percent = ( + (self.total_dropped / total_expected * 100) if total_expected > 0 else 0 + ) + + # Create styled text + stats_text = Text() + stats_text.append("Messages: ", style="bold") + stats_text.append(f"{self.message_count:,}", style="cyan") + stats_text.append(" | msgs/sec: ", style="bold") + stats_text.append(f"{rate:.1f}/s", style="green") + stats_text.append(" | Dropped: ", style="bold") + stats_text.append(f"{self.total_dropped:,}", style="red") + stats_text.append(" | Loss: ", style="bold") + stats_text.append(f"{loss_percent:.2f}%", style="yellow") + stats_text.append(" | Runtime: ", style="bold") + stats_text.append(f"{current_time - self.start_time:.1f}s", style="blue") + + return stats_text class BroadbandFrameWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - self.data_queue = queue.Queue(maxsize=1000) + self.data_queue = queue.Queue(maxsize=2000) # Increased queue size self.stop_event = threading.Event() self.writer_thread = None - + + # Stats tracking + self.start_time = time.time() + self.frames_received = 0 + self.samples_received = 0 + self.last_sequence = 0 + self.dropped_frames = 0 + # Create HDF5 file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.filename = os.path.join(output_dir, f"broadband_data_{timestamp}.h5") self.file = h5py.File(self.filename, "w") - + # Create datasets self.timestamp_dataset = self.file.create_dataset( "/acquisition/timestamp", shape=(0,), maxshape=(None,), dtype="uint64" @@ -83,16 +143,35 @@ def __init__(self, output_dir: str): ) # Create frame data dataset as a flat array of samples self.frame_data_dataset = self.file.create_dataset( - "/acquisition/ElectricalSeries", - shape=(0,), - maxshape=(None,), - dtype="int32" + "/acquisition/ElectricalSeries", shape=(0,), maxshape=(None,), dtype="int32" ) - + # Buffer for collecting frames before writing self.frame_buffer = [] - self.buffer_size = 1000 # Number of frames to collect before writing - + self.buffer_size = 500 # Reduced buffer size for more frequent writes + + def get_stats(self): + """Get current statistics""" + elapsed = time.time() - self.start_time + if elapsed == 0: + return { + "frames_per_sec": 0, + "samples_per_sec": 0, + "total_frames": 0, + "total_samples": 0, + "dropped_frames": 0, + "last_sequence": 0, + } + + return { + "frames_per_sec": self.frames_received / elapsed, + "samples_per_sec": self.samples_received / elapsed, + "total_frames": self.frames_received, + "total_samples": self.samples_received, + "dropped_frames": self.dropped_frames, + "last_sequence": self.last_sequence, + } + def set_attributes( self, sample_rate_hz: float, channels: list, session_description: str = "" ): @@ -101,26 +180,26 @@ def set_attributes( self.file.attrs["sample_rate_hz"] = sample_rate_hz if session_description: self.file.attrs["session_description"] = session_description - + # Set session start time self.file.attrs["session_start_time"] = datetime.now().isoformat() - + # Set device type device_group = self.file.create_group("general/device") device_group.attrs["device_type"] = "SciFi" - + # Create electrodes group and write channel IDs electrodes_group = self.file.create_group( "general/extracellular_ephys/electrodes" ) channel_ids = channels electrodes_group.create_dataset("id", data=channel_ids, dtype="uint32") - + def start(self): """Start the writer thread""" self.writer_thread = threading.Thread(target=self._write_loop) self.writer_thread.start() - + def stop(self): """Stop the writer thread and wait for it to finish""" self.stop_event.set() @@ -128,9 +207,20 @@ def stop(self): self.writer_thread.join() self.flush() self.file.close() - + def put(self, frame: BroadbandFrame): - """Add frame to the write queue""" + """Add frame to the write queue (non-blocking)""" + # Update stats + self.frames_received += 1 + self.samples_received += len(frame.frame_data) + + # Check for dropped frames + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.dropped_frames += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + try: self.data_queue.put(frame, block=False) except queue.Full: @@ -140,60 +230,86 @@ def put(self, frame: BroadbandFrame): self.data_queue.put(frame, block=False) except queue.Empty: pass - + + def put_batch(self, frames: list): + """Add multiple frames to the write queue efficiently""" + for frame in frames: + self.frames_received += 1 + self.samples_received += len(frame.frame_data) + + # Check for dropped frames + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.dropped_frames += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + + # Try to add all frames to queue + for frame in frames: + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, drop oldest and try again + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + def _write_loop(self): """Main writing loop that consumes data from the queue""" while not self.stop_event.is_set() or not self.data_queue.empty(): try: - frame = self.data_queue.get(timeout=1) + frame = self.data_queue.get(timeout=0.1) self.frame_buffer.append(frame) - + # Write when buffer is full if len(self.frame_buffer) >= self.buffer_size: self._write_buffer() - + except queue.Empty: continue except Exception as e: print(f"Error writing data: {e}") continue - + def _write_buffer(self): """Write the buffered frames to disk""" if not self.frame_buffer: return - + # Get current sizes current_timestamp_size = self.timestamp_dataset.shape[0] current_frame_size = self.frame_data_dataset.shape[0] num_frames = len(self.frame_buffer) - + # Resize datasets new_timestamp_size = current_timestamp_size + num_frames - new_frame_size = current_frame_size + (num_frames * len(self.frame_buffer[0].frame_data)) - + new_frame_size = current_frame_size + ( + num_frames * len(self.frame_buffer[0].frame_data) + ) + self.timestamp_dataset.resize(new_timestamp_size, axis=0) self.sequence_dataset.resize(new_timestamp_size, axis=0) self.frame_data_dataset.resize(new_frame_size, axis=0) - + # Write data for i, frame in enumerate(self.frame_buffer): idx = current_timestamp_size + i self.timestamp_dataset[idx] = frame.timestamp_ns self.sequence_dataset[idx] = frame.sequence_number - + # Write frame data frame_start = current_frame_size + (i * len(frame.frame_data)) frame_end = frame_start + len(frame.frame_data) self.frame_data_dataset[frame_start:frame_end] = frame.frame_data - + # Clear buffer self.frame_buffer = [] - + # Flush to disk self.flush() - print(f"Wrote {num_frames} frames") - + def flush(self): """Flush all datasets to disk""" if self.frame_buffer: @@ -204,6 +320,24 @@ def flush(self): self.file.flush() +def create_status_table(writer: BroadbandFrameWriter) -> Table: + """Create a status table for display""" + stats = writer.get_stats() + table = Table(title="Streaming Status") + + table.add_column("Metric", style="cyan") + table.add_column("Value", style="green") + + table.add_row("Frames/sec", f"{stats['frames_per_sec']:.1f}") + table.add_row("Samples/sec", f"{stats['samples_per_sec']:.1f}") + table.add_row("Total Frames", str(stats["total_frames"])) + table.add_row("Total Samples", str(stats["total_samples"])) + table.add_row("Dropped Frames", str(stats["dropped_frames"])) + table.add_row("Last Sequence", str(stats["last_sequence"])) + + return table + + def add_commands(subparsers): read_parser = subparsers.add_parser( "read", help="Read from a device's Broadband Tap" @@ -218,6 +352,17 @@ def add_commands(subparsers): read_parser.add_argument( "--overwrite", action="store_true", help="Overwrite existing files" ) + read_parser.add_argument( + "--plot", action="store_true", help="Show real-time plot of Broadband Data" + ) + read_parser.add_argument( + "--tap-name", + type=str, + help="Specific tap name to connect to (if not specified, will auto-select first BroadbandFrame tap)", + ) + read_parser.add_argument( + "--list-taps", action="store_true", help="List all available taps and exit" + ) read_parser.set_defaults(func=read) @@ -269,13 +414,51 @@ def setup_output(args, console): return True +def list_available_taps(args, device, console): + """List all available taps on the device""" + read_tap = Tap(args.uri, args.verbose) + taps = read_tap.list_taps() + + if not taps: + console.print("[bold red]No taps found on device[/bold red]") + return + + console.print("\n[bold cyan]Available Taps:[/bold cyan]") + console.print("=" * 50) + + for tap in taps: + console.print(f"[green]Name:[/green] {tap.name}") + console.print(f"[blue]Type:[/blue] {tap.message_type}") + console.print(f"[yellow]Endpoint:[/yellow] {tap.endpoint}") + console.print("-" * 30) + + console.print(f"\n[bold]Total: {len(taps)} taps available[/bold]") + + def get_broadband_tap(args, device, console): read_tap = Tap(args.uri, args.verbose) taps = read_tap.list_taps() - # Just get the first tap that has BroadbandFrame as the type + # If user specified a tap name, try to use it first + if hasattr(args, "tap_name") and args.tap_name: + console.log(f"[cyan]Looking for specified tap: {args.tap_name}[/cyan]") + for t in taps: + if t.name == args.tap_name: + console.log( + f"[green]Found specified tap: {args.tap_name} (type: {t.message_type})[/green]" + ) + read_tap.connect(t.name) + return read_tap + + console.print( + f"[yellow]Warning: Specified tap '{args.tap_name}' not found, falling back to auto-selection[/yellow]" + ) + + # Auto-select: get the first tap that has BroadbandFrame as the type + console.log("[cyan]Auto-selecting first BroadbandFrame tap[/cyan]") for t in taps: if "BroadbandFrame" in t.message_type: + console.log(f"[green]Found BroadbandFrame tap: {t.name}[/green]") read_tap.connect(t.name) return read_tap @@ -298,6 +481,11 @@ def read(args): device_name = device.get_name() console.log(f"[green]Connected to {device_name}[/green]") + # If user just wants to list taps, do that and exit + if hasattr(args, "list_taps") and args.list_taps: + list_available_taps(args, device, console) + return + # Apply the configuration to the device if not configure_device(device, config, console): console.print("[bold red]Failed to configure device[/bold red]") @@ -324,25 +512,67 @@ def read(args): writer = None if args.output: writer = BroadbandFrameWriter(args.output) - # Get sample rate and channels from config - # broadband_node = next((n for n in config.nodes if n.type == NodeType.kBroadbandSource), None) writer.set_attributes(sample_rate_hz=32000, channels=list(range(256))) - writer.start() + # Setup plotter if requested + plotter = None + if args.plot: + try: + from synapse.cli.synapse_plotter import create_broadband_plotter + + # Always make all 256 channels available, but start with only 5 selected + available_channels = list(range(256)) + plotter = create_broadband_plotter( + sample_rate_hz=32000, + window_size_seconds=5, + channel_ids=available_channels, + ) + plotter.start() + console.log( + f"[green]Started real-time plotter with all {len(available_channels)} channels available[/green]" + ) + except ImportError as e: + console.print( + f"[bold red]Failed to import plotter (missing dearpygui?): {e}[/bold red]" + ) + return + + # Setup stream monitor + monitor = StreamMonitor(console) + monitor.start() + try: - # Now we need to start the streaming - frame = BroadbandFrame() - with console.status("Streaming data...", spinner="bouncingBall"): - for message in broadband_tap.stream(): - frame.ParseFromString(message) - if writer: - writer.put(frame) - else: - print(frame) + # Use batch streaming for better throughput + with Live(monitor.get_current_stats(), refresh_per_second=4) as live: + for message_batch in broadband_tap.stream_batch(batch_size=10): + frames = [] + for message in message_batch: + frame = BroadbandFrame() + frame.ParseFromString(message) + frames.append(frame) + + # Send to monitor (non-blocking) + monitor.put(frame) + if plotter: + plotter.put(frame) + + # Batch write for better performance + if writer and frames: + writer.put_batch(frames) + + live.update(monitor.get_current_stats()) + except KeyboardInterrupt: console.print("\n[yellow]Stopping data collection...[/yellow]") finally: if writer: writer.stop() + if plotter: + plotter.stop() + if monitor: + monitor.stop() + if args.output: console.print(f"[green]Data saved to {args.output}[/green]") + if args.plot: + console.print("[green]Plotter stopped[/green]") diff --git a/synapse/cli/streaming_bkup.py b/synapse/cli/streaming_bkup.py deleted file mode 100644 index ae5c225..0000000 --- a/synapse/cli/streaming_bkup.py +++ /dev/null @@ -1,387 +0,0 @@ -import json -import queue -import threading -import time -import traceback -import os -import logging -from typing import Optional -from operator import itemgetter -import copy - -from google.protobuf.json_format import Parse, MessageToJson - -from synapse.api.node_pb2 import NodeType -from synapse.api.status_pb2 import DeviceState, StatusCode -from synapse.api.device_pb2 import DeviceConfiguration -import synapse as syn -import synapse.client.channel as channel -import synapse.utils.ndtp_types as ndtp_types -import synapse.cli.synapse_plotter as plotter -from synapse.utils.packet_monitor import PacketMonitor - -from rich.console import Console -from rich.live import Live -from rich.pretty import pprint - - -def add_commands(subparsers): - a = subparsers.add_parser("read", help="Read from a device's StreamOut node") - a.add_argument( - "--config", - type=str, - help="Configuration file", - ) - a.add_argument( - "--num_ch", type=int, help="Number of channels to read from, overrides config" - ) - a.add_argument("--bin", type=bool, help="Output binary format instead of JSON") - a.add_argument("--duration", type=int, help="Duration to read for in seconds") - a.add_argument("--node_id", type=int, help="ID of the StreamOut node to read from") - a.add_argument("--plot", action="store_true", help="Plot the data in real-time") - a.add_argument("--output", type=str, help="Name of the output directory and files") - a.add_argument( - "--overwrite", - action="store_true", - default=False, - help="Overwrite existing files", - ) - a.set_defaults(func=read) - - -def load_config_from_file(path): - with open(path, "r") as f: - data = f.read() - proto = Parse(data, DeviceConfiguration()) - return syn.Config.from_proto(proto) - - -def read(args): - console = Console() - if not args.config and not args.node_id: - console.print("[bold red]Either `--config` or `--node_id` must be specified.") - return - - output_base = args.output - timestamp = time.strftime("%Y%m%d-%H%M%S") - if not output_base: - output_base = f"synapse_data_{timestamp}" - else: - output_base = f"{output_base}_{timestamp}" - - # Check if the output directory exists, we will make the directory later after we know the config - if os.path.exists(output_base): - if not args.overwrite: - console.print( - f"[bold red]Output directory {output_base} already exists, please specify a different output directory or use `--overwrite` to overwrite existing files" - ) - return - else: - console.print(f"[bold yellow]Overwriting existing files in {output_base}") - device = syn.Device(args.uri, args.verbose) - with console.status( - "Reading from Synapse Device", spinner="bouncingBall", spinner_style="green" - ) as status: - status.update("Requesting device info") - info = device.info() - if not info: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return - - console.log(f"Got info from: {info.name}") - if args.verbose: - pprint(info) - console.print("\n") - - status.update("Loading recording configuration") - - if args.config: - config = load_config_from_file(args.config) - if not config: - console.print(f"[bold red]Failed to load config from {args.config}") - return - stream_out = next( - (n for n in config.nodes if n.type == NodeType.kStreamOut), None - ) - if not stream_out: - console.print("[bold red]No StreamOut node found in config") - return - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - if not broadband: - console.print("[bold red]No BroadbandSource node found in config") - return - signal = broadband.signal - if not signal: - console.print("[bold red]No signal configured for BroadbandSource node") - return - - if not signal.electrode: - console.print( - "[bold red]No electrode signal configured for BroadbandSource node" - ) - return - - num_ch = len(signal.electrode.channels) - if args.num_ch: - num_ch = args.num_ch - offset = 0 - channels = [] - for ch in range(offset, offset + num_ch): - channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) - - broadband.signal.electrode.channels = channels - - with console.status( - "Configuring device", spinner="bouncingBall", spinner_style="green" - ) as status: - configure_status = device.configure_with_status(config) - if configure_status is None: - console.print( - "[bold red]Failed to configure device. Run with `--verbose` for more information." - ) - return - if configure_status.code == StatusCode.kInvalidConfiguration: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print("[yellow]Is there a peripheral connected to the device?") - return - elif configure_status.code == StatusCode.kFailedPrecondition: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print( - f"[yellow]If the device is already running, run `synapsectl stop {args.uri}` to stop the device and try again." - ) - return - console.print("[bold green]Device configured successfully") - - if info.status.state != DeviceState.kRunning: - print("Starting device...") - if not device.start(): - raise ValueError("Failed to start device") - - # Get the sample rate from the device - # We need to look at the node configuration with type kBroadbandSource for the sample rate - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - assert broadband is not None, "No BroadbandSource node found in config" - - else: - # TODO(gilbert): Get rid of this giant if-else block - node = next( - ( - n - for n in info.configuration.nodes - if n.type == NodeType.kStreamOut and n.id == args.node_id - ), - None, - ) - if node is None: - console.print( - "[bold red]No StreamOut node found in device configuration; please configure the device with a StreamOut node." - ) - return - - stream_out = syn.StreamOut.from_proto(node) - stream_out.device = device - - # We are ready to start streaming, make the output directory - os.makedirs(output_base, exist_ok=True) - - # Copy our config that was taken from the device to the output directory - device_info_after_config = device.info() - if not device_info_after_config: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return - runtime_config = device_info_after_config.configuration - runtime_config_json = MessageToJson( - runtime_config, - always_print_fields_with_no_presence=True, - preserving_proto_field_name=True, - ) - output_config_path = os.path.join(output_base, "runtime_config.json") - with open(output_config_path, "w") as f: - f.write(runtime_config_json) - - console.print(f"[bold green]Streaming data to {output_base}") - - status_title = ( - f"Streaming data for {args.duration} seconds" - if args.duration - else "Streaming data indefinitely" - ) - console.print(status_title) - - q = queue.Queue() - plot_q = queue.Queue() if args.plot else None - - threads = [] - stop = threading.Event() - if args.bin: - threads.append( - threading.Thread(target=_binary_writer, args=(stop, q, num_ch, output_base)) - ) - else: - threads.append( - threading.Thread(target=_data_writer, args=(stop, q, output_base)) - ) - - if args.plot: - threads.append( - threading.Thread(target=_plot_data, args=(stop, plot_q, runtime_config)) - ) - - for thread in threads: - thread.start() - - try: - read_packets(stream_out, q, plot_q, stop, args.duration) - except KeyboardInterrupt: - pass - finally: - console.print("Stopping read...") - stop.set() - for thread in threads: - thread.join() - - if args.config: - console.print("Stopping device...") - if not device.stop(): - console.print("[red]Failed to stop device") - console.print("Stopped") - - console.print("[bold green]Streaming complete") - console.print("[cyan]================") - console.print(f"[cyan]Output directory: {output_base}/") - console.print(f"[cyan]Run `synapsectl plot --dir {output_base}/` to plot the data") - console.print("[cyan]================") - - -def read_packets( - node: syn.StreamOut, - q: queue.Queue, - plot_q: queue.Queue, - stop: threading.Event, - duration: Optional[int] = None, - num_ch: int = 32, -): - start = time.time() - - # Keep track of our statistics - monitor = PacketMonitor() - monitor.start_monitoring() - - with Live(monitor.generate_stat_table(), refresh_per_second=4) as live: - while not stop.is_set(): - read_ret = node.read() - if read_ret is None: - logging.error("Could not get a valid read from the node") - continue - - synapse_data, bytes_read = read_ret - if synapse_data is None or bytes_read == 0: - logging.error("Could not read data from node") - continue - header, data = synapse_data - monitor.process_packet(header, data, bytes_read) - live.update(monitor.generate_stat_table()) - - # Always add the data to the writer queues - q.put(data) - if plot_q: - plot_q.put(copy.deepcopy(data)) - - if duration and (time.time() - start) > duration: - break - - -def _binary_writer(stop, q: queue.Queue, num_ch, output_base): - filename = f"{output_base}.dat" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") - - channel_data = [] - while not stop.is_set() or not q.empty(): - try: - data: ndtp_types.ElectricalBroadbandData = q.get(True, 1) - except queue.Empty: - continue - - try: - for ch_id, samples in data.samples: - channel_data.append([ch_id, samples]) - if len(channel_data) == num_ch: - channel_data.sort(key=itemgetter(0)) - channel_samples = [ch_data[1] for ch_data in channel_data] - frames = list(zip(*channel_samples)) - channel_data = [] - - for frame in frames: - for sample in frame: - fd.write( - int(sample).to_bytes(2, byteorder="little", signed=True) - ) - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _data_writer(stop, q, output_base): - filename = f"{output_base}.jsonl" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") - - while not stop.is_set() or not q.empty(): - try: - data = q.get(True, 1) - except queue.Empty: - continue - - try: - fd.write(json.dumps(data.to_list()).encode("utf-8")) - fd.write(b"\n") - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _plot_data(stop, q, runtime_config): - """Plot streaming data from the synapse device.""" - # TODO(gilbert): Make these configurable - window_size_seconds = 3 - - # Find the first broadband source node - broadband_nodes = [ - node for node in runtime_config.nodes if node.type == NodeType.kBroadbandSource - ] - - # Make sure we have a broadband node - # TODO(gilbert): We should be able to support binned spikes here too - # but will need a refactor - if not broadband_nodes: - print("Could not find broadband source config. Cannot plot") - return - - broadband_source = broadband_nodes[0].broadband_source - electrode_config = broadband_source.signal.electrode - - if not electrode_config: - print( - "Could not find an electrode configuration for broadband node. Cannot plot" - ) - return - - # Get configuration parameters - sample_rate_hz = broadband_source.sample_rate_hz - channel_ids = [ch.id for ch in electrode_config.channels] - - # Start the plotter - plotter.plot_synapse_data(stop, q, sample_rate_hz, window_size_seconds, channel_ids) diff --git a/synapse/cli/synapse_plotter.py b/synapse/cli/synapse_plotter.py index 647fa9d..1c98dbe 100644 --- a/synapse/cli/synapse_plotter.py +++ b/synapse/cli/synapse_plotter.py @@ -1,8 +1,9 @@ import dearpygui.dearpygui as dpg import queue import time -from threading import Event +from threading import Event, Thread import numpy as np +from synapse.api.datatype_pb2 import BroadbandFrame class SynapsePlotter: @@ -18,33 +19,46 @@ def __init__(self, sample_rate: int, window_size: int, channel_ids): ch_id: idx for idx, ch_id in enumerate(self.channel_ids) } + # Track which channels are selected for plotting (start with first 5) + self.selected_channels = set(self.channel_ids[:5]) + # One ring buffer (of length BUFFER_SIZE) per channel self.data_buffers = [ np.zeros(self.buffer_size) for _ in range(self.num_channels) ] - # A time axis (0..WINDOW_SIZE) also size BUFFER_SIZE - self.time_buffer = np.linspace( - 0, self.window_size_seconds, self.buffer_size, endpoint=True - ) + # Timestamp buffer for each channel (in seconds, relative to start) + self.timestamp_buffers = [ + np.zeros(self.buffer_size) for _ in range(self.num_channels) + ] # A separate ring-buffer pointer for each channel - # TODO(gilbert): this is assuming that channels are truly independent self.buffer_positions = [0] * self.num_channels # Track which channel to display in the "zoom" (single channel) plot self.selected_channel_idx = 0 self.selected_channel_id = self.channel_ids[0] - # Track start time for display + # Track start time for display and timestamp conversion self.start_time = None + self.start_timestamp_ns = None + self.latest_data_time = 0 # Track the most recent data timestamp in seconds # Defaults for the zoomed channel plot - self.zoom_y_min = 0 + self.zoom_y_min = -4096 self.zoom_y_max = 4096 self.signal_separation = 1000 + # Dictionary to store line series for plotted channels + self.active_lines = {} + + # Queue and threading for BroadbandFrame processing + self.data_queue = queue.Queue(maxsize=2000) + self.stop_event = Event() + self.plot_thread = None + self.running = False + dpg.create_context() self.setup_gui() @@ -62,14 +76,38 @@ def setup_gui(self): pos=(0, 0), tag="control_window", ): + dpg.add_text("Select Channels to Plot:") + + # Add convenience buttons + with dpg.group(horizontal=True): + dpg.add_button(label="All", callback=self.select_all_channels, width=40) + dpg.add_button(label="None", callback=self.select_no_channels, width=40) + dpg.add_button( + label="First 5", callback=self.select_first_5_channels, width=50 + ) + + # Add checkboxes for each channel in a scrollable region + with dpg.child_window(height=150, tag="channel_selection"): + for ch_id in self.channel_ids: + is_selected = ch_id in self.selected_channels + dpg.add_checkbox( + label=f"Channel {ch_id}", + default_value=is_selected, + callback=self.channel_checkbox_callback, + user_data=ch_id, + tag=f"ch_checkbox_{ch_id}", + ) + + dpg.add_separator() dpg.add_text("Select Channel to Zoom:") dpg.add_combo( items=[str(ch_id) for ch_id in self.channel_ids], default_value=str(self.selected_channel_id), - callback=self.channel_selection_callback, - tag="channel_combo", + callback=self.zoom_channel_callback, + tag="zoom_channel_combo", width=80, ) + dpg.add_separator() dpg.add_text("Elapsed Time (s):") dpg.add_text("", tag="elapsed_time_text") @@ -120,24 +158,17 @@ def setup_gui(self): dpg.add_plot_legend() # Axes - dpg.add_plot_axis(dpg.mvXAxis, label="Time (s)") + dpg.add_plot_axis( + dpg.mvXAxis, label="Time (s)", tag="x_axis_all" + ) self.y_axis_all = dpg.add_plot_axis( dpg.mvYAxis, label="Amplitude", tag="y_axis_all" ) - dpg.set_axis_limits("y_axis_all", 0, 4096 * 10) - - # Create line series for each channel - self.lines_all = [] - for idx, ch_id in enumerate(self.channel_ids): - line_tag = f"all_line_ch{ch_id}" - line = dpg.add_line_series( - [], - [], - label=f"Ch {ch_id}", - parent=self.y_axis_all, - tag=line_tag, - ) - self.lines_all.append(line) + dpg.set_axis_limits("y_axis_all", -4096, 4096 * 10) + + # Create line series for initially selected channels + for ch_id in self.selected_channels: + self.create_line_series(ch_id) # Zoomed Channel Tab with dpg.tab(label="Zoomed Channel"): @@ -150,7 +181,9 @@ def setup_gui(self): dpg.add_plot_legend() # Axes - dpg.add_plot_axis(dpg.mvXAxis, label="Time (s)") + dpg.add_plot_axis( + dpg.mvXAxis, label="Time (s)", tag="x_axis_zoom" + ) self.y_axis_zoom = dpg.add_plot_axis( dpg.mvYAxis, label="Amplitude", tag="y_axis_zoom" ) @@ -164,13 +197,43 @@ def setup_gui(self): tag="zoomed_line", ) - def channel_selection_callback(self, sender, app_data, user_data): - """Called when user picks a channel from the combo.""" + def channel_checkbox_callback(self, sender, app_data, user_data): + """Called when user checks/unchecks a channel checkbox.""" + ch_id = user_data + if app_data: # Checked + self.selected_channels.add(ch_id) + self.create_line_series(ch_id) + else: # Unchecked + self.selected_channels.discard(ch_id) + self.remove_line_series(ch_id) + + def zoom_channel_callback(self, sender, app_data, user_data): + """Called when user picks a channel for zooming.""" self.selected_channel_id = int(app_data) self.selected_channel_idx = self.channel_to_index[self.selected_channel_id] # Update the label of the zoomed line dpg.configure_item("zoomed_line", label=f"Channel {self.selected_channel_id}") + def create_line_series(self, ch_id): + """Create a line series for the specified channel.""" + if ch_id not in self.active_lines: + line_tag = f"all_line_ch{ch_id}" + dpg.add_line_series( + [], + [], + label=f"Ch {ch_id}", + parent=self.y_axis_all, + tag=line_tag, + ) + self.active_lines[ch_id] = line_tag + + def remove_line_series(self, ch_id): + """Remove the line series for the specified channel.""" + if ch_id in self.active_lines: + line_tag = self.active_lines[ch_id] + dpg.delete_item(line_tag) + del self.active_lines[ch_id] + def set_zoom_y_min(self, sender, app_data): self.zoom_y_min = app_data @@ -180,38 +243,91 @@ def set_zoom_y_max(self, sender, app_data): def set_signal_separation(self, sender, app_data): self.signal_separation = app_data - def on_split_drag(self, sender, app_data): - """Handle dragging the splitter between plots.""" - # Get the main window height - main_window_height = dpg.get_item_height("main_window") - - # Calculate new heights based on drag position - mouse_pos = dpg.get_mouse_pos(local=False) - window_pos = dpg.get_item_pos("main_window") - relative_height = mouse_pos[1] - window_pos[1] - - # Ensure minimum heights for both windows - MIN_HEIGHT = 100 - if ( - relative_height < MIN_HEIGHT - or relative_height > main_window_height - MIN_HEIGHT - ): + def put(self, frame: BroadbandFrame): + """Add a BroadbandFrame to the processing queue""" + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, drop multiple old frames and add the new one + dropped = 0 + while dropped < 5: # Drop up to 5 old frames + try: + self.data_queue.get_nowait() + dropped += 1 + except queue.Empty: + break + try: + self.data_queue.put(frame, block=False) + except queue.Full: + pass # Still full, drop this frame + + def put_batch(self, frames: list): + """Add multiple BroadbandFrames efficiently""" + for frame in frames: + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # Drop old frames to make room + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + + def start(self): + """Start the plotter in a separate thread""" + if self.running: return - # Update the heights of both windows - dpg.configure_item("top_plot_window", height=relative_height) + self.running = True + self.stop_event.clear() + self.plot_thread = Thread(target=self._plot_thread_main) + self.plot_thread.start() + + def stop(self): + """Stop the plotter thread""" + if not self.running: + return + + self.running = False + self.stop_event.set() + if self.plot_thread: + self.plot_thread.join() + + def _plot_thread_main(self): + """Main plotting thread that runs the DearPyGui event loop""" + dpg.setup_dearpygui() + dpg.show_viewport() + + # Record start time + self.start_time = time.time() + + # Main loop + fps_limit = 30 + frame_duration = 1.0 / fps_limit + last_time = time.time() - def resize_callback(self, sender, app_data): - """Handle window resize events.""" - viewport_width = dpg.get_viewport_width() - viewport_height = dpg.get_viewport_height() + while dpg.is_dearpygui_running() and not self.stop_event.is_set(): + # Process multiple frames per iteration for better throughput + frames_processed = 0 + max_frames_per_iter = 10 - # Update main window size - main_window_width = viewport_width - 280 # Leave space for control window - main_window_height = viewport_height - 20 # Leave small margin - dpg.configure_item( - "main_window", width=main_window_width, height=main_window_height - ) + while frames_processed < max_frames_per_iter: + try: + frame = self.data_queue.get_nowait() + self.process_broadband_frame(frame) + frames_processed += 1 + except queue.Empty: + break + + # Throttle rendering to the fps limit + now = time.time() + if (now - last_time) >= frame_duration: + self.update_plot() + dpg.render_dearpygui_frame() + last_time = now + + dpg.destroy_context() def update_plot(self): """ @@ -220,28 +336,52 @@ def update_plot(self): """ # Downsample factor for performance # Note(gilbert): we should probably make this configurable, it is arbitrary - ds_factor = 10 + ds_factor = 4 + + # Get the current time window for x-axis limits based on latest data + # Use latest data timestamp instead of wall clock for better sync + current_data_time = self.latest_data_time + x_min = max(0, current_data_time - self.window_size_seconds) + x_max = current_data_time # ----------------------------- # Update "All Channels" Plot # ----------------------------- - for idx, ch_id in enumerate(self.channel_ids): + active_channel_idx = 0 + for ch_id in self.selected_channels: + if ch_id not in self.active_lines: + continue + + idx = self.channel_to_index[ch_id] pos = self.buffer_positions[idx] # Roll data so that index -1 corresponds to the newest sample rolled_y = np.roll(self.data_buffers[idx], -pos) - rolled_x = np.roll(self.time_buffer, -pos) + rolled_x = np.roll(self.timestamp_buffers[idx], -pos) # Downsample ds_x = rolled_x[::ds_factor] ds_y = rolled_y[::ds_factor] - # Apply vertical offset for each channel to avoid overlap - offset = idx * self.signal_separation + # Apply vertical offset for each active channel to avoid overlap + offset = active_channel_idx * self.signal_separation ds_y_offset = ds_y + offset - # Update the line series for channel ch - dpg.set_value(self.lines_all[idx], [ds_x.tolist(), ds_y_offset.tolist()]) + # Update the line series for this channel + line_tag = self.active_lines[ch_id] + dpg.set_value(line_tag, [ds_x.tolist(), ds_y_offset.tolist()]) + + active_channel_idx += 1 + + # Set X-axis limits to show current time window + dpg.set_axis_limits("x_axis_all", x_min, x_max) + + # Set Y-axis limits based on number of active channels + num_active = len(self.selected_channels) + if num_active > 0: + y_max_all = (num_active - 1) * self.signal_separation + 4096 + y_min_all = -4096 + dpg.set_axis_limits("y_axis_all", y_min_all, y_max_all) # ----------------------------- # Update Zoomed Channel Plot @@ -250,7 +390,7 @@ def update_plot(self): pos = self.buffer_positions[idx] rolled_y_ch = np.roll(self.data_buffers[idx], -pos) - rolled_x_ch = np.roll(self.time_buffer, -pos) + rolled_x_ch = np.roll(self.timestamp_buffers[idx], -pos) ds_x_ch = rolled_x_ch[::ds_factor] ds_y_ch = rolled_y_ch[::ds_factor] @@ -258,7 +398,8 @@ def update_plot(self): # Update the single "zoomed_line" series dpg.set_value("zoomed_line", [ds_x_ch.tolist(), ds_y_ch.tolist()]) - # Optionally set the zoomed plot Y-axis range for a closer look + # Set axis limits for both plots + dpg.set_axis_limits("x_axis_zoom", x_min, x_max) dpg.set_axis_limits("y_axis_zoom", self.zoom_y_min, self.zoom_y_max) # ----------------------------- @@ -268,68 +409,89 @@ def update_plot(self): elapsed = time.time() - self.start_time dpg.set_value("elapsed_time_text", f"{elapsed:.2f}") - def process_data(self, data): + def process_broadband_frame(self, frame: BroadbandFrame): """ - data = (channel, samples) - Write these samples into the ring buffer for that channel. + Process a BroadbandFrame and distribute the data to channel buffers. + Uses actual timestamps from the frame for proper time synchronization. """ - channel_id, samples = data - if channel_id not in self.channel_to_index: - print( - f"Warning: Received data for channel {channel_id} which is not in the configured channel list." - ) - return - - # Get our channel id mapped to our plotting index - idx = self.channel_to_index[channel_id] - num_samples = len(samples) - pos = self.buffer_positions[idx] - - end_pos = pos + num_samples - if end_pos <= self.buffer_size: - # Simple case: fits without wrap - self.data_buffers[idx][pos:end_pos] = samples - else: - # Wrap around case - first_part = self.buffer_size - pos - second_part = num_samples - first_part - self.data_buffers[idx][pos:] = samples[:first_part] - self.data_buffers[idx][:second_part] = samples[first_part:] - - self.buffer_positions[idx] = (pos + num_samples) % self.buffer_size - - def start(self, stop, data_queue): - """Run the DearPyGui event/render loop.""" - dpg.setup_dearpygui() - dpg.show_viewport() - - # Record start time - self.start_time = time.time() - - # Main loop - fps_limit = 60 - frame_duration = 1.0 / fps_limit - last_time = time.time() - - while dpg.is_dearpygui_running() and not stop.is_set(): - # Process any incoming data in the queue - while True: - try: - data = data_queue.get_nowait() - self.process_data(data.samples[0]) - except queue.Empty: - break - - # Throttle rendering to the fps limit - now = time.time() - if (now - last_time) >= frame_duration: - self.update_plot() - dpg.render_dearpygui_frame() - last_time = now - - dpg.destroy_context() + # Set start timestamp on first frame + if self.start_timestamp_ns is None: + self.start_timestamp_ns = frame.timestamp_ns + self.start_time = time.time() + + # Convert timestamp to seconds relative to start + relative_time_s = (frame.timestamp_ns - self.start_timestamp_ns) / 1e9 + + # Update latest data time + self.latest_data_time = relative_time_s + + # frame_data is a flat array with one sample per channel + # We assume the data is organized as: [ch0_sample, ch1_sample, ch2_sample, ...] + frame_data = frame.frame_data + + # Distribute data to each channel buffer + for ch_idx, ch_id in enumerate(self.channel_ids): + if ch_idx < len(frame_data): + sample = frame_data[ch_idx] + + # Add sample to this channel's ring buffer + pos = self.buffer_positions[ch_idx] + self.data_buffers[ch_idx][pos] = sample + + # Add actual timestamp to this channel's timestamp buffer + self.timestamp_buffers[ch_idx][pos] = relative_time_s + + self.buffer_positions[ch_idx] = (pos + 1) % self.buffer_size + + def select_all_channels(self): + """Select all channels for plotting.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set(self.channel_ids) + + # Update checkboxes + for ch_id in self.channel_ids: + dpg.set_value(f"ch_checkbox_{ch_id}", True) + if ch_id not in old_selection: + self.create_line_series(ch_id) + + def select_no_channels(self): + """Deselect all channels.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set() + + # Update checkboxes and remove line series + for ch_id in old_selection: + dpg.set_value(f"ch_checkbox_{ch_id}", False) + self.remove_line_series(ch_id) + + def select_first_5_channels(self): + """Select only the first 5 channels.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set(self.channel_ids[:5]) + + # Update checkboxes + for ch_id in self.channel_ids: + should_be_selected = ch_id in self.selected_channels + dpg.set_value(f"ch_checkbox_{ch_id}", should_be_selected) + + if should_be_selected and ch_id not in old_selection: + self.create_line_series(ch_id) + elif not should_be_selected and ch_id in old_selection: + self.remove_line_series(ch_id) + + +# Factory function to create plotter with BroadbandFrame support +def create_broadband_plotter( + sample_rate_hz: int, window_size_seconds: int, channel_ids +): + """Create a SynapsePlotter configured for BroadbandFrame data""" + return SynapsePlotter(sample_rate_hz, window_size_seconds, channel_ids) +# Legacy function for backward compatibility def plot_synapse_data( stop: Event, data_queue: queue.Queue, diff --git a/synapse/cli/updated_stream_plot.py b/synapse/cli/updated_stream_plot.py deleted file mode 100644 index b19bef4..0000000 --- a/synapse/cli/updated_stream_plot.py +++ /dev/null @@ -1,442 +0,0 @@ -#!/usr/bin/env python3 -""" -High-performance neural data visualization using raster/sweep display -Much faster than scrolling plots - similar to oscilloscope displays -""" - -import numpy as np -import pyqtgraph as pg -from pyqtgraph.Qt import QtCore, QtWidgets -import threading -from collections import deque -from synapse.client.taps import Tap -from synapse.api.datatype_pb2 import BroadbandFrame as BroadbandFrameProto - -# Assuming you have your protobuf compiled -# from your_proto_pb2 import BroadbandFrame - - -class NeuralRasterDisplay(QtWidgets.QWidget): - """Fast raster-style display for neural data""" - - def __init__(self, n_channels=5, sample_rate=32000, display_seconds=5): - super().__init__() - - self.n_channels = n_channels - self.sample_rate = sample_rate - self.display_seconds = display_seconds - - # Display parameters - self.width = 1200 # pixels - self.height = 800 # pixels - self.channel_height = self.height // n_channels - - # Calculate downsampling - total_samples = sample_rate * display_seconds # 160,000 - self.downsample = total_samples // self.width # ~133 samples per pixel - self.samples_per_pixel = self.downsample - - # Data buffer for incoming samples - larger buffer to prevent overflow - buffer_size = self.downsample * n_channels * 5 # 5x larger buffer - self.data_buffer = deque(maxlen=buffer_size) - - # Image array for display (channels x width) - black background - self.image_data = np.zeros((self.height, self.width), dtype=np.float32) - - # Current write position - self.write_pos = 0 - - # Debug counter - self.pixels_written = 0 - - print( - f"Setup: {n_channels} channels, {self.downsample} samples/pixel, buffer size: {buffer_size}" - ) - - # Setup UI - self.setup_ui() - - # Update timer - only update the sweep line - self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.update_display) - self.timer.start(16) # 60 FPS - - self.setWindowTitle("Neural Data Raster Display") - self.resize(self.width + 100, self.height + 100) - - def setup_ui(self): - """Create the display widget""" - layout = QtWidgets.QVBoxLayout() - self.setLayout(layout) - - # Graphics widget - self.graphics_widget = pg.GraphicsLayoutWidget() - layout.addWidget(self.graphics_widget) - - # Single plot with image item - self.plot = self.graphics_widget.addPlot() - self.plot.setLabel("left", "Channel") - self.plot.setLabel("bottom", "Time", units="s") - - # Image item for raster display - self.img_item = pg.ImageItem() - self.plot.addItem(self.img_item) - - # Set image dimensions - self.img_item.setImage(self.image_data) # No transpose needed - - # Sweep line - self.sweep_line = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen("r", width=2)) - self.plot.addItem(self.sweep_line) - - # Channel divider lines and baselines - for i in range(1, self.n_channels): - y = i * self.channel_height - line = pg.InfiniteLine( - pos=y, - angle=0, - pen=pg.mkPen("w", width=1, style=QtCore.Qt.PenStyle.DashLine), - ) - self.plot.addItem(line) - - # Add baseline (center) line for each channel - for i in range(self.n_channels): - y = i * self.channel_height + self.channel_height // 2 - baseline = pg.InfiniteLine( - pos=y, - angle=0, - pen=pg.mkPen("gray", width=1, style=QtCore.Qt.PenStyle.DotLine), - ) - self.plot.addItem(baseline) - - # Channel labels - adjusted for inverted Y - self.plot.getAxis("left").setTicks( - [ - [ - ( - self.height - - (i * self.channel_height + self.channel_height / 2), - f"Ch {i + 1}", - ) - for i in range(self.n_channels) - ] - ] - ) - - # Time axis - time_ticks = [(i * self.width / 5, f"{i}s") for i in range(6)] - self.plot.getAxis("bottom").setTicks([time_ticks]) - - # Colormap - different color for each channel - # Each channel gets its own color range: ch0=0.0-0.2, ch1=0.2-0.4, etc. - channel_colors = [ - (255, 100, 100), # Red/pink for channel 0 - (100, 255, 100), # Green for channel 1 - (100, 150, 255), # Blue for channel 2 - (255, 255, 100), # Yellow for channel 3 - (255, 100, 255), # Magenta for channel 4 - ] - - # Create color positions and colors list - colors = [(0, 0, 0)] # Black background - positions = [0.0] - - for i in range(self.n_channels): - base_pos = 0.2 + ( - i * 0.16 - ) # Each channel gets 0.16 range (0.8 total / 5 channels) - # Add darker and brighter versions of each channel color - r, g, b = channel_colors[i % len(channel_colors)] - - # Darker version (low intensity) - colors.append((r // 3, g // 3, b // 3)) - positions.append(base_pos) - - # Medium version - colors.append((r // 2, g // 2, b // 2)) - positions.append(base_pos + 0.05) - - # Bright version (high intensity) - colors.append((r, g, b)) - positions.append(base_pos + 0.15) - - cmap = pg.ColorMap(pos=positions, color=colors) - self.img_item.setLookupTable(cmap.getLookupTable(alpha=True)) - - # Set initial black image with correct orientation - self.img_item.setImage(self.image_data.T) # Transpose for ImageItem - - # Scale the image to fit the plot - self.plot.setXRange(0, self.width) - self.plot.setYRange(0, self.height) - - # Invert Y axis so channel 1 is at top - self.plot.invertY(True) - - # Controls - control_layout = QtWidgets.QHBoxLayout() - - # Gain control - self.gain_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) - self.gain_slider.setRange(1, 100) - self.gain_slider.setValue(50) - self.gain_label = QtWidgets.QLabel("Gain: 50") - self.gain_slider.valueChanged.connect(self.update_gain) - - control_layout.addWidget(QtWidgets.QLabel("Gain:")) - control_layout.addWidget(self.gain_slider) - control_layout.addWidget(self.gain_label) - control_layout.addStretch() - - layout.addLayout(control_layout) - - def update_gain(self, value): - """Update display gain""" - self.gain_label.setText(f"Gain: {value}") - # The gain will be applied when processing new data - - def add_samples(self, channel_data): - """Add new samples from protobuf frame - - Args: - channel_data: numpy array of shape (n_channels, n_samples) - """ - # Process each sample - for sample_idx in range(channel_data.shape[1]): - # Add sample from each channel to buffer - for ch in range(self.n_channels): - self.data_buffer.append(channel_data[ch, sample_idx]) - - # Process buffer when we have enough samples - while len(self.data_buffer) >= self.downsample * self.n_channels: - # Extract samples for one pixel column - pixel_samples = [] - for _ in range(self.downsample * self.n_channels): - pixel_samples.append(self.data_buffer.popleft()) - pixel_samples = np.array(pixel_samples) - - # Reshape to channels x samples - pixel_samples = pixel_samples.reshape(self.downsample, self.n_channels).T - - # Apply gain and center around channel middle - gain = self.gain_slider.value() / 50.0 # More sensitive gain - - # Clear the column first (black background) - self.image_data[:, self.write_pos] = 0 - - # Draw continuous waveform for each channel with different colors - for ch in range(self.n_channels): - channel_center = ch * self.channel_height + self.channel_height // 2 - channel_top = ch * self.channel_height + 5 # Leave some margin - channel_bottom = (ch + 1) * self.channel_height - 5 - - # Calculate color value range for this channel - # Channel 0: 0.2-0.35, Channel 1: 0.36-0.51, etc. - color_base = 0.2 + (ch * 0.16) - color_range = 0.15 # Range within each channel's color space - - # Scale the signal data appropriately - channel_data = pixel_samples[ch] * gain / 1500.0 # Scale for visibility - - # Use more of the channel height for better signal visibility - max_amplitude = ( - self.channel_height - 10 - ) // 2 # Use most of channel height - - # Downsample the data to create a smooth waveform representation - # Take several evenly spaced samples across the pixel column time - n_points = min(20, len(channel_data)) # Use up to 20 points per pixel - if len(channel_data) > n_points: - indices = np.linspace(0, len(channel_data) - 1, n_points, dtype=int) - channel_data = channel_data[indices] - - # Convert samples to pixel positions - y_positions = [] - for sample in channel_data: - y_offset = int(sample * max_amplitude) - y_pos = channel_center + y_offset - # Clamp to channel bounds - y_pos = max(channel_top, min(channel_bottom, y_pos)) - y_positions.append(y_pos) - - # Draw the waveform with channel-specific colors - if len(y_positions) > 1: - # Method 1: Fill between min and max (thick line effect) - min_y = min(y_positions) - max_y = max(y_positions) - - # Draw the main signal band - for y in range(min_y, max_y + 1): - # Intensity based on distance from the mean - mean_y = np.mean(y_positions) - distance_from_mean = abs(y - mean_y) - relative_intensity = max( - 0.2, 1.0 - distance_from_mean / max(1, max_y - min_y) - ) - - # Map to this channel's color range - color_value = color_base + (relative_intensity * color_range) - self.image_data[y, self.write_pos] = color_value - - # Method 2: Also draw connecting lines between consecutive points - for i in range(len(y_positions) - 1): - y1, y2 = y_positions[i], y_positions[i + 1] - # Draw line between consecutive points - start_y, end_y = min(y1, y2), max(y1, y2) - for y in range(start_y, end_y + 1): - # Higher intensity for connecting lines - color_value = color_base + (0.8 * color_range) - self.image_data[y, self.write_pos] = max( - color_value, self.image_data[y, self.write_pos] - ) - - # Highlight the actual sample points - for y_pos in y_positions[::3]: # Every 3rd point - if channel_top <= y_pos <= channel_bottom: - # Maximum intensity for sample points - color_value = color_base + color_range - self.image_data[y_pos, self.write_pos] = color_value - - elif len(y_positions) == 1: - # Single point - draw it with some thickness - y_pos = y_positions[0] - color_value = color_base + (0.6 * color_range) - for dy in range(-1, 2): - if channel_top <= y_pos + dy <= channel_bottom: - self.image_data[y_pos + dy, self.write_pos] = color_value - - # Debug - print first few pixels - if self.pixels_written < 5: - print( - f"Pixel {self.pixels_written}: processing column {self.write_pos}" - ) - - # Clear a few pixels ahead of the sweep line for visibility - clear_pos = (self.write_pos + 2) % self.width - for i in range(3): - self.image_data[:, (clear_pos + i) % self.width] = 0 - - # Move write position - self.write_pos = (self.write_pos + 1) % self.width - self.pixels_written += 1 - - # Progress update - if self.pixels_written % 200 == 0: - print( - f"Written {self.pixels_written} pixels, buffer size: {len(self.data_buffer)}" - ) - - def update_display(self): - """Update only the image and sweep line position""" - # Update image - transpose because ImageItem expects (X, Y) not (Y, X) - self.img_item.setImage(self.image_data.T, autoLevels=True) - - # Update sweep line position - self.sweep_line.setPos(self.write_pos) - - -class NeuralDataReceiver: - """Handles ZMQ reception and protobuf parsing""" - - def __init__(self, display, zmq_address="tcp://localhost:5555"): - self.display = display - self.tap = Tap("10.40.61.119", verbose=True) - self.tap.connect("broadband_source_2") - self.running = False - self.thread = None - - def start(self): - """Start receiving data""" - self.running = True - self.thread = threading.Thread(target=self._receive_loop, daemon=True) - self.thread.start() - - def stop(self): - """Stop receiving data""" - self.running = False - if self.thread: - self.thread.join(timeout=1.0) - - def _receive_loop(self): - """Main reception loop""" - # Setup ZMQ - - while self.running: - try: - # Receive raw message - for message in self.tap.stream(): - frame = BroadbandFrameProto() - frame.ParseFromString(message) - # Get the first five channels - frame_data = np.array(frame.frame_data[:5]).reshape(5, -1) - self.display.add_samples(frame_data) - - # # Parse protobuf - # # frame = BroadbandFrame() - # # frame.ParseFromString(raw_message) - - # # For testing without protobuf - simulate data - # frame_data = np.random.randn(5 * 32).astype(np.float32) * 100 - # frame_data = frame_data + np.sin(np.arange(5 * 32) * 0.1) * 200 - - # # Reshape to channels x samples - # n_samples = len(frame_data) // self.display.n_channels - # channel_data = frame_data.reshape(n_samples, self.display.n_channels).T - - # # Add to display - # self.display.add_samples(channel_data) - - except Exception as e: - print(f"Error receiving data: {e}") - - -def main(): - """Main application""" - import sys - - app = QtWidgets.QApplication.instance() - if not app: - app = QtWidgets.QApplication(sys.argv) - - # Create display - display = NeuralRasterDisplay(n_channels=5, sample_rate=32000, display_seconds=5) - display.show() - - # Create receiver (comment out if you want to test without ZMQ) - receiver = NeuralDataReceiver(display) - receiver.start() - - # For testing - simulate data - # def simulate_data(): - # while True: - # # Simulate 1ms of data (32 samples per channel) - # n_samples = 32 - # channel_data = np.zeros((5, n_samples)) - - # for ch in range(5): - # # Different frequency for each channel - # t = np.arange(n_samples) / 32000 - # channel_data[ch] = ( - # 100 * np.sin(2 * np.pi * (5 + ch * 2) * t) + - # 50 * np.random.randn(n_samples) - # ) - - # # Random spikes - # if np.random.rand() < 0.05: - # channel_data[ch, np.random.randint(n_samples)] = np.random.choice([-500, 500]) - - # display.add_samples(channel_data) - # time.sleep(0.001) # 1ms - - # sim_thread = threading.Thread(target=simulate_data, daemon=True) - # sim_thread.start() - - try: - sys.exit(app.exec()) - except KeyboardInterrupt: - # receiver.stop() - pass - - -if __name__ == "__main__": - main() diff --git a/synapse/client/taps.py b/synapse/client/taps.py index 8143c87..4feb38d 100644 --- a/synapse/client/taps.py +++ b/synapse/client/taps.py @@ -86,6 +86,17 @@ def connect(self, name: str) -> bool: # For producer taps (or unspecified), we need to subscribe and listen FROM the tap self.zmq_socket = self.zmq_context.socket(zmq.SUB) + # Optimize ZMQ for high-throughput data + # Increase receive buffer size significantly for high-speed data + self.zmq_socket.setsockopt( + zmq.RCVHWM, 10000 + ) # High water mark - buffer up to 10K messages + self.zmq_socket.setsockopt(zmq.RCVBUF, 16 * 1024 * 1024) # 16MB receive buffer + + # Set TCP keepalive for connection stability + self.zmq_socket.setsockopt(zmq.TCP_KEEPALIVE, 1) + self.zmq_socket.setsockopt(zmq.TCP_KEEPALIVE_IDLE, 60) + # Replace the endpoint with our device URI if needed endpoint = selected_tap.endpoint if "://" in endpoint: @@ -99,14 +110,13 @@ def connect(self, name: str) -> bool: try: self.zmq_socket.connect(endpoint) - # Give the socket a chance to connect - self.logger.info("Waiting for socket to connect...") - time.sleep(1) + # Reduce connection wait time to minimize startup delay + self.logger.info("Connecting to tap...") + time.sleep(0.1) # Reduced from 1 second # Only set subscription options for subscriber sockets if selected_tap.tap_type != TapType.TAP_TYPE_CONSUMER: self.zmq_socket.setsockopt(zmq.SUBSCRIBE, b"") - print("Subscribed to all messages") return True except zmq.ZMQError as e: @@ -167,11 +177,11 @@ def send(self, data: bytes) -> bool: self.logger.error(f"Error sending message: {e}") return False - def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: - """Stream raw data from the tap. + def stream(self, timeout_ms: int = 100) -> Generator[bytes, None, None]: + """Stream raw data from the tap with optimizations for high-throughput data. Args: - timeout_ms (int, optional): Timeout between messages in milliseconds. Defaults to 1000. + timeout_ms (int, optional): Timeout between messages in milliseconds. Defaults to 100. Yields: Generator[bytes, None, None]: Stream of raw message data. @@ -180,16 +190,18 @@ def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: self.logger.error("Not connected to any tap") return - # Set socket timeout + # Set a shorter timeout for high-frequency data self.zmq_socket.setsockopt(zmq.RCVTIMEO, timeout_ms) try: while True: try: - data = self.zmq_socket.recv() + # Use non-blocking receive with DONTWAIT for maximum throughput + data = self.zmq_socket.recv(zmq.DONTWAIT) yield data except zmq.Again: - # Timeout occurred, continue to next iteration + # No data available right now, yield control briefly + time.sleep(0.0001) # 0.1ms sleep to prevent busy waiting continue except KeyboardInterrupt: self.logger.info("Stream interrupted") @@ -199,6 +211,50 @@ def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: # Don't close the socket here, let the user call disconnect() pass + def stream_batch( + self, batch_size: int = 10, timeout_ms: int = 100 + ) -> Generator[list, None, None]: + """Stream data in batches for improved throughput. + + Args: + batch_size (int, optional): Number of messages to batch together. Defaults to 10. + timeout_ms (int, optional): Timeout for each receive operation. Defaults to 100. + + Yields: + Generator[list, None, None]: Batches of raw message data. + """ + if not self.zmq_socket: + self.logger.error("Not connected to any tap") + return + + self.zmq_socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + + batch = [] + try: + while True: + try: + data = self.zmq_socket.recv(zmq.DONTWAIT) + batch.append(data) + + if len(batch) >= batch_size: + yield batch + batch = [] + except zmq.Again: + # No data available, yield partial batch if any + if batch: + yield batch + batch = [] + time.sleep(0.0001) + continue + except KeyboardInterrupt: + self.logger.info("Stream interrupted") + if batch: + yield batch + except zmq.ZMQError as e: + self.logger.error(f"Error streaming messages: {e}") + if batch: + yield batch + def disconnect(self): """Disconnect from the tap.""" self._cleanup() From 77b7460ceff839ac9988d3373e959fd17d2f8638 Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 13 Jun 2025 15:19:12 -0700 Subject: [PATCH 7/8] Updated to use streaming --- synapse/cli/streaming.py | 95 +++++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 12 deletions(-) diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index 4b270e3..73d8fe8 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -426,13 +426,63 @@ def list_available_taps(args, device, console): console.print("\n[bold cyan]Available Taps:[/bold cyan]") console.print("=" * 50) + supported_count = 0 for tap in taps: - console.print(f"[green]Name:[/green] {tap.name}") + is_supported = tap.message_type == "synapse.BroadbandFrame" + if is_supported: + supported_count += 1 + console.print( + f"[green]Name:[/green] {tap.name} [bold green]✓ SUPPORTED[/bold green]" + ) + else: + console.print(f"[green]Name:[/green] {tap.name}") + console.print(f"[blue]Type:[/blue] {tap.message_type}") console.print(f"[yellow]Endpoint:[/yellow] {tap.endpoint}") + + if not is_supported: + console.print( + "[dim red]Note: Only synapse.BroadbandFrame taps are supported[/dim red]" + ) console.print("-" * 30) - console.print(f"\n[bold]Total: {len(taps)} taps available[/bold]") + console.print( + f"\n[bold]Total: {len(taps)} taps found, {supported_count} supported[/bold]" + ) + + +def detect_stream_parameters(broadband_tap, console): + """Detect sample rate and available channels from the first message""" + console.log("[cyan]Detecting stream parameters from first message...[/cyan]") + + try: + # Get the first message to detect parameters + first_message = broadband_tap.read(timeout_ms=5000) # 5 second timeout + if not first_message: + console.print( + "[bold red]Failed to receive first message for parameter detection[/bold red]" + ) + return None, None, None + + # Parse the first frame + first_frame = BroadbandFrame() + first_frame.ParseFromString(first_message) + + # Extract parameters + sample_rate = first_frame.sample_rate_hz + num_channels = len(first_frame.frame_data) + available_channels = list(range(num_channels)) + + console.log(f"[green]Detected sample rate: {sample_rate} Hz[/green]") + console.log( + f"[green]Detected {num_channels} channels (0-{num_channels - 1})[/green]" + ) + + return sample_rate, available_channels, first_frame + + except Exception as e: + console.print(f"[bold red]Error detecting stream parameters: {e}[/bold red]") + return None, None, None def get_broadband_tap(args, device, console): @@ -447,6 +497,12 @@ def get_broadband_tap(args, device, console): console.log( f"[green]Found specified tap: {args.tap_name} (type: {t.message_type})[/green]" ) + # Check if it's the correct type + if t.message_type != "synapse.BroadbandFrame": + console.print( + f"[bold red]Error: Specified tap '{args.tap_name}' has type '{t.message_type}', but only 'synapse.BroadbandFrame' is supported[/bold red]" + ) + return None read_tap.connect(t.name) return read_tap @@ -454,15 +510,15 @@ def get_broadband_tap(args, device, console): f"[yellow]Warning: Specified tap '{args.tap_name}' not found, falling back to auto-selection[/yellow]" ) - # Auto-select: get the first tap that has BroadbandFrame as the type - console.log("[cyan]Auto-selecting first BroadbandFrame tap[/cyan]") + # Auto-select: get the first tap that has exact synapse.BroadbandFrame type + console.log("[cyan]Auto-selecting first synapse.BroadbandFrame tap[/cyan]") for t in taps: - if "BroadbandFrame" in t.message_type: - console.log(f"[green]Found BroadbandFrame tap: {t.name}[/green]") + if t.message_type == "synapse.BroadbandFrame": + console.log(f"[green]Found synapse.BroadbandFrame tap: {t.name}[/green]") read_tap.connect(t.name) return read_tap - console.print("[bold red]No BroadbandFrame tap found[/bold red]") + console.print("[bold red]No synapse.BroadbandFrame tap found[/bold red]") return None @@ -508,11 +564,19 @@ def read(args): console.print("[bold red]Failed to get broadband tap[/bold red]") return + # Detect stream parameters from the first message + sample_rate, available_channels, first_frame = detect_stream_parameters( + broadband_tap, console + ) + if sample_rate is None: + console.print("[bold red]Failed to detect stream parameters[/bold red]") + return + # Setup our HDF5 writer if output is requested writer = None if args.output: writer = BroadbandFrameWriter(args.output) - writer.set_attributes(sample_rate_hz=32000, channels=list(range(256))) + writer.set_attributes(sample_rate_hz=sample_rate, channels=available_channels) writer.start() # Setup plotter if requested @@ -521,16 +585,14 @@ def read(args): try: from synapse.cli.synapse_plotter import create_broadband_plotter - # Always make all 256 channels available, but start with only 5 selected - available_channels = list(range(256)) plotter = create_broadband_plotter( - sample_rate_hz=32000, + sample_rate_hz=sample_rate, window_size_seconds=5, channel_ids=available_channels, ) plotter.start() console.log( - f"[green]Started real-time plotter with all {len(available_channels)} channels available[/green]" + f"[green]Started real-time plotter with {len(available_channels)} channels available[/green]" ) except ImportError as e: console.print( @@ -545,6 +607,15 @@ def read(args): try: # Use batch streaming for better throughput with Live(monitor.get_current_stats(), refresh_per_second=4) as live: + # Process the first frame that we already read for parameter detection + if first_frame: + if writer: + writer.put(first_frame) + if plotter: + plotter.put(first_frame) + monitor.put(first_frame) + + # Continue with batch streaming for remaining frames for message_batch in broadband_tap.stream_batch(batch_size=10): frames = [] for message in message_batch: From 328227fe10a14779ff90b016adfa128a66d8395d Mon Sep 17 00:00:00 2001 From: Gilbert Montague Date: Fri, 13 Jun 2025 15:23:41 -0700 Subject: [PATCH 8/8] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 39453e4..1df76a2 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup( name="science-synapse", - version="2.2.6", + version="2.2.7", description="Client library and CLI for the Synapse API", author="Science Team", author_email="team@science.xyz",