Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@ python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true

[dependency-groups]
dev = [
"pytest>=8.3.5",
]
31 changes: 20 additions & 11 deletions sdk/python/src/p95/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional
from typing import Literal, Optional, Set


def _default_logdir() -> str:
Expand Down Expand Up @@ -49,8 +49,9 @@ class SDKConfig:
retry_delay: float = 1.0


# Global configuration instance
# Global configuration instance and set of fields explicitly set via configure()
_config = SDKConfig()
_explicitly_set: Set[str] = set()


def configure(
Expand All @@ -69,6 +70,8 @@ def configure(
"""
Configure the SDK globally.

Explicit calls to configure() take priority over environment variables.

Args:
mode: Operating mode ("local" for file-based, "remote" for API server)
base_url: Base URL for the p95 API server (remote mode)
Expand Down Expand Up @@ -99,16 +102,20 @@ def configure(
api_key="p95_xxxx",
)
"""
global _config
global _config, _explicitly_set

if mode is not None:
_config.mode = mode
_explicitly_set.add("mode")
if base_url is not None:
_config.base_url = base_url
_explicitly_set.add("base_url")
if api_key is not None:
_config.api_key = api_key
_explicitly_set.add("api_key")
if logdir is not None:
_config.logdir = logdir
_explicitly_set.add("logdir")
if batch_size is not None:
_config.batch_size = batch_size
if flush_interval is not None:
Expand Down Expand Up @@ -144,18 +151,20 @@ def _detect_mode() -> Literal["local", "remote"]:


def get_config() -> SDKConfig:
"""Get the current SDK configuration with environment variable overrides."""
global _config
"""Get the current SDK configuration.

# Auto-detect mode from environment
_config.mode = _detect_mode()
configure() takes highest priority. Environment variables fill in any
fields not explicitly set. Hardcoded defaults are used for the rest.
"""
global _config

# Override with environment variables
if os.environ.get("P95_LOGDIR"):
if "mode" not in _explicitly_set:
_config.mode = _detect_mode()
if "logdir" not in _explicitly_set and os.environ.get("P95_LOGDIR"):
_config.logdir = os.environ["P95_LOGDIR"]
if os.environ.get("P95_URL"):
if "base_url" not in _explicitly_set and os.environ.get("P95_URL"):
_config.base_url = os.environ["P95_URL"]
if os.environ.get("P95_API_KEY"):
if "api_key" not in _explicitly_set and os.environ.get("P95_API_KEY"):
_config.api_key = os.environ["P95_API_KEY"]

return _config
4 changes: 2 additions & 2 deletions sdk/python/src/p95/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _run_remote_agent(
sweep_data=sweep,
is_local=False,
)
token = _set_sweep_context(ctx)
_set_sweep_context(ctx)

try:
# Call the training function
Expand Down Expand Up @@ -478,7 +478,7 @@ def _run_local_agent(
project=project,
is_local=True,
)
token = _set_sweep_context(ctx)
_set_sweep_context(ctx)

try:
# Call the training function - it will create its own Run
Expand Down
113 changes: 113 additions & 0 deletions sdk/python/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Tests for config priority: configure() should win over environment variables."""

import importlib

import pytest


def reload_config():
"""Reload the config module to reset global state between tests."""
import p95.config as config_mod

importlib.reload(config_mod)
return config_mod


def test_env_var_overrides_default(monkeypatch):
"""P95_LOGDIR env var switches mode to local (sanity check)."""
monkeypatch.setenv("P95_LOGDIR", "/tmp/logs")
monkeypatch.delenv("P95_URL", raising=False)
monkeypatch.delenv("P95_API_KEY", raising=False)

config_mod = reload_config()
config = config_mod.get_config()

assert config.mode == "local"


def test_env_var_remote_detection(monkeypatch):
"""P95_URL env var switches mode to remote."""
monkeypatch.setenv("P95_URL", "http://example.com")
monkeypatch.setenv("P95_API_KEY", "test-key")
monkeypatch.delenv("P95_LOGDIR", raising=False)

config_mod = reload_config()
config = config_mod.get_config()

assert config.mode == "remote"


def test_configure_takes_priority_over_env_var(monkeypatch):
"""
configure(mode=...) must take priority over environment variables.
"""
monkeypatch.setenv("P95_LOGDIR", "/tmp/logs")
monkeypatch.delenv("P95_URL", raising=False)
monkeypatch.delenv("P95_API_KEY", raising=False)

config_mod = reload_config()
config_mod.configure(mode="remote", base_url="http://example.com", api_key="key")
config = config_mod.get_config()

assert config.mode == "remote"


@pytest.mark.parametrize(
("configure_kwargs", "env_var", "env_value", "expected_attr", "expected_value"),
[
({"mode": "remote"}, "P95_LOGDIR", "/tmp/logs", "mode", "remote"),
(
{"logdir": "/custom/logs"},
"P95_LOGDIR",
"/tmp/logs",
"logdir",
"/custom/logs",
),
(
{"base_url": "http://configured.example.com"},
"P95_URL",
"http://env.example.com",
"base_url",
"http://configured.example.com",
),
(
{"api_key": "configured-key"},
"P95_API_KEY",
"env-key",
"api_key",
"configured-key",
),
],
)
def test_configure_field_takes_priority_over_matching_env(
monkeypatch, configure_kwargs, env_var, env_value, expected_attr, expected_value
):
"""Fields explicitly set via configure() should not be overridden by env vars."""
monkeypatch.setenv(env_var, env_value)
monkeypatch.delenv("P95_LOGDIR", raising=False)
monkeypatch.delenv("P95_URL", raising=False)
monkeypatch.delenv("P95_API_KEY", raising=False)
monkeypatch.setenv(env_var, env_value)

config_mod = reload_config()
config_mod.configure(**configure_kwargs)
config = config_mod.get_config()

assert getattr(config, expected_attr) == expected_value


def test_unset_fields_still_populate_from_env_when_some_fields_configured(monkeypatch):
"""
Explicitly setting one field should not block env vars for other unset fields.
"""
monkeypatch.setenv("P95_URL", "http://env.example.com")
monkeypatch.setenv("P95_API_KEY", "env-key")
monkeypatch.delenv("P95_LOGDIR", raising=False)

config_mod = reload_config()
config_mod.configure(base_url="http://configured.example.com")
config = config_mod.get_config()

assert config.base_url == "http://configured.example.com"
assert config.api_key == "env-key"
assert config.mode == "remote"
12 changes: 11 additions & 1 deletion sdk/python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading