From 34e0cc7e7cd5eef83ec429757a6eae82c7efb816 Mon Sep 17 00:00:00 2001 From: Youssef Date: Sat, 31 Jan 2026 10:22:10 -0500 Subject: [PATCH 1/5] Add Lab Streaming Layer (LSL) support for multiple devices and WebSocket bridge - Introduced LSL device profiles for 8, 16, 32, 64 channels, and specific manufacturers (Brain Products, BioSemi, g.tec, Cognionics, ANT Neuro, NIRx). - Updated stream adapter registry to include LSL streams with appropriate factory methods and descriptions. - Implemented a Python WebSocket bridge for LSL, allowing real-time streaming of EEG and fNIRS data to web clients. - Added functionality for automatic stream discovery, connection management, and sample forwarding. - Included a simulator for testing without hardware, generating simulated EEG data. --- EEG_INTEGRATION.md | 48 +- README.md | 104 +++- scripts/lsl_ws_bridge.py | 918 ++++++++++++++++++++++++++++++++++ src/devices/deviceProfiles.ts | 316 ++++++++++++ src/streams/index.ts | 70 +++ 5 files changed, 1451 insertions(+), 5 deletions(-) create mode 100644 scripts/lsl_ws_bridge.py diff --git a/EEG_INTEGRATION.md b/EEG_INTEGRATION.md index 617e4ce..20ee9c3 100644 --- a/EEG_INTEGRATION.md +++ b/EEG_INTEGRATION.md @@ -20,9 +20,17 @@ | **PiEEG** | ardEEG | 8 | 250 Hz | N/A* | Serial (Arduino) | | **PiEEG** | MicroBCI | 8 | 250 Hz | N/A* | BLE (STM32) | | **Cerelog** | ESP-EEG | 8 | 250 Hz | N/A* | WiFi (TCP) | +| **LSL** | Generic (8-64ch) | 8-64 | Variable | -2** | Lab Streaming Layer | +| **LSL** | Brain Products | 32+ | Up to 25kHz | N/A* | LSL | +| **LSL** | BioSemi ActiveTwo | 32+ | Up to 16kHz | N/A* | LSL | +| **LSL** | g.tec | 16+ | Up to 38kHz | N/A* | LSL | +| **LSL** | Cognionics | 20-30 | 500 Hz | N/A* | LSL | +| **LSL** | ANT Neuro | 32+ | 2048 Hz | 29 | LSL | +| **LSL** | NIRx fNIRS | 16+ | 10 Hz | N/A* | LSL | | **Brainflow** | Synthetic | 8 | 250 Hz | -1 | Virtual | *Requires WebSocket bridge (included). +**LSL streams can use Brainflow Streaming Board ID -2 for forwarding. ## ⚠️ Browser Connectivity @@ -30,9 +38,9 @@ For hardware devices, you need a **WebSocket bridge** that runs locally and proxies the device data to the browser. This project includes: -1. **pieeg_ws_bridge.py** - For PiEEG devices (SPI/BrainFlow → WebSocket) -2. **cerelog_ws_bridge.py** - For Cerelog ESP-EEG (TCP → WebSocket) -3. Community bridges for other devices (see Bridge Setup section) +1. **lsl_ws_bridge.py** - For any LSL source (130+ devices → WebSocket) +2. **pieeg_ws_bridge.py** - For PiEEG devices (SPI/BrainFlow → WebSocket) +3. **cerelog_ws_bridge.py** - For Cerelog ESP-EEG (TCP → WebSocket) ### Bridge Architecture ``` @@ -103,6 +111,40 @@ pip install muselsl # Stream via LSL, then use LSL-to-WebSocket bridge muselsl stream + +# In another terminal, run the LSL bridge +python scripts/lsl_ws_bridge.py --stream "Muse" +``` + +#### Lab Streaming Layer (130+ Devices) +```bash +# The LSL bridge supports any LSL-compatible device: +# Brain Products, BioSemi, g.tec, ANT Neuro, Cognionics, NIRx, etc. + +# Install dependencies +pip install websockets pylsl numpy + +# Auto-discover and connect to first EEG stream +python scripts/lsl_ws_bridge.py + +# Connect to specific stream by name +python scripts/lsl_ws_bridge.py --stream "OpenBCI_EEG" + +# List available LSL streams on your network +python scripts/lsl_ws_bridge.py --list + +# Run with simulated data (for testing) +python scripts/lsl_ws_bridge.py --simulate + +# Connect in PhantomLoop to ws://localhost:8767 +``` + +**WebSocket Commands:** +```json +{"command": "discover"} +{"command": "connect", "name": "OpenBCI_EEG", "stream_type": "EEG"} +{"command": "disconnect"} +{"command": "ping"} ``` #### Emotiv Insight/EPOC diff --git a/README.md b/README.md index 727d67e..3cfc47d 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ PhantomLoop is one component of the **Phantom Stack**, an integrated ecosystem f ## ✨ Key Features - **🔌 Universal Stream Architecture** – Connect to any multichannel data source (EEG, spikes, simulated) -- **🧠 15+ Device Support** – OpenBCI, Muse, Emotiv, NeuroSky, PiEEG, Cerelog ESP-EEG, and more +- **🧠 25+ Device Support** – OpenBCI, Muse, Emotiv, NeuroSky, PiEEG, LSL (130+ devices), and more - **⚡ Real-Time Performance** – 40Hz streaming with <50ms end-to-end latency - **🤖 AI-Powered Decoders** – TensorFlow.js models with WebGPU/WebGL acceleration - **📝 Monaco Code Editor** – Write custom decoders with VS Code-quality IntelliSense @@ -69,6 +69,13 @@ PhantomLoop supports **any multichannel time-series source** through a unified a | **PiEEG** | ardEEG | 8 | 250 Hz | Serial (Arduino) | | **PiEEG** | MicroBCI | 8 | 250 Hz | BLE (STM32) | | **Cerelog** | ESP-EEG | 8 | 250 Hz | WiFi (TCP) | +| **LSL** | Generic (8-64ch) | 8-64 | Variable | Lab Streaming Layer | +| **LSL** | Brain Products | 32+ | Up to 25kHz | LSL (via Connector) | +| **LSL** | BioSemi ActiveTwo | 32+ | Up to 16kHz | LSL | +| **LSL** | g.tec | 16+ | Up to 38kHz | LSL (g.NEEDaccess) | +| **LSL** | Cognionics | 20-30 | 500 Hz | LSL | +| **LSL** | ANT Neuro | 32+ | 2048 Hz | LSL | +| **LSL** | NIRx fNIRS | 16+ | 10 Hz | LSL | | **Brainflow** | Synthetic | 8 | 250 Hz | Virtual | > ⚠️ **Note:** Browsers cannot connect directly to TCP/Serial/BLE. Hardware devices require a WebSocket bridge (Python scripts included). @@ -96,6 +103,35 @@ PhantomLoop supports **any multichannel time-series source** through a unified a ⚠️ **Safety:** PiEEG must be powered by battery only (5V). Never connect to mains power! +### 🌐 Lab Streaming Layer (LSL) Integration + +[Lab Streaming Layer (LSL)](https://labstreaminglayer.org) is the universal protocol for streaming EEG and biosignal data in research settings. PhantomLoop supports **130+ LSL-compatible devices** through the included WebSocket bridge. + +**Key Features:** +- Real-time stream discovery on local network +- Sub-millisecond time synchronization +- Multi-stream support (EEG, markers, motion) +- Automatic reconnection on stream loss + +**LSL-Compatible Devices:** + +| Manufacturer | Devices | Notes | +|--------------|---------|-------| +| **Brain Products** | actiCHamp, LiveAmp, BrainVision | Via LSL Connector app | +| **BioSemi** | ActiveTwo 32-256ch | Research gold standard | +| **g.tec** | g.USBamp, g.Nautilus, g.HIamp | Via g.NEEDaccess | +| **ANT Neuro** | eego sport, eego mylab | Mobile & lab EEG | +| **Cognionics** | Quick-20, Quick-30, Mobile-72 | Dry electrode systems | +| **OpenBCI** | All models | Via OpenBCI GUI LSL | +| **Muse** | Muse 1/2/S | Via muse-lsl | +| **Emotiv** | EPOC, Insight, EPOC Flex | Via EmotivPRO LSL | +| **NIRx** | NIRSport, NIRScout | fNIRS devices | +| **Tobii** | Pro Glasses, Screen-based | Eye tracking | +| **Neurosity** | Notion, Crown | Consumer EEG | +| **BrainAccess** | HALO, MINI, MIDI | Affordable research EEG | + +📚 Full device list: [labstreaminglayer.org](https://labstreaminglayer.org/#checks:certified) + --- PhantomLoop streams neural data from PhantomLink (MC_Maze dataset, 142 channels @ 40Hz) and visualizes **ground truth cursor movements** alongside **your decoder's predictions**. Built for BCI researchers who need to rapidly prototype, test, and compare decoding algorithms. @@ -140,7 +176,24 @@ python scripts/pieeg_ws_bridge.py --rate 250 --gain 24 # 5. Select "PiEEG" in the device selector ``` -**Option 3: Cerelog ESP-EEG (WiFi)** +**Option 3: Lab Streaming Layer (130+ Devices)** +```bash +# 1. Start your LSL source (OpenBCI GUI, muse-lsl, BrainVision, etc.) +# 2. Run the LSL WebSocket bridge +pip install websockets pylsl numpy +python scripts/lsl_ws_bridge.py + +# 3. In PhantomLoop, connect to ws://localhost:8767 +# 4. Select "LSL Stream" in the device selector + +# Advanced: Connect to specific stream by name +python scripts/lsl_ws_bridge.py --stream "OpenBCI_EEG" + +# List available LSL streams on your network +python scripts/lsl_ws_bridge.py --list +``` + +**Option 4: Cerelog ESP-EEG (WiFi)** ```bash # 1. Connect to ESP-EEG WiFi: SSID: CERELOG_EEG, Password: cerelog123 # 2. Run the WebSocket bridge @@ -158,9 +211,56 @@ Since browsers cannot directly access hardware (SPI, Serial, BLE, TCP), PhantomL | Script | Device | Port | Mode | |--------|--------|------|------| +| `lsl_ws_bridge.py` | Any LSL source (130+ devices) | 8767 | LSL Inlet → WebSocket | | `pieeg_ws_bridge.py` | PiEEG (Raspberry Pi) | 8766 | SPI / BrainFlow / Simulation | | `cerelog_ws_bridge.py` | Cerelog ESP-EEG | 8765 | TCP-to-WebSocket | +### LSL Bridge + +The LSL bridge connects to any Lab Streaming Layer source and forwards data to the browser: + +```bash +# Auto-discover and connect to first EEG stream +python scripts/lsl_ws_bridge.py + +# Connect to specific stream by name +python scripts/lsl_ws_bridge.py --stream "OpenBCI_EEG" + +# Connect to Muse via muse-lsl +python scripts/lsl_ws_bridge.py --stream "Muse" --type EEG + +# List available streams on network +python scripts/lsl_ws_bridge.py --list + +# Run with simulated data for testing (no hardware) +python scripts/lsl_ws_bridge.py --simulate + +# Custom port +python scripts/lsl_ws_bridge.py --port 8768 +``` + +**WebSocket Commands:** +```json +{"command": "discover"} +{"command": "connect", "name": "OpenBCI_EEG", "stream_type": "EEG"} +{"command": "disconnect"} +{"command": "ping"} +``` + +**Response: Stream Metadata** +```json +{ + "type": "metadata", + "stream": { + "name": "OpenBCI_EEG", + "stream_type": "EEG", + "channel_count": 8, + "sampling_rate": 250.0, + "channel_labels": ["Fp1", "Fp2", "C3", "C4", "P3", "P4", "O1", "O2"] + } +} +``` + ### PiEEG Bridge ```bash diff --git a/scripts/lsl_ws_bridge.py b/scripts/lsl_ws_bridge.py new file mode 100644 index 0000000..9397b2c --- /dev/null +++ b/scripts/lsl_ws_bridge.py @@ -0,0 +1,918 @@ +#!/usr/bin/env python3 +""" +WebSocket Bridge for Lab Streaming Layer (LSL) + +This script bridges the gap between browsers (WebSocket only) and any +LSL-compatible EEG device or software source. + +Lab Streaming Layer (LSL) is an open-source networked middleware ecosystem +for real-time streaming of time series data (EEG, fNIRS, eye tracking, etc.) + +Supported LSL Sources (130+ devices): +- OpenBCI (Cyton, Ganglion, Ultracortex) +- Muse (1, 2, S) via muse-lsl +- Emotiv (EPOC, Insight) via native LSL +- Brain Products (actiCHamp, LiveAmp) +- g.tec (g.USBamp, g.Nautilus) +- ANT Neuro (eego sport) +- BioSemi (ActiveTwo) +- NIRx (NIRSport, NIRScout) - fNIRS +- Cognionics (Quick-20, Quick-30) +- Neurosity (Notion, Crown) +- BrainAccess (HALO, MINI, MIDI) +- Eye trackers (Tobii, SR Research, Pupil Labs) +- Motion capture (OptiTrack, PhaseSpace) +- Any custom LSL outlet + +Usage: + 1. Start your LSL stream source (OpenBCI GUI, muse-lsl, etc.) + + 2. Run this bridge to discover and relay streams: + python lsl_ws_bridge.py + + 3. Run with specific stream name: + python lsl_ws_bridge.py --stream "OpenBCI_EEG" + + 4. In PhantomLoop, connect to: ws://localhost:8767 + +Requirements: + pip install websockets pylsl numpy + +Protocol: + - Discovers LSL streams on the network + - Pulls samples from inlet and forwards via WebSocket + - Supports multiple concurrent inlets + - Maintains LSL timestamp synchronization + +Documentation: + https://labstreaminglayer.readthedocs.io/ + https://github.com/labstreaminglayer/pylsl +""" + +import asyncio +import json +import sys +import time +import struct +import argparse +from typing import Optional, List, Dict, Set, Tuple, Any +from dataclasses import dataclass, asdict +from enum import Enum +import threading +import queue +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s' +) +logger = logging.getLogger(__name__) + +try: + import websockets + from websockets.server import WebSocketServerProtocol +except ImportError: + logger.error("websockets library required. Install with: pip install websockets") + sys.exit(1) + +try: + import numpy as np +except ImportError: + logger.error("numpy library required. Install with: pip install numpy") + sys.exit(1) + +try: + from pylsl import StreamInlet, StreamInfo, resolve_streams, resolve_byprop, local_clock + LSL_AVAILABLE = True +except ImportError: + logger.warning("pylsl not available. Install with: pip install pylsl") + LSL_AVAILABLE = False + + +# ============================================================================ +# LSL CONSTANTS +# ============================================================================ + +class LSLChannelFormat(Enum): + """LSL channel format types""" + FLOAT32 = 1 + DOUBLE64 = 2 + STRING = 3 + INT32 = 4 + INT16 = 5 + INT8 = 6 + INT64 = 7 + UNDEFINED = 0 + + +# Common LSL stream types +LSL_STREAM_TYPES = { + 'EEG': 'Electroencephalography', + 'EMG': 'Electromyography', + 'ECG': 'Electrocardiography', + 'EOG': 'Electrooculography', + 'fNIRS': 'Functional Near-Infrared Spectroscopy', + 'Gaze': 'Eye tracking gaze data', + 'Markers': 'Event markers', + 'Audio': 'Audio signal', + 'Video': 'Video frames', + 'MoCap': 'Motion capture', + 'Accelerometer': 'Accelerometer data', + 'Gyroscope': 'Gyroscope data', + 'PPG': 'Photoplethysmography', + 'GSR': 'Galvanic skin response', +} + + +# ============================================================================ +# DATA STRUCTURES +# ============================================================================ + +@dataclass +class StreamMetadata: + """Metadata about an LSL stream""" + name: str + stream_type: str + channel_count: int + sampling_rate: float + channel_format: str + source_id: str + hostname: str + uid: str + version: float + created_at: float + channel_labels: List[str] + channel_types: List[str] + channel_units: List[str] + manufacturer: str + model: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class Sample: + """A single sample from an LSL stream""" + timestamp: float + lsl_timestamp: float + channels: List[float] + stream_id: str + + +# ============================================================================ +# LSL INLET MANAGER +# ============================================================================ + +class LSLInletManager: + """Manages LSL stream discovery and inlet connections""" + + def __init__(self): + self.inlets: Dict[str, StreamInlet] = {} + self.metadata: Dict[str, StreamMetadata] = {} + self.sample_queues: Dict[str, queue.Queue] = {} + self.running = False + self.reader_threads: Dict[str, threading.Thread] = {} + + def discover_streams( + self, + stream_type: Optional[str] = None, + stream_name: Optional[str] = None, + timeout: float = 2.0 + ) -> List[StreamMetadata]: + """Discover available LSL streams on the network""" + if not LSL_AVAILABLE: + return [] + + logger.info(f"Discovering LSL streams (timeout={timeout}s)...") + + if stream_name: + streams = resolve_byprop('name', stream_name, timeout=timeout) + elif stream_type: + streams = resolve_byprop('type', stream_type, timeout=timeout) + else: + streams = resolve_streams(timeout) + + metadata_list = [] + for stream_info in streams: + metadata = self._parse_stream_info(stream_info) + metadata_list.append(metadata) + logger.info(f" Found: {metadata.name} ({metadata.stream_type}) " + f"{metadata.channel_count}ch @ {metadata.sampling_rate}Hz") + + if not metadata_list: + logger.warning("No LSL streams found on the network") + + return metadata_list + + def _parse_stream_info(self, info: 'StreamInfo') -> StreamMetadata: + """Extract metadata from LSL StreamInfo""" + # Get channel info from XML description + channel_labels = [] + channel_types = [] + channel_units = [] + manufacturer = "" + model = "" + + try: + desc = info.desc() + + # Try to get manufacturer/model + acq = desc.child("acquisition") + if not acq.empty(): + manufacturer = acq.child_value("manufacturer") or "" + model = acq.child_value("model") or "" + + # Get channel metadata + channels = desc.child("channels") + if not channels.empty(): + ch = channels.child("channel") + while not ch.empty(): + channel_labels.append(ch.child_value("label") or f"Ch{len(channel_labels)+1}") + channel_types.append(ch.child_value("type") or info.type()) + channel_units.append(ch.child_value("unit") or "µV") + ch = ch.next_sibling("channel") + except Exception as e: + logger.warning(f"Could not parse stream description: {e}") + + # Fill in missing channel labels + num_channels = info.channel_count() + while len(channel_labels) < num_channels: + channel_labels.append(f"Ch{len(channel_labels)+1}") + while len(channel_types) < num_channels: + channel_types.append(info.type()) + while len(channel_units) < num_channels: + channel_units.append("µV") + + # Get channel format name + format_names = { + 1: 'float32', 2: 'double64', 3: 'string', + 4: 'int32', 5: 'int16', 6: 'int8', 7: 'int64', 0: 'undefined' + } + channel_format = format_names.get(info.channel_format(), 'unknown') + + return StreamMetadata( + name=info.name(), + stream_type=info.type(), + channel_count=num_channels, + sampling_rate=info.nominal_srate(), + channel_format=channel_format, + source_id=info.source_id(), + hostname=info.hostname(), + uid=info.uid(), + version=info.version(), + created_at=info.created_at(), + channel_labels=channel_labels, + channel_types=channel_types, + channel_units=channel_units, + manufacturer=manufacturer, + model=model, + ) + + def connect_stream( + self, + stream_name: Optional[str] = None, + stream_type: str = "EEG", + buffer_length: float = 360.0, # 6 minutes max buffer + max_chunklen: int = 0, # 0 = variable chunk size + ) -> Optional[str]: + """Connect to an LSL stream and start reading samples""" + if not LSL_AVAILABLE: + logger.error("pylsl not available") + return None + + # Resolve the stream + logger.info(f"Resolving LSL stream: name={stream_name}, type={stream_type}") + + if stream_name: + streams = resolve_byprop('name', stream_name, timeout=5.0) + else: + streams = resolve_byprop('type', stream_type, timeout=5.0) + + if not streams: + logger.error(f"Could not find LSL stream: {stream_name or stream_type}") + return None + + # Use first matching stream + stream_info = streams[0] + stream_id = stream_info.uid() + + if stream_id in self.inlets: + logger.info(f"Already connected to stream: {stream_id}") + return stream_id + + # Create inlet + logger.info(f"Creating inlet for: {stream_info.name()}") + inlet = StreamInlet( + stream_info, + max_buflen=buffer_length, + max_chunklen=max_chunklen, + recover=True, # Auto-recover if stream is temporarily lost + ) + + # Parse and store metadata + metadata = self._parse_stream_info(stream_info) + + self.inlets[stream_id] = inlet + self.metadata[stream_id] = metadata + self.sample_queues[stream_id] = queue.Queue(maxsize=10000) + + # Start reader thread + self._start_reader_thread(stream_id) + + logger.info(f"Connected to stream: {metadata.name} ({metadata.channel_count}ch)") + return stream_id + + def _start_reader_thread(self, stream_id: str): + """Start background thread to read samples from inlet""" + if stream_id in self.reader_threads: + return + + thread = threading.Thread( + target=self._reader_loop, + args=(stream_id,), + daemon=True, + name=f"lsl-reader-{stream_id[:8]}" + ) + self.reader_threads[stream_id] = thread + self.running = True + thread.start() + + def _reader_loop(self, stream_id: str): + """Background loop to continuously read samples from inlet""" + inlet = self.inlets.get(stream_id) + sample_queue = self.sample_queues.get(stream_id) + + if not inlet or not sample_queue: + return + + logger.info(f"Reader thread started for stream: {stream_id[:8]}...") + + while self.running and stream_id in self.inlets: + try: + # Pull chunk of samples (more efficient than single samples) + samples, timestamps = inlet.pull_chunk(timeout=0.1, max_samples=32) + + if samples: + for sample, ts in zip(samples, timestamps): + try: + sample_obj = Sample( + timestamp=time.time(), + lsl_timestamp=ts, + channels=list(sample), + stream_id=stream_id, + ) + sample_queue.put_nowait(sample_obj) + except queue.Full: + # Drop oldest samples if queue is full + try: + sample_queue.get_nowait() + sample_queue.put_nowait(sample_obj) + except queue.Empty: + pass + except Exception as e: + logger.warning(f"Error reading from stream {stream_id[:8]}: {e}") + time.sleep(0.1) + + logger.info(f"Reader thread stopped for stream: {stream_id[:8]}") + + def get_samples(self, stream_id: str, max_samples: int = 100) -> List[Sample]: + """Get pending samples from a stream (non-blocking)""" + sample_queue = self.sample_queues.get(stream_id) + if not sample_queue: + return [] + + samples = [] + for _ in range(max_samples): + try: + sample = sample_queue.get_nowait() + samples.append(sample) + except queue.Empty: + break + + return samples + + def disconnect_stream(self, stream_id: str): + """Disconnect from a specific stream""" + if stream_id in self.inlets: + # Stop reader thread by removing inlet + inlet = self.inlets.pop(stream_id) + inlet.close_stream() + + if stream_id in self.metadata: + del self.metadata[stream_id] + if stream_id in self.sample_queues: + del self.sample_queues[stream_id] + if stream_id in self.reader_threads: + del self.reader_threads[stream_id] + + logger.info(f"Disconnected from stream: {stream_id[:8]}") + + def disconnect_all(self): + """Disconnect from all streams""" + self.running = False + stream_ids = list(self.inlets.keys()) + for stream_id in stream_ids: + self.disconnect_stream(stream_id) + + def get_time_correction(self, stream_id: str) -> float: + """Get clock offset for a stream (for synchronization)""" + inlet = self.inlets.get(stream_id) + if inlet: + try: + return inlet.time_correction() + except Exception: + return 0.0 + return 0.0 + + +# ============================================================================ +# WEBSOCKET SERVER +# ============================================================================ + +class LSLWebSocketBridge: + """WebSocket server that bridges LSL streams to browser clients""" + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8767, + stream_name: Optional[str] = None, + stream_type: str = "EEG", + auto_connect: bool = True, + ): + self.host = host + self.port = port + self.stream_name = stream_name + self.stream_type = stream_type + self.auto_connect = auto_connect + + self.inlet_manager = LSLInletManager() + self.connected_clients: Set[WebSocketServerProtocol] = set() + self.active_stream_id: Optional[str] = None + self.running = False + + # Message queue for broadcasting + self.broadcast_queue: asyncio.Queue = asyncio.Queue() + + async def start(self): + """Start the WebSocket server""" + self.running = True + + # Start the broadcast task + broadcast_task = asyncio.create_task(self._broadcast_loop()) + + # Start LSL connection if auto_connect + if self.auto_connect: + await self._auto_connect_lsl() + + # Start sample forwarding task + forward_task = asyncio.create_task(self._forward_samples()) + + # Start WebSocket server + logger.info(f"Starting LSL WebSocket bridge on ws://{self.host}:{self.port}") + + async with websockets.serve(self._handle_client, self.host, self.port): + logger.info("LSL WebSocket bridge is running") + logger.info("Press Ctrl+C to stop") + + try: + await asyncio.Future() # Run forever + except asyncio.CancelledError: + pass + + self.running = False + broadcast_task.cancel() + forward_task.cancel() + self.inlet_manager.disconnect_all() + + async def _auto_connect_lsl(self): + """Automatically connect to first available LSL stream""" + # Run discovery in thread pool to avoid blocking + loop = asyncio.get_event_loop() + streams = await loop.run_in_executor( + None, + lambda: self.inlet_manager.discover_streams( + stream_type=self.stream_type if not self.stream_name else None, + stream_name=self.stream_name, + timeout=5.0 + ) + ) + + if streams: + # Connect to first stream + stream_id = await loop.run_in_executor( + None, + lambda: self.inlet_manager.connect_stream( + stream_name=self.stream_name, + stream_type=self.stream_type, + ) + ) + if stream_id: + self.active_stream_id = stream_id + logger.info(f"Auto-connected to stream: {stream_id[:8]}") + else: + logger.warning("No LSL streams found. Waiting for streams...") + # Schedule periodic retry + asyncio.create_task(self._retry_connect()) + + async def _retry_connect(self): + """Periodically retry LSL connection""" + while self.running and not self.active_stream_id: + await asyncio.sleep(5.0) + await self._auto_connect_lsl() + + async def _forward_samples(self): + """Forward samples from LSL inlet to broadcast queue""" + while self.running: + if self.active_stream_id: + samples = self.inlet_manager.get_samples(self.active_stream_id, max_samples=50) + for sample in samples: + metadata = self.inlet_manager.metadata.get(self.active_stream_id) + msg = self._format_sample_message(sample, metadata) + await self.broadcast_queue.put(msg) + + await asyncio.sleep(0.01) # ~100Hz update rate + + def _format_sample_message( + self, + sample: Sample, + metadata: Optional[StreamMetadata] + ) -> bytes: + """Format sample as binary message for WebSocket transmission""" + # Binary format for efficiency: + # [4 bytes: packet type] + # [8 bytes: timestamp (double)] + # [8 bytes: LSL timestamp (double)] + # [4 bytes: channel count (uint32)] + # [N * 4 bytes: channel values (float32)] + + num_channels = len(sample.channels) + packet = struct.pack( + f'>I d d I {num_channels}f', + 0x01, # Packet type: sample + sample.timestamp, + sample.lsl_timestamp, + num_channels, + *sample.channels + ) + return packet + + async def _broadcast_loop(self): + """Broadcast messages to all connected clients""" + while self.running: + try: + msg = await asyncio.wait_for( + self.broadcast_queue.get(), + timeout=0.1 + ) + + if self.connected_clients: + # Send to all connected clients + disconnected = set() + for client in self.connected_clients: + try: + await client.send(msg) + except websockets.exceptions.ConnectionClosed: + disconnected.add(client) + + # Remove disconnected clients + self.connected_clients -= disconnected + + except asyncio.TimeoutError: + continue + except Exception as e: + logger.warning(f"Broadcast error: {e}") + + async def _handle_client(self, websocket: WebSocketServerProtocol): + """Handle a new WebSocket client connection""" + client_addr = websocket.remote_address + logger.info(f"Client connected: {client_addr}") + + self.connected_clients.add(websocket) + + # Send stream metadata on connect + if self.active_stream_id: + metadata = self.inlet_manager.metadata.get(self.active_stream_id) + if metadata: + await self._send_metadata(websocket, metadata) + + try: + async for message in websocket: + await self._handle_message(websocket, message) + except websockets.exceptions.ConnectionClosed: + pass + finally: + self.connected_clients.discard(websocket) + logger.info(f"Client disconnected: {client_addr}") + + async def _send_metadata(self, websocket: WebSocketServerProtocol, metadata: StreamMetadata): + """Send stream metadata to client""" + msg = { + "type": "metadata", + "stream": metadata.to_dict(), + "lsl_time": local_clock() if LSL_AVAILABLE else time.time(), + } + await websocket.send(json.dumps(msg)) + + async def _handle_message(self, websocket: WebSocketServerProtocol, message: str): + """Handle incoming client message""" + try: + data = json.loads(message) + command = data.get("command") + + if command == "discover": + # Discover available streams + loop = asyncio.get_event_loop() + streams = await loop.run_in_executor( + None, + lambda: self.inlet_manager.discover_streams( + stream_type=data.get("type"), + stream_name=data.get("name"), + timeout=data.get("timeout", 5.0) + ) + ) + + response = { + "type": "streams", + "streams": [s.to_dict() for s in streams], + } + await websocket.send(json.dumps(response)) + + elif command == "connect": + # Connect to specific stream + stream_name = data.get("name") + stream_type = data.get("stream_type", "EEG") + + loop = asyncio.get_event_loop() + stream_id = await loop.run_in_executor( + None, + lambda: self.inlet_manager.connect_stream( + stream_name=stream_name, + stream_type=stream_type, + ) + ) + + if stream_id: + self.active_stream_id = stream_id + metadata = self.inlet_manager.metadata.get(stream_id) + if metadata: + await self._send_metadata(websocket, metadata) + + response = {"type": "connected", "stream_id": stream_id} + else: + response = {"type": "error", "message": "Could not connect to stream"} + + await websocket.send(json.dumps(response)) + + elif command == "disconnect": + # Disconnect from active stream + if self.active_stream_id: + self.inlet_manager.disconnect_stream(self.active_stream_id) + self.active_stream_id = None + + response = {"type": "disconnected"} + await websocket.send(json.dumps(response)) + + elif command == "configure": + # Handle device configuration (for compatibility) + logger.info(f"Configure command received: {data}") + response = {"type": "configured", "status": "ok"} + await websocket.send(json.dumps(response)) + + elif command == "ping": + response = { + "type": "pong", + "timestamp": time.time(), + "lsl_time": local_clock() if LSL_AVAILABLE else time.time(), + } + await websocket.send(json.dumps(response)) + + else: + response = {"type": "error", "message": f"Unknown command: {command}"} + await websocket.send(json.dumps(response)) + + except json.JSONDecodeError: + logger.warning(f"Invalid JSON message: {message[:100]}") + except Exception as e: + logger.error(f"Error handling message: {e}") + await websocket.send(json.dumps({"type": "error", "message": str(e)})) + + +# ============================================================================ +# LSL SIMULATION (for testing without hardware) +# ============================================================================ + +class LSLSimulator: + """Simulates an LSL outlet for testing without hardware""" + + def __init__( + self, + name: str = "PhantomLoop_Simulated_EEG", + stream_type: str = "EEG", + channel_count: int = 8, + sampling_rate: float = 250.0, + ): + self.name = name + self.stream_type = stream_type + self.channel_count = channel_count + self.sampling_rate = sampling_rate + self.running = False + self.outlet = None + + def start(self): + """Start simulated LSL outlet""" + if not LSL_AVAILABLE: + logger.error("pylsl not available for simulation") + return + + from pylsl import StreamOutlet, StreamInfo as LSLStreamInfo + + # Create stream info + info = LSLStreamInfo( + self.name, + self.stream_type, + self.channel_count, + self.sampling_rate, + 'float32', + 'phantomloop-sim-001' + ) + + # Add channel descriptions + desc = info.desc() + channels = desc.append_child("channels") + for i in range(self.channel_count): + ch = channels.append_child("channel") + ch.append_child_value("label", f"Ch{i+1}") + ch.append_child_value("type", "EEG") + ch.append_child_value("unit", "µV") + + # Add acquisition info + acq = desc.append_child("acquisition") + acq.append_child_value("manufacturer", "PhantomLoop") + acq.append_child_value("model", "Simulated EEG") + + # Create outlet + self.outlet = StreamOutlet(info) + self.running = True + + logger.info(f"Started simulated LSL outlet: {self.name}") + + # Start streaming thread + thread = threading.Thread(target=self._stream_loop, daemon=True) + thread.start() + + def _stream_loop(self): + """Generate and push simulated EEG samples""" + sample_interval = 1.0 / self.sampling_rate + phase = np.zeros(self.channel_count) + + while self.running and self.outlet: + # Generate simulated EEG (alpha waves + noise) + sample = [] + for ch in range(self.channel_count): + # Alpha rhythm (8-12 Hz) + pink noise + alpha = 20 * np.sin(2 * np.pi * 10 * phase[ch]) + noise = np.random.randn() * 10 + sample.append(float(alpha + noise)) + phase[ch] += sample_interval + + # Push sample + self.outlet.push_sample(sample) + + time.sleep(sample_interval) + + def stop(self): + """Stop the simulator""" + self.running = False + self.outlet = None + + +# ============================================================================ +# MAIN +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser( + description="LSL to WebSocket Bridge for PhantomLoop", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Auto-discover and connect to first EEG stream + python lsl_ws_bridge.py + + # Connect to specific stream by name + python lsl_ws_bridge.py --stream "OpenBCI_EEG" + + # Connect to a Muse stream + python lsl_ws_bridge.py --stream "Muse" --type EEG + + # Run with simulated data for testing + python lsl_ws_bridge.py --simulate + + # List available streams + python lsl_ws_bridge.py --list + + # Custom port + python lsl_ws_bridge.py --port 8768 + """ + ) + + parser.add_argument( + '--host', + default='0.0.0.0', + help='Host to bind WebSocket server (default: 0.0.0.0)' + ) + parser.add_argument( + '--port', '-p', + type=int, + default=8767, + help='WebSocket server port (default: 8767)' + ) + parser.add_argument( + '--stream', '-s', + help='LSL stream name to connect to' + ) + parser.add_argument( + '--type', '-t', + default='EEG', + help='LSL stream type to search for (default: EEG)' + ) + parser.add_argument( + '--list', '-l', + action='store_true', + help='List available LSL streams and exit' + ) + parser.add_argument( + '--simulate', + action='store_true', + help='Start a simulated LSL stream for testing' + ) + parser.add_argument( + '--no-auto-connect', + action='store_true', + help='Do not auto-connect to streams on startup' + ) + + args = parser.parse_args() + + if not LSL_AVAILABLE: + logger.error("pylsl is not installed. Install with: pip install pylsl") + logger.info("On Windows, you may also need to install liblsl:") + logger.info(" pip install pylsl") + logger.info(" # or download from: https://github.com/sccn/liblsl/releases") + sys.exit(1) + + # List streams mode + if args.list: + manager = LSLInletManager() + streams = manager.discover_streams(timeout=5.0) + + if streams: + print("\nAvailable LSL Streams:") + print("-" * 60) + for s in streams: + print(f" Name: {s.name}") + print(f" Type: {s.stream_type}") + print(f" Channels: {s.channel_count}") + print(f" Rate: {s.sampling_rate} Hz") + print(f" Format: {s.channel_format}") + print(f" Source: {s.source_id}") + print(f" Host: {s.hostname}") + print("-" * 60) + else: + print("\nNo LSL streams found on the network.") + print("Make sure your LSL source is running (OpenBCI GUI, muse-lsl, etc.)") + + return + + # Start simulator if requested + simulator = None + if args.simulate: + logger.info("Starting LSL simulator...") + simulator = LSLSimulator() + simulator.start() + # Give simulator time to start + time.sleep(0.5) + + # Start bridge + bridge = LSLWebSocketBridge( + host=args.host, + port=args.port, + stream_name=args.stream, + stream_type=args.type, + auto_connect=not args.no_auto_connect, + ) + + try: + asyncio.run(bridge.start()) + except KeyboardInterrupt: + logger.info("\nShutting down...") + finally: + if simulator: + simulator.stop() + + +if __name__ == "__main__": + main() diff --git a/src/devices/deviceProfiles.ts b/src/devices/deviceProfiles.ts index 0c4fba0..9c17633 100644 --- a/src/devices/deviceProfiles.ts +++ b/src/devices/deviceProfiles.ts @@ -814,6 +814,322 @@ export const DEVICE_PROFILES: Record = { setupUrl: 'https://pieeg.com/microbci/', }, + // ------------------------------------------------------------------------- + // Lab Streaming Layer (LSL) Generic Profiles + // Supports 130+ devices via the LSL protocol + // https://labstreaminglayer.org + // ------------------------------------------------------------------------- + 'lsl-generic-8': { + id: 'lsl-generic-8', + name: 'LSL Stream (8-Channel)', + manufacturer: 'Generic', + model: 'Lab Streaming Layer', + channelCount: 8, + samplingRates: [128, 250, 256, 500, 512, 1000, 1024, 2000], + defaultSamplingRate: 256, + resolution: 32, // LSL typically uses float32 + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: false, + hasAuxChannels: false, + supportsMarkers: true, // LSL supports marker streams + supportsBrainflow: true, // Can use BrainFlow streaming board + }, + defaultMontage: { + channelCount: 8, + labels: ['Ch1', 'Ch2', 'Ch3', 'Ch4', 'Ch5', 'Ch6', 'Ch7', 'Ch8'], + positions: getPositions(['Fp1', 'Fp2', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2']), + }, + protocolConfig: { + streamType: 'EEG', + bufferLength: 360.0, // 6 minutes + recover: true, + }, + description: 'Generic 8-channel LSL stream. Supports any LSL-compatible EEG device.', + setupUrl: 'https://labstreaminglayer.readthedocs.io/', + }, + + 'lsl-generic-16': { + id: 'lsl-generic-16', + name: 'LSL Stream (16-Channel)', + manufacturer: 'Generic', + model: 'Lab Streaming Layer', + channelCount: 16, + samplingRates: [128, 250, 256, 500, 512, 1000, 1024, 2000], + defaultSamplingRate: 256, + resolution: 32, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: false, + hasAuxChannels: false, + supportsMarkers: true, + supportsBrainflow: true, + }, + defaultMontage: { + channelCount: 16, + labels: ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'P3', 'Pz', 'P4', 'O1'], + positions: getPositions(['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'P3', 'Pz', 'P4', 'O1']), + }, + protocolConfig: { + streamType: 'EEG', + bufferLength: 360.0, + recover: true, + }, + description: 'Generic 16-channel LSL stream for research-grade EEG devices.', + setupUrl: 'https://labstreaminglayer.readthedocs.io/', + }, + + 'lsl-generic-32': { + id: 'lsl-generic-32', + name: 'LSL Stream (32-Channel)', + manufacturer: 'Generic', + model: 'Lab Streaming Layer', + channelCount: 32, + samplingRates: [128, 250, 256, 500, 512, 1000, 1024, 2000, 2048], + defaultSamplingRate: 256, + resolution: 32, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: false, + hasAuxChannels: false, + supportsMarkers: true, + supportsBrainflow: true, + }, + defaultMontage: { + channelCount: 32, + labels: [ + 'Fp1', 'Fp2', 'AF3', 'AF4', 'F7', 'F3', 'Fz', 'F4', 'F8', + 'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8', + 'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', + 'PO3', 'PO4', 'O1', 'Oz', 'O2' + ], + positions: getPositions([ + 'Fp1', 'Fp2', 'AF3', 'AF4', 'F7', 'F3', 'Fz', 'F4', 'F8', + 'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8', + 'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', + 'PO3', 'PO4', 'O1', 'Oz', 'O2' + ]), + }, + protocolConfig: { + streamType: 'EEG', + bufferLength: 360.0, + recover: true, + }, + description: 'Generic 32-channel LSL stream for high-density EEG.', + setupUrl: 'https://labstreaminglayer.readthedocs.io/', + }, + + 'lsl-generic-64': { + id: 'lsl-generic-64', + name: 'LSL Stream (64-Channel)', + manufacturer: 'Generic', + model: 'Lab Streaming Layer', + channelCount: 64, + samplingRates: [128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4096], + defaultSamplingRate: 256, + resolution: 32, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: false, + hasAuxChannels: false, + supportsMarkers: true, + supportsBrainflow: true, + }, + protocolConfig: { + streamType: 'EEG', + bufferLength: 360.0, + recover: true, + }, + description: 'Generic 64-channel LSL stream for research-grade high-density EEG.', + setupUrl: 'https://labstreaminglayer.readthedocs.io/', + }, + + 'lsl-brainproducts': { + id: 'lsl-brainproducts', + name: 'Brain Products (LSL)', + manufacturer: 'BrainProducts', + model: 'actiCHamp / LiveAmp', + channelCount: 32, + samplingRates: [250, 500, 1000, 2000, 5000, 10000, 25000], + defaultSamplingRate: 500, + resolution: 24, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: true, + hasAccelerometer: true, + hasGyroscope: true, + hasBattery: true, + hasAuxChannels: true, + supportsMarkers: true, + supportsBrainflow: false, + }, + protocolConfig: { + streamType: 'EEG', + streamName: 'LiveAmp*', + }, + description: 'Brain Products actiCHamp or LiveAmp via LSL Connector app.', + setupUrl: 'https://github.com/labstreaminglayer/App-BrainProducts', + }, + + 'lsl-biosemi': { + id: 'lsl-biosemi', + name: 'BioSemi ActiveTwo (LSL)', + manufacturer: 'Generic', + model: 'BioSemi ActiveTwo', + channelCount: 32, + samplingRates: [256, 512, 1024, 2048, 4096, 8192, 16384], + defaultSamplingRate: 512, + resolution: 24, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: false, + hasAuxChannels: true, + supportsMarkers: true, + supportsBrainflow: false, + }, + protocolConfig: { + streamType: 'EEG', + streamName: 'BioSemi*', + }, + description: 'BioSemi ActiveTwo research EEG via LSL connector.', + setupUrl: 'https://github.com/labstreaminglayer/App-BioSemi', + }, + + 'lsl-gtec': { + id: 'lsl-gtec', + name: 'g.tec (LSL)', + manufacturer: 'G.Tec', + model: 'g.USBamp / g.Nautilus', + channelCount: 16, + samplingRates: [256, 512, 1200, 2400, 4800, 9600, 19200, 38400], + defaultSamplingRate: 512, + resolution: 24, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: true, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: true, + hasAuxChannels: true, + supportsMarkers: true, + supportsBrainflow: true, + }, + brainflowBoardId: BRAINFLOW_BOARD_IDS.ANT_NEURO_EE_411, // Use ANT board as proxy + protocolConfig: { + streamType: 'EEG', + streamName: 'g.Tec*', + }, + description: 'g.tec amplifiers via g.NEEDaccess LSL connector.', + setupUrl: 'https://github.com/labstreaminglayer/App-g.Tec', + }, + + 'lsl-cognionics': { + id: 'lsl-cognionics', + name: 'Cognionics Quick-20 (LSL)', + manufacturer: 'Cognionics', + model: 'Quick-20 / Quick-30', + channelCount: 20, + samplingRates: [250, 500], + defaultSamplingRate: 500, + resolution: 24, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: true, + hasAccelerometer: true, + hasGyroscope: true, + hasBattery: true, + hasAuxChannels: false, + supportsMarkers: true, + supportsBrainflow: false, + }, + protocolConfig: { + streamType: 'EEG', + streamName: 'CGX*', + }, + description: 'Cognionics dry-electrode EEG via LSL connector.', + setupUrl: 'https://github.com/labstreaminglayer/App-Cognionics', + }, + + 'lsl-antneuro': { + id: 'lsl-antneuro', + name: 'ANT Neuro eego (LSL)', + manufacturer: 'ANT Neuro', + model: 'eego sport / mylab', + channelCount: 32, + samplingRates: [256, 512, 1024, 2048], + defaultSamplingRate: 512, + resolution: 24, + protocols: ['lsl'], + defaultProtocol: 'lsl', + brainflowBoardId: BRAINFLOW_BOARD_IDS.ANT_NEURO_EE_411, + capabilities: { + hasImpedanceMeasurement: true, + hasAccelerometer: false, + hasGyroscope: false, + hasBattery: true, + hasAuxChannels: true, + supportsMarkers: true, + supportsBrainflow: true, + }, + protocolConfig: { + streamType: 'EEG', + streamName: 'ANT*', + }, + description: 'ANT Neuro eego via LSL connector.', + setupUrl: 'https://www.ant-neuro.com/', + }, + + 'lsl-nirx': { + id: 'lsl-nirx', + name: 'NIRx fNIRS (LSL)', + manufacturer: 'Generic', + model: 'NIRSport / NIRScout', + channelCount: 16, + samplingRates: [7.8125, 10, 15.625], + defaultSamplingRate: 10, + resolution: 16, + protocols: ['lsl'], + defaultProtocol: 'lsl', + capabilities: { + hasImpedanceMeasurement: false, + hasAccelerometer: true, + hasGyroscope: true, + hasBattery: true, + hasAuxChannels: false, + supportsMarkers: true, + supportsBrainflow: false, + }, + protocolConfig: { + streamType: 'fNIRS', + streamName: 'NIRx*', + }, + description: 'NIRx fNIRS systems with LSL support.', + setupUrl: 'https://nirx.net/software', + }, + // ------------------------------------------------------------------------- // Brainflow Testing // ------------------------------------------------------------------------- diff --git a/src/streams/index.ts b/src/streams/index.ts index 5c8cc44..665be8f 100644 --- a/src/streams/index.ts +++ b/src/streams/index.ts @@ -115,6 +115,71 @@ export const streamAdapterRegistry: StreamAdapterRegistry = { defaultUrl: 'ws://localhost:8765', }, + // ------------------------------------------------------------------------- + // Lab Streaming Layer (LSL) - Universal Protocol + // Supports 130+ devices via pylsl bridge + // ------------------------------------------------------------------------- + 'lsl-generic-8': { + name: 'LSL Stream (8-Channel)', + description: 'Generic 8-channel LSL stream - any LSL-compatible EEG device', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-generic-8', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-generic-16': { + name: 'LSL Stream (16-Channel)', + description: 'Generic 16-channel LSL stream for research-grade EEG', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-generic-16', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-generic-32': { + name: 'LSL Stream (32-Channel)', + description: 'Generic 32-channel LSL stream for high-density EEG', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-generic-32', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-generic-64': { + name: 'LSL Stream (64-Channel)', + description: 'Generic 64-channel LSL stream for research-grade high-density EEG', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-generic-64', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-brainproducts': { + name: 'Brain Products (LSL)', + description: 'Brain Products actiCHamp/LiveAmp via LSL (up to 25kHz)', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-brainproducts', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-biosemi': { + name: 'BioSemi ActiveTwo (LSL)', + description: 'BioSemi ActiveTwo research EEG via LSL (up to 16kHz)', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-biosemi', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-gtec': { + name: 'g.tec (LSL)', + description: 'g.tec g.USBamp/g.Nautilus via g.NEEDaccess LSL', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-gtec', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-cognionics': { + name: 'Cognionics Quick-20/30 (LSL)', + description: 'Cognionics dry-electrode EEG via LSL', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-cognionics', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-antneuro': { + name: 'ANT Neuro eego (LSL)', + description: 'ANT Neuro eego sport/mylab via LSL', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-antneuro', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + 'lsl-nirx': { + name: 'NIRx fNIRS (LSL)', + description: 'NIRx NIRSport/NIRScout fNIRS via LSL', + factory: (opts) => createUniversalEEGAdapter({ deviceId: 'lsl-nirx', ...opts }), + defaultUrl: 'ws://localhost:8767', + }, + // ------------------------------------------------------------------------- // Brainflow Generic // ------------------------------------------------------------------------- @@ -166,6 +231,11 @@ export function listAdaptersByCategory() { 'Consumer EEG': ['neurosky-mindwave', 'muse-2', 'muse-s'], 'Research EEG': ['emotiv-insight', 'emotiv-epoc-x'], 'Custom Hardware': ['esp-eeg', 'cerelog-esp-eeg'], + 'Lab Streaming Layer': [ + 'lsl-generic-8', 'lsl-generic-16', 'lsl-generic-32', 'lsl-generic-64', + 'lsl-brainproducts', 'lsl-biosemi', 'lsl-gtec', 'lsl-cognionics', + 'lsl-antneuro', 'lsl-nirx' + ], 'Testing': ['brainflow-synthetic'], }; From ab084b2ea301444569a6cff4607709b0f300fdf7 Mon Sep 17 00:00:00 2001 From: Youssef Date: Sat, 31 Jan 2026 10:30:10 -0500 Subject: [PATCH 2/5] feat: add WebSocket bridge with real-time DSP for EEG signal hygiene --- README.md | 84 +++ scripts/pieeg_ws_bridge_dsp.py | 1155 ++++++++++++++++++++++++++++++++ 2 files changed, 1239 insertions(+) create mode 100644 scripts/pieeg_ws_bridge_dsp.py diff --git a/README.md b/README.md index 3cfc47d..2231ea7 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ Since browsers cannot directly access hardware (SPI, Serial, BLE, TCP), PhantomL |--------|--------|------|------| | `lsl_ws_bridge.py` | Any LSL source (130+ devices) | 8767 | LSL Inlet → WebSocket | | `pieeg_ws_bridge.py` | PiEEG (Raspberry Pi) | 8766 | SPI / BrainFlow / Simulation | +| `pieeg_ws_bridge_dsp.py` | PiEEG + Signal Hygiene | 8766 | SPI + Real-Time DSP | | `cerelog_ws_bridge.py` | Cerelog ESP-EEG | 8765 | TCP-to-WebSocket | ### LSL Bridge @@ -287,6 +288,89 @@ python scripts/pieeg_ws_bridge.py # Auto-detects non-Pi systems {"command": "set_sample_rate", "rate": 500} ``` +### PiEEG Bridge with DSP (Signal Hygiene) + +The DSP-enhanced bridge applies real-time digital signal processing before streaming, removing common EEG artifacts at the source: + +**Signal Hygiene Pipeline:** +``` +Raw ADS1299 → DC Block → Notch (50/60 Hz) → Bandpass (0.5-45 Hz) → Artifact Reject → CAR → WebSocket +``` + +| Filter | Purpose | Default | +|--------|---------|--------| +| DC Blocker | Removes electrode drift | α=0.995 (~0.8 Hz) | +| Notch Filter | Removes powerline + harmonics | 60 Hz (3 harmonics) | +| Bandpass | Isolates EEG band | 0.5-45 Hz (order 4) | +| Artifact Rejection | Blanks amplitude spikes | ±150 µV threshold | +| CAR | Common Average Reference | Disabled by default | + +```bash +# Basic usage with 60 Hz notch (Americas, Asia) +python scripts/pieeg_ws_bridge_dsp.py --notch 60 + +# European 50 Hz with custom bandpass +python scripts/pieeg_ws_bridge_dsp.py --notch 50 --highpass 1.0 --lowpass 40 + +# Full signal hygiene with artifact rejection and CAR +python scripts/pieeg_ws_bridge_dsp.py --notch 60 --car --artifact-threshold 150 + +# Minimal processing (DC block only) +python scripts/pieeg_ws_bridge_dsp.py --no-notch --no-bandpass + +# High sample rate with adjusted filters +python scripts/pieeg_ws_bridge_dsp.py --sample-rate 500 --notch 60 --lowpass 100 +``` + +**All DSP Options:** +```bash +# Network +--host 0.0.0.0 # WebSocket bind address +--port 8766 # WebSocket port + +# Hardware +--sample-rate 250 # 250, 500, 1000, 2000 Hz +--gain 24 # PGA gain: 1, 2, 4, 6, 8, 12, 24 +--channels 8 # Number of channels + +# Notch Filter +--notch 60 # Powerline frequency (50 or 60 Hz) +--notch-harmonics 3 # Filter fundamental + N harmonics +--notch-q 30 # Quality factor (higher = narrower) +--no-notch # Disable notch filter + +# Bandpass Filter +--highpass 0.5 # High-pass cutoff (Hz) +--lowpass 45 # Low-pass cutoff (Hz) +--filter-order 4 # Butterworth order +--no-bandpass # Disable bandpass filter + +# DC Blocking +--dc-alpha 0.995 # DC blocker pole (0.99-0.999) +--no-dc-block # Disable DC blocking + +# Artifact Rejection +--artifact-threshold 150 # Threshold in µV +--no-artifact # Disable artifact rejection + +# Common Average Reference +--car # Enable CAR +--car-exclude "0,7" # Exclude channels from CAR + +# Smoothing +--smooth 0.3 # Exponential smoothing alpha (0 = disabled) +``` + +**Extended Packet Format:** + +DSP packets include artifact flags per sample: +``` +Header: magic(2) + type(1) + samples(2) + channels(1) + timestamp(8) +Data: [float32 × channels + artifact_byte] × samples +``` +- `type = 0x02` indicates DSP-processed data +- `artifact_byte` is a bitmask of channels with blanked artifacts + ### Cerelog Bridge ```bash diff --git a/scripts/pieeg_ws_bridge_dsp.py b/scripts/pieeg_ws_bridge_dsp.py new file mode 100644 index 0000000..4227aca --- /dev/null +++ b/scripts/pieeg_ws_bridge_dsp.py @@ -0,0 +1,1155 @@ +#!/usr/bin/env python3 +""" +WebSocket Bridge for PiEEG with Real-Time DSP Signal Hygiene + +This enhanced bridge applies real-time digital signal processing to clean +EEG signals before WebSocket transmission. Runs on Raspberry Pi with PiEEG. + +Signal Hygiene Pipeline: + 1. DC Blocking - Removes electrode drift (high-pass @ 0.1 Hz) + 2. Notch Filter - Removes powerline interference (50/60 Hz + harmonics) + 3. Bandpass Filter - Isolates EEG band (0.5-45 Hz default) + 4. Artifact Reject - Flags/blanks samples exceeding threshold + 5. CAR - Common Average Reference for spatial filtering + 6. Smoothing - Optional exponential moving average + +Usage: + # Basic usage with 60 Hz notch (Americas, Asia) + python pieeg_ws_bridge_dsp.py --notch 60 + + # European 50 Hz with custom bandpass + python pieeg_ws_bridge_dsp.py --notch 50 --highpass 1.0 --lowpass 40 + + # Full signal hygiene with artifact rejection + python pieeg_ws_bridge_dsp.py --notch 60 --car --artifact-threshold 150 + + # Disable specific filters + python pieeg_ws_bridge_dsp.py --no-notch --no-bandpass + +Requirements: + pip install websockets spidev RPi.GPIO numpy scipy + + For simulation mode (development): + pip install websockets numpy scipy + +License: MIT +""" + +import asyncio +import struct +import json +import sys +import time +import argparse +from typing import Optional, List, Dict, Tuple, Callable +from dataclasses import dataclass, field, asdict +from enum import IntEnum +from collections import deque +import numpy as np + +try: + from scipy import signal as scipy_signal + SCIPY_AVAILABLE = True +except ImportError: + print("Warning: scipy not available. Install with: pip install scipy") + print(" DSP filters will be disabled.") + SCIPY_AVAILABLE = False + +try: + import websockets +except ImportError: + print("Error: websockets library required. Install with: pip install websockets") + sys.exit(1) + +# Check if running on Raspberry Pi +IS_RASPBERRY_PI = False +try: + import spidev + import RPi.GPIO as GPIO + IS_RASPBERRY_PI = True +except ImportError: + print("Warning: Not running on Raspberry Pi. Using simulation mode.") + + +# ============================================================================ +# DSP FILTER CLASSES +# ============================================================================ + +class DCBlocker: + """ + First-order IIR DC blocking filter. + Removes DC offset and very slow drifts from electrode polarization. + + Transfer function: H(z) = (1 - z^-1) / (1 - α*z^-1) + where α controls the cutoff frequency. + """ + + def __init__(self, alpha: float = 0.995, num_channels: int = 8): + """ + Args: + alpha: Pole location (0.99 = ~1.6 Hz, 0.995 = ~0.8 Hz, 0.999 = ~0.16 Hz) + num_channels: Number of EEG channels + """ + self.alpha = alpha + self.num_channels = num_channels + self.x_prev = np.zeros(num_channels) + self.y_prev = np.zeros(num_channels) + + def process(self, sample: np.ndarray) -> np.ndarray: + """Process single sample (all channels)""" + y = sample - self.x_prev + self.alpha * self.y_prev + self.x_prev = sample.copy() + self.y_prev = y.copy() + return y + + def reset(self): + """Reset filter state""" + self.x_prev.fill(0) + self.y_prev.fill(0) + + +class IIRFilter: + """ + Real-time IIR filter using scipy's sosfilt with state preservation. + Supports lowpass, highpass, bandpass, and bandstop (notch) configurations. + """ + + def __init__(self, sos: np.ndarray, num_channels: int = 8): + """ + Args: + sos: Second-order sections from scipy.signal.butter/iirnotch + num_channels: Number of EEG channels + """ + self.sos = sos + self.num_channels = num_channels + # State: (n_sections, 2) per channel + self.zi = np.zeros((num_channels, sos.shape[0], 2)) + + def process(self, sample: np.ndarray) -> np.ndarray: + """Process single sample through filter""" + output = np.zeros(self.num_channels) + for ch in range(self.num_channels): + # Filter single sample, update state + out, self.zi[ch] = scipy_signal.sosfilt( + self.sos, + [sample[ch]], + zi=self.zi[ch] + ) + output[ch] = out[0] + return output + + def process_batch(self, samples: np.ndarray) -> np.ndarray: + """Process batch of samples [num_samples, num_channels]""" + output = np.zeros_like(samples) + for ch in range(self.num_channels): + output[:, ch], self.zi[ch] = scipy_signal.sosfilt( + self.sos, + samples[:, ch], + zi=self.zi[ch] + ) + return output + + def reset(self): + """Reset filter state""" + self.zi.fill(0) + + @classmethod + def create_notch(cls, freq: float, fs: float, Q: float = 30.0, + num_channels: int = 8) -> 'IIRFilter': + """Create notch filter at specified frequency""" + b, a = scipy_signal.iirnotch(freq, Q, fs) + sos = scipy_signal.tf2sos(b, a) + return cls(sos, num_channels) + + @classmethod + def create_bandpass(cls, lowcut: float, highcut: float, fs: float, + order: int = 4, num_channels: int = 8) -> 'IIRFilter': + """Create Butterworth bandpass filter""" + nyq = fs / 2 + low = lowcut / nyq + high = highcut / nyq + sos = scipy_signal.butter(order, [low, high], btype='band', output='sos') + return cls(sos, num_channels) + + @classmethod + def create_highpass(cls, cutoff: float, fs: float, order: int = 4, + num_channels: int = 8) -> 'IIRFilter': + """Create Butterworth highpass filter""" + nyq = fs / 2 + normalized_cutoff = cutoff / nyq + sos = scipy_signal.butter(order, normalized_cutoff, btype='high', output='sos') + return cls(sos, num_channels) + + @classmethod + def create_lowpass(cls, cutoff: float, fs: float, order: int = 4, + num_channels: int = 8) -> 'IIRFilter': + """Create Butterworth lowpass filter""" + nyq = fs / 2 + normalized_cutoff = cutoff / nyq + sos = scipy_signal.butter(order, normalized_cutoff, btype='low', output='sos') + return cls(sos, num_channels) + + +class NotchFilterBank: + """ + Bank of notch filters for powerline interference removal. + Includes fundamental frequency and harmonics. + """ + + def __init__(self, fundamental: float, fs: float, num_harmonics: int = 3, + Q: float = 30.0, num_channels: int = 8): + """ + Args: + fundamental: Powerline frequency (50 or 60 Hz) + fs: Sampling frequency + num_harmonics: Number of harmonics to filter (1=fundamental only) + Q: Quality factor (higher = narrower notch) + num_channels: Number of EEG channels + """ + self.filters: List[IIRFilter] = [] + nyq = fs / 2 + + for i in range(1, num_harmonics + 1): + freq = fundamental * i + if freq < nyq: # Only add if below Nyquist + notch = IIRFilter.create_notch(freq, fs, Q, num_channels) + self.filters.append(notch) + print(f" ✓ Notch filter @ {freq} Hz (Q={Q})") + + def process(self, sample: np.ndarray) -> np.ndarray: + """Apply all notch filters in series""" + output = sample + for filt in self.filters: + output = filt.process(output) + return output + + def reset(self): + """Reset all filter states""" + for filt in self.filters: + filt.reset() + + +class ArtifactRejector: + """ + Simple amplitude-based artifact detection and rejection. + Flags or blanks samples exceeding threshold. + """ + + def __init__(self, threshold_uv: float = 150.0, blanking_samples: int = 5, + num_channels: int = 8): + """ + Args: + threshold_uv: Amplitude threshold in microvolts + blanking_samples: Number of samples to blank after artifact + num_channels: Number of EEG channels + """ + self.threshold = threshold_uv + self.blanking_samples = blanking_samples + self.num_channels = num_channels + self.blanking_counter = np.zeros(num_channels, dtype=int) + self.last_good = np.zeros(num_channels) + self.artifact_count = 0 + self.total_samples = 0 + + def process(self, sample: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Process sample and return (cleaned_sample, artifact_flags). + + Returns: + cleaned: Sample with artifacts replaced by last good value + flags: Boolean array indicating artifact on each channel + """ + self.total_samples += 1 + flags = np.zeros(self.num_channels, dtype=bool) + cleaned = sample.copy() + + for ch in range(self.num_channels): + # Check if in blanking period + if self.blanking_counter[ch] > 0: + cleaned[ch] = self.last_good[ch] + self.blanking_counter[ch] -= 1 + flags[ch] = True + # Check for new artifact + elif np.abs(sample[ch]) > self.threshold: + cleaned[ch] = self.last_good[ch] + self.blanking_counter[ch] = self.blanking_samples + flags[ch] = True + self.artifact_count += 1 + else: + self.last_good[ch] = sample[ch] + + return cleaned, flags + + def get_artifact_rate(self) -> float: + """Get percentage of samples flagged as artifacts""" + if self.total_samples == 0: + return 0.0 + return (self.artifact_count / self.total_samples) * 100 + + def reset(self): + """Reset artifact detector state""" + self.blanking_counter.fill(0) + self.last_good.fill(0) + self.artifact_count = 0 + self.total_samples = 0 + + +class CommonAverageReference: + """ + Common Average Reference (CAR) spatial filter. + Subtracts the mean across all channels from each channel. + Helps remove common-mode noise and reference artifacts. + """ + + def __init__(self, num_channels: int = 8, exclude_channels: List[int] = None): + """ + Args: + num_channels: Number of EEG channels + exclude_channels: Channels to exclude from average (e.g., bad channels) + """ + self.num_channels = num_channels + self.exclude = set(exclude_channels or []) + + def process(self, sample: np.ndarray) -> np.ndarray: + """Apply common average reference""" + # Calculate mean excluding bad channels + mask = np.ones(self.num_channels, dtype=bool) + for ch in self.exclude: + if 0 <= ch < self.num_channels: + mask[ch] = False + + if mask.sum() == 0: + return sample # All channels excluded, return unchanged + + avg = np.mean(sample[mask]) + return sample - avg + + +class ExponentialSmoother: + """ + Exponential moving average for optional signal smoothing. + Lower alpha = more smoothing (more lag). + """ + + def __init__(self, alpha: float = 0.3, num_channels: int = 8): + """ + Args: + alpha: Smoothing factor (0-1). Higher = less smoothing. + num_channels: Number of EEG channels + """ + self.alpha = alpha + self.num_channels = num_channels + self.ema = np.zeros(num_channels) + self.initialized = False + + def process(self, sample: np.ndarray) -> np.ndarray: + """Apply exponential smoothing""" + if not self.initialized: + self.ema = sample.copy() + self.initialized = True + return sample + + self.ema = self.alpha * sample + (1 - self.alpha) * self.ema + return self.ema.copy() + + def reset(self): + """Reset smoother state""" + self.ema.fill(0) + self.initialized = False + + +# ============================================================================ +# DSP PIPELINE +# ============================================================================ + +@dataclass +class DSPConfig: + """Configuration for the DSP pipeline""" + # Sampling + sample_rate: float = 250.0 + num_channels: int = 8 + + # DC Blocking + dc_block_enabled: bool = True + dc_block_alpha: float = 0.995 # ~0.8 Hz cutoff at 250 SPS + + # Notch Filter + notch_enabled: bool = True + notch_freq: float = 60.0 # 50 for Europe, 60 for Americas + notch_harmonics: int = 3 # Filter 60, 120, 180 Hz + notch_q: float = 30.0 # Quality factor + + # Bandpass Filter + bandpass_enabled: bool = True + highpass_freq: float = 0.5 # Hz + lowpass_freq: float = 45.0 # Hz + filter_order: int = 4 + + # Artifact Rejection + artifact_enabled: bool = True + artifact_threshold: float = 150.0 # µV + artifact_blanking: int = 5 # samples + + # Common Average Reference + car_enabled: bool = False + car_exclude_channels: List[int] = field(default_factory=list) + + # Smoothing + smoothing_enabled: bool = False + smoothing_alpha: float = 0.3 + + +class DSPPipeline: + """ + Complete real-time DSP pipeline for EEG signal hygiene. + Processes samples through configurable filter chain. + """ + + def __init__(self, config: DSPConfig): + self.config = config + self.filters: List[Tuple[str, Callable]] = [] + self.artifact_flags = np.zeros(config.num_channels, dtype=bool) + + self._build_pipeline() + + def _build_pipeline(self): + """Build filter chain based on configuration""" + cfg = self.config + print("\n📊 Building DSP Pipeline:") + print(f" Sample Rate: {cfg.sample_rate} Hz") + print(f" Channels: {cfg.num_channels}") + + # 1. DC Blocking (first to remove drift before other filters) + if cfg.dc_block_enabled: + dc_blocker = DCBlocker(cfg.dc_block_alpha, cfg.num_channels) + self.filters.append(("DC Block", dc_blocker.process)) + print(f" ✓ DC Blocker (α={cfg.dc_block_alpha})") + + if not SCIPY_AVAILABLE: + print(" ⚠ Scipy not available - skipping IIR filters") + else: + # 2. Notch Filter Bank + if cfg.notch_enabled: + notch_bank = NotchFilterBank( + cfg.notch_freq, + cfg.sample_rate, + cfg.notch_harmonics, + cfg.notch_q, + cfg.num_channels + ) + self.filters.append(("Notch", notch_bank.process)) + + # 3. Bandpass Filter + if cfg.bandpass_enabled: + bandpass = IIRFilter.create_bandpass( + cfg.highpass_freq, + cfg.lowpass_freq, + cfg.sample_rate, + cfg.filter_order, + cfg.num_channels + ) + self.filters.append(("Bandpass", bandpass.process)) + print(f" ✓ Bandpass {cfg.highpass_freq}-{cfg.lowpass_freq} Hz (order {cfg.filter_order})") + + # 4. Artifact Rejection + if cfg.artifact_enabled: + self.artifact_rejector = ArtifactRejector( + cfg.artifact_threshold, + cfg.artifact_blanking, + cfg.num_channels + ) + # Don't add to filter chain - handle separately for flags + print(f" ✓ Artifact Rejection (±{cfg.artifact_threshold} µV)") + else: + self.artifact_rejector = None + + # 5. Common Average Reference + if cfg.car_enabled: + car = CommonAverageReference(cfg.num_channels, cfg.car_exclude_channels) + self.filters.append(("CAR", car.process)) + print(f" ✓ Common Average Reference") + + # 6. Smoothing (optional, last) + if cfg.smoothing_enabled: + smoother = ExponentialSmoother(cfg.smoothing_alpha, cfg.num_channels) + self.filters.append(("Smooth", smoother.process)) + print(f" ✓ Exponential Smoothing (α={cfg.smoothing_alpha})") + + print(f"\n Pipeline: {' → '.join([name for name, _ in self.filters])}") + if self.artifact_rejector: + print(f" + Artifact flagging\n") + + def process(self, sample: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Process single sample through entire pipeline. + + Args: + sample: Raw sample array [num_channels] + + Returns: + processed: Cleaned sample + artifact_flags: Boolean flags for each channel + """ + output = sample.astype(np.float64) + + # Apply filter chain + for name, filter_fn in self.filters: + output = filter_fn(output) + + # Artifact detection (after filtering for better threshold accuracy) + if self.artifact_rejector: + output, self.artifact_flags = self.artifact_rejector.process(output) + else: + self.artifact_flags = np.zeros(self.config.num_channels, dtype=bool) + + return output, self.artifact_flags + + def process_batch(self, samples: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Process batch of samples. + + Args: + samples: Raw samples [num_samples, num_channels] + + Returns: + processed: Cleaned samples + artifact_flags: Boolean flags [num_samples, num_channels] + """ + num_samples = samples.shape[0] + output = np.zeros_like(samples, dtype=np.float64) + flags = np.zeros((num_samples, self.config.num_channels), dtype=bool) + + for i in range(num_samples): + output[i], flags[i] = self.process(samples[i]) + + return output, flags + + def get_stats(self) -> Dict: + """Get pipeline statistics""" + stats = { + "num_filters": len(self.filters), + "filter_chain": [name for name, _ in self.filters] + } + if self.artifact_rejector: + stats["artifact_rate_percent"] = self.artifact_rejector.get_artifact_rate() + return stats + + +# ============================================================================ +# ADS1299 CONSTANTS (from original bridge) +# ============================================================================ + +class ADS1299Register(IntEnum): + ID = 0x00 + CONFIG1 = 0x01 + CONFIG2 = 0x02 + CONFIG3 = 0x03 + CH1SET = 0x05 + BIAS_SENSP = 0x0D + BIAS_SENSN = 0x0E + + +class ADS1299Command(IntEnum): + WAKEUP = 0x02 + STANDBY = 0x04 + RESET = 0x06 + START = 0x08 + STOP = 0x0A + RDATAC = 0x10 + SDATAC = 0x11 + RREG = 0x20 + WREG = 0x40 + + +class ADS1299Gain(IntEnum): + GAIN_1 = 0x00 + GAIN_2 = 0x10 + GAIN_4 = 0x20 + GAIN_6 = 0x30 + GAIN_8 = 0x40 + GAIN_12 = 0x50 + GAIN_24 = 0x60 + + +class ADS1299SampleRate(IntEnum): + SPS_16000 = 0x00 + SPS_8000 = 0x01 + SPS_4000 = 0x02 + SPS_2000 = 0x03 + SPS_1000 = 0x04 + SPS_500 = 0x05 + SPS_250 = 0x06 + + +VREF = 4.5 +NUM_CHANNELS = 8 +STATUS_BYTES = 3 +BYTES_PER_CHANNEL = 3 +DEFAULT_DRDY_PIN = 17 +DEFAULT_RESET_PIN = 27 + + +# ============================================================================ +# PIEEG DEVICE WITH DSP +# ============================================================================ + +@dataclass +class PiEEGDSPConfig: + """Extended configuration for PiEEG with DSP""" + # Hardware + spi_bus: int = 0 + spi_device: int = 0 + spi_speed: int = 2000000 + drdy_pin: int = DEFAULT_DRDY_PIN + reset_pin: int = DEFAULT_RESET_PIN + sample_rate: ADS1299SampleRate = ADS1299SampleRate.SPS_250 + gain: ADS1299Gain = ADS1299Gain.GAIN_24 + num_channels: int = 8 + daisy_chain: bool = False + + # DSP + dsp_config: DSPConfig = field(default_factory=DSPConfig) + + +class PiEEGDeviceDSP: + """PiEEG device with integrated DSP pipeline""" + + def __init__(self, config: PiEEGDSPConfig): + self.config = config + self.spi = None + self.is_streaming = False + self.sample_count = 0 + self.start_time = 0.0 + + # Calculate scale factor + gain_values = {0x00: 1, 0x10: 2, 0x20: 4, 0x30: 6, 0x40: 8, 0x50: 12, 0x60: 24} + gain = gain_values.get(config.gain, 24) + self.scale_uv = (2 * VREF / gain) / (2**24) * 1e6 + + # DSP Pipeline + config.dsp_config.sample_rate = self._get_sample_rate_hz() + config.dsp_config.num_channels = config.num_channels + self.dsp = DSPPipeline(config.dsp_config) + + def _get_sample_rate_hz(self) -> int: + rates = {0: 16000, 1: 8000, 2: 4000, 3: 2000, 4: 1000, 5: 500, 6: 250} + return rates.get(self.config.sample_rate, 250) + + def _get_gain_value(self) -> int: + gains = {0x00: 1, 0x10: 2, 0x20: 4, 0x30: 6, 0x40: 8, 0x50: 12, 0x60: 24} + return gains.get(self.config.gain, 24) + + def setup_gpio(self): + GPIO.setmode(GPIO.BCM) + GPIO.setwarnings(False) + GPIO.setup(self.config.drdy_pin, GPIO.IN, pull_up_down=GPIO.PUD_UP) + GPIO.setup(self.config.reset_pin, GPIO.OUT) + GPIO.output(self.config.reset_pin, GPIO.HIGH) + + def setup_spi(self): + self.spi = spidev.SpiDev() + self.spi.open(self.config.spi_bus, self.config.spi_device) + self.spi.max_speed_hz = self.config.spi_speed + self.spi.mode = 0b01 + + def reset(self): + GPIO.output(self.config.reset_pin, GPIO.LOW) + time.sleep(0.001) + GPIO.output(self.config.reset_pin, GPIO.HIGH) + time.sleep(0.1) + + def send_command(self, cmd: int): + self.spi.xfer2([cmd]) + time.sleep(0.000004) + + def write_register(self, reg: int, value: int): + self.send_command(ADS1299Command.SDATAC) + self.spi.xfer2([ADS1299Command.WREG | reg, 0x00, value]) + time.sleep(0.000004) + + def read_register(self, reg: int) -> int: + self.send_command(ADS1299Command.SDATAC) + result = self.spi.xfer2([ADS1299Command.RREG | reg, 0x00, 0x00]) + return result[2] + + def configure(self): + self.send_command(ADS1299Command.SDATAC) + time.sleep(0.001) + + device_id = self.read_register(ADS1299Register.ID) + if (device_id & 0x1F) != 0x1E: + print(f"Warning: Unexpected device ID: 0x{device_id:02X}") + else: + print(f"✓ ADS1299 detected (ID: 0x{device_id:02X})") + + config1 = self.config.sample_rate + if self.config.daisy_chain: + config1 |= 0xC0 + else: + config1 |= 0x90 + self.write_register(ADS1299Register.CONFIG1, config1) + + self.write_register(ADS1299Register.CONFIG2, 0xC0) + self.write_register(ADS1299Register.CONFIG3, 0xEC) + time.sleep(0.15) + + for ch in range(8): + self.write_register(ADS1299Register.CH1SET + ch, self.config.gain | 0x00) + + self.write_register(ADS1299Register.BIAS_SENSP, 0xFF) + self.write_register(ADS1299Register.BIAS_SENSN, 0xFF) + + print(f"✓ ADS1299 configured: {self._get_sample_rate_hz()} SPS, Gain: {self._get_gain_value()}x") + + def start_streaming(self): + self.send_command(ADS1299Command.START) + time.sleep(0.001) + self.send_command(ADS1299Command.RDATAC) + self.is_streaming = True + self.sample_count = 0 + self.start_time = time.time() + print("✓ DSP Streaming started") + + def stop_streaming(self): + self.send_command(ADS1299Command.SDATAC) + self.send_command(ADS1299Command.STOP) + self.is_streaming = False + + elapsed = time.time() - self.start_time + if elapsed > 0: + actual_rate = self.sample_count / elapsed + stats = self.dsp.get_stats() + print(f"✓ Streaming stopped. {self.sample_count} samples, {actual_rate:.1f} SPS") + if "artifact_rate_percent" in stats: + print(f" Artifact rate: {stats['artifact_rate_percent']:.2f}%") + + def wait_for_drdy(self, timeout: float = 0.1) -> bool: + start = time.time() + while GPIO.input(self.config.drdy_pin) == GPIO.HIGH: + if time.time() - start > timeout: + return False + time.sleep(0.0001) + return True + + def read_sample(self) -> Optional[Tuple[List[float], List[bool]]]: + """Read and process one sample through DSP pipeline""" + if not self.wait_for_drdy(): + return None + + num_bytes = STATUS_BYTES + (self.config.num_channels * BYTES_PER_CHANNEL) + data = self.spi.xfer2([0x00] * num_bytes) + + # Parse raw channels + raw_channels = np.zeros(self.config.num_channels) + for ch in range(self.config.num_channels): + offset = STATUS_BYTES + (ch * BYTES_PER_CHANNEL) + value = (data[offset] << 16) | (data[offset + 1] << 8) | data[offset + 2] + if value & 0x800000: + value -= 0x1000000 + raw_channels[ch] = value * self.scale_uv + + # Apply DSP pipeline + processed, artifact_flags = self.dsp.process(raw_channels) + + self.sample_count += 1 + return processed.tolist(), artifact_flags.tolist() + + def connect(self) -> bool: + try: + self.setup_gpio() + self.setup_spi() + self.reset() + self.configure() + return True + except Exception as e: + print(f"✗ Connection failed: {e}") + return False + + def disconnect(self): + if self.is_streaming: + self.stop_streaming() + if self.spi: + self.spi.close() + GPIO.cleanup() + + +# ============================================================================ +# SIMULATOR WITH DSP +# ============================================================================ + +class PiEEGSimulatorDSP: + """Simulates PiEEG with DSP for development""" + + def __init__(self, config: PiEEGDSPConfig): + self.config = config + self.is_streaming = False + self.sample_count = 0 + self.start_time = 0.0 + self._last_sample_time = 0.0 + self._sample_interval = 1.0 / self._get_sample_rate_hz() + + # Simulate various EEG rhythms + self._phase = np.zeros(8) + # Mix of alpha (8-12 Hz), beta (12-30 Hz), theta (4-8 Hz) + self._freqs = [10.0, 10.5, 22.0, 11.0, 6.0, 9.0, 25.0, 10.0] + self._amps = [30.0, 35.0, 10.0, 25.0, 20.0, 30.0, 8.0, 28.0] + + # Simulate powerline interference + self._powerline_phase = 0.0 + self._powerline_freq = config.dsp_config.notch_freq + + # DSP Pipeline + config.dsp_config.sample_rate = self._get_sample_rate_hz() + config.dsp_config.num_channels = config.num_channels + self.dsp = DSPPipeline(config.dsp_config) + + def _get_sample_rate_hz(self) -> int: + rates = {0: 16000, 1: 8000, 2: 4000, 3: 2000, 4: 1000, 5: 500, 6: 250} + return rates.get(self.config.sample_rate, 250) + + def connect(self) -> bool: + print(f"✓ [SIMULATION] PiEEG+DSP simulator ({self._get_sample_rate_hz()} SPS)") + return True + + def disconnect(self): + self.is_streaming = False + + def start_streaming(self): + self.is_streaming = True + self.sample_count = 0 + self.start_time = time.time() + self._last_sample_time = self.start_time + print("✓ [SIMULATION] DSP Streaming started") + + def stop_streaming(self): + self.is_streaming = False + elapsed = time.time() - self.start_time + if elapsed > 0: + stats = self.dsp.get_stats() + print(f"✓ [SIMULATION] Stopped. {self.sample_count} samples, {self.sample_count/elapsed:.1f} SPS") + if "artifact_rate_percent" in stats: + print(f" Artifact rate: {stats['artifact_rate_percent']:.2f}%") + + def read_sample(self) -> Optional[Tuple[List[float], List[bool]]]: + """Generate simulated noisy EEG and process through DSP""" + now = time.time() + + if now - self._last_sample_time < self._sample_interval * 0.9: + return None + + self._last_sample_time = now + dt = self._sample_interval + + # Generate raw signal with realistic artifacts + raw_channels = np.zeros(self.config.num_channels) + + for i in range(self.config.num_channels): + self._phase[i] += 2 * np.pi * self._freqs[i] * dt + + # EEG rhythm + eeg = self._amps[i] * np.sin(self._phase[i]) + + # Add powerline interference (before DSP removes it) + self._powerline_phase += 2 * np.pi * self._powerline_freq * dt + powerline = 15 * np.sin(self._powerline_phase) # 15 µV artifact + + # Add DC drift + dc_drift = 50 * np.sin(2 * np.pi * 0.05 * now) # Slow 0.05 Hz drift + + # Add random noise + noise = np.random.normal(0, 3) + + # Occasional large artifact (muscle twitch simulation) + if np.random.random() < 0.001: # 0.1% chance + noise += np.random.choice([-1, 1]) * 200 # ±200 µV spike + + raw_channels[i] = eeg + powerline + dc_drift + noise + + # Apply DSP pipeline + processed, artifact_flags = self.dsp.process(raw_channels) + + self.sample_count += 1 + return processed.tolist(), artifact_flags.tolist() + + +# ============================================================================ +# WEBSOCKET BRIDGE WITH DSP +# ============================================================================ + +class PiEEGBridgeDSP: + """WebSocket bridge with integrated DSP""" + + def __init__(self, config: PiEEGDSPConfig, ws_host: str = "0.0.0.0", + ws_port: int = 8766): + self.config = config + self.ws_host = ws_host + self.ws_port = ws_port + self.clients: set = set() + self.streaming = False + + # Select device implementation + if IS_RASPBERRY_PI: + self.device = PiEEGDeviceDSP(config) + else: + self.device = PiEEGSimulatorDSP(config) + + async def stream_task(self): + """Background task to read, process, and broadcast samples""" + sample_buffer = [] + artifact_buffer = [] + last_send = time.time() + + while self.streaming: + result = self.device.read_sample() + if result: + sample, artifacts = result + sample_buffer.append(sample) + artifact_buffer.append(artifacts) + + now = time.time() + if len(sample_buffer) >= 10 or (now - last_send) > 0.02: + if self.clients and sample_buffer: + packet = self._pack_samples_dsp(sample_buffer, artifact_buffer, now) + + await asyncio.gather( + *[client.send(packet) for client in self.clients], + return_exceptions=True + ) + sample_buffer.clear() + artifact_buffer.clear() + last_send = now + + await asyncio.sleep(0.0001) + + def _pack_samples_dsp(self, samples: List[List[float]], + artifacts: List[List[bool]], + timestamp: float) -> bytes: + """Pack DSP-processed samples with artifact flags""" + num_samples = len(samples) + num_channels = len(samples[0]) if samples else 8 + + # Header: magic (2) + type (1) + num_samples (2) + num_channels (1) + timestamp (8) + # Type: 0x01 = raw, 0x02 = DSP processed + header = struct.pack('>HBHBD', 0xEEEE, 0x02, num_samples, num_channels, timestamp) + + # Data: [float32 × channels + artifact_byte] × samples + data = b'' + for i, sample in enumerate(samples): + data += struct.pack(f'>{num_channels}f', *sample) + # Pack artifact flags as single byte (1 bit per channel) + artifact_byte = sum((1 << ch) for ch, flag in enumerate(artifacts[i]) if flag) + data += struct.pack('B', artifact_byte) + + return header + data + + async def handle_client(self, websocket): + """Handle WebSocket client connection""" + client_addr = websocket.remote_address + print(f"→ Client connected: {client_addr}") + self.clients.add(websocket) + + try: + dsp_cfg = self.config.dsp_config + await websocket.send(json.dumps({ + "type": "device_info", + "device": "PiEEG-DSP", + "channels": self.config.num_channels, + "sample_rate": self.device._get_sample_rate_hz(), + "dsp": { + "enabled": True, + "dc_block": dsp_cfg.dc_block_enabled, + "notch_freq": dsp_cfg.notch_freq if dsp_cfg.notch_enabled else None, + "bandpass": [dsp_cfg.highpass_freq, dsp_cfg.lowpass_freq] if dsp_cfg.bandpass_enabled else None, + "artifact_threshold": dsp_cfg.artifact_threshold if dsp_cfg.artifact_enabled else None, + "car": dsp_cfg.car_enabled + }, + "simulation": not IS_RASPBERRY_PI + })) + + async for message in websocket: + try: + cmd = json.loads(message) + await self._handle_command(websocket, cmd) + except json.JSONDecodeError: + pass + + except websockets.exceptions.ConnectionClosed: + pass + finally: + self.clients.discard(websocket) + print(f"← Client disconnected: {client_addr}") + + async def _handle_command(self, websocket, cmd: dict): + """Handle client commands""" + cmd_type = cmd.get("type", "") + + if cmd_type == "start": + if not self.streaming: + self.device.start_streaming() + self.streaming = True + asyncio.create_task(self.stream_task()) + await websocket.send(json.dumps({"type": "status", "streaming": True})) + + elif cmd_type == "stop": + self.streaming = False + self.device.stop_streaming() + await websocket.send(json.dumps({"type": "status", "streaming": False})) + + elif cmd_type == "get_stats": + stats = self.device.dsp.get_stats() + await websocket.send(json.dumps({"type": "stats", **stats})) + + async def run(self): + """Start the WebSocket server""" + if not self.device.connect(): + print("✗ Failed to connect to device") + return + + print(f"\n🧠 PiEEG DSP Bridge") + print(f" WebSocket: ws://{self.ws_host}:{self.ws_port}") + print(f" Mode: {'Hardware' if IS_RASPBERRY_PI else 'Simulation'}") + + try: + async with websockets.serve(self.handle_client, self.ws_host, self.ws_port): + print("\n✓ Waiting for connections... (Ctrl+C to stop)\n") + await asyncio.Future() # Run forever + except KeyboardInterrupt: + print("\n\n⏹ Shutting down...") + finally: + self.streaming = False + self.device.disconnect() + + +# ============================================================================ +# CLI +# ============================================================================ + +def parse_args(): + parser = argparse.ArgumentParser( + description="PiEEG WebSocket Bridge with Real-Time DSP", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage with 60 Hz notch (Americas) + python pieeg_ws_bridge_dsp.py --notch 60 + + # European 50 Hz with custom bandpass + python pieeg_ws_bridge_dsp.py --notch 50 --highpass 1.0 --lowpass 40 + + # Full signal hygiene + python pieeg_ws_bridge_dsp.py --notch 60 --car --artifact-threshold 150 + + # Minimal processing (DC block only) + python pieeg_ws_bridge_dsp.py --no-notch --no-bandpass + """ + ) + + # Network + parser.add_argument("--host", default="0.0.0.0", help="WebSocket bind address") + parser.add_argument("--port", type=int, default=8766, help="WebSocket port") + + # Hardware + parser.add_argument("--sample-rate", type=int, choices=[250, 500, 1000, 2000], + default=250, help="Sample rate in Hz") + parser.add_argument("--gain", type=int, choices=[1, 2, 4, 6, 8, 12, 24], + default=24, help="PGA gain") + parser.add_argument("--channels", type=int, default=8, help="Number of channels") + + # DSP: Notch + notch_group = parser.add_mutually_exclusive_group() + notch_group.add_argument("--notch", type=float, default=60.0, + help="Powerline frequency for notch filter (50 or 60 Hz)") + notch_group.add_argument("--no-notch", action="store_true", + help="Disable notch filter") + parser.add_argument("--notch-harmonics", type=int, default=3, + help="Number of harmonics to filter") + parser.add_argument("--notch-q", type=float, default=30.0, + help="Notch filter Q factor") + + # DSP: Bandpass + parser.add_argument("--highpass", type=float, default=0.5, + help="High-pass cutoff frequency (Hz)") + parser.add_argument("--lowpass", type=float, default=45.0, + help="Low-pass cutoff frequency (Hz)") + parser.add_argument("--no-bandpass", action="store_true", + help="Disable bandpass filter") + parser.add_argument("--filter-order", type=int, default=4, + help="Butterworth filter order") + + # DSP: DC Block + parser.add_argument("--no-dc-block", action="store_true", + help="Disable DC blocking filter") + parser.add_argument("--dc-alpha", type=float, default=0.995, + help="DC blocker alpha (0.99-0.999)") + + # DSP: Artifact Rejection + parser.add_argument("--artifact-threshold", type=float, default=150.0, + help="Artifact threshold in µV (0 to disable)") + parser.add_argument("--no-artifact", action="store_true", + help="Disable artifact rejection") + + # DSP: CAR + parser.add_argument("--car", action="store_true", + help="Enable Common Average Reference") + parser.add_argument("--car-exclude", type=str, default="", + help="Channels to exclude from CAR (comma-separated, e.g., '0,7')") + + # DSP: Smoothing + parser.add_argument("--smooth", type=float, default=0.0, + help="Smoothing alpha (0 to disable, 0.1-0.5 typical)") + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Map sample rate to enum + rate_map = {250: 6, 500: 5, 1000: 4, 2000: 3} + gain_map = {1: 0x00, 2: 0x10, 4: 0x20, 6: 0x30, 8: 0x40, 12: 0x50, 24: 0x60} + + # Parse CAR exclude channels + car_exclude = [] + if args.car_exclude: + try: + car_exclude = [int(x.strip()) for x in args.car_exclude.split(",")] + except ValueError: + print("Warning: Invalid --car-exclude format, ignoring") + + # Build DSP config + dsp_config = DSPConfig( + sample_rate=float(args.sample_rate), + num_channels=args.channels, + dc_block_enabled=not args.no_dc_block, + dc_block_alpha=args.dc_alpha, + notch_enabled=not args.no_notch, + notch_freq=args.notch, + notch_harmonics=args.notch_harmonics, + notch_q=args.notch_q, + bandpass_enabled=not args.no_bandpass, + highpass_freq=args.highpass, + lowpass_freq=args.lowpass, + filter_order=args.filter_order, + artifact_enabled=not args.no_artifact and args.artifact_threshold > 0, + artifact_threshold=args.artifact_threshold, + car_enabled=args.car, + car_exclude_channels=car_exclude, + smoothing_enabled=args.smooth > 0, + smoothing_alpha=args.smooth if args.smooth > 0 else 0.3 + ) + + # Build device config + device_config = PiEEGDSPConfig( + sample_rate=rate_map.get(args.sample_rate, 6), + gain=gain_map.get(args.gain, 0x60), + num_channels=args.channels, + dsp_config=dsp_config + ) + + # Run bridge + bridge = PiEEGBridgeDSP(device_config, args.host, args.port) + asyncio.run(bridge.run()) + + +if __name__ == "__main__": + main() From 1906ebbd234db7fa663b0c1c0446360f66d13571 Mon Sep 17 00:00:00 2001 From: Youssef Date: Sat, 31 Jan 2026 17:28:01 -0500 Subject: [PATCH 3/5] Add unit tests for LSL and PiEEG bridges - Implemented comprehensive unit tests for lsl_ws_bridge.py covering: - LSL stream discovery simulation - LSL channel format parsing - Stream metadata handling - WebSocket message formats - LSLSimulator class functionality - Added unit tests for pieeg_ws_bridge.py including: - ADS1299 constants and enums validation - PiEEGConfig dataclass tests - PiEEGSimulator functionality - Sample packing/unpacking tests - WebSocket message handling verification --- scripts/pytest.ini | 13 + scripts/requirements-dev.txt | 24 + scripts/tests/__init__.py | 1 + scripts/tests/conftest.py | 136 ++++++ scripts/tests/test_dsp_filters.py | 727 +++++++++++++++++++++++++++++ scripts/tests/test_lsl_bridge.py | 541 +++++++++++++++++++++ scripts/tests/test_pieeg_bridge.py | 455 ++++++++++++++++++ 7 files changed, 1897 insertions(+) create mode 100644 scripts/pytest.ini create mode 100644 scripts/requirements-dev.txt create mode 100644 scripts/tests/__init__.py create mode 100644 scripts/tests/conftest.py create mode 100644 scripts/tests/test_dsp_filters.py create mode 100644 scripts/tests/test_lsl_bridge.py create mode 100644 scripts/tests/test_pieeg_bridge.py diff --git a/scripts/pytest.ini b/scripts/pytest.ini new file mode 100644 index 0000000..318715b --- /dev/null +++ b/scripts/pytest.ini @@ -0,0 +1,13 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short +markers = + slow: marks tests as slow + hardware: marks tests requiring hardware + integration: marks integration tests +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/scripts/requirements-dev.txt b/scripts/requirements-dev.txt new file mode 100644 index 0000000..9a46bfd --- /dev/null +++ b/scripts/requirements-dev.txt @@ -0,0 +1,24 @@ +# Bridge Script Dependencies +# Install with: pip install -r requirements-dev.txt + +# Core dependencies +websockets>=11.0 +numpy>=1.24.0 + +# DSP (optional but recommended) +scipy>=1.10.0 + +# Hardware (Raspberry Pi only) +# spidev>=3.6 +# RPi.GPIO>=0.7.1 + +# LSL (optional) +# pylsl>=1.16.0 + +# BrainFlow (optional) +# brainflow>=5.10.0 + +# Testing +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.1.0 diff --git a/scripts/tests/__init__.py b/scripts/tests/__init__.py new file mode 100644 index 0000000..7494afb --- /dev/null +++ b/scripts/tests/__init__.py @@ -0,0 +1 @@ +# Bridge tests package diff --git a/scripts/tests/conftest.py b/scripts/tests/conftest.py new file mode 100644 index 0000000..83a7f3e --- /dev/null +++ b/scripts/tests/conftest.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +Pytest configuration and shared fixtures for bridge tests. +""" + +import sys +import os +import pytest +import numpy as np + +# Add scripts directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# ============================================================================ +# FIXTURES +# ============================================================================ + +@pytest.fixture +def sample_rate(): + """Default sample rate for tests""" + return 250.0 + + +@pytest.fixture +def num_channels(): + """Default number of channels""" + return 8 + + +@pytest.fixture +def eeg_sample(num_channels): + """Generate a realistic EEG sample""" + # Alpha wave + noise + alpha = 30.0 * np.sin(np.random.random() * 2 * np.pi) + noise = np.random.normal(0, 5, num_channels) + return alpha + noise + + +@pytest.fixture +def eeg_batch(num_channels, sample_rate): + """Generate 1 second of EEG data""" + num_samples = int(sample_rate) + t = np.arange(num_samples) / sample_rate + + # Generate alpha waves for each channel + data = np.zeros((num_samples, num_channels)) + for ch in range(num_channels): + freq = 10 + np.random.random() # 10-11 Hz + phase = np.random.random() * 2 * np.pi + amplitude = 25 + np.random.random() * 10 # 25-35 µV + data[:, ch] = amplitude * np.sin(2 * np.pi * freq * t + phase) + data[:, ch] += np.random.normal(0, 3, num_samples) # Noise + + return data + + +@pytest.fixture +def noisy_eeg_sample(num_channels): + """Generate EEG sample with powerline interference""" + t = np.random.random() + + # EEG + alpha = 30.0 * np.sin(2 * np.pi * 10 * t) + + # 60 Hz interference + powerline = 15.0 * np.sin(2 * np.pi * 60 * t) + + # DC offset + dc = 50.0 + + # Noise + noise = np.random.normal(0, 5, num_channels) + + return alpha + powerline + dc + noise + + +@pytest.fixture +def artifact_sample(num_channels): + """Generate sample with large artifact""" + sample = np.random.normal(0, 10, num_channels) + # Add artifact to channel 0 + sample[0] = 500.0 # Way above normal EEG + return sample + + +# ============================================================================ +# MOCK CLASSES +# ============================================================================ + +class MockWebSocket: + """Mock WebSocket for testing""" + + def __init__(self): + self.sent_messages = [] + self.received_messages = [] + self.is_open = True + self.remote_address = ("127.0.0.1", 12345) + + async def send(self, message): + self.sent_messages.append(message) + + async def recv(self): + if self.received_messages: + return self.received_messages.pop(0) + raise Exception("No messages") + + def close(self): + self.is_open = False + + def add_message(self, message): + """Add message to receive queue""" + self.received_messages.append(message) + + +@pytest.fixture +def mock_websocket(): + """Provide mock WebSocket instance""" + return MockWebSocket() + + +# ============================================================================ +# MARKERS +# ============================================================================ + +def pytest_configure(config): + """Register custom markers""" + config.addinivalue_line( + "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" + ) + config.addinivalue_line( + "markers", "hardware: marks tests requiring hardware (deselect with '-m \"not hardware\"')" + ) + config.addinivalue_line( + "markers", "integration: marks integration tests" + ) diff --git a/scripts/tests/test_dsp_filters.py b/scripts/tests/test_dsp_filters.py new file mode 100644 index 0000000..bb873f5 --- /dev/null +++ b/scripts/tests/test_dsp_filters.py @@ -0,0 +1,727 @@ +#!/usr/bin/env python3 +""" +Unit tests for DSP filters in pieeg_ws_bridge_dsp.py + +Tests cover: +- DC Blocker filter +- IIR Filter (notch, bandpass, highpass, lowpass) +- Notch Filter Bank +- Artifact Rejector +- Common Average Reference +- Exponential Smoother +- Full DSP Pipeline + +Run with: pytest scripts/tests/test_dsp_filters.py -v +""" + +import sys +import os +import pytest +import numpy as np + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Import DSP components +from pieeg_ws_bridge_dsp import ( + DCBlocker, + IIRFilter, + NotchFilterBank, + ArtifactRejector, + CommonAverageReference, + ExponentialSmoother, + DSPConfig, + DSPPipeline, + SCIPY_AVAILABLE, +) + + +# ============================================================================ +# DC BLOCKER TESTS +# ============================================================================ + +class TestDCBlocker: + """Tests for DCBlocker filter""" + + def test_initialization(self): + """Test DCBlocker initializes with correct parameters""" + blocker = DCBlocker(alpha=0.995, num_channels=8) + assert blocker.alpha == 0.995 + assert blocker.num_channels == 8 + assert blocker.x_prev.shape == (8,) + assert blocker.y_prev.shape == (8,) + + def test_removes_dc_offset(self): + """Test that DC blocker removes constant offset""" + blocker = DCBlocker(alpha=0.99, num_channels=1) + + # Apply 1000 samples with DC offset of 100 + dc_offset = 100.0 + outputs = [] + for _ in range(1000): + sample = np.array([dc_offset]) + output = blocker.process(sample) + outputs.append(output[0]) + + # After settling, output should be near zero + assert abs(outputs[-1]) < 5.0, "DC offset should be removed" + + def test_passes_ac_signal(self): + """Test that AC signals pass through with minimal attenuation""" + blocker = DCBlocker(alpha=0.995, num_channels=1) + fs = 250.0 + freq = 10.0 # 10 Hz signal (well above DC blocker cutoff) + + # Generate 2 seconds of 10 Hz sine wave + t = np.arange(0, 2.0, 1/fs) + signal = 50.0 * np.sin(2 * np.pi * freq * t) + + outputs = [] + for sample in signal: + output = blocker.process(np.array([sample])) + outputs.append(output[0]) + + # Check last 250 samples (1 second, after settling) + output_signal = np.array(outputs[-250:]) + input_signal = signal[-250:] + + # RMS should be similar (within 10%) + input_rms = np.sqrt(np.mean(input_signal**2)) + output_rms = np.sqrt(np.mean(output_signal**2)) + attenuation = abs(output_rms - input_rms) / input_rms + + assert attenuation < 0.1, f"10 Hz signal attenuated by {attenuation*100:.1f}%" + + def test_reset(self): + """Test that reset clears filter state""" + blocker = DCBlocker(alpha=0.995, num_channels=4) + + # Process some samples + for _ in range(10): + blocker.process(np.array([100, 200, 300, 400])) + + # Reset + blocker.reset() + + assert np.all(blocker.x_prev == 0) + assert np.all(blocker.y_prev == 0) + + def test_multichannel(self): + """Test that multichannel processing works independently""" + blocker = DCBlocker(alpha=0.99, num_channels=3) + + # Different DC offsets per channel + for _ in range(500): + sample = np.array([100.0, 200.0, 300.0]) + output = blocker.process(sample) + + # All channels should converge to near zero + assert np.all(np.abs(output) < 10.0) + + +# ============================================================================ +# IIR FILTER TESTS +# ============================================================================ + +@pytest.mark.skipif(not SCIPY_AVAILABLE, reason="scipy not installed") +class TestIIRFilter: + """Tests for IIRFilter class""" + + def test_notch_filter_creation(self): + """Test notch filter creation""" + notch = IIRFilter.create_notch(freq=60.0, fs=250.0, Q=30.0, num_channels=8) + assert notch.sos is not None + assert notch.num_channels == 8 + assert notch.zi.shape[0] == 8 + + def test_notch_attenuates_target_frequency(self): + """Test that notch filter attenuates the target frequency""" + fs = 250.0 + notch_freq = 60.0 + notch = IIRFilter.create_notch(notch_freq, fs, Q=30.0, num_channels=1) + + # Generate 60 Hz signal + t = np.arange(0, 2.0, 1/fs) + signal_60hz = 50.0 * np.sin(2 * np.pi * notch_freq * t) + + outputs = [] + for sample in signal_60hz: + output = notch.process(np.array([sample])) + outputs.append(output[0]) + + # Check last second (after filter settles) + output_signal = np.array(outputs[-250:]) + input_signal = signal_60hz[-250:] + + input_rms = np.sqrt(np.mean(input_signal**2)) + output_rms = np.sqrt(np.mean(output_signal**2)) + + attenuation_db = 20 * np.log10(output_rms / input_rms) + assert attenuation_db < -20, f"60 Hz only attenuated by {attenuation_db:.1f} dB" + + def test_notch_passes_other_frequencies(self): + """Test that notch filter passes non-target frequencies""" + fs = 250.0 + notch = IIRFilter.create_notch(60.0, fs, Q=30.0, num_channels=1) + + # Generate 10 Hz signal (alpha band) + t = np.arange(0, 2.0, 1/fs) + signal_10hz = 50.0 * np.sin(2 * np.pi * 10.0 * t) + + outputs = [] + for sample in signal_10hz: + output = notch.process(np.array([sample])) + outputs.append(output[0]) + + output_signal = np.array(outputs[-250:]) + input_signal = signal_10hz[-250:] + + input_rms = np.sqrt(np.mean(input_signal**2)) + output_rms = np.sqrt(np.mean(output_signal**2)) + + attenuation = abs(output_rms - input_rms) / input_rms + assert attenuation < 0.05, f"10 Hz attenuated by {attenuation*100:.1f}%" + + def test_bandpass_filter_creation(self): + """Test bandpass filter creation""" + bp = IIRFilter.create_bandpass(0.5, 45.0, fs=250.0, order=4, num_channels=8) + assert bp.sos is not None + assert bp.num_channels == 8 + + def test_bandpass_attenuates_out_of_band(self): + """Test bandpass filter attenuates frequencies outside passband""" + fs = 250.0 + bp = IIRFilter.create_bandpass(1.0, 40.0, fs, order=4, num_channels=1) + + # Test 0.1 Hz signal (below passband) + t = np.arange(0, 10.0, 1/fs) # Longer for low freq + signal_low = 50.0 * np.sin(2 * np.pi * 0.1 * t) + + outputs = [] + for sample in signal_low: + output = bp.process(np.array([sample])) + outputs.append(output[0]) + + # Check output is significantly attenuated + output_rms = np.sqrt(np.mean(np.array(outputs[-500:])**2)) + input_rms = np.sqrt(np.mean(signal_low[-500:]**2)) + + assert output_rms < input_rms * 0.3, "0.1 Hz should be attenuated" + + def test_highpass_filter(self): + """Test highpass filter creation and behavior""" + fs = 250.0 + hp = IIRFilter.create_highpass(1.0, fs, order=4, num_channels=1) + + # DC should be blocked + outputs = [] + for _ in range(500): + output = hp.process(np.array([100.0])) + outputs.append(output[0]) + + assert abs(outputs[-1]) < 1.0, "DC should be blocked by highpass" + + def test_lowpass_filter(self): + """Test lowpass filter creation and behavior""" + fs = 250.0 + lp = IIRFilter.create_lowpass(40.0, fs, order=4, num_channels=1) + + # High frequency should be attenuated + t = np.arange(0, 2.0, 1/fs) + signal_100hz = 50.0 * np.sin(2 * np.pi * 100.0 * t) + + outputs = [] + for sample in signal_100hz: + output = lp.process(np.array([sample])) + outputs.append(output[0]) + + output_rms = np.sqrt(np.mean(np.array(outputs[-250:])**2)) + input_rms = np.sqrt(np.mean(signal_100hz[-250:]**2)) + + assert output_rms < input_rms * 0.1, "100 Hz should be strongly attenuated" + + def test_batch_processing(self): + """Test batch processing matches sample-by-sample""" + fs = 250.0 + + # Create two identical filters + bp1 = IIRFilter.create_bandpass(1.0, 40.0, fs, order=4, num_channels=2) + bp2 = IIRFilter.create_bandpass(1.0, 40.0, fs, order=4, num_channels=2) + + # Generate test signal + samples = np.random.randn(100, 2) * 50 + + # Process sample-by-sample + outputs1 = [] + for sample in samples: + outputs1.append(bp1.process(sample)) + outputs1 = np.array(outputs1) + + # Process as batch + outputs2 = bp2.process_batch(samples) + + np.testing.assert_array_almost_equal(outputs1, outputs2, decimal=10) + + +# ============================================================================ +# NOTCH FILTER BANK TESTS +# ============================================================================ + +@pytest.mark.skipif(not SCIPY_AVAILABLE, reason="scipy not installed") +class TestNotchFilterBank: + """Tests for NotchFilterBank class""" + + def test_creates_harmonics(self, capsys): + """Test that filter bank creates filters for harmonics""" + bank = NotchFilterBank(60.0, fs=250.0, num_harmonics=2, Q=30.0, num_channels=8) + + # Should have 2 filters: 60 Hz and 120 Hz + assert len(bank.filters) == 2 + + # Check console output + captured = capsys.readouterr() + assert "60" in captured.out or "60.0" in captured.out + assert "120" in captured.out or "120.0" in captured.out + + def test_respects_nyquist(self): + """Test that harmonics above Nyquist are not created""" + # At 250 Hz sample rate, Nyquist is 125 Hz + # So 60, 120 should work, but 180 should be skipped + bank = NotchFilterBank(60.0, fs=250.0, num_harmonics=4, Q=30.0, num_channels=8) + + # Only 60 and 120 should be created (180, 240 > 125) + assert len(bank.filters) == 2 + + def test_attenuates_all_harmonics(self): + """Test that all harmonic frequencies are attenuated""" + fs = 1000.0 # Higher sample rate to test more harmonics + bank = NotchFilterBank(60.0, fs, num_harmonics=3, Q=30.0, num_channels=1) + + for freq in [60.0, 120.0, 180.0]: + bank.reset() + t = np.arange(0, 1.0, 1/fs) + signal = 50.0 * np.sin(2 * np.pi * freq * t) + + outputs = [] + for sample in signal: + output = bank.process(np.array([sample])) + outputs.append(output[0]) + + output_rms = np.sqrt(np.mean(np.array(outputs[-500:])**2)) + input_rms = np.sqrt(np.mean(signal[-500:]**2)) + + attenuation_db = 20 * np.log10(output_rms / input_rms + 1e-10) + assert attenuation_db < -15, f"{freq} Hz only attenuated by {attenuation_db:.1f} dB" + + +# ============================================================================ +# ARTIFACT REJECTOR TESTS +# ============================================================================ + +class TestArtifactRejector: + """Tests for ArtifactRejector class""" + + def test_initialization(self): + """Test ArtifactRejector initializes correctly""" + ar = ArtifactRejector(threshold_uv=150.0, blanking_samples=5, num_channels=8) + assert ar.threshold == 150.0 + assert ar.blanking_samples == 5 + assert ar.num_channels == 8 + assert ar.artifact_count == 0 + + def test_passes_normal_samples(self): + """Test that samples within threshold pass unchanged""" + ar = ArtifactRejector(threshold_uv=150.0, blanking_samples=5, num_channels=2) + + sample = np.array([50.0, -50.0]) + cleaned, flags = ar.process(sample) + + np.testing.assert_array_equal(cleaned, sample) + assert not np.any(flags) + + def test_detects_artifacts(self): + """Test that samples exceeding threshold are flagged""" + ar = ArtifactRejector(threshold_uv=150.0, blanking_samples=5, num_channels=2) + + # First, establish "last good" with a normal sample + ar.process(np.array([50.0, 50.0])) + + # Now send artifact + sample = np.array([200.0, 50.0]) # Channel 0 exceeds threshold + cleaned, flags = ar.process(sample) + + assert flags[0] == True + assert flags[1] == False + assert cleaned[0] == 50.0 # Replaced with last good + assert cleaned[1] == 50.0 + + def test_blanking_period(self): + """Test that blanking continues for specified samples""" + ar = ArtifactRejector(threshold_uv=150.0, blanking_samples=3, num_channels=1) + + # Establish baseline + ar.process(np.array([50.0])) + + # Trigger artifact + ar.process(np.array([200.0])) + + # Next 3 samples should be blanked even if normal + for i in range(3): + cleaned, flags = ar.process(np.array([60.0])) + assert flags[0] == True, f"Sample {i+1} should still be blanked" + + # 4th sample should be normal + cleaned, flags = ar.process(np.array([60.0])) + assert flags[0] == False + + def test_artifact_rate_calculation(self): + """Test artifact rate percentage calculation""" + ar = ArtifactRejector(threshold_uv=100.0, blanking_samples=0, num_channels=1) + + # Process 100 samples, 10 are artifacts + for i in range(100): + value = 150.0 if i % 10 == 0 else 50.0 + ar.process(np.array([value])) + + rate = ar.get_artifact_rate() + assert 9 < rate < 11 # Should be ~10% + + def test_reset(self): + """Test that reset clears state""" + ar = ArtifactRejector(threshold_uv=150.0, blanking_samples=5, num_channels=2) + + # Accumulate some state + for _ in range(50): + ar.process(np.array([200.0, 200.0])) + + ar.reset() + + assert ar.artifact_count == 0 + assert ar.total_samples == 0 + assert np.all(ar.blanking_counter == 0) + + +# ============================================================================ +# COMMON AVERAGE REFERENCE TESTS +# ============================================================================ + +class TestCommonAverageReference: + """Tests for CommonAverageReference class""" + + def test_subtracts_mean(self): + """Test that CAR subtracts channel mean""" + car = CommonAverageReference(num_channels=4) + + sample = np.array([10.0, 20.0, 30.0, 40.0]) + output = car.process(sample) + + # Mean is 25, so output should be [-15, -5, 5, 15] + expected = sample - 25.0 + np.testing.assert_array_almost_equal(output, expected) + + def test_output_mean_is_zero(self): + """Test that output has zero mean""" + car = CommonAverageReference(num_channels=8) + + sample = np.random.randn(8) * 50 + output = car.process(sample) + + assert abs(np.mean(output)) < 1e-10 + + def test_exclude_channels(self): + """Test that excluded channels don't affect average""" + car = CommonAverageReference(num_channels=4, exclude_channels=[3]) + + # Channel 3 has large value but should be excluded + sample = np.array([10.0, 20.0, 30.0, 1000.0]) + output = car.process(sample) + + # Mean of channels 0,1,2 is 20 + expected = sample - 20.0 + np.testing.assert_array_almost_equal(output, expected) + + def test_all_excluded_returns_unchanged(self): + """Test that excluding all channels returns unchanged sample""" + car = CommonAverageReference(num_channels=3, exclude_channels=[0, 1, 2]) + + sample = np.array([10.0, 20.0, 30.0]) + output = car.process(sample) + + np.testing.assert_array_equal(output, sample) + + +# ============================================================================ +# EXPONENTIAL SMOOTHER TESTS +# ============================================================================ + +class TestExponentialSmoother: + """Tests for ExponentialSmoother class""" + + def test_initialization(self): + """Test smoother initializes correctly""" + smoother = ExponentialSmoother(alpha=0.3, num_channels=8) + assert smoother.alpha == 0.3 + assert not smoother.initialized + + def test_first_sample_passes_through(self): + """Test first sample passes through unchanged""" + smoother = ExponentialSmoother(alpha=0.3, num_channels=2) + + sample = np.array([100.0, 200.0]) + output = smoother.process(sample) + + np.testing.assert_array_equal(output, sample) + assert smoother.initialized + + def test_smoothing_effect(self): + """Test that smoothing reduces noise""" + smoother = ExponentialSmoother(alpha=0.1, num_channels=1) + + # Generate noisy signal + np.random.seed(42) + clean = 50.0 + noisy = clean + np.random.randn(500) * 20 + + outputs = [] + for sample in noisy: + output = smoother.process(np.array([sample])) + outputs.append(output[0]) + + # Smoothed output should have less variance + output_std = np.std(outputs[-200:]) + input_std = np.std(noisy[-200:]) + + assert output_std < input_std * 0.5 + + def test_step_response(self): + """Test step response converges to new value""" + smoother = ExponentialSmoother(alpha=0.3, num_channels=1) + + # Initialize at 0 + smoother.process(np.array([0.0])) + + # Step to 100 + outputs = [] + for _ in range(50): + output = smoother.process(np.array([100.0])) + outputs.append(output[0]) + + # Should converge to ~100 + assert outputs[-1] > 99.0 + + def test_reset(self): + """Test reset clears state""" + smoother = ExponentialSmoother(alpha=0.3, num_channels=2) + + smoother.process(np.array([100.0, 200.0])) + smoother.reset() + + assert not smoother.initialized + assert np.all(smoother.ema == 0) + + +# ============================================================================ +# DSP PIPELINE TESTS +# ============================================================================ + +class TestDSPPipeline: + """Tests for complete DSP pipeline""" + + def test_minimal_pipeline(self): + """Test pipeline with only DC blocker""" + config = DSPConfig( + sample_rate=250.0, + num_channels=4, + dc_block_enabled=True, + notch_enabled=False, + bandpass_enabled=False, + artifact_enabled=False, + car_enabled=False, + smoothing_enabled=False, + ) + + pipeline = DSPPipeline(config) + + # Should have 1 filter (DC block) + assert len(pipeline.filters) == 1 + assert pipeline.filters[0][0] == "DC Block" + + @pytest.mark.skipif(not SCIPY_AVAILABLE, reason="scipy not installed") + def test_full_pipeline(self): + """Test pipeline with all filters enabled""" + config = DSPConfig( + sample_rate=250.0, + num_channels=8, + dc_block_enabled=True, + notch_enabled=True, + notch_freq=60.0, + bandpass_enabled=True, + artifact_enabled=True, + car_enabled=True, + smoothing_enabled=True, + ) + + pipeline = DSPPipeline(config) + + # Should have multiple filters + filter_names = [name for name, _ in pipeline.filters] + assert "DC Block" in filter_names + assert "Notch" in filter_names + assert "Bandpass" in filter_names + assert "CAR" in filter_names + assert "Smooth" in filter_names + + def test_process_returns_correct_shape(self): + """Test that process returns correct array shapes""" + config = DSPConfig( + sample_rate=250.0, + num_channels=8, + dc_block_enabled=True, + notch_enabled=False, + bandpass_enabled=False, + artifact_enabled=True, + ) + + pipeline = DSPPipeline(config) + + sample = np.random.randn(8) * 50 + processed, flags = pipeline.process(sample) + + assert processed.shape == (8,) + assert flags.shape == (8,) + assert flags.dtype == bool + + def test_batch_process(self): + """Test batch processing""" + config = DSPConfig( + sample_rate=250.0, + num_channels=4, + dc_block_enabled=True, + notch_enabled=False, + bandpass_enabled=False, + artifact_enabled=True, + ) + + pipeline = DSPPipeline(config) + + samples = np.random.randn(100, 4) * 50 + processed, flags = pipeline.process_batch(samples) + + assert processed.shape == (100, 4) + assert flags.shape == (100, 4) + + @pytest.mark.skipif(not SCIPY_AVAILABLE, reason="scipy not installed") + def test_removes_powerline_interference(self): + """Test that full pipeline removes 60 Hz interference""" + config = DSPConfig( + sample_rate=250.0, + num_channels=1, + dc_block_enabled=True, + notch_enabled=True, + notch_freq=60.0, + bandpass_enabled=True, + highpass_freq=0.5, + lowpass_freq=45.0, + artifact_enabled=False, + ) + + pipeline = DSPPipeline(config) + + # Generate signal: 10 Hz alpha + 60 Hz powerline + fs = 250.0 + t = np.arange(0, 4.0, 1/fs) + alpha = 30.0 * np.sin(2 * np.pi * 10.0 * t) + powerline = 20.0 * np.sin(2 * np.pi * 60.0 * t) + signal = alpha + powerline + + outputs = [] + for sample in signal: + processed, _ = pipeline.process(np.array([sample])) + outputs.append(processed[0]) + + # Analyze output frequency content (last 2 seconds) + output_signal = np.array(outputs[-500:]) + + # FFT analysis + from numpy.fft import rfft, rfftfreq + freqs = rfftfreq(len(output_signal), 1/fs) + spectrum = np.abs(rfft(output_signal)) + + # Find power at 10 Hz and 60 Hz + idx_10hz = np.argmin(np.abs(freqs - 10.0)) + idx_60hz = np.argmin(np.abs(freqs - 60.0)) + + power_10hz = spectrum[idx_10hz] + power_60hz = spectrum[idx_60hz] + + # 60 Hz should be much weaker than 10 Hz + assert power_60hz < power_10hz * 0.1, "60 Hz not sufficiently attenuated" + + def test_get_stats(self): + """Test statistics retrieval""" + config = DSPConfig( + sample_rate=250.0, + num_channels=4, + dc_block_enabled=True, + artifact_enabled=True, + ) + + pipeline = DSPPipeline(config) + + # Process some samples with artifacts + for _ in range(100): + sample = np.random.randn(4) * 50 + pipeline.process(sample) + + stats = pipeline.get_stats() + + assert "num_filters" in stats + assert "filter_chain" in stats + assert "artifact_rate_percent" in stats + + +# ============================================================================ +# CONFIGURATION TESTS +# ============================================================================ + +class TestDSPConfig: + """Tests for DSPConfig dataclass""" + + def test_default_values(self): + """Test default configuration values""" + config = DSPConfig() + + assert config.sample_rate == 250.0 + assert config.num_channels == 8 + assert config.dc_block_enabled == True + assert config.notch_enabled == True + assert config.notch_freq == 60.0 + assert config.bandpass_enabled == True + assert config.highpass_freq == 0.5 + assert config.lowpass_freq == 45.0 + assert config.artifact_enabled == True + assert config.artifact_threshold == 150.0 + assert config.car_enabled == False + assert config.smoothing_enabled == False + + def test_custom_values(self): + """Test custom configuration""" + config = DSPConfig( + sample_rate=500.0, + num_channels=16, + notch_freq=50.0, + highpass_freq=1.0, + lowpass_freq=40.0, + car_enabled=True, + ) + + assert config.sample_rate == 500.0 + assert config.num_channels == 16 + assert config.notch_freq == 50.0 + assert config.highpass_freq == 1.0 + assert config.lowpass_freq == 40.0 + assert config.car_enabled == True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/scripts/tests/test_lsl_bridge.py b/scripts/tests/test_lsl_bridge.py new file mode 100644 index 0000000..057ca52 --- /dev/null +++ b/scripts/tests/test_lsl_bridge.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +""" +Unit tests for lsl_ws_bridge.py + +Tests cover: +- LSL stream discovery simulation +- LSL channel format parsing +- Stream metadata handling +- WebSocket message formats +- LSLSimulator class + +Run with: pytest scripts/tests/test_lsl_bridge.py -v +""" + +import sys +import os +import struct +import json +import pytest +import numpy as np + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from lsl_ws_bridge import ( + LSLChannelFormat, + LSLStreamInfo, + LSLSimulator, + LSL_AVAILABLE, +) + + +# ============================================================================ +# CHANNEL FORMAT TESTS +# ============================================================================ + +class TestLSLChannelFormat: + """Tests for LSL channel format enum""" + + def test_format_values(self): + """Test channel format numeric values""" + assert LSLChannelFormat.FLOAT32.value == 1 + assert LSLChannelFormat.DOUBLE64.value == 2 + + def test_format_names(self): + """Test format enum names""" + assert LSLChannelFormat.FLOAT32.name == "FLOAT32" + assert LSLChannelFormat.DOUBLE64.name == "DOUBLE64" + + +# ============================================================================ +# STREAM INFO TESTS +# ============================================================================ + +class TestLSLStreamInfo: + """Tests for LSLStreamInfo dataclass""" + + def test_default_creation(self): + """Test creating stream info with required fields""" + info = LSLStreamInfo( + name="TestStream", + stream_type="EEG", + channel_count=8, + sampling_rate=250.0, + source_id="test-123", + ) + + assert info.name == "TestStream" + assert info.stream_type == "EEG" + assert info.channel_count == 8 + assert info.sampling_rate == 250.0 + assert info.source_id == "test-123" + + def test_optional_fields(self): + """Test optional channel labels""" + labels = ["Fp1", "Fp2", "C3", "C4", "P3", "P4", "O1", "O2"] + info = LSLStreamInfo( + name="EEG", + stream_type="EEG", + channel_count=8, + sampling_rate=256.0, + source_id="openbci", + channel_labels=labels, + ) + + assert info.channel_labels == labels + + def test_to_dict(self): + """Test conversion to dictionary (for JSON)""" + from dataclasses import asdict + + info = LSLStreamInfo( + name="MyStream", + stream_type="EEG", + channel_count=4, + sampling_rate=500.0, + source_id="src", + ) + + d = asdict(info) + + assert d["name"] == "MyStream" + assert d["stream_type"] == "EEG" + assert d["channel_count"] == 4 + assert d["sampling_rate"] == 500.0 + + +# ============================================================================ +# SIMULATOR TESTS +# ============================================================================ + +class TestLSLSimulator: + """Tests for LSLSimulator class""" + + def test_initialization(self): + """Test simulator initialization""" + sim = LSLSimulator( + stream_name="TestEEG", + channel_count=8, + sample_rate=250.0, + ) + + assert sim.stream_info.name == "TestEEG" + assert sim.stream_info.channel_count == 8 + assert sim.stream_info.sampling_rate == 250.0 + assert sim.is_streaming == False + + def test_get_stream_info(self): + """Test getting stream info""" + sim = LSLSimulator( + stream_name="OpenBCI_EEG", + channel_count=16, + sample_rate=125.0, + ) + + info = sim.get_stream_info() + + assert info.name == "OpenBCI_EEG" + assert info.channel_count == 16 + assert info.sampling_rate == 125.0 + assert info.stream_type == "EEG" + + def test_start_stop(self): + """Test start and stop streaming""" + sim = LSLSimulator( + stream_name="Test", + channel_count=4, + sample_rate=250.0, + ) + + sim.start() + assert sim.is_streaming == True + + sim.stop() + assert sim.is_streaming == False + + def test_pull_sample_shape(self): + """Test pulled sample has correct shape""" + sim = LSLSimulator( + stream_name="Test", + channel_count=8, + sample_rate=250.0, + ) + sim.start() + + import time + time.sleep(0.01) # Let it generate a sample + + sample, timestamp = sim.pull_sample() + + if sample is not None: + assert len(sample) == 8 + assert isinstance(timestamp, float) + + sim.stop() + + def test_simulated_values_range(self): + """Test simulated values are in EEG range""" + sim = LSLSimulator( + stream_name="Test", + channel_count=8, + sample_rate=1000.0, # Fast for testing + ) + sim.start() + + import time + samples = [] + for _ in range(100): + sample, ts = sim.pull_sample() + if sample is not None: + samples.append(sample) + time.sleep(0.001) + + sim.stop() + + if samples: + all_values = np.array(samples).flatten() + # Should be in µV range + assert np.abs(all_values).max() < 200 + + def test_timestamp_increases(self): + """Test timestamps are monotonically increasing""" + sim = LSLSimulator( + stream_name="Test", + channel_count=4, + sample_rate=1000.0, + ) + sim.start() + + import time + timestamps = [] + for _ in range(50): + sample, ts = sim.pull_sample() + if ts is not None and ts > 0: + timestamps.append(ts) + time.sleep(0.001) + + sim.stop() + + if len(timestamps) > 1: + # Check monotonic + for i in range(1, len(timestamps)): + assert timestamps[i] >= timestamps[i-1] + + +# ============================================================================ +# WEBSOCKET MESSAGE FORMAT TESTS +# ============================================================================ + +class TestWebSocketMessages: + """Tests for WebSocket message formats""" + + def test_metadata_message(self): + """Test metadata message structure""" + metadata = { + "type": "metadata", + "stream": { + "name": "OpenBCI_EEG", + "stream_type": "EEG", + "channel_count": 8, + "sampling_rate": 250.0, + "channel_labels": ["Fp1", "Fp2", "C3", "C4", "P3", "P4", "O1", "O2"], + } + } + + # Should be valid JSON + json_str = json.dumps(metadata) + parsed = json.loads(json_str) + + assert parsed["type"] == "metadata" + assert parsed["stream"]["name"] == "OpenBCI_EEG" + assert parsed["stream"]["channel_count"] == 8 + + def test_discover_command(self): + """Test discover command format""" + cmd = {"command": "discover"} + json_str = json.dumps(cmd) + parsed = json.loads(json_str) + + assert parsed["command"] == "discover" + + def test_connect_command(self): + """Test connect command format""" + cmd = { + "command": "connect", + "name": "OpenBCI_EEG", + "stream_type": "EEG", + } + json_str = json.dumps(cmd) + parsed = json.loads(json_str) + + assert parsed["command"] == "connect" + assert parsed["name"] == "OpenBCI_EEG" + + def test_streams_response(self): + """Test streams list response format""" + response = { + "type": "streams", + "streams": [ + { + "name": "OpenBCI_EEG", + "stream_type": "EEG", + "channel_count": 8, + "sampling_rate": 250.0, + }, + { + "name": "Muse", + "stream_type": "EEG", + "channel_count": 4, + "sampling_rate": 256.0, + }, + ] + } + + json_str = json.dumps(response) + parsed = json.loads(json_str) + + assert parsed["type"] == "streams" + assert len(parsed["streams"]) == 2 + + def test_error_response(self): + """Test error response format""" + error = { + "type": "error", + "message": "Stream not found", + "code": "STREAM_NOT_FOUND", + } + + json_str = json.dumps(error) + parsed = json.loads(json_str) + + assert parsed["type"] == "error" + assert "Stream not found" in parsed["message"] + + +# ============================================================================ +# BINARY PACKET FORMAT TESTS +# ============================================================================ + +class TestBinaryPacketFormat: + """Tests for binary sample packet format""" + + def test_header_format(self): + """Test binary packet header structure""" + magic = 0xEEEE + packet_type = 0x01 # Raw samples + num_samples = 10 + num_channels = 8 + timestamp = 1234567890.123456 + + # Pack header + header = struct.pack('>HBHBD', magic, packet_type, num_samples, num_channels, timestamp) + + # Should be 14 bytes + assert len(header) == 14 + + # Unpack and verify + unpacked = struct.unpack('>HBHBD', header) + assert unpacked[0] == magic + assert unpacked[1] == packet_type + assert unpacked[2] == num_samples + assert unpacked[3] == num_channels + assert abs(unpacked[4] - timestamp) < 0.000001 + + def test_sample_packing(self): + """Test sample data packing""" + samples = [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + ] + + # Pack samples + data = b'' + for sample in samples: + data += struct.pack(f'>{len(sample)}f', *sample) + + # Should be 3 samples × 4 channels × 4 bytes = 48 bytes + assert len(data) == 48 + + # Unpack and verify + offset = 0 + for i, sample in enumerate(samples): + unpacked = struct.unpack('>4f', data[offset:offset+16]) + for j, val in enumerate(sample): + assert abs(unpacked[j] - val) < 0.0001 + offset += 16 + + def test_full_packet(self): + """Test complete packet creation and parsing""" + # Create packet + magic = 0xEEEE + packet_type = 0x01 + num_samples = 2 + num_channels = 4 + timestamp = 1000.5 + samples = [[10.0, 20.0, 30.0, 40.0], [50.0, 60.0, 70.0, 80.0]] + + # Pack + header = struct.pack('>HBHBD', magic, packet_type, num_samples, num_channels, timestamp) + data = b'' + for sample in samples: + data += struct.pack(f'>{num_channels}f', *sample) + + packet = header + data + + # Parse + parsed_header = struct.unpack('>HBHBD', packet[:14]) + assert parsed_header[0] == magic + assert parsed_header[2] == num_samples + + # Parse samples + offset = 14 + for i in range(num_samples): + parsed_sample = struct.unpack(f'>{num_channels}f', packet[offset:offset+num_channels*4]) + for j in range(num_channels): + assert abs(parsed_sample[j] - samples[i][j]) < 0.0001 + offset += num_channels * 4 + + +# ============================================================================ +# STREAM DISCOVERY SIMULATION TESTS +# ============================================================================ + +class TestStreamDiscovery: + """Tests for stream discovery simulation""" + + def test_discover_returns_list(self): + """Test that discovery returns a list of streams""" + # Simulate what discover would return + discovered = [ + LSLStreamInfo( + name="OpenBCI_EEG", + stream_type="EEG", + channel_count=8, + sampling_rate=250.0, + source_id="openbci-1", + ), + LSLStreamInfo( + name="Markers", + stream_type="Markers", + channel_count=1, + sampling_rate=0.0, # Irregular + source_id="markers-1", + ), + ] + + assert len(discovered) == 2 + assert discovered[0].stream_type == "EEG" + assert discovered[1].stream_type == "Markers" + + def test_filter_by_type(self): + """Test filtering streams by type""" + all_streams = [ + LSLStreamInfo("EEG1", "EEG", 8, 250.0, "1"), + LSLStreamInfo("EEG2", "EEG", 16, 500.0, "2"), + LSLStreamInfo("Markers", "Markers", 1, 0.0, "3"), + LSLStreamInfo("Audio", "Audio", 2, 44100.0, "4"), + ] + + eeg_only = [s for s in all_streams if s.stream_type == "EEG"] + + assert len(eeg_only) == 2 + + +# ============================================================================ +# CHANNEL LABEL TESTS +# ============================================================================ + +class TestChannelLabels: + """Tests for channel label handling""" + + def test_standard_10_20_labels(self): + """Test standard 10-20 system labels""" + labels_8ch = ["Fp1", "Fp2", "C3", "C4", "P3", "P4", "O1", "O2"] + + assert len(labels_8ch) == 8 + assert "Fp1" in labels_8ch + assert "O2" in labels_8ch + + def test_generate_default_labels(self): + """Test generating default channel labels""" + num_channels = 16 + labels = [f"Ch{i+1}" for i in range(num_channels)] + + assert len(labels) == 16 + assert labels[0] == "Ch1" + assert labels[15] == "Ch16" + + def test_muse_labels(self): + """Test Muse headband channel labels""" + muse_labels = ["TP9", "AF7", "AF8", "TP10"] + + assert len(muse_labels) == 4 + # Left temporal + assert "TP9" in muse_labels + # Right temporal + assert "TP10" in muse_labels + + +# ============================================================================ +# INTEGRATION TESTS +# ============================================================================ + +class TestLSLSimulatorIntegration: + """Integration tests for LSL simulator""" + + def test_full_session(self): + """Test complete simulator session""" + import time + + sim = LSLSimulator( + stream_name="IntegrationTest", + channel_count=8, + sample_rate=250.0, + ) + + # Get info + info = sim.get_stream_info() + assert info.name == "IntegrationTest" + + # Start + sim.start() + assert sim.is_streaming + + # Collect samples + samples = [] + start = time.time() + while time.time() - start < 0.1: + sample, ts = sim.pull_sample() + if sample is not None: + samples.append((sample, ts)) + time.sleep(0.001) + + # Should have collected some samples + assert len(samples) > 0 + + # Stop + sim.stop() + assert not sim.is_streaming + + def test_multiple_simulators(self): + """Test running multiple simulators""" + sim1 = LSLSimulator("Stream1", 4, 250.0) + sim2 = LSLSimulator("Stream2", 8, 500.0) + + assert sim1.stream_info.channel_count == 4 + assert sim2.stream_info.channel_count == 8 + + sim1.start() + sim2.start() + + assert sim1.is_streaming + assert sim2.is_streaming + + sim1.stop() + sim2.stop() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/scripts/tests/test_pieeg_bridge.py b/scripts/tests/test_pieeg_bridge.py new file mode 100644 index 0000000..d7c2612 --- /dev/null +++ b/scripts/tests/test_pieeg_bridge.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +""" +Unit tests for pieeg_ws_bridge.py + +Tests cover: +- ADS1299 constants and enums +- PiEEGConfig dataclass +- PiEEGSimulator +- Sample packing/unpacking +- WebSocket message handling + +Run with: pytest scripts/tests/test_pieeg_bridge.py -v +""" + +import sys +import os +import struct +import json +import pytest +import numpy as np + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from pieeg_ws_bridge import ( + ADS1299Register, + ADS1299Command, + ADS1299Gain, + ADS1299SampleRate, + PiEEGConfig, + PiEEGSimulator, + VREF, + NUM_CHANNELS, + STATUS_BYTES, + BYTES_PER_CHANNEL, + SAMPLE_BYTES, +) + + +# ============================================================================ +# CONSTANTS TESTS +# ============================================================================ + +class TestADS1299Constants: + """Tests for ADS1299 hardware constants""" + + def test_register_addresses(self): + """Test key register addresses""" + assert ADS1299Register.ID == 0x00 + assert ADS1299Register.CONFIG1 == 0x01 + assert ADS1299Register.CONFIG2 == 0x02 + assert ADS1299Register.CONFIG3 == 0x03 + assert ADS1299Register.CH1SET == 0x05 + assert ADS1299Register.CH8SET == 0x0C + + def test_command_values(self): + """Test SPI command values""" + assert ADS1299Command.WAKEUP == 0x02 + assert ADS1299Command.RESET == 0x06 + assert ADS1299Command.START == 0x08 + assert ADS1299Command.STOP == 0x0A + assert ADS1299Command.RDATAC == 0x10 + assert ADS1299Command.SDATAC == 0x11 + assert ADS1299Command.RREG == 0x20 + assert ADS1299Command.WREG == 0x40 + + def test_gain_values(self): + """Test PGA gain register values""" + assert ADS1299Gain.GAIN_1 == 0x00 + assert ADS1299Gain.GAIN_2 == 0x10 + assert ADS1299Gain.GAIN_4 == 0x20 + assert ADS1299Gain.GAIN_6 == 0x30 + assert ADS1299Gain.GAIN_8 == 0x40 + assert ADS1299Gain.GAIN_12 == 0x50 + assert ADS1299Gain.GAIN_24 == 0x60 + + def test_sample_rate_values(self): + """Test sample rate register values""" + assert ADS1299SampleRate.SPS_16000 == 0x00 + assert ADS1299SampleRate.SPS_8000 == 0x01 + assert ADS1299SampleRate.SPS_4000 == 0x02 + assert ADS1299SampleRate.SPS_2000 == 0x03 + assert ADS1299SampleRate.SPS_1000 == 0x04 + assert ADS1299SampleRate.SPS_500 == 0x05 + assert ADS1299SampleRate.SPS_250 == 0x06 + + def test_hardware_constants(self): + """Test hardware spec constants""" + assert VREF == 4.5 + assert NUM_CHANNELS == 8 + assert STATUS_BYTES == 3 + assert BYTES_PER_CHANNEL == 3 + assert SAMPLE_BYTES == 27 # 3 + 8*3 + + +# ============================================================================ +# CONFIG TESTS +# ============================================================================ + +class TestPiEEGConfig: + """Tests for PiEEGConfig dataclass""" + + def test_default_values(self): + """Test default configuration""" + config = PiEEGConfig() + + assert config.spi_bus == 0 + assert config.spi_device == 0 + assert config.spi_speed == 2000000 + assert config.sample_rate == ADS1299SampleRate.SPS_250 + assert config.gain == ADS1299Gain.GAIN_24 + assert config.num_channels == 8 + assert config.daisy_chain == False + + def test_custom_values(self): + """Test custom configuration""" + config = PiEEGConfig( + sample_rate=ADS1299SampleRate.SPS_500, + gain=ADS1299Gain.GAIN_12, + num_channels=16, + daisy_chain=True, + ) + + assert config.sample_rate == ADS1299SampleRate.SPS_500 + assert config.gain == ADS1299Gain.GAIN_12 + assert config.num_channels == 16 + assert config.daisy_chain == True + + +# ============================================================================ +# SIMULATOR TESTS +# ============================================================================ + +class TestPiEEGSimulator: + """Tests for PiEEGSimulator class""" + + def test_initialization(self): + """Test simulator initializes correctly""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + + assert sim.is_streaming == False + assert sim.sample_count == 0 + + def test_connect(self): + """Test simulator connect returns True""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + + assert sim.connect() == True + + def test_start_stop_streaming(self): + """Test streaming state management""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + sim.connect() + + sim.start_streaming() + assert sim.is_streaming == True + assert sim.sample_count == 0 + + sim.stop_streaming() + assert sim.is_streaming == False + + def test_sample_rate_hz(self): + """Test sample rate Hz conversion""" + test_cases = [ + (ADS1299SampleRate.SPS_250, 250), + (ADS1299SampleRate.SPS_500, 500), + (ADS1299SampleRate.SPS_1000, 1000), + (ADS1299SampleRate.SPS_2000, 2000), + ] + + for rate_enum, expected_hz in test_cases: + config = PiEEGConfig(sample_rate=rate_enum) + sim = PiEEGSimulator(config) + assert sim._get_sample_rate_hz() == expected_hz + + def test_read_sample_when_not_streaming(self): + """Test read_sample returns None when not streaming""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + sim.connect() + + # Not streaming yet + result = sim.read_sample() + # May return None due to rate limiting + # Just ensure no exception + + def test_read_sample_returns_correct_channels(self): + """Test read_sample returns correct number of channels""" + config = PiEEGConfig(num_channels=8) + sim = PiEEGSimulator(config) + sim.connect() + sim.start_streaming() + + # Wait a bit to ensure sample is ready + import time + time.sleep(0.01) + + sample = sim.read_sample() + if sample is not None: + assert len(sample) == 8 + + def test_simulated_values_in_eeg_range(self): + """Test simulated values are in realistic EEG range""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + sim.connect() + sim.start_streaming() + + import time + samples = [] + for _ in range(100): + sample = sim.read_sample() + if sample is not None: + samples.append(sample) + time.sleep(0.001) + + if samples: + all_values = np.array(samples).flatten() + # EEG typically ±100 µV, simulator adds some noise + assert np.abs(all_values).max() < 200 # µV + + def test_disconnect(self): + """Test disconnect stops streaming""" + config = PiEEGConfig() + sim = PiEEGSimulator(config) + sim.connect() + sim.start_streaming() + + sim.disconnect() + assert sim.is_streaming == False + + +# ============================================================================ +# SCALE FACTOR TESTS +# ============================================================================ + +class TestScaleFactor: + """Tests for µV scale factor calculation""" + + def test_gain_24_scale_factor(self): + """Test scale factor with gain 24""" + # Formula: (2 * VREF / gain) / (2^24) * 1e6 + # = (2 * 4.5 / 24) / 16777216 * 1e6 + # = 0.375 / 16777216 * 1e6 + # ≈ 0.0223517 µV/LSB + + gain = 24 + expected_scale = (2 * VREF / gain) / (2**24) * 1e6 + + assert abs(expected_scale - 0.0223517) < 0.0001 + + def test_gain_1_scale_factor(self): + """Test scale factor with gain 1 (maximum range)""" + gain = 1 + expected_scale = (2 * VREF / gain) / (2**24) * 1e6 + + # Should be 24x larger than gain 24 + assert abs(expected_scale - 0.0223517 * 24) < 0.01 + + def test_full_scale_range(self): + """Test full scale input range""" + gain = 24 + # Full scale = ±VREF/gain = ±4.5/24 = ±187.5 mV + # In µV: ±187,500 µV + + scale = (2 * VREF / gain) / (2**24) * 1e6 + full_scale_uv = scale * (2**23 - 1) # Max positive value + + assert abs(full_scale_uv - 187500) < 500 # Within 0.3% + + +# ============================================================================ +# PACKET FORMAT TESTS +# ============================================================================ + +class TestPacketFormat: + """Tests for WebSocket packet format""" + + def test_pack_samples_format(self): + """Test binary packet structure""" + # Simulate what _pack_samples does + samples = [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]] + timestamp = 12345.6789 + num_samples = len(samples) + num_channels = len(samples[0]) + + # Header: magic (2) + num_samples (2) + num_channels (1) + timestamp (8) + header = struct.pack('>HHBD', 0xEEEE, num_samples, num_channels, timestamp) + + assert len(header) == 13 + + # Data: float32 × channels × samples + data = struct.pack(f'>{num_channels}f', *samples[0]) + assert len(data) == num_channels * 4 # 32 bytes + + def test_unpack_header(self): + """Test unpacking packet header""" + # Create header + magic = 0xEEEE + num_samples = 5 + num_channels = 8 + timestamp = 1234567890.123 + + header = struct.pack('>HHBD', magic, num_samples, num_channels, timestamp) + + # Unpack + unpacked = struct.unpack('>HHBD', header) + + assert unpacked[0] == magic + assert unpacked[1] == num_samples + assert unpacked[2] == num_channels + assert abs(unpacked[3] - timestamp) < 0.001 + + def test_unpack_sample_data(self): + """Test unpacking sample data""" + channels = [10.5, -20.3, 30.1, -40.7, 50.2, -60.8, 70.4, -80.9] + + # Pack + data = struct.pack(f'>8f', *channels) + + # Unpack + unpacked = struct.unpack(f'>8f', data) + + for i, (original, recovered) in enumerate(zip(channels, unpacked)): + assert abs(original - recovered) < 0.001, f"Channel {i} mismatch" + + +# ============================================================================ +# 24-BIT PARSING TESTS +# ============================================================================ + +class Test24BitParsing: + """Tests for 24-bit ADC value parsing""" + + def test_parse_positive_value(self): + """Test parsing positive 24-bit value""" + # 24-bit big-endian value: 0x123456 = 1193046 + data = bytes([0x12, 0x34, 0x56]) + + value = (data[0] << 16) | (data[1] << 8) | data[2] + # No sign extension needed for positive + assert value == 0x123456 + + def test_parse_negative_value(self): + """Test parsing negative 24-bit value (two's complement)""" + # -1 in 24-bit = 0xFFFFFF + data = bytes([0xFF, 0xFF, 0xFF]) + + value = (data[0] << 16) | (data[1] << 8) | data[2] + # Sign extend + if value & 0x800000: + value -= 0x1000000 + + assert value == -1 + + def test_parse_max_positive(self): + """Test parsing maximum positive value""" + # 0x7FFFFF = 8388607 (max positive for 24-bit signed) + data = bytes([0x7F, 0xFF, 0xFF]) + + value = (data[0] << 16) | (data[1] << 8) | data[2] + if value & 0x800000: + value -= 0x1000000 + + assert value == 8388607 + + def test_parse_min_negative(self): + """Test parsing minimum negative value""" + # 0x800000 = -8388608 (min negative for 24-bit signed) + data = bytes([0x80, 0x00, 0x00]) + + value = (data[0] << 16) | (data[1] << 8) | data[2] + if value & 0x800000: + value -= 0x1000000 + + assert value == -8388608 + + def test_parse_zero(self): + """Test parsing zero""" + data = bytes([0x00, 0x00, 0x00]) + + value = (data[0] << 16) | (data[1] << 8) | data[2] + if value & 0x800000: + value -= 0x1000000 + + assert value == 0 + + +# ============================================================================ +# INTEGRATION TESTS +# ============================================================================ + +class TestSimulatorIntegration: + """Integration tests for simulator""" + + def test_full_session(self): + """Test complete simulator session""" + import time + + config = PiEEGConfig(sample_rate=ADS1299SampleRate.SPS_250) + sim = PiEEGSimulator(config) + + # Connect + assert sim.connect() == True + + # Start streaming + sim.start_streaming() + assert sim.is_streaming == True + + # Collect samples + samples = [] + start = time.time() + while time.time() - start < 0.1: # 100ms + sample = sim.read_sample() + if sample is not None: + samples.append(sample) + time.sleep(0.001) + + # Should have some samples + assert len(samples) > 0 + + # Stop + sim.stop_streaming() + assert sim.is_streaming == False + + # Disconnect + sim.disconnect() + + def test_sample_count_increments(self): + """Test that sample count increments correctly""" + import time + + config = PiEEGConfig() + sim = PiEEGSimulator(config) + sim.connect() + sim.start_streaming() + + initial_count = sim.sample_count + + # Read some samples + for _ in range(50): + sim.read_sample() + time.sleep(0.001) + + # Count should have increased + assert sim.sample_count >= initial_count + + sim.disconnect() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 2c89ecef2116ef7924bca2b97a989690f2f9fa5c Mon Sep 17 00:00:00 2001 From: Youssef Date: Sun, 1 Feb 2026 09:59:54 -0500 Subject: [PATCH 4/5] =?UTF-8?q?-=20Added=20missing=20LSLStreamInfo=20datac?= =?UTF-8?q?lass=20to=20the=20LSL=20bridge=20-=20Updated=20LSLSimulator=20t?= =?UTF-8?q?o=20expose=20stream=5Finfo,=20is=5Fstreaming,=20get=5Fstream=5F?= =?UTF-8?q?info(),=20and=20pull=5Fsample()=20for=20testing=20-=20Fixed=20s?= =?UTF-8?q?truct=20format=20chars=20(D=20=E2=86=92=20d=20for=20double)=20i?= =?UTF-8?q?n=20test=20assertions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 + scripts/lsl_ws_bridge.py | 152 ++++++++++++++++++++--------- scripts/tests/test_lsl_bridge.py | 12 +-- scripts/tests/test_pieeg_bridge.py | 6 +- 4 files changed, 122 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index b45c786..3a5741e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,8 @@ dist-ssr cypress/videos cypress/screenshots cypress/downloads +# Python +__pycache__/ +*.py[cod] +*$py.class +.pytest_cache/ \ No newline at end of file diff --git a/scripts/lsl_ws_bridge.py b/scripts/lsl_ws_bridge.py index 9397b2c..ff81e7b 100644 --- a/scripts/lsl_ws_bridge.py +++ b/scripts/lsl_ws_bridge.py @@ -129,6 +129,17 @@ class LSLChannelFormat(Enum): # DATA STRUCTURES # ============================================================================ +@dataclass +class LSLStreamInfo: + """Lightweight stream info for external use and testing""" + name: str + stream_type: str + channel_count: int + sampling_rate: float + source_id: str + channel_labels: Optional[List[str]] = None + + @dataclass class StreamMetadata: """Metadata about an LSL stream""" @@ -709,55 +720,91 @@ class LSLSimulator: def __init__( self, - name: str = "PhantomLoop_Simulated_EEG", + stream_name: str = "PhantomLoop_Simulated_EEG", stream_type: str = "EEG", channel_count: int = 8, - sampling_rate: float = 250.0, + sample_rate: float = 250.0, + name: str = None, # Alias for backward compatibility + sampling_rate: float = None, # Alias for backward compatibility ): - self.name = name + # Handle aliases for backward compatibility + self._name = stream_name if name is None else name self.stream_type = stream_type self.channel_count = channel_count - self.sampling_rate = sampling_rate - self.running = False + self._sampling_rate = sample_rate if sampling_rate is None else sampling_rate + self._is_streaming = False self.outlet = None + self._sample_buffer: List[Tuple[List[float], float]] = [] + self._buffer_lock = threading.Lock() + + # Create stream_info for test compatibility + self.stream_info = LSLStreamInfo( + name=self._name, + stream_type=self.stream_type, + channel_count=self.channel_count, + sampling_rate=self._sampling_rate, + source_id="phantomloop-sim-001", + ) + + @property + def name(self) -> str: + return self._name + + @property + def sampling_rate(self) -> float: + return self._sampling_rate + + @property + def is_streaming(self) -> bool: + return self._is_streaming + + @property + def running(self) -> bool: + """Alias for backward compatibility""" + return self._is_streaming + + @running.setter + def running(self, value: bool): + self._is_streaming = value + + def get_stream_info(self) -> LSLStreamInfo: + """Get stream information""" + return self.stream_info def start(self): """Start simulated LSL outlet""" - if not LSL_AVAILABLE: - logger.error("pylsl not available for simulation") - return - - from pylsl import StreamOutlet, StreamInfo as LSLStreamInfo - - # Create stream info - info = LSLStreamInfo( - self.name, - self.stream_type, - self.channel_count, - self.sampling_rate, - 'float32', - 'phantomloop-sim-001' - ) + self._is_streaming = True - # Add channel descriptions - desc = info.desc() - channels = desc.append_child("channels") - for i in range(self.channel_count): - ch = channels.append_child("channel") - ch.append_child_value("label", f"Ch{i+1}") - ch.append_child_value("type", "EEG") - ch.append_child_value("unit", "µV") - - # Add acquisition info - acq = desc.append_child("acquisition") - acq.append_child_value("manufacturer", "PhantomLoop") - acq.append_child_value("model", "Simulated EEG") - - # Create outlet - self.outlet = StreamOutlet(info) - self.running = True - - logger.info(f"Started simulated LSL outlet: {self.name}") + if LSL_AVAILABLE: + from pylsl import StreamOutlet, StreamInfo as PyLSLStreamInfo + + # Create stream info + info = PyLSLStreamInfo( + self._name, + self.stream_type, + self.channel_count, + self._sampling_rate, + 'float32', + 'phantomloop-sim-001' + ) + + # Add channel descriptions + desc = info.desc() + channels = desc.append_child("channels") + for i in range(self.channel_count): + ch = channels.append_child("channel") + ch.append_child_value("label", f"Ch{i+1}") + ch.append_child_value("type", "EEG") + ch.append_child_value("unit", "µV") + + # Add acquisition info + acq = desc.append_child("acquisition") + acq.append_child_value("manufacturer", "PhantomLoop") + acq.append_child_value("model", "Simulated EEG") + + # Create outlet + self.outlet = StreamOutlet(info) + logger.info(f"Started simulated LSL outlet: {self._name}") # Start streaming thread thread = threading.Thread(target=self._stream_loop, daemon=True) @@ -765,10 +812,10 @@ def start(self): def _stream_loop(self): """Generate and push simulated EEG samples""" - sample_interval = 1.0 / self.sampling_rate + sample_interval = 1.0 / self._sampling_rate phase = np.zeros(self.channel_count) - while self.running and self.outlet: + while self._is_streaming: # Generate simulated EEG (alpha waves + noise) sample = [] for ch in range(self.channel_count): @@ -778,14 +825,31 @@ def _stream_loop(self): sample.append(float(alpha + noise)) phase[ch] += sample_interval - # Push sample - self.outlet.push_sample(sample) + timestamp = time.time() + + # Push to LSL outlet if available + if self.outlet: + self.outlet.push_sample(sample) + + # Also buffer for pull_sample() + with self._buffer_lock: + self._sample_buffer.append((sample, timestamp)) + # Keep buffer reasonable + if len(self._sample_buffer) > 1000: + self._sample_buffer = self._sample_buffer[-500:] time.sleep(sample_interval) + def pull_sample(self, timeout: float = 0.0) -> Tuple[Optional[List[float]], float]: + """Pull a sample from the buffer (for testing without pylsl)""" + with self._buffer_lock: + if self._sample_buffer: + return self._sample_buffer.pop(0) + return None, 0.0 + def stop(self): """Stop the simulator""" - self.running = False + self._is_streaming = False self.outlet = None diff --git a/scripts/tests/test_lsl_bridge.py b/scripts/tests/test_lsl_bridge.py index 057ca52..3bd14f2 100644 --- a/scripts/tests/test_lsl_bridge.py +++ b/scripts/tests/test_lsl_bridge.py @@ -329,13 +329,13 @@ def test_header_format(self): timestamp = 1234567890.123456 # Pack header - header = struct.pack('>HBHBD', magic, packet_type, num_samples, num_channels, timestamp) + header = struct.pack('>HBHBd', magic, packet_type, num_samples, num_channels, timestamp) # Should be 14 bytes assert len(header) == 14 # Unpack and verify - unpacked = struct.unpack('>HBHBD', header) + unpacked = struct.unpack('>HBHBd', header) assert unpacked[0] == magic assert unpacked[1] == packet_type assert unpacked[2] == num_samples @@ -377,7 +377,7 @@ def test_full_packet(self): samples = [[10.0, 20.0, 30.0, 40.0], [50.0, 60.0, 70.0, 80.0]] # Pack - header = struct.pack('>HBHBD', magic, packet_type, num_samples, num_channels, timestamp) + header = struct.pack('>HBHBd', magic, packet_type, num_samples, num_channels, timestamp) data = b'' for sample in samples: data += struct.pack(f'>{num_channels}f', *sample) @@ -385,7 +385,7 @@ def test_full_packet(self): packet = header + data # Parse - parsed_header = struct.unpack('>HBHBD', packet[:14]) + parsed_header = struct.unpack('>HBHBd', packet[:14]) assert parsed_header[0] == magic assert parsed_header[2] == num_samples @@ -521,8 +521,8 @@ def test_full_session(self): def test_multiple_simulators(self): """Test running multiple simulators""" - sim1 = LSLSimulator("Stream1", 4, 250.0) - sim2 = LSLSimulator("Stream2", 8, 500.0) + sim1 = LSLSimulator(stream_name="Stream1", channel_count=4, sample_rate=250.0) + sim2 = LSLSimulator(stream_name="Stream2", channel_count=8, sample_rate=500.0) assert sim1.stream_info.channel_count == 4 assert sim2.stream_info.channel_count == 8 diff --git a/scripts/tests/test_pieeg_bridge.py b/scripts/tests/test_pieeg_bridge.py index d7c2612..2be8484 100644 --- a/scripts/tests/test_pieeg_bridge.py +++ b/scripts/tests/test_pieeg_bridge.py @@ -288,7 +288,7 @@ def test_pack_samples_format(self): num_channels = len(samples[0]) # Header: magic (2) + num_samples (2) + num_channels (1) + timestamp (8) - header = struct.pack('>HHBD', 0xEEEE, num_samples, num_channels, timestamp) + header = struct.pack('>HHBd', 0xEEEE, num_samples, num_channels, timestamp) assert len(header) == 13 @@ -304,10 +304,10 @@ def test_unpack_header(self): num_channels = 8 timestamp = 1234567890.123 - header = struct.pack('>HHBD', magic, num_samples, num_channels, timestamp) + header = struct.pack('>HHBd', magic, num_samples, num_channels, timestamp) # Unpack - unpacked = struct.unpack('>HHBD', header) + unpacked = struct.unpack('>HHBd', header) assert unpacked[0] == magic assert unpacked[1] == num_samples From f782ba18fc75dea6cdd48fe43dbba245f9690f16 Mon Sep 17 00:00:00 2001 From: Youssef Date: Sun, 1 Feb 2026 10:01:08 -0500 Subject: [PATCH 5/5] fix: update CI workflow to include Python bridge tests and correct branch names --- .github/workflows/ci.yml | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0652a82..05067c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [main, cerelog-esp-eeg-experiment] + branches: [main, dev] pull_request: branches: [main] @@ -30,6 +30,31 @@ jobs: - name: Run type check run: npx tsc --noEmit + python-test: + name: Python Bridge Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: scripts/requirements-dev.txt + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r scripts/requirements-dev.txt + + - name: Run Python tests + run: | + cd scripts + python -m pytest tests/ -v --tb=short + lint: name: Lint runs-on: ubuntu-latest @@ -53,7 +78,7 @@ jobs: build: name: Build runs-on: ubuntu-latest - needs: [test, lint] + needs: [test, python-test, lint] steps: - name: Checkout repository