11"""Python interface to the Reiser lab ArenaController."""
2+ from __future__ import annotations
3+
4+ import math
5+ import os
26import socket
37import struct
48import time
913import pstats
1014
1115
12- PORT = 62222
16+ ETHERNET_SERVER_PORT = 62222
17+ # Backwards-compat alias (older code may import PORT)
18+ PORT = ETHERNET_SERVER_PORT
1319PATTERN_HEADER_SIZE = 7
1420BYTE_COUNT_PER_PANEL_GRAYSCALE = 132
1521REPEAT_LIMIT = 4
2329ANALOG_OUTPUT_VALUE_MIN = 100
2430ANALOG_OUTPUT_VALUE_MAX = 4095
2531
32+ # Chunk size used for optional STREAM_FRAME chunked sends.
33+ # Keep this comfortably below typical MTU to avoid excessive fragmentation.
34+ CHUNK_SIZE = 4096
35+
2636
2737class ArenaInterface ():
2838 """Python interface to the Reiser lab ArenaController."""
@@ -31,6 +41,7 @@ def __init__(self, debug=False):
3141 self ._debug = debug
3242 self ._serial = None
3343 self ._ethernet_ip_address = ''
44+ self ._ethernet_socket : socket .socket | None = None
3445 atexit .register (self ._exit )
3546
3647 def _debug_print (self , * args ):
@@ -42,8 +53,11 @@ def _exit(self):
4253 """
4354 Close the serial connection to provide some clean up.
4455 """
45- if self ._serial :
46- self ._serial .close ()
56+ try :
57+ self .close ()
58+ except Exception :
59+ # Best-effort cleanup only.
60+ pass
4761
4862 def _connect_ethernet_socket (self , repeat_count = 10 , reuse = True ):
4963 """Connect (or reuse) a TCP socket to the firmware's Ethernet server."""
@@ -94,6 +108,20 @@ def _recv_exact(ethernet_socket: socket.socket, n: int) -> bytes:
94108 data += chunk
95109 return data
96110
111+ def _read (self , transport , n : int ) -> bytes :
112+ """Read exactly n bytes from a serial or Ethernet transport."""
113+ if transport is None :
114+ raise RuntimeError ("No transport provided" )
115+ if isinstance (transport , socket .socket ):
116+ return self ._recv_exact (transport , n )
117+ # Assume pyserial-like.
118+ data = transport .read (n )
119+ if data is None :
120+ return b""
121+ if len (data ) != n :
122+ raise TimeoutError (f"serial read short: expected { n } , got { len (data )} " )
123+ return data
124+
97125 def _send_and_receive (self , cmd , ethernet_socket = None ):
98126 """Send a command and wait for a binary response.
99127
@@ -158,7 +186,7 @@ def set_ethernet_mode(self, ip_address):
158186 def set_serial_mode (self , port , baudrate = SERIAL_BAUDRATE ):
159187 """Set serial mode specifying the serial port."""
160188 self ._close_ethernet_socket ()
161- self ._ethernet_ip_address = None
189+ self ._ethernet_ip_address = ''
162190 if self ._serial :
163191 self ._serial .close ()
164192
@@ -258,12 +286,31 @@ def play_pattern_analog_closed_loop(self, pattern_id, gain, runtime_duration, in
258286 break
259287 self ._debug_print ('response: ' , response )
260288
261- def show_pattern_frame (self , pattern_id , frame_index , ethernet_socket = None ):
262- """Show pattern frame."""
289+ def show_pattern_frame (
290+ self ,
291+ pattern_id ,
292+ frame_index ,
293+ frame_rate : int = 0 ,
294+ runtime_duration : int = 0 ,
295+ gain : int = 0x10 ,
296+ ethernet_socket = None ,
297+ ):
298+ """Show pattern frame.
299+
300+ Parameters
301+ ----------
302+ pattern_id:
303+ Pattern ID on the controller.
304+ frame_index:
305+ Initial frame index.
306+ frame_rate:
307+ Target refresh rate (Hz). Some firmware builds use this as the mode
308+ target_hz for perf sessions.
309+ runtime_duration:
310+ Duration in 100ms ticks (same unit as play_pattern TRIAL_PARAMS).
311+ Use 0 for "run until interrupted".
312+ """
263313 control_mode = 0x03
264- frame_rate = 0
265- gain = 0x10 # dummy value
266- runtime_duration = 0
267314 cmd_bytes = struct .pack ('<BBBHhHhH' ,
268315 0x0c ,
269316 0x08 ,
@@ -307,7 +354,7 @@ def stream_pattern_frame_indicies(self, pattern_id, frame_index_min, frame_index
307354 frames_displayed_count = 0
308355 frames_to_display_count = int ((frame_rate * runtime_duration ) / RUNTIME_DURATION_PER_SECOND )
309356 ethernet_socket = self ._connect_ethernet_socket ()
310- self .show_pattern_frame (pattern_id , frame_index_min , ethernet_socket )
357+ self .show_pattern_frame (pattern_id , frame_index_min , ethernet_socket = ethernet_socket )
311358 stream_frames_start_time = time .time_ns ()
312359 while frames_displayed_count < frames_to_display_count :
313360 pattern_start_time = time .time_ns ()
@@ -334,6 +381,14 @@ def get_ethernet_ip_address(self):
334381 """Get Ethernet IP address."""
335382 return self ._send_and_receive (b'\x01 \x66 ' )
336383
384+ def get_perf_stats (self , ethernet_socket = None ) -> bytes :
385+ """Fetch a raw performance stats snapshot (binary payload)."""
386+ return self ._send_and_receive (b'\x01 \x71 ' , ethernet_socket )
387+
388+ def reset_perf_stats (self , ethernet_socket = None ):
389+ """Reset performance counters on the device."""
390+ self ._send_and_receive (b'\x01 \x72 ' , ethernet_socket )
391+
337392 def all_on (self ):
338393 """Turn all panels on."""
339394 self ._send_and_receive (b'\x01 \xff ' )
@@ -395,45 +450,69 @@ def stream_frames(
395450 stream_cmd_coalesced = False ,
396451 progress_interval_s = 1.0 ,
397452 ):
398- """Stream a pattern file's frames at a fixed rate for a fixed duration.
399-
400- Returns a dict with basic host-side throughput stats.
453+ """Stream a `.pattern` file's frames at a fixed rate for a fixed duration.
454+
455+ Notes
456+ -----
457+ - `runtime_duration` uses the same unit as TRIAL_PARAMS: **100ms ticks**.
458+ For example, `runtime_duration=50` streams for ~5 seconds.
459+ - The `.pattern` file format expected here is:
460+ [uint32_le frame_size][frame0 bytes][frame1 bytes]...
461+
462+ Returns
463+ -------
464+ dict
465+ Basic host-side throughput stats.
401466 """
402467 # Read pattern file: [uint32 frame_size][frame0][frame1]...
403468 with open (pattern_path , 'rb' ) as f :
404- frame_size = struct .unpack ('<I' , f .read (4 ))[0 ]
469+ frame_size_raw = f .read (4 )
470+ if len (frame_size_raw ) != 4 :
471+ raise ValueError (f"{ pattern_path } is too small to be a .pattern file" )
472+ frame_size = struct .unpack ('<I' , frame_size_raw )[0 ]
473+ if frame_size <= 0 :
474+ raise ValueError (f"invalid frame_size={ frame_size } in { pattern_path } " )
475+
405476 file_size = os .path .getsize (pattern_path )
477+ if (file_size - 4 ) % frame_size != 0 :
478+ raise ValueError (
479+ f"file size { file_size } not compatible with frame_size { frame_size } in { pattern_path } "
480+ )
406481 num_frames = int ((file_size - 4 ) / frame_size )
407482 frames = [f .read (frame_size ) for _ in range (num_frames )]
408483
409- frames_total = int (runtime_duration * frame_rate )
410- frame_period_ns = int ((1 / frame_rate ) * 1e9 )
484+ runtime_duration_s = float (runtime_duration ) / float (RUNTIME_DURATION_PER_SECOND )
485+ frames_total = int (runtime_duration_s * float (frame_rate ))
486+ frame_period_ns = int ((1.0 / float (frame_rate )) * 1e9 ) if frame_rate else 0
411487
412- analog_update_period_ns = int ((1 / analog_update_rate ) * 1e9 )
413- analog_start_time_ns = time .perf_counter_ns ()
488+ analog_update_period_ns = int ((1.0 / float (analog_update_rate )) * 1e9 ) if analog_update_rate else 0
414489
415- analog_amplitude = (2 ** 16 ) / 2 - 1
416- analog_offset = (2 ** 16 ) / 2
490+ # Map waveform output [-1..1] into a conservative 12-bit-ish range.
491+ analog_amplitude = (ANALOG_OUTPUT_VALUE_MAX - ANALOG_OUTPUT_VALUE_MIN ) / 2.0
492+ analog_offset = (ANALOG_OUTPUT_VALUE_MAX + ANALOG_OUTPUT_VALUE_MIN ) / 2.0
417493
418- def analog_waveform_for (name ):
494+ def analog_waveform_for (name : str ):
419495 if name == 'sin' :
420- return np .sin
421- elif name == 'square' :
422- return lambda x : np .sign (np .sin (x ))
423- elif name == 'sawtooth' :
424- return lambda x : 2 * (x / (2 * np .pi ) - np .floor (1 / 2 + x / (2 * np .pi )))
425- elif name == 'triangle' :
426- return lambda x : 2 * np .abs (2 * (x / (2 * np .pi ) - np .floor (1 / 2 + x / (2 * np .pi )))) - 1
427- elif name == 'constant' :
428- return lambda x : 0
429- else :
430- raise ValueError (f'Invalid analog output waveform: { name } ' )
496+ return math .sin
497+ if name == 'square' :
498+ return lambda x : 1.0 if math .sin (x ) >= 0 else - 1.0
499+ if name == 'sawtooth' :
500+ return lambda x : 2.0 * (x / (2.0 * math .pi ) - math .floor (0.5 + x / (2.0 * math .pi )))
501+ if name == 'triangle' :
502+ return lambda x : 2.0 * abs (2.0 * (x / (2.0 * math .pi ) - math .floor (0.5 + x / (2.0 * math .pi )))) - 1.0
503+ if name == 'constant' :
504+ return lambda x : 0.0
505+ raise ValueError (f'Invalid analog output waveform: { name } ' )
431506
432507 ethernet_socket = self ._connect_ethernet_socket (reuse = True )
433508
434509 bytes_sent = 0
435510 frames_streamed = 0
436511
512+ wf = analog_waveform_for (str (analog_out_waveform ))
513+ last_analog_update_ns = 0
514+ analog_value_cached = int (round (analog_offset ))
515+
437516 start_time_ns = time .perf_counter_ns ()
438517 next_progress_ns = None
439518 if progress_interval_s and (progress_interval_s > 0 ):
@@ -444,13 +523,20 @@ def analog_waveform_for(name):
444523 frame = frames [frame_index ]
445524
446525 # Analog output update (optional)
447- if (time .perf_counter_ns () - analog_start_time_ns ) > (i * analog_update_period_ns ):
448- analog_phase = (i / analog_update_rate ) * analog_frequency * 2 * np .pi
449- analog_output_value = analog_amplitude * analog_waveform_for (analog_out_waveform )(analog_phase ) + analog_offset
450- # Ensure this is an int in-range for uint16.
451- analog_output_value = int (max (0 , min (65535 , round (float (analog_output_value )))))
452- else :
453- analog_output_value = 0
526+ now_ns = time .perf_counter_ns ()
527+ if analog_update_period_ns and (now_ns - last_analog_update_ns ) >= analog_update_period_ns :
528+ t_s = (now_ns - start_time_ns ) / 1e9
529+ analog_phase = (t_s * float (analog_frequency )) * (2.0 * math .pi )
530+ analog_output_value_f = analog_amplitude * float (wf (analog_phase )) + analog_offset
531+ analog_value_cached = int (
532+ max (
533+ ANALOG_OUTPUT_VALUE_MIN ,
534+ min (ANALOG_OUTPUT_VALUE_MAX , round (analog_output_value_f )),
535+ )
536+ )
537+ last_analog_update_ns = now_ns
538+
539+ analog_output_value = int (analog_value_cached )
454540
455541 # Stream frame header: cmd(0x32), data_len(uint16), analog(uint16), reserved(uint16)
456542 data_len = len (frame )
@@ -476,8 +562,9 @@ def analog_waveform_for(name):
476562 next_progress_ns += int (progress_interval_s * 1e9 )
477563
478564 # Rate limiting (busy-wait)
479- while (time .perf_counter_ns () - start_time_ns ) < ((i + 1 ) * frame_period_ns ):
480- pass
565+ if frame_period_ns :
566+ while (time .perf_counter_ns () - start_time_ns ) < ((i + 1 ) * frame_period_ns ):
567+ pass
481568
482569 # End the mode
483570 self ._send_and_receive (bytes ([1 , 0 ]), ethernet_socket )
0 commit comments