diff --git a/setup.py b/setup.py index 39453e4..1df76a2 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup( name="science-synapse", - version="2.2.6", + version="2.2.7", description="Client library and CLI for the Synapse API", author="Science Team", author_email="team@science.xyz", diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index ae5c225..73d8fe8 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -1,387 +1,649 @@ -import json +import os import queue import threading import time -import traceback -import os -import logging -from typing import Optional -from operator import itemgetter -import copy - -from google.protobuf.json_format import Parse, MessageToJson +import h5py +from datetime import datetime +from rich.live import Live +from rich.table import Table +from rich.console import Console +from rich.text import Text -from synapse.api.node_pb2 import NodeType -from synapse.api.status_pb2 import DeviceState, StatusCode -from synapse.api.device_pb2 import DeviceConfiguration import synapse as syn -import synapse.client.channel as channel -import synapse.utils.ndtp_types as ndtp_types -import synapse.cli.synapse_plotter as plotter -from synapse.utils.packet_monitor import PacketMonitor +from synapse.api.status_pb2 import DeviceState, StatusCode +from synapse.client.taps import Tap +from synapse.utils.proto import load_device_config +from synapse.api.datatype_pb2 import BroadbandFrame + + +class StreamMonitor: + def __init__(self, console: Console): + self.console = console + self.start_time = time.time() + self.message_count = 0 + self.last_update = time.time() + self.last_count = 0 + self.last_sequence = 0 + self.total_dropped = 0 + self.queue = queue.Queue(maxsize=100) + self.stop_event = threading.Event() + self.monitor_thread = None + + def start(self): + """Start monitoring in separate thread""" + self.start_time = time.time() + self.last_update = self.start_time + self.message_count = 0 + self.last_count = 0 + self.last_sequence = 0 + self.total_dropped = 0 + self.stop_event.clear() + self.monitor_thread = threading.Thread(target=self._monitor_loop) + self.monitor_thread.start() + + def stop(self): + """Stop monitoring thread""" + self.stop_event.set() + if self.monitor_thread: + self.monitor_thread.join() + + def put(self, frame: BroadbandFrame): + """Add frame to monitoring queue (non-blocking)""" + try: + self.queue.put(frame, block=False) + except queue.Full: + # Drop frame if queue is full to prevent blocking + pass + + def _monitor_loop(self): + """Process frames for monitoring in separate thread""" + while not self.stop_event.is_set(): + try: + frame = self.queue.get(timeout=0.1) + self._update_stats(frame) + except queue.Empty: + continue -from rich.console import Console -from rich.live import Live -from rich.pretty import pprint + def _update_stats(self, frame: BroadbandFrame): + """Update statistics from frame""" + self.message_count += 1 + + # Check for dropped packets + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.total_dropped += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + + def get_current_stats(self) -> Text: + """Get current statistics as formatted text""" + current_time = time.time() + + # Calculate message rate + elapsed = current_time - self.last_update + if elapsed >= 1.0: # Update rate every second + rate = (self.message_count - self.last_count) / elapsed + self.last_count = self.message_count + self.last_update = current_time + else: + rate = ( + (self.message_count - self.last_count) + / (current_time - self.last_update) + if elapsed > 0 + else 0 + ) + + # Calculate packet loss percentage + total_expected = self.message_count + self.total_dropped + loss_percent = ( + (self.total_dropped / total_expected * 100) if total_expected > 0 else 0 + ) + + # Create styled text + stats_text = Text() + stats_text.append("Messages: ", style="bold") + stats_text.append(f"{self.message_count:,}", style="cyan") + stats_text.append(" | msgs/sec: ", style="bold") + stats_text.append(f"{rate:.1f}/s", style="green") + stats_text.append(" | Dropped: ", style="bold") + stats_text.append(f"{self.total_dropped:,}", style="red") + stats_text.append(" | Loss: ", style="bold") + stats_text.append(f"{loss_percent:.2f}%", style="yellow") + stats_text.append(" | Runtime: ", style="bold") + stats_text.append(f"{current_time - self.start_time:.1f}s", style="blue") + + return stats_text + + +class BroadbandFrameWriter: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.data_queue = queue.Queue(maxsize=2000) # Increased queue size + self.stop_event = threading.Event() + self.writer_thread = None + + # Stats tracking + self.start_time = time.time() + self.frames_received = 0 + self.samples_received = 0 + self.last_sequence = 0 + self.dropped_frames = 0 + + # Create HDF5 file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.filename = os.path.join(output_dir, f"broadband_data_{timestamp}.h5") + self.file = h5py.File(self.filename, "w") + + # Create datasets + self.timestamp_dataset = self.file.create_dataset( + "/acquisition/timestamp", shape=(0,), maxshape=(None,), dtype="uint64" + ) + self.sequence_dataset = self.file.create_dataset( + "/acquisition/sequence_number", shape=(0,), maxshape=(None,), dtype="uint64" + ) + # Create frame data dataset as a flat array of samples + self.frame_data_dataset = self.file.create_dataset( + "/acquisition/ElectricalSeries", shape=(0,), maxshape=(None,), dtype="int32" + ) + + # Buffer for collecting frames before writing + self.frame_buffer = [] + self.buffer_size = 500 # Reduced buffer size for more frequent writes + + def get_stats(self): + """Get current statistics""" + elapsed = time.time() - self.start_time + if elapsed == 0: + return { + "frames_per_sec": 0, + "samples_per_sec": 0, + "total_frames": 0, + "total_samples": 0, + "dropped_frames": 0, + "last_sequence": 0, + } + + return { + "frames_per_sec": self.frames_received / elapsed, + "samples_per_sec": self.samples_received / elapsed, + "total_frames": self.frames_received, + "total_samples": self.samples_received, + "dropped_frames": self.dropped_frames, + "last_sequence": self.last_sequence, + } + + def set_attributes( + self, sample_rate_hz: float, channels: list, session_description: str = "" + ): + """Set HDF5 attributes similar to C++ implementation""" + # Set basic attributes + self.file.attrs["sample_rate_hz"] = sample_rate_hz + if session_description: + self.file.attrs["session_description"] = session_description + + # Set session start time + self.file.attrs["session_start_time"] = datetime.now().isoformat() + + # Set device type + device_group = self.file.create_group("general/device") + device_group.attrs["device_type"] = "SciFi" + + # Create electrodes group and write channel IDs + electrodes_group = self.file.create_group( + "general/extracellular_ephys/electrodes" + ) + channel_ids = channels + electrodes_group.create_dataset("id", data=channel_ids, dtype="uint32") + + def start(self): + """Start the writer thread""" + self.writer_thread = threading.Thread(target=self._write_loop) + self.writer_thread.start() + + def stop(self): + """Stop the writer thread and wait for it to finish""" + self.stop_event.set() + if self.writer_thread: + self.writer_thread.join() + self.flush() + self.file.close() + + def put(self, frame: BroadbandFrame): + """Add frame to the write queue (non-blocking)""" + # Update stats + self.frames_received += 1 + self.samples_received += len(frame.frame_data) + + # Check for dropped frames + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.dropped_frames += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, we'll drop the oldest data + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + + def put_batch(self, frames: list): + """Add multiple frames to the write queue efficiently""" + for frame in frames: + self.frames_received += 1 + self.samples_received += len(frame.frame_data) + + # Check for dropped frames + if self.last_sequence != 0: + expected_sequence = self.last_sequence + 1 + if frame.sequence_number != expected_sequence: + self.dropped_frames += frame.sequence_number - expected_sequence + self.last_sequence = frame.sequence_number + + # Try to add all frames to queue + for frame in frames: + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, drop oldest and try again + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + + def _write_loop(self): + """Main writing loop that consumes data from the queue""" + while not self.stop_event.is_set() or not self.data_queue.empty(): + try: + frame = self.data_queue.get(timeout=0.1) + self.frame_buffer.append(frame) + + # Write when buffer is full + if len(self.frame_buffer) >= self.buffer_size: + self._write_buffer() + + except queue.Empty: + continue + except Exception as e: + print(f"Error writing data: {e}") + continue + + def _write_buffer(self): + """Write the buffered frames to disk""" + if not self.frame_buffer: + return + + # Get current sizes + current_timestamp_size = self.timestamp_dataset.shape[0] + current_frame_size = self.frame_data_dataset.shape[0] + num_frames = len(self.frame_buffer) + + # Resize datasets + new_timestamp_size = current_timestamp_size + num_frames + new_frame_size = current_frame_size + ( + num_frames * len(self.frame_buffer[0].frame_data) + ) + + self.timestamp_dataset.resize(new_timestamp_size, axis=0) + self.sequence_dataset.resize(new_timestamp_size, axis=0) + self.frame_data_dataset.resize(new_frame_size, axis=0) + + # Write data + for i, frame in enumerate(self.frame_buffer): + idx = current_timestamp_size + i + self.timestamp_dataset[idx] = frame.timestamp_ns + self.sequence_dataset[idx] = frame.sequence_number + + # Write frame data + frame_start = current_frame_size + (i * len(frame.frame_data)) + frame_end = frame_start + len(frame.frame_data) + self.frame_data_dataset[frame_start:frame_end] = frame.frame_data + + # Clear buffer + self.frame_buffer = [] + + # Flush to disk + self.flush() + + def flush(self): + """Flush all datasets to disk""" + if self.frame_buffer: + self._write_buffer() + self.timestamp_dataset.flush() + self.sequence_dataset.flush() + self.frame_data_dataset.flush() + self.file.flush() + + +def create_status_table(writer: BroadbandFrameWriter) -> Table: + """Create a status table for display""" + stats = writer.get_stats() + table = Table(title="Streaming Status") + + table.add_column("Metric", style="cyan") + table.add_column("Value", style="green") + + table.add_row("Frames/sec", f"{stats['frames_per_sec']:.1f}") + table.add_row("Samples/sec", f"{stats['samples_per_sec']:.1f}") + table.add_row("Total Frames", str(stats["total_frames"])) + table.add_row("Total Samples", str(stats["total_samples"])) + table.add_row("Dropped Frames", str(stats["dropped_frames"])) + table.add_row("Last Sequence", str(stats["last_sequence"])) + + return table def add_commands(subparsers): - a = subparsers.add_parser("read", help="Read from a device's StreamOut node") - a.add_argument( - "--config", - type=str, - help="Configuration file", + read_parser = subparsers.add_parser( + "read", help="Read from a device's Broadband Tap" + ) + + read_parser.add_argument( + "config", type=str, help="Device configuration or manifest file" + ) + + # Output options, we will save as HDF5 + read_parser.add_argument("--output", type=str, help="Output directory") + read_parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing files" + ) + read_parser.add_argument( + "--plot", action="store_true", help="Show real-time plot of Broadband Data" ) - a.add_argument( - "--num_ch", type=int, help="Number of channels to read from, overrides config" + read_parser.add_argument( + "--tap-name", + type=str, + help="Specific tap name to connect to (if not specified, will auto-select first BroadbandFrame tap)", ) - a.add_argument("--bin", type=bool, help="Output binary format instead of JSON") - a.add_argument("--duration", type=int, help="Duration to read for in seconds") - a.add_argument("--node_id", type=int, help="ID of the StreamOut node to read from") - a.add_argument("--plot", action="store_true", help="Plot the data in real-time") - a.add_argument("--output", type=str, help="Name of the output directory and files") - a.add_argument( - "--overwrite", - action="store_true", - default=False, - help="Overwrite existing files", + read_parser.add_argument( + "--list-taps", action="store_true", help="List all available taps and exit" ) - a.set_defaults(func=read) + + read_parser.set_defaults(func=read) -def load_config_from_file(path): - with open(path, "r") as f: - data = f.read() - proto = Parse(data, DeviceConfiguration()) - return syn.Config.from_proto(proto) +def configure_device(device, config, console): + with console.status("Configuring device...", spinner="bouncingBall"): + # check if we are running + info = device.info() + if info.status.state == DeviceState.kRunning: + console.log( + "[bold yellow]Device is already running, reading from existing tap[/bold yellow]" + ) + return True + + # Apply the configuration to the device + configure_status = device.configure_with_status(config) + if configure_status.code != StatusCode.kOk: + console.print( + f"[bold red]Failed to configure device: {configure_status.message}[/bold red]" + ) + return False + console.log("[green]Configured device[/green]") + return True -def read(args): - console = Console() - if not args.config and not args.node_id: - console.print("[bold red]Either `--config` or `--node_id` must be specified.") - return - output_base = args.output - timestamp = time.strftime("%Y%m%d-%H%M%S") - if not output_base: - output_base = f"synapse_data_{timestamp}" - else: - output_base = f"{output_base}_{timestamp}" +def start_device(device, console): + info = device.info() + if info.status.state == DeviceState.kRunning: + return True - # Check if the output directory exists, we will make the directory later after we know the config - if os.path.exists(output_base): - if not args.overwrite: + with console.status("Starting device...", spinner="bouncingBall"): + start_status = device.start_with_status() + if start_status.code != StatusCode.kOk: console.print( - f"[bold red]Output directory {output_base} already exists, please specify a different output directory or use `--overwrite` to overwrite existing files" + f"[bold red]Failed to start device: {start_status.message}[/bold red]" ) - return - else: - console.print(f"[bold yellow]Overwriting existing files in {output_base}") - device = syn.Device(args.uri, args.verbose) - with console.status( - "Reading from Synapse Device", spinner="bouncingBall", spinner_style="green" - ) as status: - status.update("Requesting device info") - info = device.info() - if not info: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return + return False + return True - console.log(f"Got info from: {info.name}") - if args.verbose: - pprint(info) - console.print("\n") - status.update("Loading recording configuration") +def setup_output(args, console): + if not args.output: + console.print("[bold red]No output directory specified[/bold red]") + return False - if args.config: - config = load_config_from_file(args.config) - if not config: - console.print(f"[bold red]Failed to load config from {args.config}") - return - stream_out = next( - (n for n in config.nodes if n.type == NodeType.kStreamOut), None - ) - if not stream_out: - console.print("[bold red]No StreamOut node found in config") - return - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - if not broadband: - console.print("[bold red]No BroadbandSource node found in config") - return - signal = broadband.signal - if not signal: - console.print("[bold red]No signal configured for BroadbandSource node") - return + # Create the output directory if it doesn't exist + os.makedirs(args.output, exist_ok=True) + return True + + +def list_available_taps(args, device, console): + """List all available taps on the device""" + read_tap = Tap(args.uri, args.verbose) + taps = read_tap.list_taps() + + if not taps: + console.print("[bold red]No taps found on device[/bold red]") + return + + console.print("\n[bold cyan]Available Taps:[/bold cyan]") + console.print("=" * 50) - if not signal.electrode: + supported_count = 0 + for tap in taps: + is_supported = tap.message_type == "synapse.BroadbandFrame" + if is_supported: + supported_count += 1 console.print( - "[bold red]No electrode signal configured for BroadbandSource node" + f"[green]Name:[/green] {tap.name} [bold green]✓ SUPPORTED[/bold green]" ) - return + else: + console.print(f"[green]Name:[/green] {tap.name}") - num_ch = len(signal.electrode.channels) - if args.num_ch: - num_ch = args.num_ch - offset = 0 - channels = [] - for ch in range(offset, offset + num_ch): - channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) - - broadband.signal.electrode.channels = channels - - with console.status( - "Configuring device", spinner="bouncingBall", spinner_style="green" - ) as status: - configure_status = device.configure_with_status(config) - if configure_status is None: - console.print( - "[bold red]Failed to configure device. Run with `--verbose` for more information." - ) - return - if configure_status.code == StatusCode.kInvalidConfiguration: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print("[yellow]Is there a peripheral connected to the device?") - return - elif configure_status.code == StatusCode.kFailedPrecondition: - console.print("[bold red]Failed to configure device.") - console.print(f"[italic red]Why: {configure_status.message}") - console.print( - f"[yellow]If the device is already running, run `synapsectl stop {args.uri}` to stop the device and try again." - ) - return - console.print("[bold green]Device configured successfully") - - if info.status.state != DeviceState.kRunning: - print("Starting device...") - if not device.start(): - raise ValueError("Failed to start device") - - # Get the sample rate from the device - # We need to look at the node configuration with type kBroadbandSource for the sample rate - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - assert broadband is not None, "No BroadbandSource node found in config" - - else: - # TODO(gilbert): Get rid of this giant if-else block - node = next( - ( - n - for n in info.configuration.nodes - if n.type == NodeType.kStreamOut and n.id == args.node_id - ), - None, - ) - if node is None: + console.print(f"[blue]Type:[/blue] {tap.message_type}") + console.print(f"[yellow]Endpoint:[/yellow] {tap.endpoint}") + + if not is_supported: console.print( - "[bold red]No StreamOut node found in device configuration; please configure the device with a StreamOut node." + "[dim red]Note: Only synapse.BroadbandFrame taps are supported[/dim red]" ) - return + console.print("-" * 30) - stream_out = syn.StreamOut.from_proto(node) - stream_out.device = device + console.print( + f"\n[bold]Total: {len(taps)} taps found, {supported_count} supported[/bold]" + ) - # We are ready to start streaming, make the output directory - os.makedirs(output_base, exist_ok=True) - # Copy our config that was taken from the device to the output directory - device_info_after_config = device.info() - if not device_info_after_config: - console.print(f"[bold red]Failed to get device info from {args.uri}") - return - runtime_config = device_info_after_config.configuration - runtime_config_json = MessageToJson( - runtime_config, - always_print_fields_with_no_presence=True, - preserving_proto_field_name=True, - ) - output_config_path = os.path.join(output_base, "runtime_config.json") - with open(output_config_path, "w") as f: - f.write(runtime_config_json) +def detect_stream_parameters(broadband_tap, console): + """Detect sample rate and available channels from the first message""" + console.log("[cyan]Detecting stream parameters from first message...[/cyan]") - console.print(f"[bold green]Streaming data to {output_base}") + try: + # Get the first message to detect parameters + first_message = broadband_tap.read(timeout_ms=5000) # 5 second timeout + if not first_message: + console.print( + "[bold red]Failed to receive first message for parameter detection[/bold red]" + ) + return None, None, None - status_title = ( - f"Streaming data for {args.duration} seconds" - if args.duration - else "Streaming data indefinitely" - ) - console.print(status_title) + # Parse the first frame + first_frame = BroadbandFrame() + first_frame.ParseFromString(first_message) - q = queue.Queue() - plot_q = queue.Queue() if args.plot else None + # Extract parameters + sample_rate = first_frame.sample_rate_hz + num_channels = len(first_frame.frame_data) + available_channels = list(range(num_channels)) - threads = [] - stop = threading.Event() - if args.bin: - threads.append( - threading.Thread(target=_binary_writer, args=(stop, q, num_ch, output_base)) - ) - else: - threads.append( - threading.Thread(target=_data_writer, args=(stop, q, output_base)) + console.log(f"[green]Detected sample rate: {sample_rate} Hz[/green]") + console.log( + f"[green]Detected {num_channels} channels (0-{num_channels - 1})[/green]" ) - if args.plot: - threads.append( - threading.Thread(target=_plot_data, args=(stop, plot_q, runtime_config)) + return sample_rate, available_channels, first_frame + + except Exception as e: + console.print(f"[bold red]Error detecting stream parameters: {e}[/bold red]") + return None, None, None + + +def get_broadband_tap(args, device, console): + read_tap = Tap(args.uri, args.verbose) + taps = read_tap.list_taps() + + # If user specified a tap name, try to use it first + if hasattr(args, "tap_name") and args.tap_name: + console.log(f"[cyan]Looking for specified tap: {args.tap_name}[/cyan]") + for t in taps: + if t.name == args.tap_name: + console.log( + f"[green]Found specified tap: {args.tap_name} (type: {t.message_type})[/green]" + ) + # Check if it's the correct type + if t.message_type != "synapse.BroadbandFrame": + console.print( + f"[bold red]Error: Specified tap '{args.tap_name}' has type '{t.message_type}', but only 'synapse.BroadbandFrame' is supported[/bold red]" + ) + return None + read_tap.connect(t.name) + return read_tap + + console.print( + f"[yellow]Warning: Specified tap '{args.tap_name}' not found, falling back to auto-selection[/yellow]" ) - for thread in threads: - thread.start() + # Auto-select: get the first tap that has exact synapse.BroadbandFrame type + console.log("[cyan]Auto-selecting first synapse.BroadbandFrame tap[/cyan]") + for t in taps: + if t.message_type == "synapse.BroadbandFrame": + console.log(f"[green]Found synapse.BroadbandFrame tap: {t.name}[/green]") + read_tap.connect(t.name) + return read_tap - try: - read_packets(stream_out, q, plot_q, stop, args.duration) - except KeyboardInterrupt: - pass - finally: - console.print("Stopping read...") - stop.set() - for thread in threads: - thread.join() - - if args.config: - console.print("Stopping device...") - if not device.stop(): - console.print("[red]Failed to stop device") - console.print("Stopped") - - console.print("[bold green]Streaming complete") - console.print("[cyan]================") - console.print(f"[cyan]Output directory: {output_base}/") - console.print(f"[cyan]Run `synapsectl plot --dir {output_base}/` to plot the data") - console.print("[cyan]================") - - -def read_packets( - node: syn.StreamOut, - q: queue.Queue, - plot_q: queue.Queue, - stop: threading.Event, - duration: Optional[int] = None, - num_ch: int = 32, -): - start = time.time() - - # Keep track of our statistics - monitor = PacketMonitor() - monitor.start_monitoring() - - with Live(monitor.generate_stat_table(), refresh_per_second=4) as live: - while not stop.is_set(): - read_ret = node.read() - if read_ret is None: - logging.error("Could not get a valid read from the node") - continue + console.print("[bold red]No synapse.BroadbandFrame tap found[/bold red]") + return None - synapse_data, bytes_read = read_ret - if synapse_data is None or bytes_read == 0: - logging.error("Could not read data from node") - continue - header, data = synapse_data - monitor.process_packet(header, data, bytes_read) - live.update(monitor.generate_stat_table()) - # Always add the data to the writer queues - q.put(data) - if plot_q: - plot_q.put(copy.deepcopy(data)) +def read(args): + console = Console() - if duration and (time.time() - start) > duration: - break + # Make sure we can actually get to this device + try: + config = load_device_config(args.config, console) + except Exception as e: + console.print(f"[bold red]Failed to load device configuration: {e}[/bold red]") + return + # Create the device object + device = syn.Device(args.uri, args.verbose) + device_name = device.get_name() + console.log(f"[green]Connected to {device_name}[/green]") -def _binary_writer(stop, q: queue.Queue, num_ch, output_base): - filename = f"{output_base}.dat" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") + # If user just wants to list taps, do that and exit + if hasattr(args, "list_taps") and args.list_taps: + list_available_taps(args, device, console) + return - channel_data = [] - while not stop.is_set() or not q.empty(): - try: - data: ndtp_types.ElectricalBroadbandData = q.get(True, 1) - except queue.Empty: - continue + # Apply the configuration to the device + if not configure_device(device, config, console): + console.print("[bold red]Failed to configure device[/bold red]") + return - try: - for ch_id, samples in data.samples: - channel_data.append([ch_id, samples]) - if len(channel_data) == num_ch: - channel_data.sort(key=itemgetter(0)) - channel_samples = [ch_data[1] for ch_data in channel_data] - frames = list(zip(*channel_samples)) - channel_data = [] - - for frame in frames: - for sample in frame: - fd.write( - int(sample).to_bytes(2, byteorder="little", signed=True) - ) - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _data_writer(stop, q, output_base): - filename = f"{output_base}.jsonl" - full_path = os.path.join(output_base, filename) - if filename: - fd = open(full_path, "wb") - - while not stop.is_set() or not q.empty(): - try: - data = q.get(True, 1) - except queue.Empty: - continue + # If we got this far and they want to save things, we need to make sure they have a place to save + if args.output: + if not setup_output(args, console): + console.print("[bold red]Failed to setup output[/bold red]") + return - try: - fd.write(json.dumps(data.to_list()).encode("utf-8")) - fd.write(b"\n") - - except Exception as e: - print(f"Error processing data: {e}") - traceback.print_exc() - continue - - -def _plot_data(stop, q, runtime_config): - """Plot streaming data from the synapse device.""" - # TODO(gilbert): Make these configurable - window_size_seconds = 3 - - # Find the first broadband source node - broadband_nodes = [ - node for node in runtime_config.nodes if node.type == NodeType.kBroadbandSource - ] - - # Make sure we have a broadband node - # TODO(gilbert): We should be able to support binned spikes here too - # but will need a refactor - if not broadband_nodes: - print("Could not find broadband source config. Cannot plot") + # Start the device + if not start_device(device, console): + console.print("[bold red]Failed to start device[/bold red]") return - broadband_source = broadband_nodes[0].broadband_source - electrode_config = broadband_source.signal.electrode + # With the device running, get the tap for us to connect to + broadband_tap = get_broadband_tap(args, device, console) + if not broadband_tap: + console.print("[bold red]Failed to get broadband tap[/bold red]") + return - if not electrode_config: - print( - "Could not find an electrode configuration for broadband node. Cannot plot" - ) + # Detect stream parameters from the first message + sample_rate, available_channels, first_frame = detect_stream_parameters( + broadband_tap, console + ) + if sample_rate is None: + console.print("[bold red]Failed to detect stream parameters[/bold red]") return - # Get configuration parameters - sample_rate_hz = broadband_source.sample_rate_hz - channel_ids = [ch.id for ch in electrode_config.channels] + # Setup our HDF5 writer if output is requested + writer = None + if args.output: + writer = BroadbandFrameWriter(args.output) + writer.set_attributes(sample_rate_hz=sample_rate, channels=available_channels) + writer.start() + + # Setup plotter if requested + plotter = None + if args.plot: + try: + from synapse.cli.synapse_plotter import create_broadband_plotter + + plotter = create_broadband_plotter( + sample_rate_hz=sample_rate, + window_size_seconds=5, + channel_ids=available_channels, + ) + plotter.start() + console.log( + f"[green]Started real-time plotter with {len(available_channels)} channels available[/green]" + ) + except ImportError as e: + console.print( + f"[bold red]Failed to import plotter (missing dearpygui?): {e}[/bold red]" + ) + return + + # Setup stream monitor + monitor = StreamMonitor(console) + monitor.start() - # Start the plotter - plotter.plot_synapse_data(stop, q, sample_rate_hz, window_size_seconds, channel_ids) + try: + # Use batch streaming for better throughput + with Live(monitor.get_current_stats(), refresh_per_second=4) as live: + # Process the first frame that we already read for parameter detection + if first_frame: + if writer: + writer.put(first_frame) + if plotter: + plotter.put(first_frame) + monitor.put(first_frame) + + # Continue with batch streaming for remaining frames + for message_batch in broadband_tap.stream_batch(batch_size=10): + frames = [] + for message in message_batch: + frame = BroadbandFrame() + frame.ParseFromString(message) + frames.append(frame) + + # Send to monitor (non-blocking) + monitor.put(frame) + if plotter: + plotter.put(frame) + + # Batch write for better performance + if writer and frames: + writer.put_batch(frames) + + live.update(monitor.get_current_stats()) + + except KeyboardInterrupt: + console.print("\n[yellow]Stopping data collection...[/yellow]") + finally: + if writer: + writer.stop() + if plotter: + plotter.stop() + if monitor: + monitor.stop() + if args.output: + console.print(f"[green]Data saved to {args.output}[/green]") + if args.plot: + console.print("[green]Plotter stopped[/green]") diff --git a/synapse/cli/synapse_plotter.py b/synapse/cli/synapse_plotter.py index 647fa9d..1c98dbe 100644 --- a/synapse/cli/synapse_plotter.py +++ b/synapse/cli/synapse_plotter.py @@ -1,8 +1,9 @@ import dearpygui.dearpygui as dpg import queue import time -from threading import Event +from threading import Event, Thread import numpy as np +from synapse.api.datatype_pb2 import BroadbandFrame class SynapsePlotter: @@ -18,33 +19,46 @@ def __init__(self, sample_rate: int, window_size: int, channel_ids): ch_id: idx for idx, ch_id in enumerate(self.channel_ids) } + # Track which channels are selected for plotting (start with first 5) + self.selected_channels = set(self.channel_ids[:5]) + # One ring buffer (of length BUFFER_SIZE) per channel self.data_buffers = [ np.zeros(self.buffer_size) for _ in range(self.num_channels) ] - # A time axis (0..WINDOW_SIZE) also size BUFFER_SIZE - self.time_buffer = np.linspace( - 0, self.window_size_seconds, self.buffer_size, endpoint=True - ) + # Timestamp buffer for each channel (in seconds, relative to start) + self.timestamp_buffers = [ + np.zeros(self.buffer_size) for _ in range(self.num_channels) + ] # A separate ring-buffer pointer for each channel - # TODO(gilbert): this is assuming that channels are truly independent self.buffer_positions = [0] * self.num_channels # Track which channel to display in the "zoom" (single channel) plot self.selected_channel_idx = 0 self.selected_channel_id = self.channel_ids[0] - # Track start time for display + # Track start time for display and timestamp conversion self.start_time = None + self.start_timestamp_ns = None + self.latest_data_time = 0 # Track the most recent data timestamp in seconds # Defaults for the zoomed channel plot - self.zoom_y_min = 0 + self.zoom_y_min = -4096 self.zoom_y_max = 4096 self.signal_separation = 1000 + # Dictionary to store line series for plotted channels + self.active_lines = {} + + # Queue and threading for BroadbandFrame processing + self.data_queue = queue.Queue(maxsize=2000) + self.stop_event = Event() + self.plot_thread = None + self.running = False + dpg.create_context() self.setup_gui() @@ -62,14 +76,38 @@ def setup_gui(self): pos=(0, 0), tag="control_window", ): + dpg.add_text("Select Channels to Plot:") + + # Add convenience buttons + with dpg.group(horizontal=True): + dpg.add_button(label="All", callback=self.select_all_channels, width=40) + dpg.add_button(label="None", callback=self.select_no_channels, width=40) + dpg.add_button( + label="First 5", callback=self.select_first_5_channels, width=50 + ) + + # Add checkboxes for each channel in a scrollable region + with dpg.child_window(height=150, tag="channel_selection"): + for ch_id in self.channel_ids: + is_selected = ch_id in self.selected_channels + dpg.add_checkbox( + label=f"Channel {ch_id}", + default_value=is_selected, + callback=self.channel_checkbox_callback, + user_data=ch_id, + tag=f"ch_checkbox_{ch_id}", + ) + + dpg.add_separator() dpg.add_text("Select Channel to Zoom:") dpg.add_combo( items=[str(ch_id) for ch_id in self.channel_ids], default_value=str(self.selected_channel_id), - callback=self.channel_selection_callback, - tag="channel_combo", + callback=self.zoom_channel_callback, + tag="zoom_channel_combo", width=80, ) + dpg.add_separator() dpg.add_text("Elapsed Time (s):") dpg.add_text("", tag="elapsed_time_text") @@ -120,24 +158,17 @@ def setup_gui(self): dpg.add_plot_legend() # Axes - dpg.add_plot_axis(dpg.mvXAxis, label="Time (s)") + dpg.add_plot_axis( + dpg.mvXAxis, label="Time (s)", tag="x_axis_all" + ) self.y_axis_all = dpg.add_plot_axis( dpg.mvYAxis, label="Amplitude", tag="y_axis_all" ) - dpg.set_axis_limits("y_axis_all", 0, 4096 * 10) - - # Create line series for each channel - self.lines_all = [] - for idx, ch_id in enumerate(self.channel_ids): - line_tag = f"all_line_ch{ch_id}" - line = dpg.add_line_series( - [], - [], - label=f"Ch {ch_id}", - parent=self.y_axis_all, - tag=line_tag, - ) - self.lines_all.append(line) + dpg.set_axis_limits("y_axis_all", -4096, 4096 * 10) + + # Create line series for initially selected channels + for ch_id in self.selected_channels: + self.create_line_series(ch_id) # Zoomed Channel Tab with dpg.tab(label="Zoomed Channel"): @@ -150,7 +181,9 @@ def setup_gui(self): dpg.add_plot_legend() # Axes - dpg.add_plot_axis(dpg.mvXAxis, label="Time (s)") + dpg.add_plot_axis( + dpg.mvXAxis, label="Time (s)", tag="x_axis_zoom" + ) self.y_axis_zoom = dpg.add_plot_axis( dpg.mvYAxis, label="Amplitude", tag="y_axis_zoom" ) @@ -164,13 +197,43 @@ def setup_gui(self): tag="zoomed_line", ) - def channel_selection_callback(self, sender, app_data, user_data): - """Called when user picks a channel from the combo.""" + def channel_checkbox_callback(self, sender, app_data, user_data): + """Called when user checks/unchecks a channel checkbox.""" + ch_id = user_data + if app_data: # Checked + self.selected_channels.add(ch_id) + self.create_line_series(ch_id) + else: # Unchecked + self.selected_channels.discard(ch_id) + self.remove_line_series(ch_id) + + def zoom_channel_callback(self, sender, app_data, user_data): + """Called when user picks a channel for zooming.""" self.selected_channel_id = int(app_data) self.selected_channel_idx = self.channel_to_index[self.selected_channel_id] # Update the label of the zoomed line dpg.configure_item("zoomed_line", label=f"Channel {self.selected_channel_id}") + def create_line_series(self, ch_id): + """Create a line series for the specified channel.""" + if ch_id not in self.active_lines: + line_tag = f"all_line_ch{ch_id}" + dpg.add_line_series( + [], + [], + label=f"Ch {ch_id}", + parent=self.y_axis_all, + tag=line_tag, + ) + self.active_lines[ch_id] = line_tag + + def remove_line_series(self, ch_id): + """Remove the line series for the specified channel.""" + if ch_id in self.active_lines: + line_tag = self.active_lines[ch_id] + dpg.delete_item(line_tag) + del self.active_lines[ch_id] + def set_zoom_y_min(self, sender, app_data): self.zoom_y_min = app_data @@ -180,38 +243,91 @@ def set_zoom_y_max(self, sender, app_data): def set_signal_separation(self, sender, app_data): self.signal_separation = app_data - def on_split_drag(self, sender, app_data): - """Handle dragging the splitter between plots.""" - # Get the main window height - main_window_height = dpg.get_item_height("main_window") - - # Calculate new heights based on drag position - mouse_pos = dpg.get_mouse_pos(local=False) - window_pos = dpg.get_item_pos("main_window") - relative_height = mouse_pos[1] - window_pos[1] - - # Ensure minimum heights for both windows - MIN_HEIGHT = 100 - if ( - relative_height < MIN_HEIGHT - or relative_height > main_window_height - MIN_HEIGHT - ): + def put(self, frame: BroadbandFrame): + """Add a BroadbandFrame to the processing queue""" + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # If queue is full, drop multiple old frames and add the new one + dropped = 0 + while dropped < 5: # Drop up to 5 old frames + try: + self.data_queue.get_nowait() + dropped += 1 + except queue.Empty: + break + try: + self.data_queue.put(frame, block=False) + except queue.Full: + pass # Still full, drop this frame + + def put_batch(self, frames: list): + """Add multiple BroadbandFrames efficiently""" + for frame in frames: + try: + self.data_queue.put(frame, block=False) + except queue.Full: + # Drop old frames to make room + try: + self.data_queue.get_nowait() + self.data_queue.put(frame, block=False) + except queue.Empty: + pass + + def start(self): + """Start the plotter in a separate thread""" + if self.running: return - # Update the heights of both windows - dpg.configure_item("top_plot_window", height=relative_height) + self.running = True + self.stop_event.clear() + self.plot_thread = Thread(target=self._plot_thread_main) + self.plot_thread.start() + + def stop(self): + """Stop the plotter thread""" + if not self.running: + return + + self.running = False + self.stop_event.set() + if self.plot_thread: + self.plot_thread.join() + + def _plot_thread_main(self): + """Main plotting thread that runs the DearPyGui event loop""" + dpg.setup_dearpygui() + dpg.show_viewport() + + # Record start time + self.start_time = time.time() + + # Main loop + fps_limit = 30 + frame_duration = 1.0 / fps_limit + last_time = time.time() - def resize_callback(self, sender, app_data): - """Handle window resize events.""" - viewport_width = dpg.get_viewport_width() - viewport_height = dpg.get_viewport_height() + while dpg.is_dearpygui_running() and not self.stop_event.is_set(): + # Process multiple frames per iteration for better throughput + frames_processed = 0 + max_frames_per_iter = 10 - # Update main window size - main_window_width = viewport_width - 280 # Leave space for control window - main_window_height = viewport_height - 20 # Leave small margin - dpg.configure_item( - "main_window", width=main_window_width, height=main_window_height - ) + while frames_processed < max_frames_per_iter: + try: + frame = self.data_queue.get_nowait() + self.process_broadband_frame(frame) + frames_processed += 1 + except queue.Empty: + break + + # Throttle rendering to the fps limit + now = time.time() + if (now - last_time) >= frame_duration: + self.update_plot() + dpg.render_dearpygui_frame() + last_time = now + + dpg.destroy_context() def update_plot(self): """ @@ -220,28 +336,52 @@ def update_plot(self): """ # Downsample factor for performance # Note(gilbert): we should probably make this configurable, it is arbitrary - ds_factor = 10 + ds_factor = 4 + + # Get the current time window for x-axis limits based on latest data + # Use latest data timestamp instead of wall clock for better sync + current_data_time = self.latest_data_time + x_min = max(0, current_data_time - self.window_size_seconds) + x_max = current_data_time # ----------------------------- # Update "All Channels" Plot # ----------------------------- - for idx, ch_id in enumerate(self.channel_ids): + active_channel_idx = 0 + for ch_id in self.selected_channels: + if ch_id not in self.active_lines: + continue + + idx = self.channel_to_index[ch_id] pos = self.buffer_positions[idx] # Roll data so that index -1 corresponds to the newest sample rolled_y = np.roll(self.data_buffers[idx], -pos) - rolled_x = np.roll(self.time_buffer, -pos) + rolled_x = np.roll(self.timestamp_buffers[idx], -pos) # Downsample ds_x = rolled_x[::ds_factor] ds_y = rolled_y[::ds_factor] - # Apply vertical offset for each channel to avoid overlap - offset = idx * self.signal_separation + # Apply vertical offset for each active channel to avoid overlap + offset = active_channel_idx * self.signal_separation ds_y_offset = ds_y + offset - # Update the line series for channel ch - dpg.set_value(self.lines_all[idx], [ds_x.tolist(), ds_y_offset.tolist()]) + # Update the line series for this channel + line_tag = self.active_lines[ch_id] + dpg.set_value(line_tag, [ds_x.tolist(), ds_y_offset.tolist()]) + + active_channel_idx += 1 + + # Set X-axis limits to show current time window + dpg.set_axis_limits("x_axis_all", x_min, x_max) + + # Set Y-axis limits based on number of active channels + num_active = len(self.selected_channels) + if num_active > 0: + y_max_all = (num_active - 1) * self.signal_separation + 4096 + y_min_all = -4096 + dpg.set_axis_limits("y_axis_all", y_min_all, y_max_all) # ----------------------------- # Update Zoomed Channel Plot @@ -250,7 +390,7 @@ def update_plot(self): pos = self.buffer_positions[idx] rolled_y_ch = np.roll(self.data_buffers[idx], -pos) - rolled_x_ch = np.roll(self.time_buffer, -pos) + rolled_x_ch = np.roll(self.timestamp_buffers[idx], -pos) ds_x_ch = rolled_x_ch[::ds_factor] ds_y_ch = rolled_y_ch[::ds_factor] @@ -258,7 +398,8 @@ def update_plot(self): # Update the single "zoomed_line" series dpg.set_value("zoomed_line", [ds_x_ch.tolist(), ds_y_ch.tolist()]) - # Optionally set the zoomed plot Y-axis range for a closer look + # Set axis limits for both plots + dpg.set_axis_limits("x_axis_zoom", x_min, x_max) dpg.set_axis_limits("y_axis_zoom", self.zoom_y_min, self.zoom_y_max) # ----------------------------- @@ -268,68 +409,89 @@ def update_plot(self): elapsed = time.time() - self.start_time dpg.set_value("elapsed_time_text", f"{elapsed:.2f}") - def process_data(self, data): + def process_broadband_frame(self, frame: BroadbandFrame): """ - data = (channel, samples) - Write these samples into the ring buffer for that channel. + Process a BroadbandFrame and distribute the data to channel buffers. + Uses actual timestamps from the frame for proper time synchronization. """ - channel_id, samples = data - if channel_id not in self.channel_to_index: - print( - f"Warning: Received data for channel {channel_id} which is not in the configured channel list." - ) - return - - # Get our channel id mapped to our plotting index - idx = self.channel_to_index[channel_id] - num_samples = len(samples) - pos = self.buffer_positions[idx] - - end_pos = pos + num_samples - if end_pos <= self.buffer_size: - # Simple case: fits without wrap - self.data_buffers[idx][pos:end_pos] = samples - else: - # Wrap around case - first_part = self.buffer_size - pos - second_part = num_samples - first_part - self.data_buffers[idx][pos:] = samples[:first_part] - self.data_buffers[idx][:second_part] = samples[first_part:] - - self.buffer_positions[idx] = (pos + num_samples) % self.buffer_size - - def start(self, stop, data_queue): - """Run the DearPyGui event/render loop.""" - dpg.setup_dearpygui() - dpg.show_viewport() - - # Record start time - self.start_time = time.time() - - # Main loop - fps_limit = 60 - frame_duration = 1.0 / fps_limit - last_time = time.time() - - while dpg.is_dearpygui_running() and not stop.is_set(): - # Process any incoming data in the queue - while True: - try: - data = data_queue.get_nowait() - self.process_data(data.samples[0]) - except queue.Empty: - break - - # Throttle rendering to the fps limit - now = time.time() - if (now - last_time) >= frame_duration: - self.update_plot() - dpg.render_dearpygui_frame() - last_time = now - - dpg.destroy_context() + # Set start timestamp on first frame + if self.start_timestamp_ns is None: + self.start_timestamp_ns = frame.timestamp_ns + self.start_time = time.time() + + # Convert timestamp to seconds relative to start + relative_time_s = (frame.timestamp_ns - self.start_timestamp_ns) / 1e9 + + # Update latest data time + self.latest_data_time = relative_time_s + + # frame_data is a flat array with one sample per channel + # We assume the data is organized as: [ch0_sample, ch1_sample, ch2_sample, ...] + frame_data = frame.frame_data + + # Distribute data to each channel buffer + for ch_idx, ch_id in enumerate(self.channel_ids): + if ch_idx < len(frame_data): + sample = frame_data[ch_idx] + + # Add sample to this channel's ring buffer + pos = self.buffer_positions[ch_idx] + self.data_buffers[ch_idx][pos] = sample + + # Add actual timestamp to this channel's timestamp buffer + self.timestamp_buffers[ch_idx][pos] = relative_time_s + + self.buffer_positions[ch_idx] = (pos + 1) % self.buffer_size + + def select_all_channels(self): + """Select all channels for plotting.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set(self.channel_ids) + + # Update checkboxes + for ch_id in self.channel_ids: + dpg.set_value(f"ch_checkbox_{ch_id}", True) + if ch_id not in old_selection: + self.create_line_series(ch_id) + + def select_no_channels(self): + """Deselect all channels.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set() + + # Update checkboxes and remove line series + for ch_id in old_selection: + dpg.set_value(f"ch_checkbox_{ch_id}", False) + self.remove_line_series(ch_id) + + def select_first_5_channels(self): + """Select only the first 5 channels.""" + # Update internal state + old_selection = self.selected_channels.copy() + self.selected_channels = set(self.channel_ids[:5]) + + # Update checkboxes + for ch_id in self.channel_ids: + should_be_selected = ch_id in self.selected_channels + dpg.set_value(f"ch_checkbox_{ch_id}", should_be_selected) + + if should_be_selected and ch_id not in old_selection: + self.create_line_series(ch_id) + elif not should_be_selected and ch_id in old_selection: + self.remove_line_series(ch_id) + + +# Factory function to create plotter with BroadbandFrame support +def create_broadband_plotter( + sample_rate_hz: int, window_size_seconds: int, channel_ids +): + """Create a SynapsePlotter configured for BroadbandFrame data""" + return SynapsePlotter(sample_rate_hz, window_size_seconds, channel_ids) +# Legacy function for backward compatibility def plot_synapse_data( stop: Event, data_queue: queue.Queue, diff --git a/synapse/client/taps.py b/synapse/client/taps.py index 8143c87..4feb38d 100644 --- a/synapse/client/taps.py +++ b/synapse/client/taps.py @@ -86,6 +86,17 @@ def connect(self, name: str) -> bool: # For producer taps (or unspecified), we need to subscribe and listen FROM the tap self.zmq_socket = self.zmq_context.socket(zmq.SUB) + # Optimize ZMQ for high-throughput data + # Increase receive buffer size significantly for high-speed data + self.zmq_socket.setsockopt( + zmq.RCVHWM, 10000 + ) # High water mark - buffer up to 10K messages + self.zmq_socket.setsockopt(zmq.RCVBUF, 16 * 1024 * 1024) # 16MB receive buffer + + # Set TCP keepalive for connection stability + self.zmq_socket.setsockopt(zmq.TCP_KEEPALIVE, 1) + self.zmq_socket.setsockopt(zmq.TCP_KEEPALIVE_IDLE, 60) + # Replace the endpoint with our device URI if needed endpoint = selected_tap.endpoint if "://" in endpoint: @@ -99,14 +110,13 @@ def connect(self, name: str) -> bool: try: self.zmq_socket.connect(endpoint) - # Give the socket a chance to connect - self.logger.info("Waiting for socket to connect...") - time.sleep(1) + # Reduce connection wait time to minimize startup delay + self.logger.info("Connecting to tap...") + time.sleep(0.1) # Reduced from 1 second # Only set subscription options for subscriber sockets if selected_tap.tap_type != TapType.TAP_TYPE_CONSUMER: self.zmq_socket.setsockopt(zmq.SUBSCRIBE, b"") - print("Subscribed to all messages") return True except zmq.ZMQError as e: @@ -167,11 +177,11 @@ def send(self, data: bytes) -> bool: self.logger.error(f"Error sending message: {e}") return False - def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: - """Stream raw data from the tap. + def stream(self, timeout_ms: int = 100) -> Generator[bytes, None, None]: + """Stream raw data from the tap with optimizations for high-throughput data. Args: - timeout_ms (int, optional): Timeout between messages in milliseconds. Defaults to 1000. + timeout_ms (int, optional): Timeout between messages in milliseconds. Defaults to 100. Yields: Generator[bytes, None, None]: Stream of raw message data. @@ -180,16 +190,18 @@ def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: self.logger.error("Not connected to any tap") return - # Set socket timeout + # Set a shorter timeout for high-frequency data self.zmq_socket.setsockopt(zmq.RCVTIMEO, timeout_ms) try: while True: try: - data = self.zmq_socket.recv() + # Use non-blocking receive with DONTWAIT for maximum throughput + data = self.zmq_socket.recv(zmq.DONTWAIT) yield data except zmq.Again: - # Timeout occurred, continue to next iteration + # No data available right now, yield control briefly + time.sleep(0.0001) # 0.1ms sleep to prevent busy waiting continue except KeyboardInterrupt: self.logger.info("Stream interrupted") @@ -199,6 +211,50 @@ def stream(self, timeout_ms: int = 1000) -> Generator[bytes, None, None]: # Don't close the socket here, let the user call disconnect() pass + def stream_batch( + self, batch_size: int = 10, timeout_ms: int = 100 + ) -> Generator[list, None, None]: + """Stream data in batches for improved throughput. + + Args: + batch_size (int, optional): Number of messages to batch together. Defaults to 10. + timeout_ms (int, optional): Timeout for each receive operation. Defaults to 100. + + Yields: + Generator[list, None, None]: Batches of raw message data. + """ + if not self.zmq_socket: + self.logger.error("Not connected to any tap") + return + + self.zmq_socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + + batch = [] + try: + while True: + try: + data = self.zmq_socket.recv(zmq.DONTWAIT) + batch.append(data) + + if len(batch) >= batch_size: + yield batch + batch = [] + except zmq.Again: + # No data available, yield partial batch if any + if batch: + yield batch + batch = [] + time.sleep(0.0001) + continue + except KeyboardInterrupt: + self.logger.info("Stream interrupted") + if batch: + yield batch + except zmq.ZMQError as e: + self.logger.error(f"Error streaming messages: {e}") + if batch: + yield batch + def disconnect(self): """Disconnect from the tap.""" self._cleanup()