11import asyncio
2+ import io
3+ import logging
24from abc import ABC , abstractmethod
35from typing import *
46
5- from pubsub import pub
7+ import serial
8+ import serial_asyncio
69
710from meshtastic .protobuf .mesh_pb2 import FromRadio , ToRadio
811
912
13+ # magic number used in streaming client headers
14+ HEADER_MAGIC : bytes = b"\x94 \xc3 "
15+
16+
17+ class ConnectionError (Exception ):
18+ """Base class for MeshConnection-related errors."""
19+
20+
21+ class BadPayloadError (ConnectionError ):
22+ def __init__ (self , payload , reason : str ):
23+ self .payload = payload
24+ super ().__init__ (reason )
25+
26+
1027class MeshConnection (ABC ):
1128 """A client API connection to a meshtastic radio."""
1229
13- def __init__ (self ):
14- self ._on_disconnect : asyncio .Event = asyncio .Event ()
30+ def __init__ (self , name : str ):
31+ self .name : str = name
32+ self .on_disconnect : asyncio .Event = asyncio .Event ()
33+ self ._is_ready : bool = False
34+ self ._send_lock : asyncio .Lock = asyncio .Lock ()
35+ self ._recv_lock : asyncio .Lock = asyncio .Lock ()
36+ self ._init_task : asyncio .Task = asyncio .create_task (self ._initialize ())
37+ self ._init_task .add_done_callback (self ._after_initialize )
38+
39+ @abstractmethod
40+ async def _initialize (self ):
41+ """Perform any connection initialization that must be performed async
42+ (and therefore not from the constructor)."""
1543
1644 @abstractmethod
1745 async def _send_bytes (self , msg : buffer ):
@@ -23,11 +51,6 @@ async def _recv_bytes(self) -> buffer:
2351 """Recieve bytes from the mesh device."""
2452 pass
2553
26- @abstractmethod
27- def close (self ):
28- """Close the connection"""
29- pass
30-
3154 @staticmethod
3255 @abstractmethod
3356 async def get_available () -> AsyncGenerator [Any ]:
@@ -37,22 +60,121 @@ async def get_available() -> AsyncGenerator[Any]:
3760 constructor."""
3861 pass
3962
40- def __enter__ (self ):
41- return self
63+ def ready (self ):
64+ return self . _is_ready
4265
43- def __exit__ (self , exc_type , exc_value , trace ):
44- self .close ()
66+ def _after_initialize (self ):
67+ self ._is_ready = True
68+ del self ._init_task
4569
4670 async def send (self , message : ToRadio ):
4771 """Send something to the connected device."""
48- msg_str : str = message .SerializeToString ()
49- await self ._send_bytes (bytes (msg_str ))
72+ async with self ._send_lock :
73+ msg_str : str = message .SerializeToString ()
74+ await self ._send_bytes (bytes (msg_str ))
5075
5176 async def recv (self ) -> FromRadio :
5277 """Recieve something from the connected device."""
53- msg_bytes : buffer = await self ._recv_bytes ()
54- return FromRadio .FromString (str (msg_bytes ))
78+ async with self ._recv_lock :
79+ msg_bytes : buffer = await self ._recv_bytes ()
80+ return FromRadio .FromString (str (msg_bytes , errors = "ignore" ))
5581
5682 async def listen (self ) -> AsyncGenerator [FromRadio ]:
57- while True :
83+ while not self . on_disconnect . is_set () :
5884 yield await self .recv ()
85+
86+ def close (self ):
87+ """Close the connection.
88+ Overloaders should remember to call supermethod"""
89+ if not self .ready ():
90+ self ._init_task .cancel ()
91+
92+ self .on_disconnect .set ()
93+
94+ def __enter__ (self ):
95+ return self
96+
97+ def __exit__ (self , exc_type , exc_value , trace ):
98+ self .close ()
99+
100+
101+ class StreamConnection (MeshConnection ):
102+ """Base class for connections using the aio stream API"""
103+ def __init__ (self , name : str ):
104+ self ._reader : Optional [asyncio .StreamReader ] = None
105+ self ._writer : Optional [asyncio .StreamWriter ] = None
106+ self .stream_debug_out : io .StringIO = io .StringIO ()
107+ super ().__init__ (name )
108+
109+ def _handle_debug (self , debug_out : bytes ):
110+ self .stream_debug_out .write (str (debug_out ))
111+ self .stream_debug_out .flush ()
112+
113+ async def _send_bytes (self , msg : buffer ):
114+ length : int = len (msg )
115+ if length > 512 :
116+ raise BadPayloadError (msg , "Cannot send client API messages over 512 bytes" )
117+
118+ self ._writer .write (HEADER_MAGIC )
119+ self ._writer .write (length .to_bytes (2 , "big" ))
120+ self ._writer .write (msg )
121+ await self ._writer .drain ()
122+
123+ async def _find_stream_header (self ):
124+ """Consumes and logs debug out bytes until a valid header is detected"""
125+ try :
126+ while True :
127+ from_stream : bytes = await self ._reader .readuntil ((b'\n ' , HEADER_MAGIC ))
128+ if from_stream .endswith (HEADER_MAGIC ):
129+ self ._handle_debug (from_stream [:- 2 ])
130+ return
131+ else :
132+ self ._handle_debug (from_stream )
133+
134+ except asyncio .IncompleteReadError as err :
135+ if len (err .partial ) > 0 :
136+ self ._handle_debug (err .partial )
137+ raise
138+
139+ async def _recv_bytes (self ) -> buffer :
140+ try :
141+ while True :
142+ await self ._find_stream_header ()
143+ size_bytes : bytes = await self ._reader .readexactly (2 )
144+ size : int = int .from_bytes (size_bytes , "big" )
145+ if 0 < size <= 512 :
146+ return await self ._reader .readexactly (size )
147+
148+ self ._handle_debug (size_bytes )
149+
150+ except asyncio .LimitOverrunError as err :
151+ raise ConnectionError (
152+ "Read buffer overrun while reading stream" ) from err
153+
154+ except asyncio .IncompleteReadError :
155+ logging .error (f"Connection to { self .name } terminated: stream EOF reached" )
156+ self .close ()
157+
158+ def close (self ):
159+ super ().close ()
160+ self ._writer .close ()
161+ self .stream_debug_out .close ()
162+ asyncio .as_completed ((self ._writer .wait_closed (),))
163+
164+
165+ class SerialConnection (StreamConnection ):
166+ def __init__ (self , portaddr : str , baudrate : int = 115200 ):
167+ self .port : str = portaddr
168+ self .baudrate : int = baudrate
169+ super ().__init__ (portaddr )
170+
171+ async def _initialize (self ):
172+ self ._reader , self ._writer = await serial_asyncio .open_serial_connectio (
173+ port = self ._port , baudrate = self ._baudrate ,
174+ )
175+
176+ @staticmethod
177+ async def get_available () -> AsyncGenerator [str ]:
178+ for port in serial .tools .list_ports .comports ():
179+ if port .hwid != "n/a" :
180+ yield port .device
0 commit comments