diff --git a/src/clabe/pickers/dataverse.py b/src/clabe/pickers/dataverse.py index 5e7a8e32..c4324683 100644 --- a/src/clabe/pickers/dataverse.py +++ b/src/clabe/pickers/dataverse.py @@ -11,13 +11,14 @@ import pydantic import requests from aind_behavior_curriculum import TrainerState +from aind_behavior_services.rig import Rig from pydantic import BaseModel, SecretStr, computed_field, field_validator from .. import ui from .._typing import TTask from ..launcher import Launcher from ..services import ServiceSettings -from ..utils.aind_auth import validate_aind_username +from ..utils.aind_validators import validate_rig_computer_name, validate_username from ..utils.keepass import KeePass, KeePassSettings from .default_behavior import DefaultBehaviorPicker, DefaultBehaviorPickerSettings @@ -498,7 +499,8 @@ def __init__( settings: DefaultBehaviorPickerSettings, launcher: Launcher, ui_helper: Optional[ui.IUiHelper] = None, - experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username, + experimenter_validator: Optional[Callable[[str], bool]] = validate_username, + rig_validator: Optional[Callable[[Rig], Rig]] = validate_rig_computer_name, ): """ Initializes the DataversePicker. @@ -508,9 +510,14 @@ def __init__( settings: Settings containing configuration including config_library_dir ui_helper: Helper for user interface interactions experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed + rig_validator: Function to validate the rig configuration. If None, no validation is performed """ super().__init__( - settings=settings, launcher=launcher, ui_helper=ui_helper, experimenter_validator=experimenter_validator + settings=settings, + launcher=launcher, + ui_helper=ui_helper, + experimenter_validator=experimenter_validator, + rig_validator=rig_validator, ) self._dataverse_client = ( dataverse_client diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 13a72774..793d6d9f 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -15,7 +15,7 @@ from ..constants import ByAnimalFiles from ..launcher import Launcher from ..services import ServiceSettings -from ..utils.aind_auth import validate_aind_username +from ..utils.aind_validators import validate_rig_computer_name, validate_username logger = logging.getLogger(__name__) T = TypeVar("T") @@ -70,7 +70,8 @@ def __init__( settings: DefaultBehaviorPickerSettings, launcher: Launcher, ui_helper: Optional[ui.IUiHelper] = None, - experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username, + experimenter_validator: Optional[Callable[[str], bool]] = validate_username, + rig_validator: Optional[Callable[[Rig], Rig]] = validate_rig_computer_name, use_cache: bool = True, ): """ @@ -80,7 +81,8 @@ def __init__( settings: Settings containing configuration including config_library_dir. By default, attempts to rely on DefaultBehaviorPickerSettings to automatic loading from yaml files launcher: The launcher instance for managing experiment execution ui_helper: Helper for user interface interactions. If None, uses launcher's ui_helper. Defaults to None - experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_aind_username + experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_username + rig_validator: Function to validate the rig configuration. If None, no validation is performed. Defaults to validate_rig_computer_name use_cache: Whether to use caching for selections. Defaults to True """ self._launcher = launcher @@ -88,6 +90,7 @@ def __init__( self._settings = settings self._ensure_directories() self._experimenter_validator = experimenter_validator + self._rig_validator = rig_validator self._trainer_state: Optional[TrainerState] = None self._session: Optional[Session] = None self._cache_manager = CacheManager.get_instance() @@ -241,6 +244,8 @@ def pick_rig(self, model: Type[TRig]) -> TRig: rig = self._load_rig_from_path(Path(rig_path), model) assert rig_path is not None assert rig is not None + if self._rig_validator: + rig = self._rig_validator(rig) # Add the selected rig path to the cache self._cache_manager.add_to_cache("rigs", rig_path) return rig diff --git a/src/clabe/utils/aind_auth.py b/src/clabe/utils/aind_auth.py index 612e9f21..8f1ddce1 100644 --- a/src/clabe/utils/aind_auth.py +++ b/src/clabe/utils/aind_auth.py @@ -1,40 +1,9 @@ -import logging -from typing import Optional -from urllib.parse import quote +import warnings -import requests +from .aind_validators import validate_username as validate_aind_username # noqa: F401 -logger = logging.getLogger(__name__) - -_AD_ENDPOINT = "http://aind-metadata-service/api/v2/active_directory" - - -def validate_aind_username( - username: str, - timeout: Optional[float] = 2, -) -> bool: - """ - Validates if the given username exists in the AIND Active Directory. - - Queries the AIND metadata service to verify the username exists. - Returns False (instead of raising) on network errors so callers can - decide how to handle the degraded state. - - Args: - username: The username to validate. - timeout: Timeout in seconds for the HTTP request. Defaults to 2. - - Returns: - bool: True if the username was found, False otherwise. - - Example: - ```python - is_valid = validate_aind_username("j.doe") - ``` - """ - try: - response = requests.get(f"{_AD_ENDPOINT}/{quote(username, safe='')}", timeout=timeout) - return response.ok - except requests.RequestException as e: - logger.warning("Failed to validate username '%s': %s", username, e) - return False +warnings.warn( + "The 'clabe.utils.aind_auth' module is deprecated and will be removed in a future version. Use 'clabe.utils.aind_validators' instead.", + FutureWarning, + stacklevel=2, +) diff --git a/src/clabe/utils/aind_validators.py b/src/clabe/utils/aind_validators.py new file mode 100644 index 00000000..257960f6 --- /dev/null +++ b/src/clabe/utils/aind_validators.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING, Optional, TypeVar +from urllib.parse import quote + +import requests + +if TYPE_CHECKING: + from aind_behavior_services.rig import Rig + + TRig = TypeVar("TRig", bound=Rig) +else: + TRig = TypeVar("TRig") + +logger = logging.getLogger(__name__) + +_ACTIVEDIRECTORY_ENDPOINT = "http://aind-metadata-service/api/v2/active_directory" + + +def validate_username( + username: str, + timeout: Optional[float] = 2, +) -> bool: + """ + Validates if the given username exists in the AIND Active Directory. + + Queries the AIND metadata service to verify the username exists. + Returns False (instead of raising) on network errors so callers can + decide how to handle the degraded state. + + Args: + username: The username to validate. + timeout: Timeout in seconds for the HTTP request. Defaults to 2. + + Returns: + bool: True if the username was found, False otherwise. + + Example: + ```python + is_valid = validate_username("j.doe") + ``` + """ + try: + response = requests.get(f"{_ACTIVEDIRECTORY_ENDPOINT}/{quote(username, safe='')}", timeout=timeout) + return response.ok + except requests.RequestException as e: + logger.warning("Failed to validate username '%s': %s", username, e) + return False + + +def validate_rig_computer_name(rig: TRig) -> TRig: + """Ensures rig and computer name are set from environment variables if available, otherwise defaults to rig configuration values.""" + rig_name = os.environ.get("aibs_comp_id", None) + computer_name = os.environ.get("hostname", None) + + if rig_name is None: + logger.warning( + "'%s' environment variable not set. Defaulting to rig name from configuration. %s", + "aibs_comp_id", + rig.rig_name, + ) + rig_name = rig.rig_name + if computer_name is None: + computer_name = rig.computer_name + logger.warning( + "'hostname' environment variable not set. Defaulting to computer name from configuration. %s", + rig.computer_name, + ) + + if rig_name != rig.rig_name or computer_name != rig.computer_name: + logger.warning( + "Rig name or computer name from environment variables do not match the rig configuration. " + "Forcing rig name: %s and computer name: %s from environment variables.", + rig_name, + computer_name, + ) + _rig = rig.model_copy(update={"rig_name": rig_name, "computer_name": computer_name}) + return _rig diff --git a/src/clabe/xml_rpc/_server.py b/src/clabe/xml_rpc/_server.py index d25eda62..80163972 100644 --- a/src/clabe/xml_rpc/_server.py +++ b/src/clabe/xml_rpc/_server.py @@ -120,7 +120,7 @@ def _normalize_returncode(returncode: int | None) -> int | None: for Access Violation = 3221225477). XML-RPC only supports signed 32-bit integers, so codes above 2**31-1 must be reinterpreted as negative values. XML-RPC docs: https://xmlrpc.com/spec.md - + """ if returncode is None: return None diff --git a/tests/utils/test_aind_auth.py b/tests/utils/test_aind_auth.py index 34545a40..2ec66b09 100644 --- a/tests/utils/test_aind_auth.py +++ b/tests/utils/test_aind_auth.py @@ -1,13 +1,30 @@ from unittest.mock import MagicMock, patch +import pytest import requests -from clabe.utils import aind_auth +from clabe.utils import aind_validators -def test_validate_aind_username_valid(): +@pytest.fixture +def mock_rig(): + rig = MagicMock() + rig.rig_name = "rig_1" + rig.computer_name = "host_1" + + def _model_copy(update=None): + new_rig = MagicMock() + new_rig.rig_name = (update or {}).get("rig_name", rig.rig_name) + new_rig.computer_name = (update or {}).get("computer_name", rig.computer_name) + return new_rig + + rig.model_copy.side_effect = _model_copy + return rig + + +def test_validate_username_valid(): """Returns True when the metadata service finds the user.""" - with patch("clabe.utils.aind_auth.requests.get") as mock_get: + with patch("clabe.utils.aind_validators.requests.get") as mock_get: mock_response = MagicMock() mock_response.ok = True mock_response.json.return_value = { @@ -17,53 +34,91 @@ def test_validate_aind_username_valid(): } mock_get.return_value = mock_response - assert aind_auth.validate_aind_username("j.doe") is True + assert aind_validators.validate_username("j.doe") is True mock_get.assert_called_once_with( "http://aind-metadata-service/api/v2/active_directory/j.doe", timeout=2, ) -def test_validate_aind_username_invalid(): +def test_validate_username_invalid(): """Returns False when the metadata service does not find the user.""" - with patch("clabe.utils.aind_auth.requests.get") as mock_get: + with patch("clabe.utils.aind_validators.requests.get") as mock_get: mock_response = MagicMock() mock_response.ok = False mock_get.return_value = mock_response - assert aind_auth.validate_aind_username("no.one") is False + assert aind_validators.validate_username("no.one") is False -def test_validate_aind_username_request_exception(): +def test_validate_username_request_exception(): """Returns False (with a logged warning) on a network error.""" - with patch("clabe.utils.aind_auth.requests.get", side_effect=requests.RequestException("timeout")): - assert aind_auth.validate_aind_username("j.doe") is False + with patch("clabe.utils.aind_validators.requests.get", side_effect=requests.RequestException("timeout")): + assert aind_validators.validate_username("j.doe") is False -def test_validate_aind_username_custom_timeout(): +def test_validate_username_custom_timeout(): """Passes the timeout argument through to requests.get.""" - with patch("clabe.utils.aind_auth.requests.get") as mock_get: + with patch("clabe.utils.aind_validators.requests.get") as mock_get: mock_response = MagicMock() mock_response.ok = True mock_response.json.return_value = {"username": "j.doe"} mock_get.return_value = mock_response - aind_auth.validate_aind_username("j.doe", timeout=10) + aind_validators.validate_username("j.doe", timeout=10) mock_get.assert_called_once_with( "http://aind-metadata-service/api/v2/active_directory/j.doe", timeout=10, ) -def test_validate_aind_username_encodes_special_chars(): +def test_validate_username_encodes_special_chars(): """URL-encodes special characters in the username.""" - with patch("clabe.utils.aind_auth.requests.get") as mock_get: + with patch("clabe.utils.aind_validators.requests.get") as mock_get: mock_response = MagicMock() mock_response.ok = False mock_get.return_value = mock_response - aind_auth.validate_aind_username("../admin") + aind_validators.validate_username("../admin") mock_get.assert_called_once_with( "http://aind-metadata-service/api/v2/active_directory/..%2Fadmin", timeout=2, ) + + +# --- validate_rig_computer_name --- + + +def test_validate_rig_no_env_vars_returns_original_values(mock_rig): + """When env vars are absent the returned rig keeps the original values.""" + with patch.dict("os.environ", {}, clear=True): + result = aind_validators.validate_rig_computer_name(mock_rig) + + assert result.rig_name == "rig_1" + assert result.computer_name == "host_1" + + +def test_validate_rig_matching_env_vars_returns_same_values(mock_rig): + """When env vars match the rig config the returned rig has the same values.""" + with patch.dict("os.environ", {"aibs_comp_id": "rig_1", "hostname": "host_1"}): + result = aind_validators.validate_rig_computer_name(mock_rig) + + assert result.rig_name == "rig_1" + assert result.computer_name == "host_1" + + +def test_validate_rig_differing_env_vars_returns_updated_rig(mock_rig): + """When env vars differ from the rig config the returned rig has the env var values.""" + with patch.dict("os.environ", {"aibs_comp_id": "rig_2", "hostname": "host_2"}): + result = aind_validators.validate_rig_computer_name(mock_rig) + + assert result.rig_name == "rig_2" + assert result.computer_name == "host_2" + + +def test_validate_rig_returns_new_object_not_original(mock_rig): + """The function always returns a new rig copy, never mutates the input.""" + with patch.dict("os.environ", {"aibs_comp_id": "rig_1", "hostname": "host_1"}): + result = aind_validators.validate_rig_computer_name(mock_rig) + + assert result is not mock_rig