Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,33 @@ def stream(self, method="connect"):
async def stream_async(self, method):
return await self.tcp.stream_async(method)

@property
def command(self) -> str:
"""Get the base SSH command"""
return self.call("get_ssh_command")

@property
def identity(self) -> str | None:
"""
Get the SSH identity (private key) as a string.

Returns:
The SSH identity key content, or None if not configured.

Raises:
ConfigurationError: If `ssh_identity_file` is configured on the
driver but cannot be read.
"""
return self.call("get_ssh_identity")

@property
def username(self) -> str:
"""Get the default SSH username"""
return self.call("get_default_username")

def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
"""Run SSH command with the given parameters and arguments"""
# Get SSH command and default username from driver
ssh_command = self.call("get_ssh_command")
default_username = self.call("get_default_username")
ssh_identity = self.call("get_ssh_identity")

if options.direct:
# Use direct TCP address
try:
Expand All @@ -111,7 +131,7 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
if not host or not port:
raise ValueError(f"Invalid address format: {address}")
self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)
return self._run_ssh_local(host, port, options, args)
except (DriverMethodNotImplemented, ValueError) as e:
self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e)
return self.run(SSHCommandRunOptions(
Expand All @@ -127,11 +147,12 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
) as addr:
host, port = addr
self.logger.debug("SSH port forward established - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)
return self._run_ssh_local(host, port, options, args)

def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh_identity, args):
def _run_ssh_local(self, host, port, options, args):
"""Run SSH command with the given host, port, and arguments"""
# Create temporary identity file if needed
ssh_identity = self.identity
identity_file = None
temp_file = None
if ssh_identity:
Expand All @@ -154,7 +175,7 @@ def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh

try:
# Build SSH command arguments
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args)
ssh_args = self._build_ssh_command_args(port, identity_file, args)

# Separate SSH options from command arguments
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
Expand All @@ -173,10 +194,11 @@ def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh
except Exception as e:
self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e))

def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args):
def _build_ssh_command_args(self, port, identity_file, args):
"""Build initial SSH command arguments"""
# Split the SSH command into individual arguments
ssh_args = shlex.split(ssh_command)
ssh_args = shlex.split(self.command)
default_username = self.username

# Add identity file if provided
if identity_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,18 @@ def test_ssh_identity_temp_file_cleanup_error():

assert result.return_code == 0
assert result.stdout == "some stdout"


def test_ssh_client_properties():
"""Test that the client properties correctly reflect the driver configuration"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY,
ssh_command="my-ssh-command",
)

with serve(instance) as client:
assert client.username == "testuser"
assert client.identity == TEST_SSH_KEY
assert client.command == "my-ssh-command"
Loading