Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/clabe/pickers/dataverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import pydantic
import requests
from aind_behavior_curriculum import TrainerState
from aind_behavior_services.rig import Rig
from pydantic import BaseModel, SecretStr, computed_field, field_validator

from .. import ui
from .._typing import TTask
from ..launcher import Launcher
from ..services import ServiceSettings
from ..utils.aind_auth import validate_aind_username
from ..utils.aind_validators import validate_rig_computer_name, validate_username
from ..utils.keepass import KeePass, KeePassSettings
from .default_behavior import DefaultBehaviorPicker, DefaultBehaviorPickerSettings

Expand Down Expand Up @@ -498,7 +499,8 @@ def __init__(
settings: DefaultBehaviorPickerSettings,
launcher: Launcher,
ui_helper: Optional[ui.IUiHelper] = None,
experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username,
experimenter_validator: Optional[Callable[[str], bool]] = validate_username,
rig_validator: Optional[Callable[[Rig], Rig]] = validate_rig_computer_name,
):
"""
Initializes the DataversePicker.
Expand All @@ -508,9 +510,14 @@ def __init__(
settings: Settings containing configuration including config_library_dir
ui_helper: Helper for user interface interactions
experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed
rig_validator: Function to validate the rig configuration. If None, no validation is performed
"""
super().__init__(
settings=settings, launcher=launcher, ui_helper=ui_helper, experimenter_validator=experimenter_validator
settings=settings,
launcher=launcher,
ui_helper=ui_helper,
experimenter_validator=experimenter_validator,
rig_validator=rig_validator,
)
self._dataverse_client = (
dataverse_client
Expand Down
11 changes: 8 additions & 3 deletions src/clabe/pickers/default_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..constants import ByAnimalFiles
from ..launcher import Launcher
from ..services import ServiceSettings
from ..utils.aind_auth import validate_aind_username
from ..utils.aind_validators import validate_rig_computer_name, validate_username

logger = logging.getLogger(__name__)
T = TypeVar("T")
Expand Down Expand Up @@ -70,7 +70,8 @@ def __init__(
settings: DefaultBehaviorPickerSettings,
launcher: Launcher,
ui_helper: Optional[ui.IUiHelper] = None,
experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username,
experimenter_validator: Optional[Callable[[str], bool]] = validate_username,
rig_validator: Optional[Callable[[Rig], Rig]] = validate_rig_computer_name,
use_cache: bool = True,
):
"""
Expand All @@ -80,14 +81,16 @@ def __init__(
settings: Settings containing configuration including config_library_dir. By default, attempts to rely on DefaultBehaviorPickerSettings to automatic loading from yaml files
launcher: The launcher instance for managing experiment execution
ui_helper: Helper for user interface interactions. If None, uses launcher's ui_helper. Defaults to None
experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_aind_username
experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_username
rig_validator: Function to validate the rig configuration. If None, no validation is performed. Defaults to validate_rig_computer_name
use_cache: Whether to use caching for selections. Defaults to True
"""
self._launcher = launcher
self._ui_helper = launcher.ui_helper if ui_helper is None else ui_helper
self._settings = settings
self._ensure_directories()
self._experimenter_validator = experimenter_validator
self._rig_validator = rig_validator
self._trainer_state: Optional[TrainerState] = None
self._session: Optional[Session] = None
self._cache_manager = CacheManager.get_instance()
Expand Down Expand Up @@ -241,6 +244,8 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
rig = self._load_rig_from_path(Path(rig_path), model)
assert rig_path is not None
assert rig is not None
if self._rig_validator:
rig = self._rig_validator(rig)
# Add the selected rig path to the cache
self._cache_manager.add_to_cache("rigs", rig_path)
return rig
Expand Down
45 changes: 7 additions & 38 deletions src/clabe/utils/aind_auth.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,9 @@
import logging
from typing import Optional
from urllib.parse import quote
import warnings

import requests
from .aind_validators import validate_username as validate_aind_username # noqa: F401

logger = logging.getLogger(__name__)

_AD_ENDPOINT = "http://aind-metadata-service/api/v2/active_directory"


def validate_aind_username(
username: str,
timeout: Optional[float] = 2,
) -> bool:
"""
Validates if the given username exists in the AIND Active Directory.

Queries the AIND metadata service to verify the username exists.
Returns False (instead of raising) on network errors so callers can
decide how to handle the degraded state.

Args:
username: The username to validate.
timeout: Timeout in seconds for the HTTP request. Defaults to 2.

Returns:
bool: True if the username was found, False otherwise.

Example:
```python
is_valid = validate_aind_username("j.doe")
```
"""
try:
response = requests.get(f"{_AD_ENDPOINT}/{quote(username, safe='')}", timeout=timeout)
return response.ok
except requests.RequestException as e:
logger.warning("Failed to validate username '%s': %s", username, e)
return False
warnings.warn(
"The 'clabe.utils.aind_auth' module is deprecated and will be removed in a future version. Use 'clabe.utils.aind_validators' instead.",
FutureWarning,
stacklevel=2,
)
80 changes: 80 additions & 0 deletions src/clabe/utils/aind_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING, Optional, TypeVar
from urllib.parse import quote

import requests

if TYPE_CHECKING:
from aind_behavior_services.rig import Rig

TRig = TypeVar("TRig", bound=Rig)
else:
TRig = TypeVar("TRig")

logger = logging.getLogger(__name__)

_ACTIVEDIRECTORY_ENDPOINT = "http://aind-metadata-service/api/v2/active_directory"


def validate_username(
username: str,
timeout: Optional[float] = 2,
) -> bool:
"""
Validates if the given username exists in the AIND Active Directory.

Queries the AIND metadata service to verify the username exists.
Returns False (instead of raising) on network errors so callers can
decide how to handle the degraded state.

Args:
username: The username to validate.
timeout: Timeout in seconds for the HTTP request. Defaults to 2.

Returns:
bool: True if the username was found, False otherwise.

Example:
```python
is_valid = validate_username("j.doe")
```
"""
try:
response = requests.get(f"{_ACTIVEDIRECTORY_ENDPOINT}/{quote(username, safe='')}", timeout=timeout)
return response.ok
except requests.RequestException as e:
logger.warning("Failed to validate username '%s': %s", username, e)
return False


def validate_rig_computer_name(rig: TRig) -> TRig:
"""Ensures rig and computer name are set from environment variables if available, otherwise defaults to rig configuration values."""
rig_name = os.environ.get("aibs_comp_id", None)
computer_name = os.environ.get("hostname", None)

if rig_name is None:
logger.warning(
"'%s' environment variable not set. Defaulting to rig name from configuration. %s",
"aibs_comp_id",
rig.rig_name,
)
rig_name = rig.rig_name
if computer_name is None:
computer_name = rig.computer_name
logger.warning(
"'hostname' environment variable not set. Defaulting to computer name from configuration. %s",
rig.computer_name,
)

if rig_name != rig.rig_name or computer_name != rig.computer_name:
logger.warning(
"Rig name or computer name from environment variables do not match the rig configuration. "
"Forcing rig name: %s and computer name: %s from environment variables.",
rig_name,
computer_name,
)
_rig = rig.model_copy(update={"rig_name": rig_name, "computer_name": computer_name})
return _rig
2 changes: 1 addition & 1 deletion src/clabe/xml_rpc/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _normalize_returncode(returncode: int | None) -> int | None:
for Access Violation = 3221225477). XML-RPC only supports signed 32-bit
integers, so codes above 2**31-1 must be reinterpreted as negative values.
XML-RPC docs: https://xmlrpc.com/spec.md

