Skip to content

Commit 84d407b

Browse files
committed
Add benchmarking
1 parent d8fbf9c commit 84d407b

File tree

6 files changed

+328
-119
lines changed

6 files changed

+328
-119
lines changed

.metadata/metadata.org

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ docker-container-port:
813813
# File edits may be overwritten!
814814
[build-system]
815815
requires = ["setuptools"]
816-
build-backed = "setuptools.build_meta"
816+
build-backend = "setuptools.build_meta"
817817
#+END_SRC
818818

819819
#+HEADER: :tangle (if tangle-external-files "../setup.cfg" "no")

arena_interface/__about__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# This file is generated automatically from metadata
33
# File edits may be overwritten!
44

5-
__version__ = '5.0.0'
5+
__version__ = '5.1.0'
66
__description__ = 'Python interface to the Reiser lab ArenaController.'
77
__license__ = 'BSD-3-Clause'
88
__url__ = 'https://github.com/janelia-python/arena_interface_python'
99
__author__ = 'Peter Polidoro'
1010
__email__ = 'peter@polidoro.io'
11-
__copyright__ = '2025 Howard Hughes Medical Institute'
11+
__copyright__ = '2026 Howard Hughes Medical Institute'

arena_interface/arena_interface.py

Lines changed: 128 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
"""Python interface to the Reiser lab ArenaController."""
2+
from __future__ import annotations
3+
4+
import math
5+
import os
26
import socket
37
import struct
48
import time
@@ -9,7 +13,9 @@
913
import pstats
1014

1115

12-
PORT = 62222
16+
ETHERNET_SERVER_PORT = 62222
17+
# Backwards-compat alias (older code may import PORT)
18+
PORT = ETHERNET_SERVER_PORT
1319
PATTERN_HEADER_SIZE = 7
1420
BYTE_COUNT_PER_PANEL_GRAYSCALE = 132
1521
REPEAT_LIMIT = 4
@@ -23,6 +29,10 @@
2329
ANALOG_OUTPUT_VALUE_MIN = 100
2430
ANALOG_OUTPUT_VALUE_MAX = 4095
2531

32+
# Chunk size used for optional STREAM_FRAME chunked sends.
33+
# Keep this comfortably below typical MTU to avoid excessive fragmentation.
34+
CHUNK_SIZE = 4096
35+
2636

2737
class ArenaInterface():
2838
"""Python interface to the Reiser lab ArenaController."""
@@ -31,6 +41,7 @@ def __init__(self, debug=False):
3141
self._debug = debug
3242
self._serial = None
3343
self._ethernet_ip_address = ''
44+
self._ethernet_socket: socket.socket | None = None
3445
atexit.register(self._exit)
3546

3647
def _debug_print(self, *args):
@@ -42,8 +53,11 @@ def _exit(self):
4253
"""
4354
Close the serial connection to provide some clean up.
4455
"""
45-
if self._serial:
46-
self._serial.close()
56+
try:
57+
self.close()
58+
except Exception:
59+
# Best-effort cleanup only.
60+
pass
4761

4862
def _connect_ethernet_socket(self, repeat_count=10, reuse=True):
4963
"""Connect (or reuse) a TCP socket to the firmware's Ethernet server."""
@@ -94,6 +108,20 @@ def _recv_exact(ethernet_socket: socket.socket, n: int) -> bytes:
94108
data += chunk
95109
return data
96110

