Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 89 additions & 69 deletions verifiers/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,106 @@
from __future__ import annotations

import importlib
import inspect
import logging
from importlib.metadata import entry_points
from typing import Callable

from verifiers.envs.environment import Environment

LOGGER = logging.getLogger("verifiers.utils.env_utils")


def _call_loader(
func: Callable[..., Environment], env_id: str, **env_args
) -> Environment:
sig = inspect.signature(func)

if env_args:
LOGGER.info(
"Using provided args: "
+ ", ".join(f"{k}={v!r}" for k, v in env_args.items())
)

defaults = []
for name, p in sig.parameters.items():
if name not in env_args and p.default is not inspect._empty:
defaults.append(f"{name}={p.default!r}")
if defaults:
LOGGER.info("Using default args: " + ", ".join(defaults))

env = func(**env_args)
LOGGER.info(f"Successfully loaded environment '{env_id}'")
return env


def _load_from_target_spec(target: str, env_id: str, **env_args) -> Environment:
mod, sep, attr = target.partition(":")
if not sep or not attr:
raise AttributeError(f"Invalid target spec '{target}'. Expected 'module:attr'.")
module = importlib.import_module(mod)
func = getattr(module, attr)
if not callable(func):
raise TypeError(f"Target '{target}' is not callable")
return _call_loader(func, env_id, **env_args)


def _load_via_entry_point_exact(env_id: str, **env_args) -> Environment | None:
"""Exact match on the 'verifiers' entry point name. No aliasing or splitting."""
eps = entry_points(group="verifiers")
matches = [ep for ep in eps if ep.name == env_id]
if not matches:
return None
if len(matches) > 1:
details = ", ".join(ep.value for ep in matches)
raise RuntimeError(
f"Multiple 'verifiers' entry points named '{env_id}' found: {details}"
)
func = matches[0].load()
if not callable(func):
raise TypeError(
f"Entry point '{env_id}' did not load a callable; got {type(func)!r}"
)
return _call_loader(func, env_id, **env_args)


def load_environment(env_id: str, **env_args) -> Environment:
logger = logging.getLogger("verifiers.utils.env_utils")
logger.info(f"Loading environment: {env_id}")
LOGGER.info(f"Loading environment: {env_id}")

module_name = env_id.replace("-", "_")
try:
module = importlib.import_module(module_name)
# 1) Explicit module target: "pkg.mod:callable"
if ":" in env_id:
try:
return _load_from_target_spec(env_id, env_id, **env_args)
except Exception as e:
LOGGER.error(f"Failed to load environment {env_id} via target spec: {e}")
raise RuntimeError(f"Failed to load environment '{env_id}': {e}") from e

if not hasattr(module, "load_environment"):
raise AttributeError(
f"Module '{module_name}' does not have a 'load_environment' function. "
f"This usually means there's a package name collision. Please either:\n"
f"1. Rename your environment (e.g. suffix with '-env')\n"
f"2. Remove unneeded files with the same name\n"
f"3. Check that you've installed the correct environment package"
)

env_load_func: Callable[..., Environment] = getattr(module, "load_environment")
sig = inspect.signature(env_load_func)
defaults_info = []
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
if isinstance(param.default, (dict, list)):
defaults_info.append(f"{param_name}={param.default}")
elif isinstance(param.default, str):
defaults_info.append(f"{param_name}='{param.default}'")
else:
defaults_info.append(f"{param_name}={param.default}")
else:
defaults_info.append(f"{param_name}=<required>")

if defaults_info:
logger.debug(f"Environment defaults: {', '.join(defaults_info)}")

if env_args:
provided_params = set(env_args.keys())
else:
provided_params = set()

all_params = set(sig.parameters.keys())
default_params = all_params - provided_params

if provided_params:
provided_values = []
for param_name in provided_params:
provided_values.append(f"{param_name}={env_args[param_name]}")
logger.info(f"Using provided args: {', '.join(provided_values)}")

if default_params:
default_values = []
for param_name in default_params:
param = sig.parameters[param_name]
if param.default != inspect.Parameter.empty:
if isinstance(param.default, str):
default_values.append(f"{param_name}='{param.default}'")
else:
default_values.append(f"{param_name}={param.default}")
if default_values:
logger.info(f"Using default args: {', '.join(default_values)}")

env_instance: Environment = env_load_func(**env_args)

logger.info(f"Successfully loaded environment '{env_id}'")

return env_instance
# 2) Prefer entry points (exact match only)
try:
ep_env = _load_via_entry_point_exact(env_id, **env_args)
if ep_env is not None:
return ep_env
except Exception as e:
LOGGER.error(f"Failed to load environment {env_id} via entry point: {e}")
raise RuntimeError(f"Failed to load environment '{env_id}': {e}") from e

# 3) Back-compat fallback: import by module name (slug or namespaced ID's tail)
module_name = env_id.split("/")[-1].replace("-", "_")
try:
module = importlib.import_module(module_name)
except ImportError as e:
logger.error(
f"Failed to import environment module {module_name} for env_id {env_id}: {str(e)}"
LOGGER.error(
f"Failed to import environment module {module_name} for env_id {env_id}: {e}"
)
raise ValueError(
f"Could not import '{env_id}' environment. Ensure the package for the '{env_id}' environment is installed."
f"Could not import '{env_id}'. Install a package that exposes a matching "
f"[project.entry-points.verifiers] = \"{env_id}\" entry or provide 'module:attr'."
) from e
except Exception as e:
logger.error(
f"Failed to load environment {env_id} with args {env_args}: {str(e)}"

if not hasattr(module, "load_environment"):
raise AttributeError(
f"Module '{module_name}' has no 'load_environment'. "
f"Prefer registering an entry point named '{env_id}' under the 'verifiers' group."
)
raise RuntimeError(f"Failed to load environment '{env_id}': {str(e)}") from e

return _call_loader(getattr(module, "load_environment"), env_id, **env_args)