"""
if returncode is None:
return None
Expand Down
87 changes: 71 additions & 16 deletions tests/utils/test_aind_auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
from unittest.mock import MagicMock, patch

import pytest
import requests

from clabe.utils import aind_auth
from clabe.utils import aind_validators


def test_validate_aind_username_valid():
@pytest.fixture
def mock_rig():
rig = MagicMock()
rig.rig_name = "rig_1"
rig.computer_name = "host_1"

def _model_copy(update=None):
new_rig = MagicMock()
new_rig.rig_name = (update or {}).get("rig_name", rig.rig_name)
new_rig.computer_name = (update or {}).get("computer_name", rig.computer_name)
return new_rig

rig.model_copy.side_effect = _model_copy
return rig


def test_validate_username_valid():
"""Returns True when the metadata service finds the user."""
with patch("clabe.utils.aind_auth.requests.get") as mock_get:
with patch("clabe.utils.aind_validators.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.ok = True
mock_response.json.return_value = {
Expand All @@ -17,53 +34,91 @@ def test_validate_aind_username_valid():
}
mock_get.return_value = mock_response

assert aind_auth.validate_aind_username("j.doe") is True
assert aind_validators.validate_username("j.doe") is True
mock_get.assert_called_once_with(
"http://aind-metadata-service/api/v2/active_directory/j.doe",
timeout=2,
)


def test_validate_aind_username_invalid():
def test_validate_username_invalid():
"""Returns False when the metadata service does not find the user."""
with patch("clabe.utils.aind_auth.requests.get") as mock_get:
with patch("clabe.utils.aind_validators.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.ok = False
mock_get.return_value = mock_response

assert aind_auth.validate_aind_username("no.one") is False
assert aind_validators.validate_username("no.one") is False


def test_validate_aind_username_request_exception():
def test_validate_username_request_exception():
"""Returns False (with a logged warning) on a network error."""
with patch("clabe.utils.aind_auth.requests.get", side_effect=requests.RequestException("timeout")):
assert aind_auth.validate_aind_username("j.doe") is False
with patch("clabe.utils.aind_validators.requests.get", side_effect=requests.RequestException("timeout")):
assert aind_validators.validate_username("j.doe") is False


def test_validate_aind_username_custom_timeout():
def test_validate_username_custom_timeout():
"""Passes the timeout argument through to requests.get."""
with patch("clabe.utils.aind_auth.requests.get") as mock_get:
with patch("clabe.utils.aind_validators.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.ok = True
mock_response.json.return_value = {"username": "j.doe"}
mock_get.return_value = mock_response

aind_auth.validate_aind_username("j.doe", timeout=10)
aind_validators.validate_username("j.doe", timeout=10)
mock_get.assert_called_once_with(
"http://aind-metadata-service/api/v2/active_directory/j.doe",
timeout=10,
)


def test_validate_aind_username_encodes_special_chars():
def test_validate_username_encodes_special_chars():
"""URL-encodes special characters in the username."""
with patch("clabe.utils.aind_auth.requests.get") as mock_get:
with patch("clabe.utils.aind_validators.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.ok = False
mock_get.return_value = mock_response

aind_auth.validate_aind_username("../admin")
aind_validators.validate_username("../admin")
mock_get.assert_called_once_with(
"http://aind-metadata-service/api/v2/active_directory/..%2Fadmin",
timeout=2,
)


# --- validate_rig_computer_name ---


def test_validate_rig_no_env_vars_returns_original_values(mock_rig):
"""When env vars are absent the returned rig keeps the original values."""
with patch.dict("os.environ", {}, clear=True):
result = aind_validators.validate_rig_computer_name(mock_rig)

assert result.rig_name == "rig_1"
assert result.computer_name == "host_1"


def test_validate_rig_matching_env_vars_returns_same_values(mock_rig):
"""When env vars match the rig config the returned rig has the same values."""
with patch.dict("os.environ", {"aibs_comp_id": "rig_1", "hostname": "host_1"}):
result = aind_validators.validate_rig_computer_name(mock_rig)

assert result.rig_name == "rig_1"
assert result.computer_name == "host_1"


def test_validate_rig_differing_env_vars_returns_updated_rig(mock_rig):
"""When env vars differ from the rig config the returned rig has the env var values."""
with patch.dict("os.environ", {"aibs_comp_id": "rig_2", "hostname": "host_2"}):
result = aind_validators.validate_rig_computer_name(mock_rig)

assert result.rig_name == "rig_2"
assert result.computer_name == "host_2"


def test_validate_rig_returns_new_object_not_original(mock_rig):
"""The function always returns a new rig copy, never mutates the input."""
with patch.dict("os.environ", {"aibs_comp_id": "rig_1", "hostname": "host_1"}):
result = aind_validators.validate_rig_computer_name(mock_rig)

assert result is not mock_rig
Loading