Skip to content

Commit b769276

Browse files
committed
speed up import
- split some of the import to be lazy - change logging verbosity
1 parent 5ed139a commit b769276

File tree

4 files changed

+58
-65
lines changed

4 files changed

+58
-65
lines changed

eval_protocol/__init__.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,17 @@
88
tool-augmented models using self-contained task bundles.
99
"""
1010

11+
from importlib import import_module
12+
from typing import Any
1113
import warnings
1214

13-
from .adapters.braintrust import reward_fn_to_scorer, scorer_to_reward_fn
15+
# Lightweight imports (no heavy optional dependencies)
16+
from .integrations.braintrust import reward_fn_to_scorer, scorer_to_reward_fn
1417
from .auth import get_fireworks_account_id, get_fireworks_api_key
1518
from .common_utils import load_jsonl
1619
from .config import RewardKitConfig, get_config, load_config
17-
from .mcp_env import (
18-
AnthropicPolicy,
19-
OpenAIPolicy,
20-
LiteLLMPolicy,
21-
FireworksPolicy,
22-
make,
23-
rollout,
24-
test_mcp,
25-
)
2620

27-
# Try to import FireworksPolicy if available
28-
try:
29-
from .mcp_env import FireworksPolicy
30-
31-
_FIREWORKS_AVAILABLE = True
32-
except (ImportError, AttributeError):
33-
_FIREWORKS_AVAILABLE = False
3421
# Import submodules to make them available via eval_protocol.rewards, etc.
35-
from . import mcp, rewards
3622
from .models import EvaluateResult, Message, MetricResult
3723
from .playback_policy import PlaybackPolicyBase
3824
from .resources import create_llm_resource
@@ -41,6 +27,7 @@
4127

4228
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
4329

30+
# Public API (static exports only; dynamic MCP symbols are provided via __getattr__)
4431
__all__ = [
4532
# Core interfaces
4633
"Message",
@@ -59,14 +46,6 @@
5946
"RewardKitConfig",
6047
# Utilities
6148
"load_jsonl",
62-
# MCP Environment API
63-
"make",
64-
"rollout",
65-
"LiteLLMPolicy",
66-
"AnthropicPolicy",
67-
"FireworksPolicy",
68-
"OpenAIPolicy",
69-
"test_mcp",
7049
# Playback functionality
7150
"PlaybackPolicyBase",
7251
# Resource management
@@ -76,6 +55,30 @@
7655
"mcp",
7756
]
7857

58+
59+
def __getattr__(name: str) -> Any:
60+
"""Lazily import heavy MCP environment symbols to speed up package import.
61+
62+
This defers importing modules that depend on optional or heavy dependencies
63+
(e.g., vendored tau2, OpenAI clients) until they are actually used.
64+
"""
65+
if name in {
66+
"make",
67+
"rollout",
68+
"LiteLLMPolicy",
69+
"AnthropicPolicy",
70+
"FireworksPolicy",
71+
"OpenAIPolicy",
72+
"test_mcp",
73+
}:
74+
m = import_module(".mcp_env", __name__)
75+
return getattr(m, name)
76+
if name in {"mcp", "rewards"}:
77+
# Lazy-load subpackages for attribute access like eval_protocol.mcp
78+
return import_module(f".{name}", __name__)
79+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
80+
81+
7982
from . import _version
8083

8184
__version__ = _version.get_versions()["version"]

eval_protocol/adapters/__init__.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,24 @@
1010
- TRL integration (legacy)
1111
"""
1212

13-
# Conditional imports based on available dependencies
14-
try:
15-
from .langfuse import LangfuseAdapter, create_langfuse_adapter
16-
__all__ = ["LangfuseAdapter", "create_langfuse_adapter"]
17-
except ImportError:
18-
__all__ = []
13+
from importlib import import_module
14+
from typing import Any
1915

