Skip to content

Commit 278b663

Browse files
committed
Add data loader abstraction to evaluation_test
1 parent 64abf2d commit 278b663

File tree

7 files changed

+421
-32
lines changed

7 files changed

+421
-32
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(self):
219219
if not LANGFUSE_AVAILABLE:
220220
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")
221221

222-
self.client = get_client()
222+
self.client = get_client() # pyright: ignore[reportCallIssue]
223223

224224
def get_evaluation_rows(
225225
self,

eval_protocol/adapters/langsmith.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ class LangSmithAdapter(BaseAdapter):
3535
- outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict]
3636
"""
3737

38-
def __init__(self, client: Optional[Client] = None) -> None:
38+
def __init__(self, client: Optional[Any] = None) -> None:
3939
if not LANGSMITH_AVAILABLE:
4040
raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'")
41-
self.client = client or Client()
41+
self.client = client or Client() # pyright: ignore[reportCallIssue]
4242

4343
def get_evaluation_rows(
4444
self,

eval_protocol/pytest/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
10+
from .data_loaders import (
11+
EvaluationDataLoader,
12+
InlineDataLoader,
13+
LangfuseAdapterLoader,
14+
LangfuseLoaderConfig,
15+
)
1016

1117
# Conditional import for optional dependencies
1218
try:
@@ -38,6 +44,10 @@
3844
"ExceptionHandlerConfig",
3945
"BackoffConfig",
4046
"get_default_exception_handler_config",
47+
"EvaluationDataLoader",
48+
"InlineDataLoader",
49+
"LangfuseAdapterLoader",
50+
"LangfuseLoaderConfig",
4151
]
4252

4353
# Only add to __all__ if available
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Data loader abstractions for evaluation tests."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Any, Callable, Protocol, Sequence
7+
8+
from eval_protocol.adapters.base import BaseAdapter
9+
from eval_protocol.models import EvaluationRow, Message
10+
from eval_protocol.pytest.types import EvaluationTestMode, InputMessagesParam
11+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
12+
13+
14+
@dataclass(slots=True)
15+
class DataLoaderContext:
16+
"""Context provided to loader variants when materializing data."""
17+
18+
max_rows: int | None
19+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None
20+
logger: DatasetLogger
21+
invocation_id: str
22+
experiment_id: str
23+
mode: EvaluationTestMode
24+
25+
26+
@dataclass(slots=True)
27+
class DataLoaderResult:
28+
"""Rows and metadata returned by a loader variant."""
29+
30+
rows: list[EvaluationRow]
31+
source_id: str
32+
source_metadata: dict[str, Any] = field(default_factory=dict)
33+
raw_payload: Any | None = None
34+
preprocessed: bool = False
35+
36+
37+
@dataclass(slots=True)
38+
class DataLoaderVariant:
39+
"""Single parameterizable variant from a data loader."""
40+
41+
id: str
42+
description: str
43+
loader: Callable[[DataLoaderContext], DataLoaderResult]
44+
metadata: dict[str, Any] = field(default_factory=dict)
45+
46+
def load(self, ctx: DataLoaderContext) -> DataLoaderResult:
47+
"""Load a dataset for this variant using the provided context."""
48+
49+
return self.loader(ctx)
50+
51+
52+
class EvaluationDataLoader(Protocol):
53+
"""Protocol for data loaders that can be consumed by ``evaluation_test``."""
54+
55+
def variants(self) -> Sequence[DataLoaderVariant]:
56+
"""Return parameterizable variants emitted by this loader."""
57+
58+
...
59+
60+
61+
@dataclass(slots=True)
62+
class InlineDataLoader(EvaluationDataLoader):
63+
"""Data loader for inline ``EvaluationRow`` or message payloads."""
64+
65+
rows: Sequence[EvaluationRow] | None = None
66+
messages: Sequence[InputMessagesParam] | None = None
67+
variant_id: str = "inline"
68+
description: str | None = None
69+
70+
def __post_init__(self) -> None:
71+
if self.rows is None and self.messages is None:
72+
raise ValueError("InlineDataLoader requires rows or messages to be provided")
73+
74+
def variants(self) -> Sequence[DataLoaderVariant]:
75+
def _load(ctx: DataLoaderContext) -> DataLoaderResult:
76+
resolved_rows: list[EvaluationRow] = []
77+
if self.rows is not None:
78+
resolved_rows.extend(row.model_copy(deep=True) for row in self.rows)
79+
if self.messages is not None:
80+
for dataset_messages in self.messages:
81+
row_messages: list[Message] = []
82+
for msg in dataset_messages:
83+
if isinstance(msg, Message):
84+
row_messages.append(msg.model_copy(deep=True))
85+
else:
86+
row_messages.append(Message.model_validate(msg))
87+
resolved_rows.append(EvaluationRow(messages=row_messages))
88+
89+
if ctx.max_rows is not None:
90+
resolved_rows = resolved_rows[: ctx.max_rows]
91+
92+
metadata = {
93+
"data_loader_variant_id": self.variant_id,
94+
"data_loader_type": "inline",
95+
"row_count": len(resolved_rows),
96+
}
97+
98+
return DataLoaderResult(
99+
rows=resolved_rows,
100+
source_id=self.variant_id,
101+
source_metadata=metadata,
102+
)
103+
104+
description = self.description or self.variant_id
105+
return [
106+
DataLoaderVariant(
107+
id=self.variant_id,
108+
description=description,
109+
loader=_load,
110+
metadata={"type": "inline"},
111+
)
112+
]
113+
114+
115+
@dataclass(slots=True)
116+
class LangfuseLoaderConfig:
117+
"""Configuration for a single Langfuse adapter variant."""
118+
119+
id: str
120+
kwargs: dict[str, Any] = field(default_factory=dict)
121+
description: str | None = None
122+
123+
124+
@dataclass(slots=True)
125+
class LangfuseAdapterLoader(EvaluationDataLoader):
126+
"""Wrap a ``LangfuseAdapter`` (or compatible adapter) as a data loader."""
127+
128+
adapter: BaseAdapter
129+
variants_config: Sequence[LangfuseLoaderConfig]
130+
131+
def variants(self) -> Sequence[DataLoaderVariant]:
132+
loader_variants: list[DataLoaderVariant] = []
133+
134+
for config in self.variants_config:
135+
136+
def _load(ctx: DataLoaderContext, *, _config: LangfuseLoaderConfig = config) -> DataLoaderResult:
137+
rows = self.adapter.get_evaluation_rows(**_config.kwargs)
138+
if ctx.max_rows is not None:
139+
rows = rows[: ctx.max_rows]
140+
141+
metadata = {
142+
"data_loader_variant_id": _config.id,
143+
"data_loader_type": "langfuse",
144+
"adapter_kwargs": _config.kwargs,
145+
}
146+
147+
return DataLoaderResult(
148+
rows=[row.model_copy(deep=True) for row in rows],
149+
source_id=_config.id,
150+
source_metadata=metadata,
151+
)
152+
153+
loader_variants.append(
154+
DataLoaderVariant(
155+
id=config.id,
156+
description=config.description or config.id,
157+
loader=_load,
158+
metadata={"type": "langfuse", "adapter_kwargs": config.kwargs},
159+
)
160+
)
161+
162+
return loader_variants
163+
164+
165+
__all__ = [
166+
"DataLoaderContext",
167+
"DataLoaderResult",
168+
"DataLoaderVariant",
169+
"EvaluationDataLoader",
170+
"InlineDataLoader",
171+
"LangfuseAdapterLoader",
172+
"LangfuseLoaderConfig",
173+
]

0 commit comments

Comments
 (0)