diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index 802bfb36b..2a0bbd63c 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -94,13 +94,33 @@ def stream(self, method="connect"): async def stream_async(self, method): return await self.tcp.stream_async(method) + @property + def command(self) -> str: + """Get the base SSH command""" + return self.call("get_ssh_command") + + @property + def identity(self) -> str | None: + """ + Get the SSH identity (private key) as a string. + + Returns: + The SSH identity key content, or None if not configured. + + Raises: + ConfigurationError: If `ssh_identity_file` is configured on the + driver but cannot be read. + """ + return self.call("get_ssh_identity") + + @property + def username(self) -> str: + """Get the default SSH username""" + return self.call("get_default_username") + def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: """Run SSH command with the given parameters and arguments""" # Get SSH command and default username from driver - ssh_command = self.call("get_ssh_command") - default_username = self.call("get_default_username") - ssh_identity = self.call("get_ssh_identity") - if options.direct: # Use direct TCP address try: @@ -111,7 +131,7 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: if not host or not port: raise ValueError(f"Invalid address format: {address}") self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port) - return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args) + return self._run_ssh_local(host, port, options, args) except (DriverMethodNotImplemented, ValueError) as e: self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e) return self.run(SSHCommandRunOptions( @@ -127,11 +147,12 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: ) as addr: host, port = addr self.logger.debug("SSH port forward established - host: %s, port: %s", host, port) - return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args) + return self._run_ssh_local(host, port, options, args) - def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh_identity, args): + def _run_ssh_local(self, host, port, options, args): """Run SSH command with the given host, port, and arguments""" # Create temporary identity file if needed + ssh_identity = self.identity identity_file = None temp_file = None if ssh_identity: @@ -154,7 +175,7 @@ def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh try: # Build SSH command arguments - ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args) + ssh_args = self._build_ssh_command_args(port, identity_file, args) # Separate SSH options from command arguments ssh_options, command_args = self._separate_ssh_options_and_command_args(args) @@ -173,10 +194,11 @@ def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh except Exception as e: self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e)) - def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args): + def _build_ssh_command_args(self, port, identity_file, args): """Build initial SSH command arguments""" # Split the SSH command into individual arguments - ssh_args = shlex.split(ssh_command) + ssh_args = shlex.split(self.command) + default_username = self.username # Add identity file if provided if identity_file: diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 368f31bec..92a540406 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -692,3 +692,18 @@ def test_ssh_identity_temp_file_cleanup_error(): assert result.return_code == 0 assert result.stdout == "some stdout" + + +def test_ssh_client_properties(): + """Test that the client properties correctly reflect the driver configuration""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY, + ssh_command="my-ssh-command", + ) + + with serve(instance) as client: + assert client.username == "testuser" + assert client.identity == TEST_SSH_KEY + assert client.command == "my-ssh-command"