111+
def _read(self, transport, n: int) -> bytes:
112+
"""Read exactly n bytes from a serial or Ethernet transport."""
113+
if transport is None:
114+
raise RuntimeError("No transport provided")
115+
if isinstance(transport, socket.socket):
116+
return self._recv_exact(transport, n)
117+
# Assume pyserial-like.
118+
data = transport.read(n)
119+
if data is None:
120+
return b""
121+
if len(data) != n:
122+
raise TimeoutError(f"serial read short: expected {n}, got {len(data)}")
123+
return data
124+
97125
def _send_and_receive(self, cmd, ethernet_socket=None):
98126
"""Send a command and wait for a binary response.
99127
@@ -158,7 +186,7 @@ def set_ethernet_mode(self, ip_address):
158186
def set_serial_mode(self, port, baudrate=SERIAL_BAUDRATE):
159187
"""Set serial mode specifying the serial port."""
160188
self._close_ethernet_socket()
161-
self._ethernet_ip_address = None
189+
self._ethernet_ip_address = ''
162190
if self._serial:
163191
self._serial.close()
164192

@@ -258,12 +286,31 @@ def play_pattern_analog_closed_loop(self, pattern_id, gain, runtime_duration, in
258286
break
259287
self._debug_print('response: ', response)
260288

261-
def show_pattern_frame(self, pattern_id, frame_index, ethernet_socket=None):
262-
"""Show pattern frame."""
289+
def show_pattern_frame(
290+
self,
291+
pattern_id,
292+
frame_index,
293+
frame_rate: int = 0,
294+
runtime_duration: int = 0,
295+
gain: int = 0x10,
296+
ethernet_socket=None,
297+
):
298+
"""Show pattern frame.
299+
300+
Parameters
301+
----------
302+
pattern_id:
303+
Pattern ID on the controller.
304+
frame_index:
305+
Initial frame index.
306+
frame_rate:
307+
Target refresh rate (Hz). Some firmware builds use this as the mode
308+
target_hz for perf sessions.
309+
runtime_duration:
310+
Duration in 100ms ticks (same unit as play_pattern TRIAL_PARAMS).
311+
Use 0 for "run until interrupted".
312+
"""
263313
control_mode = 0x03
264-
frame_rate = 0
265-
gain = 0x10 # dummy value
266-
runtime_duration = 0
267314
cmd_bytes = struct.pack('<BBBHhHhH',
268315
0x0c,
269316
0x08,
@@ -307,7 +354,7 @@ def stream_pattern_frame_indicies(self, pattern_id, frame_index_min, frame_index
307354
frames_displayed_count = 0
308355
frames_to_display_count = int((frame_rate * runtime_duration) / RUNTIME_DURATION_PER_SECOND)
309356
ethernet_socket = self._connect_ethernet_socket()
310-
self.show_pattern_frame(pattern_id, frame_index_min, ethernet_socket)
357+
self.show_pattern_frame(pattern_id, frame_index_min, ethernet_socket=ethernet_socket)
311358
stream_frames_start_time = time.time_ns()
312359
while frames_displayed_count < frames_to_display_count:
313360
pattern_start_time = time.time_ns()
@@ -334,6 +381,14 @@ def get_ethernet_ip_address(self):
334381
"""Get Ethernet IP address."""
335382
return self._send_and_receive(b'\x01\x66')
336383

384+
def get_perf_stats(self, ethernet_socket=None) -> bytes:
385+
"""Fetch a raw performance stats snapshot (binary payload)."""
386+
return self._send_and_receive(b'\x01\x71', ethernet_socket)
387+
388+
def reset_perf_stats(self, ethernet_socket=None):
389+
"""Reset performance counters on the device."""
390+
self._send_and_receive(b'\x01\x72', ethernet_socket)
391+
337392
def all_on(self):
338393
"""Turn all panels on."""
339394
self._send_and_receive(b'\x01\xff')
@@ -395,45 +450,69 @@ def stream_frames(
395450
stream_cmd_coalesced=False,
396451
progress_interval_s=1.0,
397452
):
398-
"""Stream a pattern file's frames at a fixed rate for a fixed duration.
399-
400-
Returns a dict with basic host-side throughput stats.
453+
"""Stream a `.pattern` file's frames at a fixed rate for a fixed duration.
454+
455+
Notes
456+
-----
457+
- `runtime_duration` uses the same unit as TRIAL_PARAMS: **100ms ticks**.
458+
For example, `runtime_duration=50` streams for ~5 seconds.
459+
- The `.pattern` file format expected here is:
460+
[uint32_le frame_size][frame0 bytes][frame1 bytes]...
461+
462+
Returns
463+
-------
464+
dict
465+
Basic host-side throughput stats.
401466
"""
402467
# Read pattern file: [uint32 frame_size][frame0][frame1]...
403468
with open(pattern_path, 'rb') as f:
404-
frame_size = struct.unpack('<I', f.read(4))[0]
469+
frame_size_raw = f.read(4)
470+
if len(frame_size_raw) != 4:
471+
raise ValueError(f"{pattern_path} is too small to be a .pattern file")
472+
frame_size = struct.unpack('<I', frame_size_raw)[0]
473+
if frame_size <= 0:
474+
raise ValueError(f"invalid frame_size={frame_size} in {pattern_path}")
475+
405476
file_size = os.path.getsize(pattern_path)
477+
if (file_size - 4) % frame_size != 0:
478+
raise ValueError(
479+
f"file size {file_size} not compatible with frame_size {frame_size} in {pattern_path}"
480+
)
406481
num_frames = int((file_size - 4) / frame_size)
407482
frames = [f.read(frame_size) for _ in range(num_frames)]
408483

409-
frames_total = int(runtime_duration * frame_rate)
410-
frame_period_ns = int((1 / frame_rate) * 1e9)
484+
runtime_duration_s = float(runtime_duration) / float(RUNTIME_DURATION_PER_SECOND)
485+
frames_total = int(runtime_duration_s * float(frame_rate))
486+
frame_period_ns = int((1.0 / float(frame_rate)) * 1e9) if frame_rate else 0
411487

