diff --git a/synapse/examples/tap_example.py b/synapse/examples/tap_example.py new file mode 100644 index 00000000..029efc03 --- /dev/null +++ b/synapse/examples/tap_example.py @@ -0,0 +1,101 @@ +import synapse as syn +import sys +import time + +from synapse.client.taps import Tap + +SIMULATED_PERIPHERAL_ID = 100 + +if __name__ == "__main__": + uri = sys.argv[1] if len(sys.argv) > 1 else "127.0.0.1:647" + device = syn.Device(uri) + info = device.info() + if info is None: + print("Couldn't get device info") + sys.exit(1) + + print("Device info:") + print(info) + + channels = [ + syn.Channel( + id=channel_num, + electrode_id=channel_num * 2, + reference_id=channel_num * 2 + 1, + ) + for channel_num in range(32) + ] + + broadband = syn.BroadbandSource( + # Use the simulated peripheral (100), or replace with your own + peripheral_id=SIMULATED_PERIPHERAL_ID, + sample_rate_hz=30000, + bit_width=12, + gain=20.0, + signal=syn.SignalConfig( + electrode=syn.ElectrodeConfig( + channels=channels, + low_cutoff_hz=500.0, + high_cutoff_hz=6000.0, + ) + ), + ) + + config = syn.Config() + config.add_node(broadband) + + device.configure(config) + + # export the config to a json file for using with CLI + # from google.protobuf.json_format import MessageToJson + # with open("device_config.json", "w") as f: + # f.write(MessageToJson(config.to_proto())) + # print("Config written to device_config.json") + + device.start() + + info = device.info() + if info is None: + print("Couldn't get device info") + sys.exit(1) + print("Configured device info:") + print(info) + + # stream with tap api + tap_client = Tap(uri) + tap_client.connect("broadband_source_sim") + + should_run = True + total_bytes_read = 0 + start_time = time.time() + last_update_time = start_time + update_interval_sec = 1 + while should_run: + try: + # Wait for data + syn_data = tap_client.read() + bytes_read = len(syn_data) + if syn_data is None or bytes_read == 0: + print("Failed to read data from node") + continue + # Do something with the data + total_bytes_read += bytes_read + + current_time = time.time() + if (current_time - last_update_time) >= update_interval_sec: + sys.stdout.write("\r") + sys.stdout.write( + f"{total_bytes_read} bytes in {time.time() - start_time:.2f} sec" + ) + last_update_time = current_time + + if current_time - start_time > 5: + should_run = False + + except KeyboardInterrupt: + print("Keyboard interrupt detected, stopping") + should_run = False + + print("Stopping device") + device.stop() + diff --git a/synapse/server/nodes/base.py b/synapse/server/nodes/base.py index 413c4936..95bdccf1 100644 --- a/synapse/server/nodes/base.py +++ b/synapse/server/nodes/base.py @@ -71,3 +71,6 @@ def node_socket(self): bind=f"{self.socket[0]}:{self.socket[1]}", type=self.type, ) + + def tap_connections(self): + return [] diff --git a/synapse/server/rpc.py b/synapse/server/rpc.py index e124a23b..14d678b4 100644 --- a/synapse/server/rpc.py +++ b/synapse/server/rpc.py @@ -9,6 +9,7 @@ from synapse.api.node_pb2 import NodeConnection, NodeType from synapse.api.logging_pb2 import LogLevel, LogQueryResponse from synapse.api.query_pb2 import QueryResponse +from synapse.api.tap_pb2 import ListTapsResponse from synapse.api.status_pb2 import DeviceState, Status, StatusCode from synapse.api.device_pb2 import DeviceConfiguration, DeviceInfo from synapse.api.synapse_pb2_grpc import ( @@ -69,6 +70,7 @@ class SynapseServicer(SynapseDeviceServicer): def __init__(self, name, serial, iface_ip, node_object_map, peripherals): self.name = name self.serial = serial + self.iface_ip = iface_ip self.node_object_map = node_object_map self.peripherals = peripherals self.logger = logging.getLogger("server") @@ -168,6 +170,10 @@ async def Query(self, request, context): # handle query + taps = [] + for node in self.nodes: + taps.extend(node.tap_connections()) + return QueryResponse( data=[1, 2, 3, 4, 5], status=Status( @@ -176,6 +182,7 @@ async def Query(self, request, context): sockets=self._sockets_status_info(), state=self.state, ), + list_taps_response=ListTapsResponse(taps=taps), ) async def GetLogs(self, request, context): @@ -320,7 +327,7 @@ def _reconfigure(self, configuration): "Creating %s node(%d)" % (NodeType.Name(node.type), node.id) ) node = self.node_object_map[node.type](node.id) - if node.type in [NodeType.kStreamIn]: + if node.type in [NodeType.kStreamIn, NodeType.kBroadbandSource]: node.configure_iface_ip(self.iface_ip) status = node.configure(config) diff --git a/synapse/simulator/nodes/broadband_source.py b/synapse/simulator/nodes/broadband_source.py index efb9cf39..469a04f8 100644 --- a/synapse/simulator/nodes/broadband_source.py +++ b/synapse/simulator/nodes/broadband_source.py @@ -2,9 +2,13 @@ import random import time +import zmq + from synapse.api.node_pb2 import NodeType from synapse.api.nodes.broadband_source_pb2 import BroadbandSourceConfig from synapse.server.nodes.base import BaseNode +from synapse.api.tap_pb2 import TapConnection, TapType +from synapse.api.datatype_pb2 import BroadbandFrame from synapse.server.status import Status from synapse.utils.ndtp_types import ElectricalBroadbandData @@ -16,6 +20,11 @@ class BroadbandSource(BaseNode): def __init__(self, id): super().__init__(id, NodeType.kBroadbandSource) self.__config: BroadbandSourceConfig = None + self.zmq_context = None + self.zmq_socket = None + self.seq_number = 0 + self.iface_ip = None + self.port = None def config(self): c = super().config() @@ -39,16 +48,25 @@ async def run(self): if not c.HasField("signal") or not c.signal: self.logger.error("node signal not configured") return - + if not c.signal.HasField("electrode") or not c.signal.electrode: self.logger.error("node signal electrode not configured") return - + e = c.signal.electrode if not e.channels: self.logger.error("node signal electrode channels not configured") return + if not self.zmq_context: + if not self.iface_ip: + self.logger.error("iface_ip not configured") + return + + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.PUB) + self.port = self.zmq_socket.bind_to_random_port(f"tcp://{self.iface_ip}") + channels = e.channels bit_width = c.bit_width if c.bit_width else 4 sample_rate_hz = c.sample_rate_hz if c.sample_rate_hz else 16000 @@ -56,20 +74,63 @@ async def run(self): t_last_ns = time.time_ns() while self.running: await asyncio.sleep(0.01) - + now = time.time_ns() elapsed_ns = now - t_last_ns n_samples = int(sample_rate_hz * elapsed_ns / 1e9) - samples = [[ch.id, [r_sample(bit_width) for _ in range(n_samples)]] for ch in channels] - data = ElectricalBroadbandData( - bit_width=bit_width, - is_signed=False, - sample_rate=sample_rate_hz, - t0=t_last_ns, - samples=samples + try: + # for backwards compatibility + data = ElectricalBroadbandData( + bit_width=bit_width, + is_signed=False, + sample_rate=sample_rate_hz, + t0=t_last_ns, + samples=samples + ) + await self.emit_data(data) + + # send data over tap + for i in range(n_samples): + frame = BroadbandFrame( + timestamp_ns = t_last_ns + int(i * 1e9 / sample_rate_hz), + sequence_number = self.seq_number, + frame_data = [chan_samples[i] for _, chan_samples in samples], + sample_rate_hz = sample_rate_hz, + ) + try: + self.zmq_socket.send(frame.SerializeToString()) + self.seq_number += 1 + except Exception as e: + self.logger.error(f"Error sending data: {e}") + + t_last_ns = now + except Exception as e: + print(f"Error sending data: {e}") + + def stop(self): + """Clean up ZMQ resources.""" + if self.zmq_socket: + self.zmq_socket.close() + self.zmq_socket = None + + if self.zmq_context: + self.zmq_context.destroy() + self.zmq_context = None + + return super().stop() + + def configure_iface_ip(self, iface_ip): + self.iface_ip = iface_ip + + def tap_connections(self): + return [ + TapConnection( + name="broadband_source_sim", + endpoint=f"tcp://{self.iface_ip}:{self.port}", + message_type="synapse.BroadbandFrame", + tap_type=TapType.TAP_TYPE_PRODUCER, ) + ] - await self.emit_data(data) - t_last_ns = now