|
| 1 | +"""Google BigQuery adapter for Eval Protocol. |
| 2 | +
|
| 3 | +This adapter allows querying data from Google BigQuery tables and converting it |
| 4 | +to EvaluationRow format for use in evaluation pipelines. |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +from typing import Any, Callable, Dict, Iterator, List, Optional, Union |
| 9 | + |
| 10 | +from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | +try: |
| 15 | + from google.auth.exceptions import DefaultCredentialsError |
| 16 | + from google.cloud import bigquery |
| 17 | + from google.cloud.exceptions import Forbidden, NotFound |
| 18 | + from google.oauth2 import service_account |
| 19 | + |
| 20 | + BIGQUERY_AVAILABLE = True |
| 21 | +except ImportError: |
| 22 | + BIGQUERY_AVAILABLE = False |
| 23 | + logger.warning("Google Cloud BigQuery not installed. Install with: pip install 'eval-protocol[bigquery]'") |
| 24 | + |
| 25 | +# Type alias for transformation function |
| 26 | +TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]] |
| 27 | + |
| 28 | + |
| 29 | +class BigQueryAdapter: |
| 30 | + """Adapter to query data from Google BigQuery and convert to EvaluationRow format. |
| 31 | +
|
| 32 | + This adapter connects to Google BigQuery, executes SQL queries, and applies |
| 33 | + a user-provided transformation function to convert each row to the format |
| 34 | + expected by EvaluationRow. |
| 35 | +
|
| 36 | + The transformation function should take a BigQuery row dictionary and return: |
| 37 | + { |
| 38 | + 'messages': List[Dict] - list of message dictionaries with 'role' and 'content' |
| 39 | + 'ground_truth': Optional[str] - expected answer/output |
| 40 | + 'metadata': Optional[Dict] - any additional metadata to preserve |
| 41 | + 'tools': Optional[List[Dict]] - tool definitions for tool calling scenarios |
| 42 | + } |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + transform_fn: TransformFunction, |
| 48 | + dataset_id: Optional[str] = None, |
| 49 | + credentials_path: Optional[str] = None, |
| 50 | + location: Optional[str] = None, |
| 51 | + **client_kwargs, |
| 52 | + ): |
| 53 | + """Initialize the BigQuery adapter. |
| 54 | +
|
| 55 | + Args: |
| 56 | + transform_fn: Function to transform BigQuery rows to evaluation format |
| 57 | + dataset_id: Google Cloud project ID (if None, uses default from environment) |
| 58 | + credentials_path: Path to service account JSON file (if None, uses default auth) |
| 59 | + location: Default location for BigQuery jobs |
| 60 | + **client_kwargs: Additional arguments to pass to BigQuery client |
| 61 | +
|
| 62 | + Raises: |
| 63 | + ImportError: If google-cloud-bigquery is not installed |
| 64 | + DefaultCredentialsError: If authentication fails |
| 65 | + """ |
| 66 | + if not BIGQUERY_AVAILABLE: |
| 67 | + raise ImportError( |
| 68 | + "Google Cloud BigQuery not installed. Install with: pip install 'eval-protocol[bigquery]'" |
| 69 | + ) |
| 70 | + |
| 71 | + self.transform_fn = transform_fn |
| 72 | + self.dataset_id = dataset_id |
| 73 | + self.location = location |
| 74 | + |
| 75 | + # Initialize BigQuery client |
| 76 | + try: |
| 77 | + client_args = {} |
| 78 | + if dataset_id: |
| 79 | + client_args["project"] = dataset_id |
| 80 | + if credentials_path: |
| 81 | + credentials = service_account.Credentials.from_service_account_file(credentials_path) |
| 82 | + client_args["credentials"] = credentials |
| 83 | + if location: |
| 84 | + client_args["location"] = location |
| 85 | + |
| 86 | + client_args.update(client_kwargs) |
| 87 | + self.client = bigquery.Client(**client_args) |
| 88 | + |
| 89 | + except DefaultCredentialsError as e: |
| 90 | + logger.error("Failed to authenticate with BigQuery: %s", e) |
| 91 | + raise |
| 92 | + except Exception as e: |
| 93 | + logger.error("Failed to initialize BigQuery client: %s", e) |
| 94 | + raise |
| 95 | + |
| 96 | + def get_evaluation_rows( |
| 97 | + self, |
| 98 | + query: str, |
| 99 | + query_params: Optional[List[Union[bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter]]] = None, |
| 100 | + limit: Optional[int] = None, |
| 101 | + offset: int = 0, |
| 102 | + model_name: str = "gpt-3.5-turbo", |
| 103 | + temperature: float = 0.0, |
| 104 | + max_tokens: Optional[int] = None, |
| 105 | + **completion_params_kwargs, |
| 106 | + ) -> Iterator[EvaluationRow]: |
| 107 | + """Execute BigQuery query and convert results to EvaluationRow format. |
| 108 | +
|
| 109 | + Args: |
| 110 | + query: SQL query to execute |
| 111 | + query_params: Optional list of query parameters for parameterized queries |
| 112 | + limit: Maximum number of rows to return (applied after BigQuery query) |
| 113 | + offset: Number of rows to skip (applied after BigQuery query) |
| 114 | + model_name: Model name for completion parameters |
| 115 | + temperature: Temperature for completion parameters |
| 116 | + max_tokens: Max tokens for completion parameters |
| 117 | + **completion_params_kwargs: Additional completion parameters |
| 118 | +
|
| 119 | + Yields: |
| 120 | + EvaluationRow: Converted evaluation rows |
| 121 | +
|
| 122 | + Raises: |
| 123 | + NotFound: If the query references non-existent tables/datasets |
| 124 | + Forbidden: If insufficient permissions |
| 125 | + """ |
| 126 | + try: |
| 127 | + # Configure query job |
| 128 | + job_config = bigquery.QueryJobConfig() |
| 129 | + if query_params: |
| 130 | + job_config.query_parameters = query_params |
| 131 | + if self.location: |
| 132 | + job_config.location = self.location |
| 133 | + |
| 134 | + query_job = self.client.query(query, job_config=job_config) |
| 135 | + |
| 136 | + results = query_job.result() |
| 137 | + |
| 138 | + completion_params: CompletionParams = { |
| 139 | + "model": model_name, |
| 140 | + "temperature": temperature, |
| 141 | + "max_tokens": max_tokens, |
| 142 | + **completion_params_kwargs, |
| 143 | + } |
| 144 | + |
| 145 | + # Convert rows with offset/limit |
| 146 | + row_count = 0 |
| 147 | + processed_count = 0 |
| 148 | + |
| 149 | + for raw_row in results: |
| 150 | + # Apply offset |
| 151 | + if row_count < offset: |
| 152 | + row_count += 1 |
| 153 | + continue |
| 154 | + |
| 155 | + # Apply limit |
| 156 | + if limit is not None and processed_count >= limit: |
| 157 | + break |
| 158 | + |
| 159 | + try: |
| 160 | + eval_row = self._convert_row_to_evaluation_row(raw_row, processed_count, completion_params) |
| 161 | + if eval_row: |
| 162 | + yield eval_row |
| 163 | + processed_count += 1 |
| 164 | + |
| 165 | + except (AttributeError, ValueError, KeyError) as e: |
| 166 | + logger.warning("Failed to convert row %d: %s", row_count, e) |
| 167 | + |
| 168 | + row_count += 1 |
| 169 | + |
| 170 | + except (NotFound, Forbidden) as e: |
| 171 | + logger.error("BigQuery access error: %s", e) |
| 172 | + raise |
| 173 | + except Exception as e: |
| 174 | + logger.error("Error executing BigQuery query: %s", e) |
| 175 | + raise |
| 176 | + |
| 177 | + def _convert_row_to_evaluation_row( |
| 178 | + self, |
| 179 | + raw_row: Dict[str, Any], |
| 180 | + row_index: int, |
| 181 | + completion_params: CompletionParams, |
| 182 | + ) -> EvaluationRow: |
| 183 | + """Convert a single BigQuery row to EvaluationRow format. |
| 184 | +
|
| 185 | + Args: |
| 186 | + raw_row: BigQuery row dictionary |
| 187 | + row_index: Index of the row in the result set |
| 188 | + completion_params: Completion parameters to use |
| 189 | +
|
| 190 | + Returns: |
| 191 | + EvaluationRow object or None if conversion fails |
| 192 | + """ |
| 193 | + # Apply user transformation |
| 194 | + transformed = self.transform_fn(raw_row) |
| 195 | + |
| 196 | + # Validate required fields |
| 197 | + if "messages" not in transformed: |
| 198 | + raise ValueError("Transform function must return 'messages' field") |
| 199 | + |
| 200 | + # Convert message dictionaries to Message objects |
| 201 | + messages = [] |
| 202 | + for msg_dict in transformed["messages"]: |
| 203 | + if not isinstance(msg_dict, dict): |
| 204 | + raise ValueError("Each message must be a dictionary") |
| 205 | + if "role" not in msg_dict: |
| 206 | + raise ValueError("Each message must have a 'role' field") |
| 207 | + |
| 208 | + messages.append( |
| 209 | + Message( |
| 210 | + role=msg_dict["role"], |
| 211 | + content=msg_dict.get("content"), |
| 212 | + name=msg_dict.get("name"), |
| 213 | + tool_call_id=msg_dict.get("tool_call_id"), |
| 214 | + tool_calls=msg_dict.get("tool_calls"), |
| 215 | + function_call=msg_dict.get("function_call"), |
| 216 | + ) |
| 217 | + ) |
| 218 | + |
| 219 | + # Extract other fields |
| 220 | + ground_truth = transformed.get("ground_truth") |
| 221 | + tools = transformed.get("tools") |
| 222 | + user_metadata = transformed.get("metadata", {}) |
| 223 | + |
| 224 | + # Create dataset info |
| 225 | + dataset_info = { |
| 226 | + "source": "bigquery", |
| 227 | + "dataset_id": self.dataset_id or self.client.project, |
| 228 | + "row_index": row_index, |
| 229 | + "transform_function": ( |
| 230 | + self.transform_fn.__name__ if hasattr(self.transform_fn, "__name__") else "anonymous" |
| 231 | + ), |
| 232 | + } |
| 233 | + |
| 234 | + # Add user metadata |
| 235 | + dataset_info.update(user_metadata) |
| 236 | + |
| 237 | + # Add original row data (with prefix to avoid conflicts) |
| 238 | + for key, value in raw_row.items(): |
| 239 | + # Convert BigQuery types to JSON-serializable types |
| 240 | + dataset_info[f"original_{key}"] = value |
| 241 | + |
| 242 | + # Create input metadata (following HuggingFace pattern) |
| 243 | + input_metadata = InputMetadata( |
| 244 | + row_id=f"{self.dataset_id}_{row_index}", |
| 245 | + completion_params=completion_params, |
| 246 | + dataset_info=dataset_info, |
| 247 | + session_data={ |
| 248 | + "dataset_source": "bigquery", |
| 249 | + }, |
| 250 | + ) |
| 251 | + |
| 252 | + return EvaluationRow( |
| 253 | + messages=messages, |
| 254 | + tools=tools, |
| 255 | + input_metadata=input_metadata, |
| 256 | + ground_truth=str(ground_truth) if ground_truth is not None else None, |
| 257 | + ) |
| 258 | + |
| 259 | + |
| 260 | +def create_bigquery_adapter( |
| 261 | + transform_fn: TransformFunction, |
| 262 | + dataset_id: Optional[str] = None, |
| 263 | + credentials_path: Optional[str] = None, |
| 264 | + location: Optional[str] = None, |
| 265 | + **client_kwargs, |
| 266 | +) -> BigQueryAdapter: |
| 267 | + """Factory function to create a BigQuery adapter. |
| 268 | +
|
| 269 | + Args: |
| 270 | + transform_fn: Function to transform BigQuery rows to evaluation format |
| 271 | + dataset_id: Google Cloud project ID |
| 272 | + credentials_path: Path to service account JSON file |
| 273 | + location: Default location for BigQuery jobs |
| 274 | + **client_kwargs: Additional arguments for BigQuery client |
| 275 | +
|
| 276 | + Returns: |
| 277 | + BigQueryAdapter instance |
| 278 | + """ |
| 279 | + return BigQueryAdapter( |
| 280 | + transform_fn=transform_fn, |
| 281 | + dataset_id=dataset_id, |
| 282 | + credentials_path=credentials_path, |
| 283 | + location=location, |
| 284 | + **client_kwargs, |
| 285 | + ) |
0 commit comments