diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index 1bf239c85..802bfb36b 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -13,6 +13,40 @@ from jumpstarter.client.decorators import driver_click_command +@dataclass +class SSHCommandRunResult: + """Result of executing an SSH command""" + return_code: int + stdout: str | bytes + stderr: str | bytes + + @staticmethod + def from_completed_process(result: subprocess.CompletedProcess) -> "SSHCommandRunResult": + return SSHCommandRunResult( + return_code=result.returncode, + stdout=result.stdout or "", + stderr=result.stderr or "", + ) + + +@dataclass +class SSHCommandRunOptions: + """ + Options for running an SSH command + + Attributes: + direct: If True, connect directly to the host's TCP address. + If False, use SSH port forwarding. + capture_output: If True, capture stdout and stderr. + If False, they are inherited from the parent process. + capture_as_text: If True and output is captured, decode stdout and + stderr as text. Otherwise, they are captured as bytes. + """ + direct: bool = False + capture_output: bool = True + capture_as_text: bool = True + + @dataclass(kw_only=True) class SSHWrapperClient(CompositeClient): """ @@ -30,11 +64,25 @@ def cli(self): @click.option("--direct", is_flag=True, help="Use direct TCP address") @click.argument("args", nargs=-1) def ssh(direct, args): - result = self.run(direct, args) - self.logger.debug(f"SSH result: {result}") - if result != 0: - click.get_current_context().exit(result) - return result + options = SSHCommandRunOptions( + direct=direct, + # For the CLI, we never capture output so that interactive shells + # and long-running commands stream their output directly. + capture_output=False, + ) + + result = self.run(options, args) + self.logger.debug("SSH exit code: %s", result.return_code) + + if result.stdout: + click.echo(result.stdout, nl=False) + if result.stderr: + click.echo(result.stderr, nl=False, err=True) + + if result.return_code != 0: + click.get_current_context().exit(result.return_code) + + return result.return_code return ssh @@ -46,14 +94,14 @@ def stream(self, method="connect"): async def stream_async(self, method): return await self.tcp.stream_async(method) - def run(self, direct, args): + def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: """Run SSH command with the given parameters and arguments""" # Get SSH command and default username from driver ssh_command = self.call("get_ssh_command") default_username = self.call("get_default_username") ssh_identity = self.call("get_ssh_identity") - if direct: + if options.direct: # Use direct TCP address try: address = self.tcp.address() # (format: "tcp://host:port") @@ -62,23 +110,26 @@ def run(self, direct, args): port = parsed.port if not host or not port: raise ValueError(f"Invalid address format: {address}") - self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}") - return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args) + self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port) + return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args) except (DriverMethodNotImplemented, ValueError) as e: - self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding") - return self.run(False, args) + self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e) + return self.run(SSHCommandRunOptions( + direct=False, + capture_output=options.capture_output, + capture_as_text=options.capture_as_text, + ), args) else: # Use SSH port forwarding (default behavior) self.logger.debug("Using SSH port forwarding for SSH connection") with TcpPortforwardAdapter( client=self.tcp, ) as addr: - host = addr[0] - port = addr[1] - self.logger.debug(f"SSH port forward established - host: {host}, port: {port}") - return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args) + host, port = addr + self.logger.debug("SSH port forward established - host: %s, port: %s", host, port) + return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args) - def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args): + def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh_identity, args): """Run SSH command with the given host, port, and arguments""" # Create temporary identity file if needed identity_file = None @@ -91,9 +142,9 @@ def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity # Set proper permissions (600) for SSH key os.chmod(temp_file.name, 0o600) identity_file = temp_file.name - self.logger.debug(f"Created temporary identity file: {identity_file}") + self.logger.debug("Created temporary identity file: %s", identity_file) except Exception as e: - self.logger.error(f"Failed to create temporary identity file: {e}") + self.logger.error("Failed to create temporary identity file: %s", e) if temp_file: try: os.unlink(temp_file.name) @@ -112,15 +163,15 @@ def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args) # Execute the command - return self._execute_ssh_command(ssh_args) + return self._execute_ssh_command(ssh_args, options) finally: # Clean up temporary identity file if identity_file: try: os.unlink(identity_file) - self.logger.debug(f"Cleaned up temporary identity file: {identity_file}") + self.logger.debug("Cleaned up temporary identity file: %s", identity_file) except Exception as e: - self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}") + self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e)) def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args): """Build initial SSH command arguments""" @@ -192,8 +243,8 @@ def _separate_ssh_options_and_command_args(self, args): i += 1 # Debug output - self.logger.debug(f"SSH options: {ssh_options}") - self.logger.debug(f"Command args: {command_args}") + self.logger.debug("SSH options: %s", ssh_options) + self.logger.debug("Command args: %s", command_args) return ssh_options, command_args @@ -209,16 +260,21 @@ def _build_final_ssh_command(self, ssh_args, ssh_options, host, command_args): # Add command arguments ssh_args.extend(command_args) - self.logger.debug(f"Running SSH command: {ssh_args}") + self.logger.debug("Running SSH command: %s", ssh_args) return ssh_args - def _execute_ssh_command(self, ssh_args): + def _execute_ssh_command(self, ssh_args, options: SSHCommandRunOptions) -> SSHCommandRunResult: """Execute the SSH command and return the result""" try: - result = subprocess.run(ssh_args) - return result.returncode + result = subprocess.run(ssh_args, capture_output=options.capture_output, text=options.capture_as_text) + return SSHCommandRunResult.from_completed_process(result) except FileNotFoundError: self.logger.error( - f"SSH command '{ssh_args[0]}' not found. Please ensure SSH is installed and available in PATH." + "SSH command '%s' not found. Please ensure SSH is installed and available in PATH.", + ssh_args[0], + ) + return SSHCommandRunResult( + return_code=127, # Standard exit code for "command not found" + stdout="", + stderr=f"SSH command '{ssh_args[0]}' not found", ) - return 127 # Standard exit code for "command not found" diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 02ea1ba1a..368f31bec 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -5,6 +5,7 @@ import pytest from jumpstarter_driver_network.driver import TcpNetwork +from jumpstarter_driver_ssh.client import SSHCommandRunOptions, SSHCommandRunResult from jumpstarter_driver_ssh.driver import SSHWrapper from jumpstarter.common.exceptions import ConfigurationError @@ -51,10 +52,11 @@ def test_ssh_command_with_default_username(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with default username - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -69,7 +71,8 @@ def test_ssh_command_with_default_username(): assert "127.0.0.1" in call_args assert "hostname" in call_args # Should be preserved as command argument - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_without_default_username(): @@ -81,10 +84,11 @@ def test_ssh_command_without_default_username(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command without default username - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -97,7 +101,8 @@ def test_ssh_command_without_default_username(): assert "127.0.0.1" in call_args assert "hostname" in call_args # Should be preserved as command argument - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_user_override(): @@ -109,10 +114,11 @@ def test_ssh_command_with_user_override(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with -l flag overriding default username - result = client.run(False, ["-l", "myuser", "hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["-l", "myuser", "hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -128,7 +134,8 @@ def test_ssh_command_with_user_override(): assert "127.0.0.1" in call_args assert "hostname" in call_args # Should be preserved as command argument - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_port(): @@ -140,7 +147,7 @@ def test_ssh_command_with_port(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Mock the TcpPortforwardAdapter to return the expected port with patch('jumpstarter_driver_ssh.client.TcpPortforwardAdapter') as mock_adapter: @@ -148,7 +155,8 @@ def test_ssh_command_with_port(): mock_adapter.return_value.__exit__.return_value = None # Test SSH command with custom port - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -167,7 +175,8 @@ def test_ssh_command_with_port(): assert "127.0.0.1" in call_args assert "hostname" in call_args # Should be preserved as command argument - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_direct_flag(): @@ -179,12 +188,13 @@ def test_ssh_command_with_direct_flag(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Mock the tcp.address() method with patch.object(client.tcp, 'address', return_value="tcp://192.168.1.100:22"): # Test SSH command with direct flag - result = client.run(True, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=True), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -198,7 +208,8 @@ def test_ssh_command_with_direct_flag(): assert "192.168.1.100" in call_args assert "hostname" in call_args # Should be preserved as command argument - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_error_handling(): @@ -213,10 +224,13 @@ def test_ssh_command_error_handling(): mock_run.side_effect = FileNotFoundError("SSH not found") # Test SSH command error handling - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Should return error code 127 - assert result == 127 + assert result.return_code == 127 + assert result.stdout == "" + assert "not found" in result.stderr def test_ssh_command_with_multiple_ssh_options(): @@ -228,12 +242,13 @@ def test_ssh_command_with_multiple_ssh_options(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with multiple SSH options - result = client.run(False, [ + result = client.run(SSHCommandRunOptions(direct=False), [ "-o", "StrictHostKeyChecking=no", "-i", "/path/to/key", "command", "arg1", "arg2" ]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -252,7 +267,8 @@ def test_ssh_command_with_multiple_ssh_options(): assert "arg1" in call_args assert "arg2" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_unknown_option_treated_as_command(): @@ -264,10 +280,11 @@ def test_ssh_command_with_unknown_option_treated_as_command(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with unknown option - result = client.run(False, ["-l", "user", "-unknown", "command", "arg1"]) + result = client.run(SSHCommandRunOptions(direct=False), ["-l", "user", "-unknown", "command", "arg1"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -284,7 +301,8 @@ def test_ssh_command_with_unknown_option_treated_as_command(): assert "command" in call_args assert "arg1" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_no_ssh_options(): @@ -296,10 +314,11 @@ def test_ssh_command_with_no_ssh_options(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with no SSH options - result = client.run(False, ["command", "arg1", "arg2"]) + result = client.run(SSHCommandRunOptions(direct=False), ["command", "arg1", "arg2"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -312,7 +331,8 @@ def test_ssh_command_with_no_ssh_options(): assert "arg1" in call_args assert "arg2" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_injection(): @@ -324,10 +344,11 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with -l flag in the command (like ls -la -l ajo) - result = client.run(False, ["ls", "-la", "-l", "ajo"]) + result = client.run(SSHCommandRunOptions(direct=False), ["ls", "-la", "-l", "ajo"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -354,7 +375,8 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject assert ssh_l_index < hostname_index < command_l_index - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_identity_string_configuration(): @@ -444,10 +466,11 @@ def test_ssh_command_with_identity_string(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with identity string - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -470,7 +493,8 @@ def test_ssh_command_with_identity_string(): assert "127.0.0.1" in call_args assert "hostname" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_command_with_identity_file(): @@ -492,10 +516,11 @@ def test_ssh_command_with_identity_file(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command with identity file - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -519,7 +544,8 @@ def test_ssh_command_with_identity_file(): assert "127.0.0.1" in call_args assert "hostname" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" finally: # Clean up the temporary file os.unlink(temp_file_path) @@ -534,10 +560,11 @@ def test_ssh_command_without_identity(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") # Test SSH command without identity - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify subprocess.run was called assert mock_run.called @@ -554,7 +581,8 @@ def test_ssh_command_without_identity(): assert "127.0.0.1" in call_args assert "hostname" in call_args - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_identity_temp_file_creation_and_cleanup(): @@ -567,7 +595,7 @@ def test_ssh_identity_temp_file_creation_and_cleanup(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") with patch('tempfile.NamedTemporaryFile') as mock_temp_file: with patch('os.chmod') as mock_chmod: @@ -580,7 +608,8 @@ def test_ssh_identity_temp_file_creation_and_cleanup(): mock_temp_file.return_value = mock_temp_file_instance # Test SSH command with identity - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify temporary file was created mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key') @@ -593,7 +622,8 @@ def test_ssh_identity_temp_file_creation_and_cleanup(): # Verify temporary file was cleaned up mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") - assert result == 0 + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_identity_temp_file_creation_error(): @@ -614,7 +644,7 @@ def test_ssh_identity_temp_file_creation_error(): # Test SSH command with identity should raise an error # The exception will be wrapped in an ExceptionGroup due to the context manager with pytest.raises(ExceptionGroup) as exc_info: - client.run(False, ["hostname"]) + client.run(SSHCommandRunOptions(direct=False), ["hostname"]) # Check that the original OSError is in the exception group assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) @@ -630,7 +660,7 @@ def test_ssh_identity_temp_file_cleanup_error(): with serve(instance) as client: with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0) + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") with patch('tempfile.NamedTemporaryFile') as mock_temp_file: with patch('os.chmod') as mock_chmod: @@ -647,15 +677,18 @@ def test_ssh_identity_temp_file_cleanup_error(): # Test SSH command with identity - should still succeed but log warning with patch.object(client, 'logger') as mock_logger: - result = client.run(False, ["hostname"]) + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) # Verify chmod was called mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) # Verify warning was logged - mock_logger.warning.assert_called_once() - warning_call = mock_logger.warning.call_args[0][0] - assert "Failed to clean up temporary identity file" in warning_call - assert "/tmp/test_ssh_key_12345" in warning_call - - assert result == 0 + mock_logger.warning.assert_called_once_with( + "Failed to clean up temporary identity file %s: %s", + "/tmp/test_ssh_key_12345", + str(mock_unlink.side_effect) + ) + + assert result.return_code == 0 + assert result.stdout == "some stdout"