Skip to content

Commit 2e2dabf

Browse files
committed
Code cleanup and improvement
1 parent d880108 commit 2e2dabf

1 file changed

Lines changed: 100 additions & 41 deletions

File tree

can_waveshare/bus.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import errno
4+
import logging
45
import select
56
import socket
67
import time
@@ -88,10 +89,11 @@ class WaveShareBus(can.BusABC):
8889
8990
Optional kwargs:
9091
- receive_own_messages: bool (default False) – suppress echoed frames (best effort)
91-
- tcp_nodelay: bool (default True)
92+
- tcp_nodelay: bool (default True) | alias: tcp_tune=True
9293
- keepalive: bool (default True)
9394
- timeout: float | None – default timeout used as base for send/recv select
9495
- can_filters: list[dict] – software filters applied in this backend
96+
- echo_window: float (default 0.1s) – suppression window for own echoes
9597
9698
Limitations:
9799
- CAN-FD is NOT supported (Waveshare wire format is limited to 8 data bytes).
@@ -109,8 +111,13 @@ def __init__(
109111
keepalive: bool = True,
110112
timeout: Optional[float] = None,
111113
can_filters: Optional[List[dict[str, Any]]] = None,
114+
echo_window: float = 0.1,
112115
**kwargs: Any,
113116
) -> None:
117+
# normalize legacy kw: bridge may pass tcp_tune=True
118+
if "tcp_tune" in kwargs:
119+
tcp_nodelay = bool(kwargs.pop("tcp_tune"))
120+
114121
# Allow passing host/port via channel
115122
ch_host, ch_port = _parse_channel(channel) if channel else (None, None)
116123
host = host or ch_host
@@ -122,40 +129,55 @@ def __init__(
122129
)
123130

124131
self.channel_info = f"Waveshare TCP {host}:{port}"
125-
# Initialize BusABC (sets up periodic send plumbing, etc.)
126-
super().__init__(channel=channel or f"{host}:{port}", **kwargs)
132+
# Initialize BusABC (don’t forward unknown kwargs)
133+
super().__init__(channel=channel or f"{host}:{port}")
127134

128135
self._host: str = host
129136
self._port: int = int(port)
130137
self._timeout_default: Optional[float] = timeout
131138
self._filters: List[dict[str, Any]] = list(can_filters or [])
139+
self._echo_window: float = float(echo_window)
132140

133-
# Socket setup
134-
self._sock = socket.socket(
135-
socket.AF_INET6 if _is_ipv6_literal(host) else socket.AF_INET,
136-
socket.SOCK_STREAM,
137-
)
138-
try:
139-
if tcp_nodelay:
140-
self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
141-
except OSError:
142-
pass
143-
try:
144-
if keepalive:
145-
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
146-
except OSError:
147-
pass
141+
self._started: bool = False
142+
self._closed: bool = False
143+
self._sock: Optional[socket.socket] = None
148144

149-
# Connect
145+
# Resolve host (IPv4/IPv6/DNS) and try candidates in order
150146
try:
151-
self._sock.connect((self._host, self._port))
147+
infos = socket.getaddrinfo(self._host, self._port, 0, socket.SOCK_STREAM)
152148
except OSError as e:
153-
self._sock.close()
154149
raise can.CanError(
155-
f"WaveShareBus: connect to {self._host}:{self._port} failed: {e}"
150+
f"WaveShareBus: resolve failed for {self._host}:{self._port}: {e}"
156151
) from e
157152

158-
self._closed: bool = False
153+
last_err: Optional[OSError] = None
154+
for family, socktype, proto, _, sockaddr in infos:
155+
s = socket.socket(family, socktype, proto)
156+
try:
157+
if tcp_nodelay:
158+
try:
159+
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
160+
except OSError:
161+
pass
162+
if keepalive:
163+
_apply_keepalive(s)
164+
165+
s.connect(sockaddr)
166+
# success
167+
self._sock = s
168+
self._started = True
169+
break
170+
except OSError as e:
171+
last_err = e
172+
try:
173+
s.close()
174+
except OSError:
175+
pass
176+
continue
177+
else:
178+
raise can.CanError(
179+
f"WaveShareBus: connect to {self._host}:{self._port} failed: {last_err}"
180+
) from last_err
159181