412-
analog_update_period_ns = int((1 / analog_update_rate) * 1e9)
413-
analog_start_time_ns = time.perf_counter_ns()
488+
analog_update_period_ns = int((1.0 / float(analog_update_rate)) * 1e9) if analog_update_rate else 0
414489

415-
analog_amplitude = (2 ** 16) / 2 - 1
416-
analog_offset = (2 ** 16) / 2
490+
# Map waveform output [-1..1] into a conservative 12-bit-ish range.
491+
analog_amplitude = (ANALOG_OUTPUT_VALUE_MAX - ANALOG_OUTPUT_VALUE_MIN) / 2.0
492+
analog_offset = (ANALOG_OUTPUT_VALUE_MAX + ANALOG_OUTPUT_VALUE_MIN) / 2.0
417493

418-
def analog_waveform_for(name):
494+
def analog_waveform_for(name: str):
419495
if name == 'sin':
420-
return np.sin
421-
elif name == 'square':
422-
return lambda x: np.sign(np.sin(x))
423-
elif name == 'sawtooth':
424-
return lambda x: 2 * (x / (2 * np.pi) - np.floor(1 / 2 + x / (2 * np.pi)))
425-
elif name == 'triangle':
426-
return lambda x: 2 * np.abs(2 * (x / (2 * np.pi) - np.floor(1 / 2 + x / (2 * np.pi)))) - 1
427-
elif name == 'constant':
428-
return lambda x: 0
429-
else:
430-
raise ValueError(f'Invalid analog output waveform: {name}')
496+
return math.sin
497+
if name == 'square':
498+
return lambda x: 1.0 if math.sin(x) >= 0 else -1.0
499+
if name == 'sawtooth':
500+
return lambda x: 2.0 * (x / (2.0 * math.pi) - math.floor(0.5 + x / (2.0 * math.pi)))
501+
if name == 'triangle':
502+
return lambda x: 2.0 * abs(2.0 * (x / (2.0 * math.pi) - math.floor(0.5 + x / (2.0 * math.pi)))) - 1.0
503+
if name == 'constant':
504+
return lambda x: 0.0
505+
raise ValueError(f'Invalid analog output waveform: {name}')
431506

432507
ethernet_socket = self._connect_ethernet_socket(reuse=True)
433508

434509
bytes_sent = 0
435510
frames_streamed = 0
436511

512+
wf = analog_waveform_for(str(analog_out_waveform))
513+
last_analog_update_ns = 0
514+
analog_value_cached = int(round(analog_offset))
515+
437516
start_time_ns = time.perf_counter_ns()
438517
next_progress_ns = None
439518
if progress_interval_s and (progress_interval_s > 0):
@@ -444,13 +523,20 @@ def analog_waveform_for(name):
444523
frame = frames[frame_index]
445524

446525
# Analog output update (optional)
447-
if (time.perf_counter_ns() - analog_start_time_ns) > (i * analog_update_period_ns):
448-
analog_phase = (i / analog_update_rate) * analog_frequency * 2 * np.pi
449-
analog_output_value = analog_amplitude * analog_waveform_for(analog_out_waveform)(analog_phase) + analog_offset
450-
# Ensure this is an int in-range for uint16.
451-
analog_output_value = int(max(0, min(65535, round(float(analog_output_value)))))
452-
else:
453-
analog_output_value = 0
526+
now_ns = time.perf_counter_ns()
527+
if analog_update_period_ns and (now_ns - last_analog_update_ns) >= analog_update_period_ns:
528+
t_s = (now_ns - start_time_ns) / 1e9
529+
analog_phase = (t_s * float(analog_frequency)) * (2.0 * math.pi)
530+
analog_output_value_f = analog_amplitude * float(wf(analog_phase)) + analog_offset
531+
analog_value_cached = int(
532+
max(
533+
ANALOG_OUTPUT_VALUE_MIN,
534+
min(ANALOG_OUTPUT_VALUE_MAX, round(analog_output_value_f)),
535+
)
536+
)
537+
last_analog_update_ns = now_ns
538+
539+
analog_output_value = int(analog_value_cached)
454540

455541
# Stream frame header: cmd(0x32), data_len(uint16), analog(uint16), reserved(uint16)
456542
data_len = len(frame)
@@ -476,8 +562,9 @@ def analog_waveform_for(name):
476562
next_progress_ns += int(progress_interval_s * 1e9)
477563

478564
# Rate limiting (busy-wait)
479-
while (time.perf_counter_ns() - start_time_ns) < ((i + 1) * frame_period_ns):
480-
pass
565+
if frame_period_ns:
566+
while (time.perf_counter_ns() - start_time_ns) < ((i + 1) * frame_period_ns):
567+
pass
481568

482569
# End the mode
483570
self._send_and_receive(bytes([1, 0]), ethernet_socket)

0 commit comments

Comments
 (0)