Skip to content

Commit 11330f5

Browse files
authored
BigQuery Adapter (#86)
* BigQuery * removing unneeded
1 parent 44654a5 commit 11330f5

File tree

6 files changed

+1069
-171
lines changed

6 files changed

+1069
-171
lines changed

eval_protocol/adapters/__init__.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,64 @@
66
Available adapters:
77
- LangfuseAdapter: Pull data from Langfuse deployments
88
- HuggingFaceAdapter: Load datasets from HuggingFace Hub
9+
- BigQueryAdapter: Query data from Google BigQuery
910
- Braintrust integration (legacy)
1011
- TRL integration (legacy)
1112
"""
1213

1314
# Conditional imports based on available dependencies
1415
try:
1516
from .langfuse import LangfuseAdapter, create_langfuse_adapter
17+
1618
__all__ = ["LangfuseAdapter", "create_langfuse_adapter"]
1719
except ImportError:
1820
__all__ = []
1921

2022
try:
2123
from .huggingface import (
22-
HuggingFaceAdapter,
23-
create_huggingface_adapter,
24+
HuggingFaceAdapter,
2425
create_gsm8k_adapter,
26+
create_huggingface_adapter,
2527
create_math_adapter,
2628
)
27-
__all__.extend([
28-
"HuggingFaceAdapter",
29-
"create_huggingface_adapter",
30-
"create_gsm8k_adapter",
31-
"create_math_adapter",
32-
])
29+
30+
__all__.extend(
31+
[
32+
"HuggingFaceAdapter",
33+
"create_huggingface_adapter",
34+
"create_gsm8k_adapter",
35+
"create_math_adapter",
36+
]
37+
)
38+
except ImportError:
39+
pass
40+
41+
try:
42+
from .bigquery import (
43+
BigQueryAdapter,
44+
create_bigquery_adapter,
45+
)
46+
47+
__all__.extend(
48+
[
49+
"BigQueryAdapter",
50+
"create_bigquery_adapter",
51+
]
52+
)
3353
except ImportError:
3454
pass
3555

3656
# Legacy adapters (always available)
3757
try:
3858
from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn
59+
3960
__all__.extend(["scorer_to_reward_fn", "reward_fn_to_scorer"])
4061
except ImportError:
4162
pass
4263

4364
try:
4465
from .trl import create_trl_adapter
66+
4567
__all__.extend(["create_trl_adapter"])
4668
except ImportError:
4769
pass

eval_protocol/adapters/bigquery.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""Google BigQuery adapter for Eval Protocol.
2+
3+
This adapter allows querying data from Google BigQuery tables and converting it
4+
to EvaluationRow format for use in evaluation pipelines.
5+
"""
6+
7+
import logging
8+
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
9+
10+
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
11+
12+
logger = logging.getLogger(__name__)
13+
14+
try:
15+
from google.auth.exceptions import DefaultCredentialsError
16+
from google.cloud import bigquery
17+
from google.cloud.exceptions import Forbidden, NotFound
18+
from google.oauth2 import service_account
19+
20+
BIGQUERY_AVAILABLE = True
21+
except ImportError:
22+
BIGQUERY_AVAILABLE = False
23+
logger.warning("Google Cloud BigQuery not installed. Install with: pip install 'eval-protocol[bigquery]'")
24+
25+
# Type alias for transformation function
26+
TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]]
27+
28+
29+
class BigQueryAdapter:
30+
"""Adapter to query data from Google BigQuery and convert to EvaluationRow format.
31+
32+
This adapter connects to Google BigQuery, executes SQL queries, and applies
33+
a user-provided transformation function to convert each row to the format
34+
expected by EvaluationRow.
35+
36+
The transformation function should take a BigQuery row dictionary and return:
37+
{
38+
'messages': List[Dict] - list of message dictionaries with 'role' and 'content'
39+
'ground_truth': Optional[str] - expected answer/output
40+
'metadata': Optional[Dict] - any additional metadata to preserve
41+
'tools': Optional[List[Dict]] - tool definitions for tool calling scenarios
42+
}
43+
"""
44+
45+
def __init__(
46+
self,
47+
transform_fn: TransformFunction,
48+
dataset_id: Optional[str] = None,
49+
credentials_path: Optional[str] = None,
50+
location: Optional[str] = None,
51+
**client_kwargs,
52+
):
53+
"""Initialize the BigQuery adapter.
54+
55+
Args:
56+
transform_fn: Function to transform BigQuery rows to evaluation format
57+
dataset_id: Google Cloud project ID (if None, uses default from environment)
58+
credentials_path: Path to service account JSON file (if None, uses default auth)
59+
location: Default location for BigQuery jobs
60+
**client_kwargs: Additional arguments to pass to BigQuery client
61+
62+
Raises:
63+
ImportError: If google-cloud-bigquery is not installed
64+
DefaultCredentialsError: If authentication fails
65+
"""
66+
if not BIGQUERY_AVAILABLE:
67+
raise ImportError(
68+
"Google Cloud BigQuery not installed. Install with: pip install 'eval-protocol[bigquery]'"
69+
)
70+
71+
self.transform_fn = transform_fn
72+
self.dataset_id = dataset_id
73+
self.location = location
74+
75+
# Initialize BigQuery client
76+
try:
77+
client_args = {}
78+
if dataset_id:
79+
client_args["project"] = dataset_id
80+
if credentials_path:
81+
credentials = service_account.Credentials.from_service_account_file(credentials_path)
82+
client_args["credentials"] = credentials
83+
if location:
84+
client_args["location"] = location
85+
86+
client_args.update(client_kwargs)
87+
self.client = bigquery.Client(**client_args)
88+
89+
except DefaultCredentialsError as e:
90+
logger.error("Failed to authenticate with BigQuery: %s", e)
91+
raise
92+
except Exception as e:
93+
logger.error("Failed to initialize BigQuery client: %s", e)
94+
raise
95+
96+
def get_evaluation_rows(
97+
self,
98+
query: str,
99+
query_params: Optional[List[Union[bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter]]] = None,
100+
limit: Optional[int] = None,
101+
offset: int = 0,
102+
model_name: str = "gpt-3.5-turbo",
103+
temperature: float = 0.0,
104+
max_tokens: Optional[int] = None,
105+
**completion_params_kwargs,
106+
) -> Iterator[EvaluationRow]:
107+
"""Execute BigQuery query and convert results to EvaluationRow format.
108+
109+
Args:
110+
query: SQL query to execute
111+
query_params: Optional list of query parameters for parameterized queries
112+
limit: Maximum number of rows to return (applied after BigQuery query)
113+
offset: Number of rows to skip (applied after BigQuery query)
114+
model_name: Model name for completion parameters
115+
temperature: Temperature for completion parameters
116+
max_tokens: Max tokens for completion parameters
117+
**completion_params_kwargs: Additional completion parameters
118+
119+
Yields:
120+
EvaluationRow: Converted evaluation rows
121+
122+
Raises:
123+
NotFound: If the query references non-existent tables/datasets
124+
Forbidden: If insufficient permissions
125+
"""
126+
try:
127+
# Configure query job
128+
job_config = bigquery.QueryJobConfig()
129+
if query_params:
130+
job_config.query_parameters = query_params
131+
if self.location:
132+
job_config.location = self.location
133+
134+
query_job = self.client.query(query, job_config=job_config)
135+
136+
results = query_job.result()
137+
138+
completion_params: CompletionParams = {
139+
"model": model_name,
140+
"temperature": temperature,
141+
"max_tokens": max_tokens,
142+
**completion_params_kwargs,
143+
}
144+
145+
# Convert rows with offset/limit
146+
row_count = 0
147+
processed_count = 0
148+
149+
for raw_row in results:
150+
# Apply offset
151+
if row_count < offset:
152+
row_count += 1
153+
continue
154+
155+
# Apply limit
156+
if limit is not None and processed_count >= limit:
157+
break
158+
159+
try:
160+
eval_row = self._convert_row_to_evaluation_row(raw_row, processed_count, completion_params)
161+
if eval_row:
162+
yield eval_row
163+
processed_count += 1
164+
165+
except (AttributeError, ValueError, KeyError) as e:
166+
logger.warning("Failed to convert row %d: %s", row_count, e)
167+
168+
row_count += 1
169+
170+
except (NotFound, Forbidden) as e:
171+
logger.error("BigQuery access error: %s", e)
172+
raise
173+
except Exception as e:
174+
logger.error("Error executing BigQuery query: %s", e)
175+
raise
176+
177+
def _convert_row_to_evaluation_row(
178+
self,
179+
raw_row: Dict[str, Any],
180+
row_index: int,
181+
completion_params: CompletionParams,
182+
) -> EvaluationRow:
183+
"""Convert a single BigQuery row to EvaluationRow format.
184+
185+
Args:
186+
raw_row: BigQuery row dictionary
187+
row_index: Index of the row in the result set
188+
completion_params: Completion parameters to use
189+
190+
Returns:
191+
EvaluationRow object or None if conversion fails
192+
"""
193+
# Apply user transformation
194+
transformed = self.transform_fn(raw_row)
195+
196+
# Validate required fields
197+
if "messages" not in transformed:
198+
raise ValueError("Transform function must return 'messages' field")
199+
200+
# Convert message dictionaries to Message objects
201+
messages = []
202+
for msg_dict in transformed["messages"]:
203+
if not isinstance(msg_dict, dict):
204+
raise ValueError("Each message must be a dictionary")
205+
if "role" not in msg_dict:
206+
raise ValueError("Each message must have a 'role' field")
207+
208+
messages.append(
209+
Message(
210+
role=msg_dict["role"],
211+
content=msg_dict.get("content"),
212+
name=msg_dict.get("name"),
213+
tool_call_id=msg_dict.get("tool_call_id"),
214+
tool_calls=msg_dict.get("tool_calls"),
215+
function_call=msg_dict.get("function_call"),
216+
)
217+
)
218+
219+
# Extract other fields
220+
ground_truth = transformed.get("ground_truth")
221+
tools = transformed.get("tools")
222+
user_metadata = transformed.get("metadata", {})
223+
224+
# Create dataset info
225+
dataset_info = {
226+
"source": "bigquery",
227+
"dataset_id": self.dataset_id or self.client.project,
228+
"row_index": row_index,
229+
"transform_function": (
230+
self.transform_fn.__name__ if hasattr(self.transform_fn, "__name__") else "anonymous"
231+
),
232+
}
233+
234+
# Add user metadata
235+
dataset_info.update(user_metadata)
236+
237+
# Add original row data (with prefix to avoid conflicts)
238+
for key, value in raw_row.items():
239+
# Convert BigQuery types to JSON-serializable types
240+
dataset_info[f"original_{key}"] = value
241+
242+
# Create input metadata (following HuggingFace pattern)
243+
input_metadata = InputMetadata(
244+
row_id=f"{self.dataset_id}_{row_index}",
245+
completion_params=completion_params,
246+
dataset_info=dataset_info,
247+
session_data={
248+
"dataset_source": "bigquery",
249+
},
250+
)
251+
252+
return EvaluationRow(
253+
messages=messages,
254+
tools=tools,
255+
input_metadata=input_metadata,
256+
ground_truth=str(ground_truth) if ground_truth is not None else None,
257+
)
258+
259+
260+
def create_bigquery_adapter(
261+
transform_fn: TransformFunction,
262+
dataset_id: Optional[str] = None,
263+
credentials_path: Optional[str] = None,
264+
location: Optional[str] = None,
265+
**client_kwargs,
266+
) -> BigQueryAdapter:
267+
"""Factory function to create a BigQuery adapter.
268+
269+
Args:
270+
transform_fn: Function to transform BigQuery rows to evaluation format
271+
dataset_id: Google Cloud project ID
272+
credentials_path: Path to service account JSON file
273+
location: Default location for BigQuery jobs
274+
**client_kwargs: Additional arguments for BigQuery client
275+
276+
Returns:
277+
BigQueryAdapter instance
278+
"""
279+
return BigQueryAdapter(
280+
transform_fn=transform_fn,
281+
dataset_id=dataset_id,
282+
credentials_path=credentials_path,
283+
location=location,
284+
**client_kwargs,
285+
)

0 commit comments

Comments
 (0)