Skip to content

Commit 6efbf13

Browse files
committed
lazy loading of all modules in init
1 parent 35db8e2 commit 6efbf13

1 file changed

Lines changed: 213 additions & 94 deletions

File tree

eval_protocol/__init__.py

Lines changed: 213 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,102 +8,142 @@
88
tool-augmented models using self-contained task bundles.
99
"""
1010

11+
import importlib
1112
import warnings
13+
from typing import TYPE_CHECKING
1214

13-
from .auth import get_fireworks_account_id, get_fireworks_api_key
14-
from .common_utils import load_jsonl
15-
from .config import RewardKitConfig, get_config, load_config
16-
from .mcp_env import (
17-
AnthropicPolicy,
18-
FireworksPolicy,
19-
LiteLLMPolicy,
20-
OpenAIPolicy,
21-
make,
22-
rollout,
23-
test_mcp,
24-
)
25-
from .data_loader import DynamicDataLoader, InlineDataLoader
26-
from . import mcp, rewards
27-
from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata, Status
28-
from .playback_policy import PlaybackPolicyBase
29-
from .resources import create_llm_resource
30-
from .reward_function import RewardFunction
31-
from .typed_interface import reward_function
32-
from .quickstart.aha_judge import aha_judge
33-
from .utils.evaluation_row_utils import (
34-
multi_turn_assistant_to_ground_truth,
35-
assistant_to_ground_truth,
36-
filter_longest_conversation,
37-
)
38-
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
39-
from .pytest.parameterize import DefaultParameterIdGenerator
40-
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
41-
from .log_utils.rollout_id_filter import RolloutIdFilter
42-
from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler
43-
from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler
44-
from .log_utils.elasticsearch_client import ElasticsearchConfig
45-
46-
47-
from .types.remote_rollout_processor import (
48-
InitRequest,
49-
RolloutMetadata,
50-
StatusResponse,
51-
create_langfuse_config_tags,
52-
DataLoaderConfig,
53-
)
54-
55-
try:
56-
from .adapters import OpenAIResponsesAdapter
57-
except ImportError:
58-
OpenAIResponsesAdapter = None
59-
60-
try:
61-
from .adapters import LangfuseAdapter, create_langfuse_adapter
62-
except ImportError:
63-
LangfuseAdapter = None
64-
65-
try:
66-
from .adapters import BraintrustAdapter, create_braintrust_adapter
67-
except ImportError:
68-
BraintrustAdapter = None
69-
70-
try:
71-
from .adapters import LangSmithAdapter
72-
except ImportError:
73-
LangSmithAdapter = None
74-
75-
76-
try:
77-
from .adapters import WeaveAdapter
78-
except ImportError:
79-
WeaveAdapter = None
80-
81-
try:
82-
from .proxy import create_app, AuthProvider, AccountInfo # pyright: ignore[reportAssignmentType]
83-
except ImportError:
84-
85-
def create_app(*args, **kwargs):
86-
raise ImportError(
87-
"Proxy functionality requires additional dependencies. "
88-
"Please install with: pip install eval-protocol[proxy]"
89-
)
90-
91-
class AuthProvider:
92-
def __init__(self, *args, **kwargs):
93-
raise ImportError(
94-
"Proxy functionality requires additional dependencies. "
95-
"Please install with: pip install eval-protocol[proxy]"
96-
)
97-
98-
class AccountInfo:
99-
def __init__(self, *args, **kwargs):
100-
raise ImportError(
101-
"Proxy functionality requires additional dependencies. "
102-
"Please install with: pip install eval-protocol[proxy]"
103-
)
15+
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
10416

17+
# Lazy import mappings: name -> (module_path, attribute_name or None for module import)
18+
_LAZY_IMPORTS = {
19+
# From .auth
20+
"get_fireworks_account_id": (".auth", "get_fireworks_account_id"),
21+
"get_fireworks_api_key": (".auth", "get_fireworks_api_key"),
22+
# From .common_utils
23+
"load_jsonl": (".common_utils", "load_jsonl"),
24+
# From .config
25+
"RewardKitConfig": (".config", "RewardKitConfig"),
26+
"get_config": (".config", "get_config"),
27+
"load_config": (".config", "load_config"),
28+
# From .mcp_env
29+
"AnthropicPolicy": (".mcp_env", "AnthropicPolicy"),
30+
"FireworksPolicy": (".mcp_env", "FireworksPolicy"),
31+
"LiteLLMPolicy": (".mcp_env", "LiteLLMPolicy"),
32+
"OpenAIPolicy": (".mcp_env", "OpenAIPolicy"),
33+
"make": (".mcp_env", "make"),
34+
"rollout": (".mcp_env", "rollout"),
35+
"test_mcp": (".mcp_env", "test_mcp"),
36+
# From .data_loader
37+
"DynamicDataLoader": (".data_loader", "DynamicDataLoader"),
38+
"InlineDataLoader": (".data_loader", "InlineDataLoader"),
39+
# Submodules
40+
"mcp": (".mcp", None),
41+
"rewards": (".rewards", None),
42+
# From .models
43+
"EvaluateResult": (".models", "EvaluateResult"),
44+
"Message": (".models", "Message"),
45+
"MetricResult": (".models", "MetricResult"),
46+
"EvaluationRow": (".models", "EvaluationRow"),
47+
"InputMetadata": (".models", "InputMetadata"),
48+
"Status": (".models", "Status"),
49+
# From .playback_policy
50+
"PlaybackPolicyBase": (".playback_policy", "PlaybackPolicyBase"),
51+
# From .resources
52+
"create_llm_resource": (".resources", "create_llm_resource"),
53+
# From .reward_function
54+
"RewardFunction": (".reward_function", "RewardFunction"),
55+
# From .typed_interface
56+
"reward_function": (".typed_interface", "reward_function"),
57+
# From .quickstart.aha_judge
58+
"aha_judge": (".quickstart.aha_judge", "aha_judge"),
59+
# From .utils.evaluation_row_utils
60+
"multi_turn_assistant_to_ground_truth": (".utils.evaluation_row_utils", "multi_turn_assistant_to_ground_truth"),
61+
"assistant_to_ground_truth": (".utils.evaluation_row_utils", "assistant_to_ground_truth"),
62+
"filter_longest_conversation": (".utils.evaluation_row_utils", "filter_longest_conversation"),
63+
# From .pytest
64+
"evaluation_test": (".pytest", "evaluation_test"),
65+
"SingleTurnRolloutProcessor": (".pytest", "SingleTurnRolloutProcessor"),
66+
"RemoteRolloutProcessor": (".pytest", "RemoteRolloutProcessor"),
67+
"GithubActionRolloutProcessor": (".pytest", "GithubActionRolloutProcessor"),
68+
# From .pytest.parameterize
69+
"DefaultParameterIdGenerator": (".pytest.parameterize", "DefaultParameterIdGenerator"),
70+
# From .log_utils
71+
"ElasticsearchDirectHttpHandler": (
72+
".log_utils.elasticsearch_direct_http_handler",
73+
"ElasticsearchDirectHttpHandler",
74+
),
75+
"RolloutIdFilter": (".log_utils.rollout_id_filter", "RolloutIdFilter"),
76+
"setup_rollout_logging_for_elasticsearch_handler": (
77+
".log_utils.util",
78+
"setup_rollout_logging_for_elasticsearch_handler",
79+
),
80+
"FireworksTracingHttpHandler": (".log_utils.fireworks_tracing_http_handler", "FireworksTracingHttpHandler"),
81+
"ElasticsearchConfig": (".log_utils.elasticsearch_client", "ElasticsearchConfig"),
82+
# From .types.remote_rollout_processor
83+
"InitRequest": (".types.remote_rollout_processor", "InitRequest"),
84+
"RolloutMetadata": (".types.remote_rollout_processor", "RolloutMetadata"),
85+
"StatusResponse": (".types.remote_rollout_processor", "StatusResponse"),
86+
"create_langfuse_config_tags": (".types.remote_rollout_processor", "create_langfuse_config_tags"),
87+
"DataLoaderConfig": (".types.remote_rollout_processor", "DataLoaderConfig"),
88+
}
89+
90+
# Optional imports that may not be available
91+
_OPTIONAL_IMPORTS = {
92+
"OpenAIResponsesAdapter": (".adapters", "OpenAIResponsesAdapter"),
93+
"LangfuseAdapter": (".adapters", "LangfuseAdapter"),
94+
"create_langfuse_adapter": (".adapters", "create_langfuse_adapter"),
95+
"BraintrustAdapter": (".adapters", "BraintrustAdapter"),
96+
"create_braintrust_adapter": (".adapters", "create_braintrust_adapter"),
97+
"LangSmithAdapter": (".adapters", "LangSmithAdapter"),
98+
"WeaveAdapter": (".adapters", "WeaveAdapter"),
99+
"create_app": (".proxy", "create_app"),
100+
"AuthProvider": (".proxy", "AuthProvider"),
101+
"AccountInfo": (".proxy", "AccountInfo"),
102+
}
103+
104+
105+
def __getattr__(name: str):
106+
"""Lazy import handler for module-level attributes."""
107+
# Check regular lazy imports
108+
if name in _LAZY_IMPORTS:
109+
module_path, attr_name = _LAZY_IMPORTS[name]
110+
module = importlib.import_module(module_path, package="eval_protocol")
111+
if attr_name is None:
112+
# It's a submodule import
113+
return module
114+
return getattr(module, attr_name)
115+
116+
# Check optional imports
117+
if name in _OPTIONAL_IMPORTS:
118+
module_path, attr_name = _OPTIONAL_IMPORTS[name]
119+
try:
120+
module = importlib.import_module(module_path, package="eval_protocol")
121+
return getattr(module, attr_name)
122+
except ImportError:
123+
# Return None or a placeholder for optional imports
124+
if name in ("create_app",):
125+
126+
def create_app(*args, **kwargs):
127+
raise ImportError(
128+
"Proxy functionality requires additional dependencies. "
129+
"Please install with: pip install eval-protocol[proxy]"
130+
)
131+
132+
return create_app
133+
elif name in ("AuthProvider", "AccountInfo"):
134+
135+
class _Placeholder:
136+
def __init__(self, *args, **kwargs):
137+
raise ImportError(
138+
"Proxy functionality requires additional dependencies. "
139+
"Please install with: pip install eval-protocol[proxy]"
140+
)
141+
142+
return _Placeholder
143+
return None
144+
145+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
105146

106-
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
107147

108148
__all__ = [
109149
"ElasticsearchConfig",
@@ -173,6 +213,85 @@ def __init__(self, *args, **kwargs):
173213
"AccountInfo",
174214
]
175215

176-
from . import _version
216+
# Version is loaded lazily too
217+
_version_info = None
218+
219+
220+
def _get_version():
221+
global _version_info
222+
if _version_info is None:
223+
from . import _version
224+
225+
_version_info = _version.get_versions()["version"]
226+
return _version_info
227+
228+
229+
# For TYPE_CHECKING, we provide type hints so IDEs can see the exports
230+
if TYPE_CHECKING:
231+
from .auth import get_fireworks_account_id, get_fireworks_api_key
232+
from .common_utils import load_jsonl
233+
from .config import RewardKitConfig, get_config, load_config
234+
from .mcp_env import (
235+
AnthropicPolicy,
236+
FireworksPolicy,
237+
LiteLLMPolicy,
238+
OpenAIPolicy,
239+
make,
240+
rollout,
241+
test_mcp,
242+
)
243+
from .data_loader import DynamicDataLoader, InlineDataLoader
244+
from . import mcp, rewards
245+
from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata, Status
246+
from .playback_policy import PlaybackPolicyBase
247+
from .resources import create_llm_resource
248+
from .reward_function import RewardFunction
249+
from .typed_interface import reward_function
250+
from .quickstart.aha_judge import aha_judge
251+
from .utils.evaluation_row_utils import (
252+
multi_turn_assistant_to_ground_truth,
253+
assistant_to_ground_truth,
254+
filter_longest_conversation,
255+
)
256+
from .pytest import (
257+
evaluation_test,
258+
SingleTurnRolloutProcessor,
259+
RemoteRolloutProcessor,
260+
GithubActionRolloutProcessor,
261+
)
262+
from .pytest.parameterize import DefaultParameterIdGenerator
263+
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
264+
from .log_utils.rollout_id_filter import RolloutIdFilter
265+
from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler
266+
from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler
267+
from .log_utils.elasticsearch_client import ElasticsearchConfig
268+
from .types.remote_rollout_processor import (
269+
InitRequest,
270+
RolloutMetadata,
271+
StatusResponse,
272+
create_langfuse_config_tags,
273+
DataLoaderConfig,
274+
)
275+
from .adapters import (
276+
OpenAIResponsesAdapter,
277+
LangfuseAdapter,
278+
create_langfuse_adapter,
279+
BraintrustAdapter,
280+
create_braintrust_adapter,
281+
LangSmithAdapter,
282+
WeaveAdapter,
283+
)
284+
from .proxy import create_app, AuthProvider, AccountInfo
285+
286+
287+
# __version__ property - computed lazily
288+
class _VersionModule:
289+
@property
290+
def __version__(self):
291+
return _get_version()
292+
293+
294+
import sys
177295

178-
__version__ = _version.get_versions()["version"]
296+
_this_module = sys.modules[__name__]
297+
_this_module.__class__ = type("module", (type(_this_module),), {"__version__": property(lambda self: _get_version())})

0 commit comments

Comments
 (0)