20-
try:
21-
from .huggingface import (
22-
HuggingFaceAdapter,
23-
create_huggingface_adapter,
24-
create_gsm8k_adapter,
25-
create_math_adapter,
26-
)
27-
__all__.extend([
28-
"HuggingFaceAdapter",
29-
"create_huggingface_adapter",
30-
"create_gsm8k_adapter",
31-
"create_math_adapter",
32-
])
33-
except ImportError:
34-
pass
16+
__all__ = []
3517

36-
# Legacy adapters (always available)
37-
try:
38-
from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn
39-
__all__.extend(["scorer_to_reward_fn", "reward_fn_to_scorer"])
40-
except ImportError:
41-
pass
4218

43-
try:
44-
from .trl import create_trl_adapter
45-
__all__.extend(["create_trl_adapter"])
46-
except ImportError:
47-
pass
19+
def __getattr__(name: str) -> Any:
20+
# Lazy import optional adapters to avoid import-time side effects and heavy deps
21+
if name in {"LangfuseAdapter", "create_langfuse_adapter"}:
22+
m = import_module(".langfuse", __name__)
23+
return getattr(m, name)
24+
if name in {"HuggingFaceAdapter", "create_huggingface_adapter", "create_gsm8k_adapter", "create_math_adapter"}:
25+
m = import_module(".huggingface", __name__)
26+
return getattr(m, name)
27+
if name in {"reward_fn_to_scorer", "scorer_to_reward_fn"}:
28+
m = import_module(".braintrust", __name__)
29+
return getattr(m, name)
30+
if name in {"create_trl_adapter"}:
31+
m = import_module(".trl", __name__)
32+
return getattr(m, name)
33+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

vendor/tau2/utils/llm_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from litellm.caching.caching import Cache
88
from litellm.main import ModelResponse, Usage
99
from loguru import logger
10+
import os
1011

1112
from vendor.tau2.config import (
1213
DEFAULT_LLM_CACHE_TYPE,
@@ -70,9 +71,8 @@
7071

7172

7273
ALLOW_SONNET_THINKING = False
73-
74-
if not ALLOW_SONNET_THINKING:
75-
logger.warning("Sonnet thinking is disabled")
74+
if os.getenv("TAU2_VERBOSE") == "1" and not ALLOW_SONNET_THINKING:
75+
logger.info("Sonnet thinking is disabled")
7676

7777

7878
def _parse_ft_model_name(model: str) -> str:

vendor/tau2/utils/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from dotenv import load_dotenv
1010
from loguru import logger
1111

12+
_TAU2_VERBOSE = os.getenv("TAU2_VERBOSE") == "1"
13+
1214
res = load_dotenv()
13-
if not res:
15+
if not res and _TAU2_VERBOSE:
1416
logger.warning("No .env file found")
1517

1618
# Try to get data directory from environment variable first
@@ -19,15 +21,17 @@
1921
if DATA_DIR_ENV:
2022
# Use environment variable if set
2123
DATA_DIR = Path(DATA_DIR_ENV)
22-
logger.info(f"Using data directory from environment: {DATA_DIR}")
24+
if _TAU2_VERBOSE:
25+
logger.info(f"Using data directory from environment: {DATA_DIR}")
2326
else:
2427
# Fallback to vendored tau2 directory
2528
SOURCE_DIR = Path(__file__).parents[1] # vendor/tau2/
2629
DATA_DIR = SOURCE_DIR / "data"
27-
logger.info(f"Using data directory from vendored tau2: {DATA_DIR}")
30+
if _TAU2_VERBOSE:
31+
logger.info(f"Using data directory from vendored tau2: {DATA_DIR}")
2832

2933
# Check if data directory exists and is accessible
30-
if not DATA_DIR.exists():
34+
if not DATA_DIR.exists() and _TAU2_VERBOSE:
3135
logger.warning(f"Data directory does not exist: {DATA_DIR}")
3236
logger.warning(
3337
"Set TAU2_DATA_DIR environment variable to point to your data directory"

0 commit comments

Comments
 (0)