diff --git a/src/viam/rpc/dial.py b/src/viam/rpc/dial.py index 80016f617..7b6361ce2 100644 --- a/src/viam/rpc/dial.py +++ b/src/viam/rpc/dial.py @@ -82,6 +82,24 @@ class DialOptions: """Number of seconds before the dial connection times out Set to 20sec to match _defaultOfferDeadline in goutils/rpc/wrtc_call_queue.go""" + force_relay: bool = False + """Force ICE transport policy to relay-only so only TURN candidates are used. + Useful for testing relay connectivity through a TURN server. Mutually exclusive + with ``force_p2p``: if both are set the connection will fail because ``force_p2p`` + strips the very TURN servers ``force_relay`` requires.""" + + force_p2p: bool = False + """Strip TURN servers from the ICE config so only host and server-reflexive + candidates are used. Useful for testing direct connectivity without relay + fallback. Setting this alongside ``turn_uri`` is a no-op for the filter, since + TURN servers are removed before any filtering happens.""" + + turn_uri: Optional[str] = None + """Filter the signaling server's TURN list to only the server whose parsed URI + matches (compared by scheme, host, port, and transport — transport defaults to + UDP if unspecified). Set to ``None`` to use all TURN servers. + Example: ``"turn:turn.viam.com:443"``.""" + def __init__( self, *, @@ -95,6 +113,9 @@ def __init__( timeout: float = 20, initial_connection_attempts: int = 3, initial_connection_attempt_timeout: Optional[float] = None, + force_relay: bool = False, + force_p2p: bool = False, + turn_uri: Optional[str] = None, ) -> None: self.disable_webrtc = disable_webrtc self.auth_entity = auth_entity @@ -106,6 +127,9 @@ def __init__( self.timeout = timeout self.initial_connection_attempts = initial_connection_attempts self.initial_connection_attempt_timeout = initial_connection_attempt_timeout if initial_connection_attempt_timeout else timeout + self.force_relay = force_relay + self.force_p2p = force_p2p + self.turn_uri = turn_uri @classmethod def with_api_key(cls, api_key: str, api_key_id: str) -> Self: @@ -236,10 +260,10 @@ def __init__(self) -> None: LOGGER.debug("Creating new viam-rust-utils runtime") libname = pathlib.Path(__file__).parent.absolute() / f"libviam_rust_utils.{suffix}" self._lib = ctypes.CDLL(libname.__str__()) - self._lib.init_rust_runtime.argtypes = () - self._lib.init_rust_runtime.restype = ctypes.c_void_p + self._lib.viam_init_rust_runtime.argtypes = () + self._lib.viam_init_rust_runtime.restype = ctypes.c_void_p - self._lib.dial.argtypes = ( + self._lib.viam_dial_with_opts.argtypes = ( ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p, @@ -247,16 +271,32 @@ def __init__(self) -> None: ctypes.c_bool, ctypes.c_float, ctypes.c_void_p, + ctypes.c_void_p, ) - self._lib.dial.restype = ctypes.c_void_p + self._lib.viam_dial_with_opts.restype = ctypes.c_void_p + + self._lib.viam_dial_opts_new.argtypes = () + self._lib.viam_dial_opts_new.restype = ctypes.c_void_p + + self._lib.viam_dial_opts_free.argtypes = (ctypes.c_void_p,) + self._lib.viam_dial_opts_free.restype = None + + self._lib.viam_dial_opts_set_force_relay.argtypes = (ctypes.c_void_p, ctypes.c_bool) + self._lib.viam_dial_opts_set_force_relay.restype = None - self._lib.free_rust_runtime.argtypes = (ctypes.c_void_p,) - self._lib.free_rust_runtime.restype = None + self._lib.viam_dial_opts_set_force_p2p.argtypes = (ctypes.c_void_p, ctypes.c_bool) + self._lib.viam_dial_opts_set_force_p2p.restype = None - self._lib.free_string.argtypes = (ctypes.c_void_p,) - self._lib.free_string.restype = None + self._lib.viam_dial_opts_set_turn_uri.argtypes = (ctypes.c_void_p, ctypes.c_char_p) + self._lib.viam_dial_opts_set_turn_uri.restype = None - self._ptr = self._lib.init_rust_runtime() + self._lib.viam_free_rust_runtime.argtypes = (ctypes.c_void_p,) + self._lib.viam_free_rust_runtime.restype = None + + self._lib.viam_free_string.argtypes = (ctypes.c_void_p,) + self._lib.viam_free_string.restype = None + + self._ptr = self._lib.viam_init_rust_runtime() async def dial(self, address: str, options: DialOptions) -> Tuple[Optional[str], ctypes.c_void_p]: type = options.credentials.type if options.credentials else "" @@ -268,27 +308,36 @@ async def dial(self, address: str, options: DialOptions) -> Tuple[Optional[str], ) LOGGER.debug(f"Dialing {address} using viam-rust-utils library") - path_ptr = await to_thread( - self._lib.dial, - address.encode("utf-8"), - options.auth_entity.encode("utf-8") if options.auth_entity else None, - type.encode("utf-8") if type else None, - payload.encode("utf-8") if payload else None, - insecure, - ctypes.c_float(options.timeout), - self._ptr, - ) + opts_handle = self._lib.viam_dial_opts_new() + try: + self._lib.viam_dial_opts_set_force_relay(opts_handle, options.force_relay) + self._lib.viam_dial_opts_set_force_p2p(opts_handle, options.force_p2p) + if options.turn_uri: + self._lib.viam_dial_opts_set_turn_uri(opts_handle, options.turn_uri.encode("utf-8")) + path_ptr = await to_thread( + self._lib.viam_dial_with_opts, + address.encode("utf-8"), + options.auth_entity.encode("utf-8") if options.auth_entity else None, + type.encode("utf-8") if type else None, + payload.encode("utf-8") if payload else None, + insecure, + ctypes.c_float(options.timeout), + self._ptr, + opts_handle, + ) + finally: + self._lib.viam_dial_opts_free(opts_handle) path = ctypes.cast(path_ptr, ctypes.c_char_p).value path = path.decode("utf-8") if path else "" return (path, path_ptr) def release(self): LOGGER.debug("Freeing viam-rust-utils runtime") - self._lib.free_rust_runtime(self._ptr) + self._lib.viam_free_rust_runtime(self._ptr) def free_str(self, ptr: ctypes.c_void_p): LOGGER.debug("Freeing socket string") - self._lib.free_string(ptr) + self._lib.viam_free_string(ptr) async def dial(address: str, options: Optional[DialOptions] = None) -> ViamChannel: