Skip to content

Commit 5d9fb07

Browse files
committed
Fix completion params getting overwritten
1 parent 8e8eef4 commit 5d9fb07

2 files changed

Lines changed: 29 additions & 17 deletions

File tree

eval_protocol/adapters/langfuse.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,32 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
379379
Returns:
380380
InputMetadata object
381381
"""
382-
# Extract completion parameters from observations
382+
# Extract completion parameters from trace input first, then observations
383383
completion_params = {}
384384

385-
# Look for model parameters in observations
386-
for obs in observations:
387-
if hasattr(obs, "model") and obs.model:
388-
completion_params["model"] = obs.model
389-
if hasattr(obs, "model_parameters") and obs.model_parameters:
390-
params = obs.model_parameters
391-
if "temperature" in params:
392-
completion_params["temperature"] = params["temperature"]
393-
if "max_tokens" in params:
394-
completion_params["max_tokens"] = params["max_tokens"]
395-
if "top_p" in params:
396-
completion_params["top_p"] = params["top_p"]
397-
break
385+
# First check trace input for evaluation test completion_params
386+
if hasattr(trace, "input") and trace.input:
387+
if isinstance(trace.input, dict):
388+
kwargs = trace.input.get("kwargs", {})
389+
if "completion_params" in kwargs:
390+
trace_completion_params = kwargs["completion_params"]
391+
if trace_completion_params and isinstance(trace_completion_params, dict):
392+
completion_params.update(trace_completion_params)
393+
394+
# Fallback: Look for model parameters in observations if not found in trace input
395+
if not completion_params:
396+
for obs in observations:
397+
if hasattr(obs, "model") and obs.model:
398+
completion_params["model"] = obs.model
399+
if hasattr(obs, "model_parameters") and obs.model_parameters:
400+
params = obs.model_parameters
401+
if "temperature" in params:
402+
completion_params["temperature"] = params["temperature"]
403+
if "max_tokens" in params:
404+
completion_params["max_tokens"] = params["max_tokens"]
405+
if "top_p" in params:
406+
completion_params["top_p"] = params["top_p"]
407+
break
398408

399409
# Create dataset info from trace metadata
400410
dataset_info = {

eval_protocol/pytest/evaluation_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,11 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
271271
passed=None,
272272
)
273273
for row in data:
274-
row.input_metadata.completion_params = (
275-
completion_params if completion_params is not None else {}
276-
)
274+
# Only set completion_params if they don't already exist
275+
if not row.input_metadata.completion_params:
276+
row.input_metadata.completion_params = (
277+
completion_params if completion_params is not None else {}
278+
)
277279
# Add mode to session_data
278280
if row.input_metadata.session_data is None:
279281
row.input_metadata.session_data = {}

0 commit comments

Comments
 (0)