diff --git a/device-connectors/src/testflinger_device_connectors/devices/__init__.py b/device-connectors/src/testflinger_device_connectors/devices/__init__.py index 3487d186a..20814cd61 100644 --- a/device-connectors/src/testflinger_device_connectors/devices/__init__.py +++ b/device-connectors/src/testflinger_device_connectors/devices/__init__.py @@ -72,6 +72,114 @@ def SerialLogger(host=None, port=None, filename=None): return StubSerialLogger(host, port, filename) +def import_ssh_key(key: str, keyfile: str = "key.pub") -> None: + """Import SSH key provided in Reserve data. + + :param key: SSH key to import. + :param keyfile: Output file where to store the imported key + :raises RuntimeError: If failure during import ssh keys + """ + cmd = ["ssh-import-id", "-o", keyfile, key] + for retry in range(10): + try: + subprocess.run( + cmd, + timeout=30, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + ) + logger.info("Successfully imported key: %s", key) + break + + except subprocess.TimeoutExpired: + pass + except subprocess.CalledProcessError as exc: + output = (exc.stdout or b"").decode() + if "status_code=404" in output: + raise RuntimeError( + f"Failed to import ssh key: {key}. User not found." + ) from exc + + logger.error("Unable to import ssh key from: %s", key) + logger.info("Retrying...") + time.sleep(min(2**retry, 100)) + else: + raise RuntimeError( + f"Failed to import ssh key: {key}. Maximum retries reached" + ) + + +def copy_ssh_key( + device_ip: str, + username: str, + password: Optional[str] = None, + key: Optional[str] = None, +): + """If provided, copy the SSH `key` to the DUT, + otherwise copy the agent's using password authentication. + + :raises RuntimeError in case it can't copy the SSH keys + """ + if not key and not password: + raise ValueError("Cannot copy the agent's SSH key w/o password") + + if password: + cmd = ["sshpass", "-p", password] + else: + cmd = [] + + cmd.extend(["ssh-copy-id", "-f"]) + + if key: + cmd.extend(["-i", key]) + + cmd.extend( + [ + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "{}@{}".format(username, device_ip), + ] + ) + + for _retry in range(10): + # Retry ssh key copy just in case it's rebooting + try: + subprocess.check_call(cmd, timeout=30) + break + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + logger.error("Error copying ssh key to device for: %s", key) + logger.info("Retrying...") + time.sleep(60) + + else: + logger.error("Failed to copy ssh key: %s", key) + raise RuntimeError + + +def copy_ssh_keys_to_devices(ssh_keys, device_ips, test_username="ubuntu"): + """Copy list of ssh keys to list of devices.""" + for key in ssh_keys: + with contextlib.suppress(FileNotFoundError): + os.unlink("key.pub") + + try: + # Import SSH Keys with ssh-import-id + import_ssh_key(key, keyfile="key.pub") + + # Attempt to copy keys only if import succeeds + with contextlib.suppress(RuntimeError): + for device_ip in device_ips: + copy_ssh_key(device_ip, test_username, key="key.pub") + except RuntimeError as exc: + logger.error(exc) + + class StubSerialLogger: """Fake SerialLogger when we don't have Serial Logger data defined.""" @@ -260,95 +368,6 @@ def allocate(self): """Allocate devices for multi-agent jobs (default method).""" pass - def import_ssh_key(self, key: str, keyfile: str = "key.pub") -> None: - """Import SSH key provided in Reserve data. - - :param key: SSH key to import. - :param keyfile: Output file where to store the imported key - :raises RuntimeError: If failure during import ssh keys - """ - cmd = ["ssh-import-id", "-o", keyfile, key] - for retry in range(10): - try: - subprocess.run( - cmd, - timeout=30, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - check=True, - ) - logger.info("Successfully imported key: %s", key) - break - - except subprocess.TimeoutExpired: - pass - except subprocess.CalledProcessError as exc: - output = (exc.stdout or b"").decode() - if "status_code=404" in output: - raise RuntimeError( - f"Failed to import ssh key: {key}. User not found." - ) from exc - - logger.error("Unable to import ssh key from: %s", key) - logger.info("Retrying...") - time.sleep(min(2**retry, 100)) - else: - raise RuntimeError( - f"Failed to import ssh key: {key}. Maximum retries reached" - ) - - def copy_ssh_key( - self, - device_ip: str, - username: str, - password: Optional[str] = None, - key: Optional[str] = None, - ): - """If provided, copy the SSH `key` to the DUT, - otherwise copy the agent's using password authentication. - - :raises RuntimeError in case it can't copy the SSH keys - """ - if not key and not password: - raise ValueError("Cannot copy the agent's SSH key w/o password") - - if password: - cmd = ["sshpass", "-p", password] - else: - cmd = [] - - cmd.extend(["ssh-copy-id", "-f"]) - - if key: - cmd.extend(["-i", key]) - - cmd.extend( - [ - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", - "{}@{}".format(username, device_ip), - ] - ) - - for _retry in range(10): - # Retry ssh key copy just in case it's rebooting - try: - subprocess.check_call(cmd, timeout=30) - break - except ( - subprocess.CalledProcessError, - subprocess.TimeoutExpired, - ): - logger.error("Error copying ssh key to device for: %s", key) - logger.info("Retrying...") - time.sleep(60) - - else: - logger.error("Failed to copy ssh key: %s", key) - raise RuntimeError - def reserve(self, args): """Reserve systems (default method).""" with open(args.config) as configfile: @@ -364,20 +383,7 @@ def reserve(self, args): device_ip = config["device_ip"] reserve_data = job_data["reserve_data"] ssh_keys = reserve_data.get("ssh_keys", []) - for key in ssh_keys: - with contextlib.suppress(FileNotFoundError): - os.unlink("key.pub") - - try: - # Import SSH Keys with ssh-import-id - self.import_ssh_key(key, keyfile="key.pub") - - # Attempt to copy keys only if import succeeds - with contextlib.suppress(RuntimeError): - self.copy_ssh_key(device_ip, test_username, key="key.pub") - except RuntimeError as exc: - logger.error(exc) - + copy_ssh_keys_to_devices(ssh_keys, [device_ip], test_username) # default reservation timeout is 1 hour timeout = int(reserve_data.get("timeout", "3600")) serial_host = config.get("serial_host") diff --git a/device-connectors/src/testflinger_device_connectors/devices/multi/multi.py b/device-connectors/src/testflinger_device_connectors/devices/multi/multi.py index 7e805d2a9..d088bc5e1 100644 --- a/device-connectors/src/testflinger_device_connectors/devices/multi/multi.py +++ b/device-connectors/src/testflinger_device_connectors/devices/multi/multi.py @@ -18,10 +18,14 @@ import logging import os import time +from datetime import datetime, timedelta import requests -from testflinger_device_connectors.devices import ProvisioningError +from testflinger_device_connectors.devices import ( + ProvisioningError, + copy_ssh_keys_to_devices, +) logger = logging.getLogger(__name__) @@ -80,6 +84,45 @@ def provision(self): self.save_job_list_file() + def reserve(self): + """Push ssh keys to each device in reservation phase.""" + logger.info("BEGIN multi device reservation") + job_data = self.job_data + try: + test_username = job_data["test_data"]["test_username"] + except KeyError: + test_username = "ubuntu" + reserve_data = job_data["reserve_data"] + ssh_keys = reserve_data.get("ssh_keys", []) + with open("job_list.json", "r") as json_file: + job_list = json.load(json_file) + device_ips = [job["device_info"]["device_ip"] for job in job_list] + copy_ssh_keys_to_devices(ssh_keys, device_ips, test_username) + print("*** TESTFLINGER SYSTEMS RESERVED ***") + print("You can now connect to the following devices:") + for job in job_list: + device_ip = job["device_info"]["device_ip"] + print(f"{test_username}@{device_ip}") + + timeout = int(reserve_data.get("timeout", "3600")) + now = datetime.now().astimezone().isoformat() + expire_time = ( + datetime.now().astimezone() + timedelta(seconds=timeout) + ).isoformat() + print("Current time: [{}]".format(now)) + print("Reservation expires at: [{}]".format(expire_time)) + print( + "Reservation will automatically timeout in {} seconds".format( + timeout + ) + ) + job_id = job_data.get("job_id", "") + print( + "To end the reservation sooner use: " + + "testflinger-cli cancel {}".format(job_id) + ) + time.sleep(timeout) + def terminate_if_parent_completed(self): """If parent job is completed or cancelled, cancel sub jobs.""" if self.this_job_completed(): diff --git a/device-connectors/src/testflinger_device_connectors/devices/multi/tests/test_multi.py b/device-connectors/src/testflinger_device_connectors/devices/multi/tests/test_multi.py index f5a516f73..48825ab22 100644 --- a/device-connectors/src/testflinger_device_connectors/devices/multi/tests/test_multi.py +++ b/device-connectors/src/testflinger_device_connectors/devices/multi/tests/test_multi.py @@ -14,6 +14,9 @@ """Unit tests for multi-device support code.""" +import json +import tempfile +from unittest.mock import patch from uuid import uuid4 import pytest @@ -98,3 +101,124 @@ def test_this_job_completed(): incomplete_client.get_status = lambda job_id: "something else" test_agent = Multi(test_config, job_data, incomplete_client) assert test_agent.this_job_completed() is False + + +@patch( + "testflinger_device_connectors.devices.multi.multi.copy_ssh_keys_to_devices" +) +@patch("time.sleep") +def test_multi_reserve(mock_sleep, mock_copy_keys): + """Test Multi.reserve method functionality.""" + test_config = {"agent_name": "test_agent"} + job_data = { + "job_id": "test-job-123", + "reserve_data": {"ssh_keys": ["key1", "key2"], "timeout": "1800"}, + "test_data": {"test_username": "testuser"}, + } + + # Create job_list.json file with mock data + job_list = [ + {"device_info": {"device_ip": "192.168.1.1"}}, + {"device_info": {"device_ip": "192.168.1.2"}}, + ] + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + json.dump(job_list, tmp_file) + + # Mock builtins.open to return our temp file + # when job_list.json is requested + with patch("builtins.open", create=True) as mo: + mo.return_value.__enter__.return_value.read.return_value = json.dumps( + job_list + ) + + # Mock print to capture output + with patch("builtins.print") as mock_print: + test_agent = Multi( + test_config, job_data, MockTFClient("http://localhost") + ) + test_agent.reserve() + + # Verify copy_ssh_keys_to_devices was called with correct parameters + mock_copy_keys.assert_called_once_with( + ["key1", "key2"], ["192.168.1.1", "192.168.1.2"], "testuser" + ) + + # Verify time.sleep was called with timeout + mock_sleep.assert_called_once_with(1800) + + # Verify print statements were made + assert ( + mock_print.call_count >= 5 + ) # Multiple print statements in reserve method + + +@patch( + "testflinger_device_connectors.devices.multi.multi.copy_ssh_keys_to_devices" +) +@patch("time.sleep") +def test_multi_reserve_default_username(mock_sleep, mock_copy_keys): + """Test Multi.reserve method with default username.""" + test_config = {"agent_name": "test_agent"} + job_data = { + "job_id": "test-job-123", + "reserve_data": {"ssh_keys": ["key1"], "timeout": "3600"}, + # No test_data section - should default to ubuntu + } + + job_list = [{"device_info": {"device_ip": "192.168.1.1"}}] + + with patch("builtins.open", create=True) as mock_open: + mock_open.return_value.__enter__.return_value.read.return_value = ( + json.dumps(job_list) + ) + + with patch("builtins.print"): + test_agent = Multi( + test_config, job_data, MockTFClient("http://localhost") + ) + test_agent.reserve() + + # Verify copy_ssh_keys_to_devices was called + mock_copy_keys.assert_called_once_with(["key1"], ["192.168.1.1"], "ubuntu") + + # Verify time.sleep was called with timeout + mock_sleep.assert_called_once_with(3600) + + +@patch( + "testflinger_device_connectors.devices.multi.multi.copy_ssh_keys_to_devices" +) +@patch("time.sleep") +def test_multi_reserve_no_ssh_keys(mock_sleep, mock_copy_keys): + """Test Multi.reserve method with no SSH keys.""" + test_config = {"agent_name": "test_agent"} + job_data = { + "job_id": "test-job-123", + "reserve_data": { + "timeout": "1800" + # No ssh_keys provided + }, + "test_data": {"test_username": "testuser"}, + } + + job_list = [{"device_info": {"device_ip": "192.168.1.1"}}] + + with patch("builtins.open", create=True) as mock_open: + mock_open.return_value.__enter__.return_value.read.return_value = ( + json.dumps(job_list) + ) + + with patch("builtins.print"): + test_agent = Multi( + test_config, job_data, MockTFClient("http://localhost") + ) + test_agent.reserve() + + # Verify copy_ssh_keys_to_devices was called with empty list + mock_copy_keys.assert_called_once_with([], ["192.168.1.1"], "testuser") + + # Verify time.sleep was called with timeout + mock_sleep.assert_called_once_with(1800) diff --git a/device-connectors/src/testflinger_device_connectors/devices/tests/test_devices.py b/device-connectors/src/testflinger_device_connectors/devices/tests/test_devices.py index 54dc4be59..c1d5e8883 100644 --- a/device-connectors/src/testflinger_device_connectors/devices/tests/test_devices.py +++ b/device-connectors/src/testflinger_device_connectors/devices/tests/test_devices.py @@ -18,7 +18,11 @@ import unittest from unittest.mock import Mock, call, patch -from testflinger_device_connectors.devices import DefaultDevice +from testflinger_device_connectors.devices import ( + DefaultDevice, + copy_ssh_key, + copy_ssh_keys_to_devices, +) class DefaultDeviceTests(unittest.TestCase): @@ -30,8 +34,8 @@ def test_copy_ssh_id(self, mock_check): to the DUT. """ fake_config = {"device_ip": "10.10.10.10", "agent_name": "fake_agent"} - connector = DefaultDevice(fake_config) - connector.copy_ssh_key( + DefaultDevice(fake_config) + copy_ssh_key( "192.168.1.2", "username", "password", @@ -50,8 +54,8 @@ def test_copy_ssh_id_with_key(self, mock_check): to the DUT. """ fake_config = {"device_ip": "10.10.10.10", "agent_name": "fake_agent"} - connector = DefaultDevice(fake_config) - connector.copy_ssh_key( + DefaultDevice(fake_config) + copy_ssh_key( "192.168.1.2", "username", key="key.pub", @@ -72,12 +76,12 @@ def test_copy_ssh_id_raises(self, mock_check): exception after 3 failed attempts. """ fake_config = {"device_ip": "10.10.10.10", "agent_name": "fake_agent"} - connector = DefaultDevice(fake_config) + DefaultDevice(fake_config) mock_check.side_effect = subprocess.CalledProcessError(1, "") with self.assertRaises(RuntimeError): - connector.copy_ssh_key( + copy_ssh_key( "192.168.1.2", "username", "password", @@ -103,3 +107,267 @@ def test_write_device_info(self): assert all( device_info[key] == value for key, value in fake_config.items() ) + + +class CopySshKeysToDevicesTests(unittest.TestCase): + """Unit tests for copy_ssh_keys_to_devices function.""" + + @patch("testflinger_device_connectors.devices.copy_ssh_key") + @patch("testflinger_device_connectors.devices.import_ssh_key") + @patch("os.unlink") + def test_copy_ssh_keys_to_devices_success( + self, mock_unlink, mock_import, mock_copy + ): + """Test successful copying of SSH keys to multiple devices.""" + ssh_keys = ["key1", "key2"] + device_ips = ["192.168.1.1", "192.168.1.2"] + + copy_ssh_keys_to_devices(ssh_keys, device_ips, "testuser") + + # Should unlink key.pub twice (once per key) + assert mock_unlink.call_count == 2 + mock_unlink.assert_has_calls([call("key.pub"), call("key.pub")]) + + # Should import each key + mock_import.assert_has_calls( + [call("key1", keyfile="key.pub"), call("key2", keyfile="key.pub")] + ) + + # Should copy each key to each device (2 keys * 2 devices = 4 calls) + expected_copy_calls = [ + call("192.168.1.1", "testuser", key="key.pub"), + call("192.168.1.2", "testuser", key="key.pub"), + call("192.168.1.1", "testuser", key="key.pub"), + call("192.168.1.2", "testuser", key="key.pub"), + ] + mock_copy.assert_has_calls(expected_copy_calls) + + @patch("testflinger_device_connectors.devices.copy_ssh_key") + @patch("testflinger_device_connectors.devices.import_ssh_key") + @patch("os.unlink") + def test_copy_ssh_keys_to_devices_import_failure( + self, mock_unlink, mock_import, mock_copy + ): + """Test handling of import_ssh_key failure.""" + ssh_keys = ["key1"] + device_ips = ["192.168.1.1"] + + # Make import_ssh_key raise RuntimeError + mock_import.side_effect = RuntimeError("Import failed") + + copy_ssh_keys_to_devices(ssh_keys, device_ips, "testuser") + + # Should still attempt to unlink + mock_unlink.assert_called_once_with("key.pub") + + # Should attempt to import + mock_import.assert_called_once_with("key1", keyfile="key.pub") + + # Should not attempt to copy since import failed + mock_copy.assert_not_called() + + @patch("testflinger_device_connectors.devices.copy_ssh_key") + @patch("testflinger_device_connectors.devices.import_ssh_key") + @patch("os.unlink") + def test_copy_ssh_keys_to_devices_copy_failure( + self, mock_unlink, mock_import, mock_copy + ): + """Test handling of copy_ssh_key failure.""" + ssh_keys = ["key1"] + device_ips = ["192.168.1.1"] + + # Make copy_ssh_key raise RuntimeError + mock_copy.side_effect = RuntimeError("Copy failed") + + copy_ssh_keys_to_devices(ssh_keys, device_ips, "testuser") + + # Should unlink and import successfully + mock_unlink.assert_called_once_with("key.pub") + mock_import.assert_called_once_with("key1", keyfile="key.pub") + + # Should attempt to copy but fail gracefully + mock_copy.assert_called_once_with( + "192.168.1.1", "testuser", key="key.pub" + ) + + @patch("testflinger_device_connectors.devices.copy_ssh_key") + @patch("testflinger_device_connectors.devices.import_ssh_key") + @patch("os.unlink", side_effect=FileNotFoundError) + def test_copy_ssh_keys_to_devices_file_not_found( + self, mock_unlink, mock_import, mock_copy + ): + """Test handling when key.pub file doesn't exist.""" + ssh_keys = ["key1"] + device_ips = ["192.168.1.1"] + + copy_ssh_keys_to_devices(ssh_keys, device_ips, "testuser") + + # Should attempt to unlink but suppress FileNotFoundError + mock_unlink.assert_called_once_with("key.pub") + + # Should continue with import and copy + mock_import.assert_called_once_with("key1", keyfile="key.pub") + mock_copy.assert_called_once_with( + "192.168.1.1", "testuser", key="key.pub" + ) + + @patch("testflinger_device_connectors.devices.copy_ssh_key") + @patch("testflinger_device_connectors.devices.import_ssh_key") + @patch("os.unlink") + def test_copy_ssh_keys_to_devices_empty_lists( + self, mock_unlink, mock_import, mock_copy + ): + """Test with empty SSH keys and device lists.""" + copy_ssh_keys_to_devices([], []) + + # Should not call any functions + mock_unlink.assert_not_called() + mock_import.assert_not_called() + mock_copy.assert_not_called() + + +class DefaultDeviceReserveTests(unittest.TestCase): + """Unit tests for DefaultDevice.reserve method.""" + + @patch("testflinger_device_connectors.devices.copy_ssh_keys_to_devices") + @patch("time.sleep") + def test_reserve_with_ssh_keys(self, mock_sleep, mock_copy_keys): + """Test DefaultDevice.reserve method with SSH keys.""" + config_data = {"device_ip": "192.168.1.10", "agent_name": "test_agent"} + + job_data = { + "reserve_data": {"ssh_keys": ["key1", "key2"], "timeout": "1800"}, + "test_data": {"test_username": "testuser"}, + } + + # Create a mock args object + mock_args = Mock() + mock_args.config = "test_config.yaml" + + # Mock file operations + with ( + patch( + "testflinger_device_connectors.get_test_opportunity", + return_value=job_data, + ), + patch("builtins.open") as mock_open, + ): + # Mock the config file read + mock_open.return_value.__enter__.return_value = Mock() + mock_open.return_value.__enter__.return_value.read.return_value = ( + '{"device_ip": "192.168.1.10"}' + ) + + with ( + patch( + "yaml.safe_load", + return_value={"device_ip": "192.168.1.10"}, + ), + patch("builtins.print"), + ): + device = DefaultDevice(config_data) + device.reserve(mock_args) + + # Verify copy_ssh_keys_to_devices was called with correct parameters + mock_copy_keys.assert_called_once_with( + ["key1", "key2"], ["192.168.1.10"], "testuser" + ) + + # Verify sleep was called with timeout + mock_sleep.assert_called_once_with(1800) + + @patch("testflinger_device_connectors.devices.copy_ssh_keys_to_devices") + @patch("time.sleep") + def test_reserve_no_ssh_keys(self, mock_sleep, mock_copy_keys): + """Test DefaultDevice.reserve method with no SSH keys.""" + config_data = {"device_ip": "192.168.1.10", "agent_name": "test_agent"} + + job_data = { + "reserve_data": { + "timeout": "3600" + # No ssh_keys provided + }, + "test_data": {"test_username": "testuser"}, + } + + mock_args = Mock() + mock_args.config = "test_config.yaml" + + with ( + patch( + "testflinger_device_connectors.get_test_opportunity", + return_value=job_data, + ), + patch("builtins.open") as mock_open, + ): + # Mock the config file read + mock_open.return_value.__enter__.return_value = Mock() + mock_open.return_value.__enter__.return_value.read.return_value = ( + '{"device_ip": "192.168.1.10"}' + ) + + with ( + patch( + "yaml.safe_load", + return_value={"device_ip": "192.168.1.10"}, + ), + patch("builtins.print"), + ): + device = DefaultDevice(config_data) + device.reserve(mock_args) + + # Verify copy_ssh_keys_to_devices was called with empty list + mock_copy_keys.assert_called_once_with( + [], ["192.168.1.10"], "testuser" + ) + + # Verify sleep was called with timeout + mock_sleep.assert_called_once_with(3600) + + @patch("testflinger_device_connectors.devices.copy_ssh_keys_to_devices") + @patch("time.sleep") + def test_reserve_default_timeout(self, mock_sleep, mock_copy_keys): + """Test DefaultDevice.reserve method with default timeout.""" + config_data = {"device_ip": "192.168.1.10", "agent_name": "test_agent"} + + job_data = { + "reserve_data": { + "ssh_keys": ["key1"] + # No timeout provided - should default to 3600 + }, + "test_data": {"test_username": "testuser"}, + } + + mock_args = Mock() + mock_args.config = "test_config.yaml" + + with ( + patch( + "testflinger_device_connectors.get_test_opportunity", + return_value=job_data, + ), + patch("builtins.open") as mock_open, + ): + # Mock the config file read + mock_open.return_value.__enter__.return_value = Mock() + mock_open.return_value.__enter__.return_value.read.return_value = ( + '{"device_ip": "192.168.1.10"}' + ) + + with ( + patch( + "yaml.safe_load", + return_value={"device_ip": "192.168.1.10"}, + ), + patch("builtins.print"), + ): + device = DefaultDevice(config_data) + device.reserve(mock_args) + + # Verify copy_ssh_keys_to_devices was called + mock_copy_keys.assert_called_once_with( + ["key1"], ["192.168.1.10"], "testuser" + ) + + # Verify sleep was called with default timeout of 3600 + mock_sleep.assert_called_once_with(3600)