160182
# Best-effort own-message suppression (for echoing bridges)
161183
self._suppress_own: bool = not bool(receive_own_messages)
@@ -166,22 +188,28 @@ def __init__(
166188
def shutdown(self) -> None:
167189
self._closed = True
168190
try:
169-
self._sock.shutdown(socket.SHUT_RDWR)
170-
except OSError:
171-
pass
172-
try:
173-
self._sock.close()
191+
if self._sock is not None:
192+
try:
193+
self._sock.shutdown(socket.SHUT_RDWR)
194+
except OSError:
195+
pass
196+
try:
197+
self._sock.close()
198+
except OSError:
199+
pass
174200
finally:
201+
self._sock = None
202+
self._started = False
175203
super().shutdown()
176204

177205
def fileno(self) -> Optional[int]:
178206
try:
179-
return self._sock.fileno()
207+
return self._sock.fileno() if self._sock is not None else None
180208
except OSError:
181209
return None
182210

183211
def send(self, msg: can.Message, timeout: Optional[float] = None) -> None:
184-
if self._closed:
212+
if self._closed or self._sock is None:
185213
raise can.CanError("WaveShareBus is closed")
186214

187215
if getattr(msg, "is_fd", False):
@@ -195,14 +223,16 @@ def send(self, msg: can.Message, timeout: Optional[float] = None) -> None:
195223

196224
if extended:
197225
if not (0 <= arb_id <= 0x1FFFFFFF):
198-
raise can.CanError(f"Invalid 29-bit CAN ID: {arb_id:#x}")
226+
raise can.CanError(f"Invalid 29-bit CAN ID: 0x{arb_id:X}")
199227
else:
200228
if not (0 <= arb_id <= 0x7FF):
201-
raise can.CanError(f"Invalid 11-bit CAN ID: {arb_id:#x}")
229+
raise can.CanError(f"Invalid 11-bit CAN ID: 0x{arb_id:X}")
202230

203231
data = bytes(msg.data or b"")
204232
if len(data) > 8:
205-
raise can.CanError("Data length > 8 not supported by Waveshare format")
233+
raise can.CanError(
234+
f"Data length {len(data)} > 8 not supported by Waveshare format"
235+
)
206236

207237
# For RTR, dlc declares requested length. Prefer explicit msg.dlc if present.
208238
try:
@@ -227,7 +257,7 @@ def send(self, msg: can.Message, timeout: Optional[float] = None) -> None:
227257
try:
228258
_send_all(self._sock, payload)
229259
except OSError as e:
230-
raise can.CanError(f"WaveShareBus.send failed: {e}") from e
260+
raise can.CanError(f"WaveShareBus.send failed (id=0x{arb_id:X}): {e}") from e
231261

232262
# record for own-echo suppression (if active)
233263
if self._suppress_own:
@@ -236,7 +266,7 @@ def send(self, msg: can.Message, timeout: Optional[float] = None) -> None:
236266
)
237267

238268
def recv(self, timeout: Optional[float] = None) -> Optional[can.Message]:
239-
if self._closed:
269+
if self._closed or self._sock is None:
240270
return None
241271

242272
deadline: Optional[float] = None
@@ -284,7 +314,9 @@ def recv(self, timeout: Optional[float] = None) -> Optional[can.Message]:
284314
):
285315
continue
286316

287-
if self._suppress_own and _matches_recent(self._recent_tx, msg, window=0.1):
317+
if self._suppress_own and _matches_recent(
318+
self._recent_tx, msg, window=self._echo_window
319+
):
288320
continue
289321

290322
return msg
@@ -294,6 +326,20 @@ def recv(self, timeout: Optional[float] = None) -> Optional[can.Message]:
294326
def set_filters(self, filters: Optional[Iterable[dict[str, Any]]]) -> None:
295327
self._filters = list(filters or [])
296328

329+
@property
330+
def is_open(self) -> bool:
331+
return not self._closed and self._started and self._sock is not None
332+
333+
def __del__(self):
334+
# Warn only if object started and not shutdown; avoid noise on failed __init__
335+
try:
336+
if getattr(self, "_started", False) and not getattr(self, "_closed", True):
337+
logging.getLogger(__name__).warning(
338+
"WaveShareBus was not properly shut down"
339+
)
340+
except Exception:
341+
pass
342+
297343

298344
# ---------- helpers ----------
299345

@@ -302,6 +348,24 @@ class _SocketClosed(Exception):
302348
pass
303349

304350

351+
def _apply_keepalive(sock: socket.socket) -> None:
352+
"""Best-effort TCP keepalive tuning (portable; guarded)."""
353+
try:
354+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
355+
# Linux-specific options guarded by hasattr:
356+
if hasattr(socket, "TCP_KEEPIDLE"):
357+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
358+
if hasattr(socket, "TCP_KEEPINTVL"):
359+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
360+
if hasattr(socket, "TCP_KEEPCNT"):
361+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
362+
if hasattr(socket, "TCP_USER_TIMEOUT"):
363+
# 120s TX user-timeout; quicker broken-link detection on some kernels
364+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, 120_000)
365+
except OSError:
366+
pass
367+
368+
305369
def _choose_timeout(
306370
specific: Optional[float], default: Optional[float]
307371
) -> Optional[float]:
@@ -408,8 +472,3 @@ def _parse_channel(channel: Optional[str]) -> Tuple[Optional[str], Optional[int]
408472
except ValueError:
409473
return (None, None)
410474
return (None, None)
411-
412-
413-
def _is_ipv6_literal(host: str) -> bool:
414-
"""Heuristic: treat as IPv6 if it contains ':' and not a bracketed literal removed by _parse_channel."""
415-
return ":" in host

0 commit comments

Comments
 (0)