diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 06945f46..00000000 --- a/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length = 119 -ignore = E203, W503 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 08aaf406..8c1b0691 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,11 +42,14 @@ jobs: - name: Install tau2 for testing run: uv pip install git+https://github.com/sierra-research/tau2-bench.git@main - - name: Lint with flake8 - run: uv run flake8 eval_protocol tests examples scripts --count --exit-zero --max-complexity=10 --max-line-length=88 --statistics + - name: Ruff format (check) + run: uv run ruff format --check . - - name: Type check with mypy - run: uv run mypy eval_protocol + - name: Ruff lint + run: uv run ruff check . + + - name: Type check with pyright + run: uv run pyright test-core: name: Core Tests (Python ${{ matrix.python-version }}) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43c0f8c1..441a2de7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,43 +1,29 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks +exclude: | + (^vite-app/|\.snap$) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace + exclude: "(^vite-app/|\\.snap$)" - id: end-of-file-fixer + exclude: "(^vite-app/|\\.snap$)" - id: check-yaml - id: check-added-large-files - id: check-merge-conflict - id: check-toml - id: detect-private-key -- repo: https://github.com/psf/black - rev: 25.1.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.8 hooks: - - id: black - args: [--line-length=119] + - id: ruff-format + - id: ruff + args: ["--fix"] -- repo: https://github.com/pycqa/isort - rev: 6.0.1 +- repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.403 hooks: - - id: isort - name: isort (python) - args: ["--profile", "black", "--line-length", "119", "--filter-files"] - -- repo: https://github.com/pycqa/flake8 - rev: 7.3.0 - hooks: - - id: flake8 - args: [--max-line-length=119, --max-complexity=100, "--ignore=E402,F401,F541,W503,E203,F811,E226,F841,E704,E713,E712,E231,E731,E501"] - # additional_dependencies: [flake8-docstrings, flake8-import-order] # Optional: add flake8 plugins - -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.17.0 - hooks: - - id: mypy - args: [--ignore-missing-imports, --install-types, --non-interactive] - additional_dependencies: - - types-requests - - types-setuptools - # Add other types-* packages your project uses + - id: pyright diff --git a/LICENSE b/LICENSE index e926381a..4bff8e12 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/development/normalize_sandbox_fusion.py b/development/normalize_sandbox_fusion.py index fa191283..988f740e 100644 --- a/development/normalize_sandbox_fusion.py +++ b/development/normalize_sandbox_fusion.py @@ -56,7 +56,7 @@ try: repobench_p_tokenizer = AutoTokenizer.from_pretrained("gpt2") except OSError: - print("Warning: Could not load gpt2 tokenizer for Repobench-P. " "Falling back to basic split for token counting.") + print("Warning: Could not load gpt2 tokenizer for Repobench-P. Falling back to basic split for token counting.") repobench_p_tokenizer = None @@ -108,8 +108,7 @@ def format_aider_prompt(problem_json: dict) -> str: """Format the prompt for Aider benchmark style problems.""" question = problem_json.get("content", "") return ( - f"{question}\n\nPlease generate the code in the following format:\n" - "```python\n# Your code response here\n```" + f"{question}\n\nPlease generate the code in the following format:\n```python\n# Your code response here\n```" ) @@ -327,7 +326,7 @@ def normalize_problem_to_openai_format( try: labels = json.loads(labels_data) except json.JSONDecodeError: - print(f"Warning: Skipping ID {problem_id_str} in {filename} " "- malformed JSON in labels.") + print(f"Warning: Skipping ID {problem_id_str} in {filename} - malformed JSON in labels.") return None elif isinstance(labels_data, dict): labels = labels_data @@ -426,10 +425,10 @@ def normalize_problem_to_openai_format( ) return None if not final_user_content.strip() or not final_assistant_content.strip(): - print(f"Warning: Skipping ID {problem_id_str} in {filename} - " "empty processed content.") + print(f"Warning: Skipping ID {problem_id_str} in {filename} - empty processed content.") return None if final_assistant_content.strip() == "import sys; sys.exit(0)": - print(f"Warning: Skipping ID {problem_id_str} in {filename} - " "placeholder solution.") + print(f"Warning: Skipping ID {problem_id_str} in {filename} - placeholder solution.") return None return { @@ -439,7 +438,7 @@ def normalize_problem_to_openai_format( ] } except Exception as e: - print(f"Warning: Skipping ID {problem_id_str} in {filename} - " f"error ({type(e).__name__}: {e}).") + print(f"Warning: Skipping ID {problem_id_str} in {filename} - error ({type(e).__name__}: {e}).") import traceback traceback.print_exc() @@ -474,7 +473,7 @@ def main(): file_error_count += 1 continue - print(f"Processing file {filename_idx + 1}/{len(ALL_SOURCE_JSONL_FILES)}: " f"{filename}...") + print(f"Processing file {filename_idx + 1}/{len(ALL_SOURCE_JSONL_FILES)}: {filename}...") lines_in_file = 0 processed_in_file = 0 skipped_in_file = 0 @@ -488,7 +487,7 @@ def main(): try: problem_data = json.loads(stripped_line) except json.JSONDecodeError: - print(f"Warning: Malformed JSON on line {line_number} " f"in {filepath}. Skipping line.") + print(f"Warning: Malformed JSON on line {line_number} in {filepath}. Skipping line.") skipped_in_file += 1 continue @@ -507,7 +506,7 @@ def main(): processed_count += processed_in_file skipped_count += skipped_in_file except Exception as e: - print(f"Error processing file {filepath}: {type(e).__name__}: {e}. " "Skipping rest of file.") + print(f"Error processing file {filepath}: {type(e).__name__}: {e}. Skipping rest of file.") import traceback traceback.print_exc() diff --git a/development/notes/pytest_integration_proposal.md b/development/notes/pytest_integration_proposal.md index c9496587..784cc215 100644 --- a/development/notes/pytest_integration_proposal.md +++ b/development/notes/pytest_integration_proposal.md @@ -115,7 +115,7 @@ def tau2_rollout_processor(row: EvaluationRow, model: str, input_params: Dict, * # from the dataset and provide a simulated tool response. # 4. Call the model again with the tool response. # 5. Construct a final EvaluationRow with the full transcript. - + # The logic is encapsulated here, away from the test definition. processed_row = ep.default_rollout_processor(row, model, input_params)[0] # Simplified for example return [processed_row] @@ -186,11 +186,11 @@ def best_of_n_processor(row: EvaluationRow, model: str, input_params: Dict, **kw # Then, apply a reward function to score each candidate. scored_rows = ep.evaluate(candidate_rows, score_politeness) - + # Finally, select the best row. # This logic could be encapsulated in a helper, e.g., ep.select_best(). best_row = select_best_by_group(scored_rows, score_key='politeness') - + return [best_row] @evaluation_test( diff --git a/development/utils/subprocess_manager.py b/development/utils/subprocess_manager.py index 5af6c768..1d568c87 100644 --- a/development/utils/subprocess_manager.py +++ b/development/utils/subprocess_manager.py @@ -139,7 +139,7 @@ def start_ngrok_and_get_url( # Or by setting NGROK_AUTHTOKEN environment variable. # Forcing it via command line is also an option but less common for persistent setup. print( - f"Note: Ngrok authtoken should be pre-configured by the user (e.g., 'ngrok config add-authtoken ') or via NGROK_AUTHTOKEN env var." + "Note: Ngrok authtoken should be pre-configured by the user (e.g., 'ngrok config add-authtoken ') or via NGROK_AUTHTOKEN env var." ) # Example if passing via env for the subprocess: # ngrok_env = os.environ.copy() diff --git a/eval_protocol/adapters/CONTRIBUTING.md b/eval_protocol/adapters/CONTRIBUTING.md index 18f31378..e47e06e9 100644 --- a/eval_protocol/adapters/CONTRIBUTING.md +++ b/eval_protocol/adapters/CONTRIBUTING.md @@ -37,36 +37,36 @@ except ImportError: class YourCustomAdapter: """Adapter for integrating with Your Custom Data Source. - + This adapter loads data from Your Custom Data Source and converts it to EvaluationRow format for use in evaluation pipelines. - + Examples: Basic usage: >>> adapter = YourCustomAdapter(api_key="your_key") >>> rows = list(adapter.get_evaluation_rows(limit=10)) """ - + def __init__(self, **config): """Initialize the adapter with configuration.""" if not DEPENDENCY_AVAILABLE: raise ImportError("your_external_library not installed") - + # Initialize your client/connection here self.client = your_external_library.Client(**config) - + def get_evaluation_rows(self, **kwargs) -> Iterator[EvaluationRow]: """Main method to fetch and convert data to EvaluationRow format. - + Args: **kwargs: Adapter-specific parameters - + Yields: EvaluationRow: Converted evaluation rows """ # Implement your data fetching logic raw_data = self.client.fetch_data(**kwargs) - + for item in raw_data: try: eval_row = self._convert_to_evaluation_row(item) @@ -75,51 +75,51 @@ class YourCustomAdapter: except Exception as e: logger.warning(f"Failed to convert item: {e}") continue - + def _convert_to_evaluation_row(self, raw_item: Any) -> Optional[EvaluationRow]: """Convert a raw data item to EvaluationRow format. - + Args: raw_item: Raw data item from your source - + Returns: EvaluationRow or None if conversion fails """ # Extract messages from your data format messages = self._extract_messages(raw_item) - + # Extract metadata input_metadata = self._create_input_metadata(raw_item) - + # Extract ground truth if available ground_truth = self._extract_ground_truth(raw_item) - + # Extract tools if available (for tool calling scenarios) tools = self._extract_tools(raw_item) - + return EvaluationRow( messages=messages, tools=tools, input_metadata=input_metadata, ground_truth=ground_truth, ) - + def _extract_messages(self, raw_item: Any) -> List[Message]: """Extract conversation messages from raw data.""" # Implement message extraction logic # Convert your data format to List[Message] pass - + def _create_input_metadata(self, raw_item: Any) -> InputMetadata: """Create InputMetadata from raw data.""" # Implement metadata extraction pass - + def _extract_ground_truth(self, raw_item: Any) -> Optional[str]: """Extract ground truth if available.""" # Implement ground truth extraction pass - + def _extract_tools(self, raw_item: Any) -> Optional[List[Dict[str, Any]]]: """Extract tool definitions if available.""" # Implement tool extraction for tool calling scenarios @@ -149,7 +149,7 @@ message = Message( content="I'll help you with that calculation.", tool_calls=[{ "id": "call_123", - "type": "function", + "type": "function", "function": { "name": "calculate", "arguments": '{"x": 5, "y": 3}' @@ -185,7 +185,7 @@ input_metadata = InputMetadata( }, session_data={ "user_id": "user123", - "session_id": "session456", + "session_id": "session456", "timestamp": "2024-01-01T00:00:00Z", } ) @@ -259,7 +259,7 @@ def get_evaluation_rows(self, **kwargs) -> Iterator[EvaluationRow]: except Exception as e: logger.error(f"Failed to fetch data: {e}") return - + for item in data: try: row = self._convert_to_evaluation_row(item) @@ -298,36 +298,36 @@ from eval_protocol.models import EvaluationRow class TestYourCustomAdapter: """Test suite for YourCustomAdapter.""" - + def test_initialization(self): """Test adapter initialization.""" adapter = YourCustomAdapter(api_key="test_key") assert adapter.client is not None - + def test_get_evaluation_rows(self): """Test conversion to EvaluationRow format.""" adapter = YourCustomAdapter(api_key="test_key") - + # Mock the external API response with patch.object(adapter.client, 'fetch_data') as mock_fetch: mock_fetch.return_value = [ # Mock data in your format {"id": "1", "question": "Test?", "answer": "Yes"} ] - + rows = list(adapter.get_evaluation_rows(limit=1)) - + assert len(rows) == 1 assert isinstance(rows[0], EvaluationRow) assert len(rows[0].messages) > 0 - + def test_error_handling(self): """Test error handling.""" adapter = YourCustomAdapter(api_key="test_key") - + with patch.object(adapter.client, 'fetch_data') as mock_fetch: mock_fetch.side_effect = Exception("API Error") - + rows = list(adapter.get_evaluation_rows()) assert len(rows) == 0 # Should handle error gracefully ``` @@ -341,18 +341,18 @@ For simple chat data: ```python def _extract_messages(self, conversation: Dict) -> List[Message]: messages = [] - + # Add system prompt if available if conversation.get('system_prompt'): messages.append(Message(role="system", content=conversation['system_prompt'])) - + # Add conversation turns for turn in conversation['turns']: messages.append(Message( role=turn['role'], content=turn['content'] )) - + return messages ``` @@ -363,27 +363,27 @@ For tool calling scenarios: ```python def _extract_messages(self, trace: Dict) -> List[Message]: messages = [] - + for step in trace['steps']: if step['type'] == 'user_message': messages.append(Message(role="user", content=step['content'])) - + elif step['type'] == 'assistant_message': message = Message(role="assistant", content=step.get('content')) - + # Add tool calls if present if step.get('tool_calls'): message.tool_calls = step['tool_calls'] - + messages.append(message) - + elif step['type'] == 'tool_response': messages.append(Message( role="tool", content=step['content'], tool_call_id=step['tool_call_id'] )) - + return messages ``` @@ -515,10 +515,10 @@ Here are some potential adapters that would be valuable: - **OpenAI Evals**: Load data from OpenAI's evals repository - **LLM Evaluation Datasets**: MMLU, HellaSwag, etc. -- **Chat Platforms**: Discord, Slack conversation exports +- **Chat Platforms**: Discord, Slack conversation exports - **Monitoring Tools**: Other observability platforms - **Custom APIs**: Company-specific data sources - **File Formats**: Parquet, Excel, database exports - **Research Datasets**: Academic benchmarks and competitions -We welcome contributions for any of these or other creative integrations! \ No newline at end of file +We welcome contributions for any of these or other creative integrations! diff --git a/eval_protocol/adapters/huggingface.py b/eval_protocol/adapters/huggingface.py index 2825dafa..7f8b6902 100644 --- a/eval_protocol/adapters/huggingface.py +++ b/eval_protocol/adapters/huggingface.py @@ -413,7 +413,7 @@ def create_math_adapter( HuggingFaceAdapter configured for MATH dataset """ default_system_prompt = ( - "You are an expert mathematician. Solve this advanced math problem " "step by step, showing detailed work." + "You are an expert mathematician. Solve this advanced math problem step by step, showing detailed work." ) system_content = system_prompt or default_system_prompt diff --git a/eval_protocol/agent/orchestrator.py b/eval_protocol/agent/orchestrator.py index 2f737e2c..61be1091 100644 --- a/eval_protocol/agent/orchestrator.py +++ b/eval_protocol/agent/orchestrator.py @@ -416,9 +416,9 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - episode_resource: Optional[ForkableResource] = None evaluation_result: Optional[Dict[str, Any]] = None - all_user_turns_successful_function_calls: List[List[Dict[str, Any]]] = ( - [] - ) # Track successful calls for reward fn, list of lists (per user turn) + all_user_turns_successful_function_calls: List[ + List[Dict[str, Any]] + ] = [] # Track successful calls for reward fn, list of lists (per user turn) conversation_messages: List[Dict[str, Any]] = [] # Use dicts for API compatibility # --- Agent Model Setup --- @@ -845,16 +845,16 @@ async def execute_task_poc(self, sample_data: Optional[Dict[str, Any]] = None) - eval_args["ground_truth"] = ground_truth_for_reward # Call the reward function - self.logger.info(f"=== CALLING REWARD FUNCTION DEBUG ===") + self.logger.info("=== CALLING REWARD FUNCTION DEBUG ===") self.logger.info(f"Reward function type: {type(self.reward_function)}") self.logger.info(f"Eval args keys: {list(eval_args.keys())}") self.logger.info(f"Task achieved: {eval_args.get('task_achieved', 'NOT_SET')}") self.logger.info(f"Messages count: {len(eval_args.get('messages', []))}") evaluation_result = self.reward_function(**eval_args) - self.logger.info(f"=== REWARD FUNCTION RESULT ===") + self.logger.info("=== REWARD FUNCTION RESULT ===") self.logger.info(f"Reward function result: {evaluation_result}") self.logger.info(f"Result type: {type(evaluation_result)}") - self.logger.info(f"=== END REWARD FUNCTION DEBUG ===") + self.logger.info("=== END REWARD FUNCTION DEBUG ===") # Return both the evaluation result and the inputs for trajectory capture return { diff --git a/eval_protocol/agent/resources/bfcl_sim_api_resource.py b/eval_protocol/agent/resources/bfcl_sim_api_resource.py index 638c915e..8bee80a7 100644 --- a/eval_protocol/agent/resources/bfcl_sim_api_resource.py +++ b/eval_protocol/agent/resources/bfcl_sim_api_resource.py @@ -4,8 +4,7 @@ from pathlib import Path # Import BFCL File and Directory for isinstance checks from local implementation -from .bfcl_envs.gorilla_file_system import Directory as BFCLDirectory -from .bfcl_envs.gorilla_file_system import File as BFCLFile +from .bfcl_envs.gorilla_file_system import Directory as BFCLDirectory, File as BFCLFile BFCL_TYPES_AVAILABLE = True import gc diff --git a/eval_protocol/agent/task_manager.py b/eval_protocol/agent/task_manager.py index 7e6ee6e3..5cff2f9d 100644 --- a/eval_protocol/agent/task_manager.py +++ b/eval_protocol/agent/task_manager.py @@ -918,7 +918,7 @@ def _save_detailed_results( output_path = Path(output_file) try: - self.logger.info(f"=== TRAJECTORY SAVE DEBUG START ===") + self.logger.info("=== TRAJECTORY SAVE DEBUG START ===") self.logger.info(f"Saving trajectory data to: {output_path}") self.logger.info(f"Chosen directory: {chosen_dir}") self.logger.info(f"Individual results count: {len(aggregated_result.get('individual_results', []))}") @@ -992,7 +992,7 @@ def _save_detailed_results( self.logger.info(f"Successfully saved trajectory data to: {output_path}") self.logger.info(f"Trajectory file size: {output_path.stat().st_size} bytes") - self.logger.info(f"=== TRAJECTORY SAVE DEBUG END ===") + self.logger.info("=== TRAJECTORY SAVE DEBUG END ===") return str(output_path) except Exception as e: diff --git a/eval_protocol/benchmarks/__init__.py b/eval_protocol/benchmarks/__init__.py index 18a872c7..e248fe9b 100644 --- a/eval_protocol/benchmarks/__init__.py +++ b/eval_protocol/benchmarks/__init__.py @@ -5,5 +5,3 @@ "get_benchmark_runner", "list_benchmarks", ] - - diff --git a/eval_protocol/benchmarks/registry.py b/eval_protocol/benchmarks/registry.py index 31840fd1..ce3c698e 100644 --- a/eval_protocol/benchmarks/registry.py +++ b/eval_protocol/benchmarks/registry.py @@ -199,6 +199,7 @@ def _composite_runner( _get_benchmark_runner = get_benchmark_runner import pathlib as _pathlib import time as _time + _json = json child_summaries: List[Dict[str, Any]] = [] @@ -302,9 +303,7 @@ def _composite_runner( f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} rows={total_rows}" ) else: - print( - f"EP Summary | suite={name} model={model} agg=None rows={total_rows}" - ) + print(f"EP Summary | suite={name} model={model} agg=None rows={total_rows}") except Exception: pass diff --git a/eval_protocol/benchmarks/run.py b/eval_protocol/benchmarks/run.py index 9195666f..a5afe900 100644 --- a/eval_protocol/benchmarks/run.py +++ b/eval_protocol/benchmarks/run.py @@ -14,10 +14,10 @@ from __future__ import annotations import argparse +import pkgutil +from importlib import import_module from typing import Any -from importlib import import_module -import pkgutil import eval_protocol.benchmarks.suites as suites_pkg from eval_protocol.benchmarks.registry import get_benchmark_runner, list_benchmarks @@ -49,7 +49,9 @@ def main() -> int: args = _parse_args() # Auto-import all suite modules so their @export_benchmark decorators register # Import all suite modules so their @export_benchmark decorators register - import sys, traceback + import sys + import traceback + for modinfo in pkgutil.iter_modules(suites_pkg.__path__): mod_name = f"{suites_pkg.__name__}.{modinfo.name}" try: @@ -96,5 +98,3 @@ def main() -> int: if __name__ == "__main__": raise SystemExit(main()) - - diff --git a/eval_protocol/benchmarks/suites/__init__.py b/eval_protocol/benchmarks/suites/__init__.py index 04746ad6..d0effd69 100644 --- a/eval_protocol/benchmarks/suites/__init__.py +++ b/eval_protocol/benchmarks/suites/__init__.py @@ -1,3 +1 @@ # Suite modules are auto-imported by eval_protocol.benchmarks.run to register benchmarks. - - diff --git a/eval_protocol/benchmarks/suites/aime25.py b/eval_protocol/benchmarks/suites/aime25.py index 92d7bedc..755795df 100644 --- a/eval_protocol/benchmarks/suites/aime25.py +++ b/eval_protocol/benchmarks/suites/aime25.py @@ -8,7 +8,7 @@ from eval_protocol.pytest.evaluation_test import evaluation_test SYSTEM_PROMPT = ( - "You are a helpful math assistant. Please reason step by step, and put your " "final answer within \\boxed{...}." + "You are a helpful math assistant. Please reason step by step, and put your final answer within \\boxed{...}." ) diff --git a/eval_protocol/cli_commands/agent_eval_cmd.py b/eval_protocol/cli_commands/agent_eval_cmd.py index f49ab0ea..08767d50 100644 --- a/eval_protocol/cli_commands/agent_eval_cmd.py +++ b/eval_protocol/cli_commands/agent_eval_cmd.py @@ -115,7 +115,7 @@ async def main_flow(): logger.info(f" - Standard deviation: {result.get('std_dev', 0.0):.4f}") logger.info(f" - Score range: {result['min_score']:.4f} - {result['max_score']:.4f}") if "aggregated_metrics" in result: - logger.info(f" - Aggregated metrics:") + logger.info(" - Aggregated metrics:") for metric_name, metric_data in result["aggregated_metrics"].items(): logger.info( f" * {metric_name}: avg={metric_data['avg_score']:.4f}, range={metric_data['min_score']:.4f}-{metric_data['max_score']:.4f}" diff --git a/eval_protocol/cli_commands/deploy.py b/eval_protocol/cli_commands/deploy.py index 7536969a..7250e291 100644 --- a/eval_protocol/cli_commands/deploy.py +++ b/eval_protocol/cli_commands/deploy.py @@ -16,8 +16,8 @@ # TODO: Consider moving subprocess_manager functions to a more central location if used by core CLI try: - from development.utils.subprocess_manager import start_ngrok_and_get_url # Added ngrok function from development.utils.subprocess_manager import ( + start_ngrok_and_get_url, # Added ngrok function start_process, start_serveo_and_get_url, stop_process, @@ -59,9 +59,12 @@ def start_ngrok_and_get_url(local_port, log_path): from eval_protocol.auth import get_fireworks_account_id -from eval_protocol.config import GCPCloudRunConfig, RewardKitConfig -from eval_protocol.config import _config_file_path as global_loaded_config_path -from eval_protocol.config import get_config +from eval_protocol.config import ( + GCPCloudRunConfig, + RewardKitConfig, + _config_file_path as global_loaded_config_path, + get_config, +) from eval_protocol.evaluation import create_evaluation from eval_protocol.gcp_tools import ( build_and_push_docker_image, diff --git a/eval_protocol/cli_commands/deploy_mcp.py b/eval_protocol/cli_commands/deploy_mcp.py index f71796d6..34cb6a6f 100644 --- a/eval_protocol/cli_commands/deploy_mcp.py +++ b/eval_protocol/cli_commands/deploy_mcp.py @@ -9,9 +9,12 @@ from pathlib import Path from typing import Dict, Optional -from eval_protocol.config import GCPCloudRunConfig, RewardKitConfig -from eval_protocol.config import _config_file_path as global_loaded_config_path -from eval_protocol.config import get_config +from eval_protocol.config import ( + GCPCloudRunConfig, + RewardKitConfig, + _config_file_path as global_loaded_config_path, + get_config, +) from eval_protocol.gcp_tools import ( build_and_push_docker_image, deploy_to_cloud_run, @@ -235,7 +238,7 @@ def _deploy_mcp_to_gcp_cloud_run(args, current_config, gcp_config_from_yaml): print("Failed to deploy to Cloud Run or retrieve service URL. Aborting.") return None - print(f"๐Ÿš€ Successfully deployed MCP server to Cloud Run!") + print("๐Ÿš€ Successfully deployed MCP server to Cloud Run!") print(f"๐Ÿ“ Service URL: {cloud_run_service_url}") print(f"๐Ÿ”— MCP Connection URL: {cloud_run_service_url}") print(f"๐Ÿ“‹ Service Name: {args.id}") diff --git a/eval_protocol/cli_commands/logs.py b/eval_protocol/cli_commands/logs.py index 07c211be..92b1be58 100644 --- a/eval_protocol/cli_commands/logs.py +++ b/eval_protocol/cli_commands/logs.py @@ -12,7 +12,7 @@ def logs_command(args): """Serve logs with file watching and real-time updates""" port = args.port - print(f"๐Ÿš€ Starting Eval Protocol Logs Server") + print("๐Ÿš€ Starting Eval Protocol Logs Server") print(f"๐ŸŒ URL: http://localhost:{port}") print(f"๐Ÿ”Œ WebSocket: ws://localhost:{port}/ws") print(f"๐Ÿ‘€ Watching paths: {['current directory']}") diff --git a/eval_protocol/cli_commands/preview.py b/eval_protocol/cli_commands/preview.py index 0df8e95a..ef438496 100644 --- a/eval_protocol/cli_commands/preview.py +++ b/eval_protocol/cli_commands/preview.py @@ -110,7 +110,7 @@ def preview_command(args): processed_messages.append(msg_item) else: print( - f"Warning: Sample {i+1} has unexpected message item type: {type(msg_item)}. Skipping this message item." + f"Warning: Sample {i + 1} has unexpected message item type: {type(msg_item)}. Skipping this message item." ) try: @@ -120,13 +120,13 @@ def preview_command(args): kwargs=sample_kwargs, ) except Exception as e: # Pydantic validation for EvaluationRequest - print(f"\n--- Sample {i+1} ---") + print(f"\n--- Sample {i + 1} ---") print(f" Error creating request payload for sample: {e}") print(f" Sample data: {sample_data}") print("--- End Sample ---") continue # Skip to next sample - print(f"\n--- Sample {i+1} ---") + print(f"\n--- Sample {i + 1} ---") try: response = requests.post( diff --git a/eval_protocol/config.py b/eval_protocol/config.py index a1fee7d6..7d75ab01 100644 --- a/eval_protocol/config.py +++ b/eval_protocol/config.py @@ -31,9 +31,9 @@ class RewardKitConfig(BaseModel): default_deployment_target: Optional[Literal["gcp-cloud-run", "aws-lambda", "fireworks", "local"]] = "fireworks" gcp_cloud_run: Optional[GCPCloudRunConfig] = GCPCloudRunConfig() aws_lambda: Optional[AWSLambdaConfig] = AWSLambdaConfig() - evaluator_endpoint_keys: Optional[Dict[str, str]] = ( - {} # Stores generated API keys for self-hosted evaluator endpoints - ) + evaluator_endpoint_keys: Optional[ + Dict[str, str] + ] = {} # Stores generated API keys for self-hosted evaluator endpoints # --- Global variable to hold the loaded configuration --- diff --git a/eval_protocol/dataset_logger/__init__.py b/eval_protocol/dataset_logger/__init__.py index 4d04ce7d..c087b6cd 100644 --- a/eval_protocol/dataset_logger/__init__.py +++ b/eval_protocol/dataset_logger/__init__.py @@ -22,7 +22,6 @@ def read(self, rollout_id=None): # Lazy property that creates the logger only when accessed class _LazyLogger(DatasetLogger): - def log(self, row): return _get_default_logger().log(row) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index ee71a37e..fe58bb8a 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -309,7 +309,7 @@ def load_multi_metrics_folder(self, folder_path): files = self._load_python_files_from_folder(folder_path) self.code_files = files - logger.info(f"Loaded {len(files)} Python files from {folder_path} " f"for multi-metrics evaluation") + logger.info(f"Loaded {len(files)} Python files from {folder_path} for multi-metrics evaluation") return files def load_samples_from_jsonl(self, sample_file, max_samples=5): @@ -327,7 +327,7 @@ def load_samples_from_jsonl(self, sample_file, max_samples=5): sample = json.loads(line) samples.append(sample) except json.JSONDecodeError: - logger.warning(f"Invalid JSON on line {i+1}, skipping") + logger.warning(f"Invalid JSON on line {i + 1}, skipping") logger.info(f"Loaded {len(samples)} samples from {sample_file}") return samples @@ -444,7 +444,9 @@ def _get_combined_requirements(self) -> str: # Fallback for multi_metrics if requirements were loaded differently (hypothetical) # This attribute doesn't exist yet, placeholder for future enhancement if needed. if self._loaded_multi_metric_requirements_str: # type: ignore - requirements_list = [r.strip() for r in self._loaded_multi_metric_requirements_str.splitlines() if r.strip()] # type: ignore + requirements_list = [ + r.strip() for r in self._loaded_multi_metric_requirements_str.splitlines() if r.strip() + ] # type: ignore for req_item in requirements_list: all_requirements_set.add(req_item) @@ -458,7 +460,7 @@ def _simulated_preview(self, samples): for i, sample in enumerate(samples): try: if "messages" not in sample: - raise ValueError(f"Sample {i+1} is missing 'messages' field") + raise ValueError(f"Sample {i + 1} is missing 'messages' field") _ = sample.get("messages", []) _ = sample.get("ground_truth", []) _ = sample.get("tools", []) @@ -486,7 +488,7 @@ def _simulated_preview(self, samples): per_metric_evals=per_metric_evals, ) except Exception as e: - logger.error(f"Error processing sample {i+1}: {str(e)}") + logger.error(f"Error processing sample {i + 1}: {str(e)}") preview_result.add_result( sample_index=i, success=False, @@ -873,7 +875,7 @@ def preview_folder_evaluation( # This function might become redundant or need t if has_main_py and not multi_metrics: py_files = list(Path(evaluator_folder).glob("*.py")) if len(py_files) > 1: - logger.info(f"Auto-detecting multi-metrics mode based on folder structure for preview_folder_evaluation") + logger.info("Auto-detecting multi-metrics mode based on folder structure for preview_folder_evaluation") detected_multi_metrics = True # Call the unified preview_evaluation @@ -947,7 +949,7 @@ def create_evaluation( ) elif ts_mode_config: # ts_mode_config already handled in Evaluator.__init__ for self.code_files - logger.info(f"Configuring evaluator with direct Python code snippet (ts_mode).") + logger.info("Configuring evaluator with direct Python code snippet (ts_mode).") elif multi_metrics: # Folder-based multi_metrics if not folder: raise ValueError("`folder` must be specified for folder-based multi_metrics mode.") @@ -1008,7 +1010,7 @@ def deploy_folder_evaluation( # This function might become redundant or need to if has_main_py and not multi_metrics: # If user says not multi_metrics, but main.py is at root py_files = list(Path(evaluator_folder_abs).glob("*.py")) if len(py_files) > 1: # Heuristic: if multiple .py files at root with main.py, likely multi-metric - logger.info(f"Auto-detecting multi-metrics mode for deploy_folder_evaluation.") + logger.info("Auto-detecting multi-metrics mode for deploy_folder_evaluation.") detected_multi_metrics = True if detected_multi_metrics: diff --git a/eval_protocol/execution/pipeline.py b/eval_protocol/execution/pipeline.py index 7e7f7c29..e644ba32 100644 --- a/eval_protocol/execution/pipeline.py +++ b/eval_protocol/execution/pipeline.py @@ -78,7 +78,7 @@ def __init__(self, pipeline_cfg: DictConfig): if self.cfg.get("agent") and self.cfg.agent.get("type") == "mcp_agent": if not self.cfg.agent.get("intermediary_server_url"): raise ValueError("agent.intermediary_server_url must be configured for mcp_agent type.") - logger.info(f"Pipeline configured for mcp_agent. IntermediaryMCPClient will be initialized in run().") + logger.info("Pipeline configured for mcp_agent. IntermediaryMCPClient will be initialized in run().") async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str) -> List[Dict[str, Any]]: """Discover available tools from MCP backend for a sample.""" @@ -348,7 +348,7 @@ async def _execute_mcp_agent_rollout( ) except Exception as e_tool_exec: logger.error( - f"Sample {sample_id}, Turn {turn_num+1}: Error executing/parsing tool '{tool_name}': {e_tool_exec}", + f"Sample {sample_id}, Turn {turn_num + 1}: Error executing/parsing tool '{tool_name}': {e_tool_exec}", exc_info=True, ) error_payload = {"error": str(e_tool_exec)} @@ -381,11 +381,11 @@ async def _execute_mcp_agent_rollout( assistant_msg_for_history["content"] = final_llm_text_response current_messages_for_rollout.append(assistant_msg_for_history) final_assistant_output_for_log = final_llm_text_response - logger.info(f"Sample {sample_id}, Turn {turn_num+1}: LLM responded with text. Ending rollout.") + logger.info(f"Sample {sample_id}, Turn {turn_num + 1}: LLM responded with text. Ending rollout.") break else: logger.warning( - f"Sample {sample_id}, Turn {turn_num+1}: LLM provided no content or tool calls. Ending rollout." + f"Sample {sample_id}, Turn {turn_num + 1}: LLM provided no content or tool calls. Ending rollout." ) final_llm_text_response = "LLM provided no actionable response in this turn." assistant_msg_for_history["content"] = final_llm_text_response diff --git a/eval_protocol/gcp_tools.py b/eval_protocol/gcp_tools.py index b19089ca..b6131157 100644 --- a/eval_protocol/gcp_tools.py +++ b/eval_protocol/gcp_tools.py @@ -453,7 +453,7 @@ def ensure_gcp_secret( dry_run=True, ) - print(f"\n2. Simulating deploy to Cloud Run (dry_run=True)") + print("\n2. Simulating deploy to Cloud Run (dry_run=True)") deploy_to_cloud_run( service_name="my-reward-service", image_name_tag=ar_img_name, # Use AR image name @@ -465,7 +465,7 @@ def ensure_gcp_secret( dry_run=True, ) - print(f"\n3. Simulating ensure_artifact_registry_repo_exists (dry_run=True)") + print("\n3. Simulating ensure_artifact_registry_repo_exists (dry_run=True)") ensure_artifact_registry_repo_exists( project_id="my-test-project", region="us-central1", @@ -473,7 +473,7 @@ def ensure_gcp_secret( dry_run=True, ) - print(f"\n4. Simulating ensure_gcp_secret (dry_run=True)") + print("\n4. Simulating ensure_gcp_secret (dry_run=True)") ensure_gcp_secret( project_id="my-test-project", secret_id="my-test-api-key-secret", diff --git a/eval_protocol/generation/clients.py b/eval_protocol/generation/clients.py index 45be6ab0..873f587e 100644 --- a/eval_protocol/generation/clients.py +++ b/eval_protocol/generation/clients.py @@ -226,7 +226,7 @@ async def generate( # ... (rest of the error handling as before) ... elif response.status == 429: # Rate limit retry_after = int(response.headers.get("Retry-After", "5")) - logger.warning(f"Rate limited. Retrying after {retry_after}s (attempt {attempt+1}).") + logger.warning(f"Rate limited. Retrying after {retry_after}s (attempt {attempt + 1}).") await asyncio.sleep(retry_after) elif response.status in [401, 403]: # Auth errors error_text = await response.text() @@ -234,7 +234,7 @@ async def generate( return GenerationResult() # Empty result on auth error elif response.status >= 500: # Server errors logger.warning( - f"Fireworks API Server Error ({response.status}). Retrying (attempt {attempt+1})." + f"Fireworks API Server Error ({response.status}). Retrying (attempt {attempt + 1})." ) await asyncio.sleep(2**attempt) else: # Other client errors diff --git a/eval_protocol/generic_server.py b/eval_protocol/generic_server.py index 7efd444b..04e111e2 100644 --- a/eval_protocol/generic_server.py +++ b/eval_protocol/generic_server.py @@ -153,7 +153,7 @@ def load_reward_function(import_string: str): try: load_reward_function(args.import_string) except Exception: - print(f"Failed to load reward function. Exiting.") + print("Failed to load reward function. Exiting.") exit(1) if not _LOADED_REWARD_FUNCTION: diff --git a/eval_protocol/integrations/trl.py b/eval_protocol/integrations/trl.py index 62f89a34..9873c593 100644 --- a/eval_protocol/integrations/trl.py +++ b/eval_protocol/integrations/trl.py @@ -180,7 +180,7 @@ def trl_reward_pipeline( if scores: logger.debug( f"Batch rewards calculated by TRL adapter. Count: {len(scores)}, " - f"Min: {min(scores)}, Max: {max(scores)}, Avg: {sum(scores)/len(scores):.2f}" + f"Min: {min(scores)}, Max: {max(scores)}, Avg: {sum(scores) / len(scores):.2f}" ) return scores diff --git a/eval_protocol/mcp/execution/base_policy.py b/eval_protocol/mcp/execution/base_policy.py index 819b33dd..bdced48a 100644 --- a/eval_protocol/mcp/execution/base_policy.py +++ b/eval_protocol/mcp/execution/base_policy.py @@ -109,7 +109,6 @@ def add_tool_response( # Add control plane metadata if provided if reward != 0.0 or terminated or info: - tool_message["metadata"] = { "reward": reward, "terminated": terminated, @@ -182,7 +181,7 @@ async def _generate_live_tool_calls( # This is crucial for proper tool call ID management in add_tool_response assistant_message_for_history = { "role": "assistant", - "content": response["choices"][0]["message"].get("content", ""), + "content": response["choices"][0]["message"]["content"], } usage_stats = CompletionUsage( prompt_tokens=response["usage"]["prompt_tokens"], diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index b0359d79..1a36afef 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -80,7 +80,7 @@ def execute_rollouts( elif playback_mode: logger.info(f"๐ŸŽฌ Playback mode: Using recorded data from {playback_file}") else: - logger.info(f"๐Ÿš€ Live mode: No recording/playback") + logger.info("๐Ÿš€ Live mode: No recording/playback") # Initialize OpenAI format logging for terminated trajectories only openai_logger = None @@ -285,7 +285,6 @@ async def _execute_rollout( # Execute each tool call sequentially for tool_call in tool_calls: - # Execute tool call for this environment observation, reward, env_end, info = await envs.step(rollout_idx, tool_call) diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index f529a21d..c7f284f7 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -72,7 +72,7 @@ def __init__( self._setup_litellm_caching(use_caching, cache_type, redis_url) logger.info(f"โœ… Initialized LiteLLM policy: {self.model_id}") else: - logger.info(f"๐ŸŽฌ Playback mode: Skipping LiteLLM initialization for performance") + logger.info("๐ŸŽฌ Playback mode: Skipping LiteLLM initialization for performance") def _setup_litellm_caching( self, use_caching: bool, cache_type: Literal["memory", "redis", "dual", "s3", "disk"], redis_url: Optional[str] diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index cf942a0f..fb8d8caa 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -141,7 +141,7 @@ def _get_session_id(self, ctx: Context) -> str: Creates stable session IDs based on client info (seed + config + client details) for consistent session management across reconnections. """ - print(f"๐Ÿ” _get_session_id: Starting session ID extraction") + print("๐Ÿ” _get_session_id: Starting session ID extraction") print(f"๐Ÿ” _get_session_id: ctx type: {type(ctx)}") print(f"๐Ÿ” _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}") @@ -208,7 +208,6 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]: return self.sessions[session_id] def _register_session_reset_endpoint(self): - @self.mcp.custom_route("/control/reset_session", methods=["POST"]) async def reset_session_endpoint(request: Request) -> JSONResponse: session_id = request.headers.get("mcp-session-id") diff --git a/eval_protocol/mcp/session/manager.py b/eval_protocol/mcp/session/manager.py index 71c23af0..a7ae679a 100644 --- a/eval_protocol/mcp/session/manager.py +++ b/eval_protocol/mcp/session/manager.py @@ -226,4 +226,4 @@ async def close(self): print(f"๐Ÿงน Closing {self.n} MCP sessions...") tasks = [self.connection_manager.close_session(session) for session in self.sessions] await asyncio.gather(*tasks) - print(f"โœ… All MCP sessions closed.") + print("โœ… All MCP sessions closed.") diff --git a/eval_protocol/mcp/simulation_server.py b/eval_protocol/mcp/simulation_server.py index f18bb0cf..7dfc11eb 100644 --- a/eval_protocol/mcp/simulation_server.py +++ b/eval_protocol/mcp/simulation_server.py @@ -398,7 +398,7 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kwargs): host: Host to bind to **kwargs: Additional arguments for uvicorn """ - print(f"๐Ÿ“ก Starting simulation server with StreamableHTTPSessionManager") + print("๐Ÿ“ก Starting simulation server with StreamableHTTPSessionManager") print(f"๐ŸŽฎ Domain tools: {list(self._domain_tools.keys())}") print(f"๐Ÿ“ฆ Domain resources: {list(self._domain_resources.keys())}") if self.production_server_app: diff --git a/eval_protocol/mcp_agent/intermediary_server.py b/eval_protocol/mcp_agent/intermediary_server.py index 7aa3f399..368f0232 100644 --- a/eval_protocol/mcp_agent/intermediary_server.py +++ b/eval_protocol/mcp_agent/intermediary_server.py @@ -23,8 +23,7 @@ logger = logging.getLogger(__name__) # logger.setLevel(logging.DEBUG) # Removed: Let level be set by main config -from mcp.server.fastmcp.server import Context as FastMCPContext -from mcp.server.fastmcp.server import FastMCP +from mcp.server.fastmcp.server import Context as FastMCPContext, FastMCP # RequestContext is not directly used by handlers anymore, mcp_ctx is. diff --git a/eval_protocol/mcp_agent/orchestration/local_docker_client.py b/eval_protocol/mcp_agent/orchestration/local_docker_client.py index 56a7c6ba..58c098dd 100644 --- a/eval_protocol/mcp_agent/orchestration/local_docker_client.py +++ b/eval_protocol/mcp_agent/orchestration/local_docker_client.py @@ -198,7 +198,7 @@ async def _perform_startup_check(self, url: str, check: Dict[str, Any]) -> bool: res.raise_for_status() return True except Exception as e: - logger.warning(f"Startup check fail {attempt+1}/5: {e}") + logger.warning(f"Startup check fail {attempt + 1}/5: {e}") if attempt < 4: await asyncio.sleep(2) return False @@ -233,7 +233,6 @@ async def provision_instances( and (template_details or backend_config.template_data_path_host) and backend_config.container_template_data_path ): - host_path_for_commit = template_details or backend_config.template_data_path_host if not host_path_for_commit or not backend_config.container_template_data_path: raise ValueError( diff --git a/eval_protocol/mcp_agent/session.py b/eval_protocol/mcp_agent/session.py index d126dc50..a4e91550 100644 --- a/eval_protocol/mcp_agent/session.py +++ b/eval_protocol/mcp_agent/session.py @@ -12,8 +12,10 @@ # Option 1: Try mcp.server.transport # from mcp.server.transport import ReadStream, WriteStream # Option 2: If not found, use typing.Any as a fallback for type hints -from typing import Any as ReadStream # Fallback if specific types are not found -from typing import Any as WriteStream +from typing import ( + Any as ReadStream, # Fallback if specific types are not found + Any as WriteStream, +) from mcp.server.session import ServerSession # Correct base class diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index efea26a7..c5c4d62e 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -290,7 +290,7 @@ def delete_fireworks_secret( test_api_key = get_fireworks_api_key() # Not passed directly, functions will resolve test_api_base = get_fireworks_api_base() - logger.info(f"Attempting to use the following configuration for testing Fireworks secrets API:") + logger.info("Attempting to use the following configuration for testing Fireworks secrets API:") logger.info(f" Resolved FIREWORKS_ACCOUNT_ID: {test_account_id}") logger.info(f" Resolved FIREWORKS_API_BASE: {test_api_base}") logger.info( diff --git a/eval_protocol/pytest/default_dataset_adapter.py b/eval_protocol/pytest/default_dataset_adapter.py index 87377cff..7c4a7d73 100644 --- a/eval_protocol/pytest/default_dataset_adapter.py +++ b/eval_protocol/pytest/default_dataset_adapter.py @@ -7,4 +7,4 @@ def default_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]: """ Default dataset adapter that simply returns the rows as is. """ - return [EvaluationRow(**row) for row in rows] \ No newline at end of file + return [EvaluationRow(**row) for row in rows] diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index b7376e9c..e9bbc1e4 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -86,12 +86,12 @@ def start(self) -> None: try: with open(self._log_file_path, "r") as f: log_content = f.read() - print(f"โŒ Server failed to start!") + print("โŒ Server failed to start!") print(f"๐Ÿ“‹ Server log ({self._log_file_path}):") print("=" * 50) print(log_content) print("=" * 50) - raise RuntimeError(f"Server failed to start or become ready. Check log above for details.") + raise RuntimeError("Server failed to start or become ready. Check log above for details.") except Exception as e: stdout, stderr = self.process.communicate() raise RuntimeError(f"Server failed to start or become ready. stderr: {stderr}, log error: {e}") @@ -108,7 +108,7 @@ def _wait_for_server_ready(self, timeout: int = 15) -> bool: while time.time() - start_time < timeout: # Check if process is still running if self.process.poll() is not None: - print(f"Server process exited early") + print("Server process exited early") return False try: diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 38f66d54..c5717113 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -590,13 +590,13 @@ async def _execute_with_semaphore(row): # Check threshold after logging if threshold is not None and not passed: - assert ( - agg_score >= threshold.success - ), f"Aggregated score {agg_score:.3f} below threshold {threshold.success}" + assert agg_score >= threshold.success, ( + f"Aggregated score {agg_score:.3f} below threshold {threshold.success}" + ) if threshold.standard_deviation is not None: - assert ( - score_std <= threshold.standard_deviation - ), f"Standard deviation {score_std:.3f} above threshold {threshold.standard_deviation}" + assert score_std <= threshold.standard_deviation, ( + f"Standard deviation {score_std:.3f} above threshold {threshold.standard_deviation}" + ) except AssertionError: _log_eval_error("finished", data if "data" in locals() else None, passed=False) @@ -944,9 +944,9 @@ def run_evaluation_test_direct( pass if passed_threshold is not None and not passed: - assert ( - agg_score >= passed_threshold.success - ), f"Aggregated score {agg_score:.3f} below threshold {passed_threshold}" + assert agg_score >= passed_threshold.success, ( + f"Aggregated score {agg_score:.3f} below threshold {passed_threshold}" + ) return {"summary": summary_obj, "results": all_results} except Exception: diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 186f7c7c..617b9e85 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -216,7 +216,9 @@ def generate_parameter_combinations( else: messages = [None] # type: ignore - kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore + kwargs: List[Optional[EvaluationInputParam]] = ( + evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] + ) # type: ignore # Generate all combinations for ds in datasets: @@ -255,9 +257,9 @@ async def retry_handler(failed_row: EvaluationRow): current_attempts = retry_counts.get(rollout_id, 0) if current_attempts >= max_retry: - assert ( - failed_row.rollout_status and failed_row.rollout_status.status == "error" - ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + assert failed_row.rollout_status and failed_row.rollout_status.status == "error", ( + f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + ) failed_permanently.append(failed_row) await queue.put(failed_row) # put failed row on queue return diff --git a/eval_protocol/resources.py b/eval_protocol/resources.py index 3ca63aaf..d34c1a05 100644 --- a/eval_protocol/resources.py +++ b/eval_protocol/resources.py @@ -57,7 +57,7 @@ def setup(self) -> None: return try: - logger.debug(f"Setting up LLM deployment for model: " f"{self.llm_instance.model}") + logger.debug(f"Setting up LLM deployment for model: {self.llm_instance.model}") # For on-demand deployments, call apply() if hasattr(self.llm_instance, "deployment_type") and self.llm_instance.deployment_type == "on-demand": @@ -68,7 +68,7 @@ def setup(self) -> None: self._client = self.llm_instance self._is_setup = True - logger.info(f"LLM resource setup completed for model: " f"{self.llm_instance.model}") + logger.info(f"LLM resource setup completed for model: {self.llm_instance.model}") except Exception as e: logger.error(f"Failed to setup LLM resource: {e}") diff --git a/eval_protocol/rewards/accuracy_length.py b/eval_protocol/rewards/accuracy_length.py index 310353f5..b8e64eb2 100644 --- a/eval_protocol/rewards/accuracy_length.py +++ b/eval_protocol/rewards/accuracy_length.py @@ -132,7 +132,7 @@ def cosine_scaled_accuracy_length_reward( # Prepare detailed reason reward_type = "reward" if accuracy_success else "penalty" length_reason = ( - f"Length-based {reward_type}: {token_count}/{max_length} tokens, " f"cosine factor: {cosine_factor:.2f}" + f"Length-based {reward_type}: {token_count}/{max_length} tokens, cosine factor: {cosine_factor:.2f}" ) combined_reason = ( diff --git a/eval_protocol/rewards/apps_coding_reward.py b/eval_protocol/rewards/apps_coding_reward.py index 3089f56e..6cbd63b9 100644 --- a/eval_protocol/rewards/apps_coding_reward.py +++ b/eval_protocol/rewards/apps_coding_reward.py @@ -244,7 +244,7 @@ def evaluate_apps_solution(messages: List[Message], ground_truth: Optional[str], # as our system prompt now asks for a main() that handles IO. # The generated code itself should be a runnable script. del in_outs_for_check["fn_name"] - logger.info(f"Removed 'fn_name' from in_outs for check_correctness to use standard_input path.") + logger.info("Removed 'fn_name' from in_outs for check_correctness to use standard_input path.") final_code_to_execute = code_solution # The model's full response (after extraction) diff --git a/eval_protocol/rewards/apps_execution_utils.py b/eval_protocol/rewards/apps_execution_utils.py index c40baffe..74e3fbfa 100644 --- a/eval_protocol/rewards/apps_execution_utils.py +++ b/eval_protocol/rewards/apps_execution_utils.py @@ -47,7 +47,7 @@ def _temp_run( # Temporarily disable stdout/stderr redirection to see debug prints from run_test # sys.stdout = open(os.devnull, "w") # sys.stderr = open(os.devnull, "w") - print(f"[_temp_run] Executing run_test for sample. Debug prints from run_test should be visible.") + print("[_temp_run] Executing run_test for sample. Debug prints from run_test should be visible.") try: res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) diff --git a/eval_protocol/rewards/apps_testing_util.py b/eval_protocol/rewards/apps_testing_util.py index 0c349797..84f13f8a 100644 --- a/eval_protocol/rewards/apps_testing_util.py +++ b/eval_protocol/rewards/apps_testing_util.py @@ -109,7 +109,7 @@ def _load_module_from_string(module_name, code_string): try: exec(code_string, module.__dict__) # sys.modules[module_name] = module # Optional: if other parts of the code expect it in sys.modules - except Exception as e: + except Exception: raise return module diff --git a/eval_protocol/rewards/bfcl_reward.py b/eval_protocol/rewards/bfcl_reward.py index 44053ad8..ccf1cda4 100644 --- a/eval_protocol/rewards/bfcl_reward.py +++ b/eval_protocol/rewards/bfcl_reward.py @@ -286,7 +286,7 @@ def bfcl_reward( final_score = 0.0 reason = "State or function calls did not perfectly match ground truth." if state_match_score < 0.5: - reason += f" State match failed." + reason += " State match failed." if state_diffs: reason += f" Differences: {json.dumps(state_diffs)}" if func_match_score < 0.5: # Check against 0.5 as perfect score for this component diff --git a/eval_protocol/rewards/code_execution.py b/eval_protocol/rewards/code_execution.py index 6bfbe8d5..52db7db2 100644 --- a/eval_protocol/rewards/code_execution.py +++ b/eval_protocol/rewards/code_execution.py @@ -276,7 +276,6 @@ def _execute_code_in_process(execute_func: Callable, args: Tuple, timeout: int = Returns: Dictionary with execution results """ - import multiprocessing manager = multiprocessing.Manager() result_dict = manager.dict() diff --git a/eval_protocol/rewards/cpp_code.py b/eval_protocol/rewards/cpp_code.py index 9511f0e9..cb324273 100644 --- a/eval_protocol/rewards/cpp_code.py +++ b/eval_protocol/rewards/cpp_code.py @@ -484,7 +484,7 @@ async def run_cpp_test_cases( for i, test_case in enumerate(test_cases): test_input = test_case.get("input", "") expected_output = test_case.get("expected_output", "") - test_name = test_case.get("name", f"Test {i+1}") + test_name = test_case.get("name", f"Test {i + 1}") execution_result = await execute_cpp_code( code=code, diff --git a/eval_protocol/rewards/deepcoder_reward.py b/eval_protocol/rewards/deepcoder_reward.py index d0674dd7..ebdc44bb 100644 --- a/eval_protocol/rewards/deepcoder_reward.py +++ b/eval_protocol/rewards/deepcoder_reward.py @@ -9,9 +9,9 @@ from ..models import EvaluateResult, Message, MetricResult from ..reward_function import reward_function -from .code_execution import _HAS_E2B # Import _HAS_E2B to check E2B availability -from .code_execution import _run_test_cases # Import the main test case runner from .code_execution import ( + _HAS_E2B, # Import _HAS_E2B to check E2B availability + _run_test_cases, # Import the main test case runner compare_outputs, execute_code_with_e2b, execute_javascript_code, diff --git a/eval_protocol/rewards/language_consistency.py b/eval_protocol/rewards/language_consistency.py index 174f137e..bfdd2052 100644 --- a/eval_protocol/rewards/language_consistency.py +++ b/eval_protocol/rewards/language_consistency.py @@ -646,7 +646,7 @@ def language_consistency_reward( if total_counted == 0: return EvaluateResult( score=0.0, - reason=f"No language markers found in model response to evaluate.", + reason="No language markers found in model response to evaluate.", metrics={ "language_consistency": MetricResult( score=0.0, diff --git a/eval_protocol/rewards/tag_count.py b/eval_protocol/rewards/tag_count.py index 1d0e04e7..83acef6f 100644 --- a/eval_protocol/rewards/tag_count.py +++ b/eval_protocol/rewards/tag_count.py @@ -126,9 +126,9 @@ def _get_tag_reason(tag: str, opening_count: int, closing_count: int, require_ba return f"Found {opening_count} balanced '{tag}' tag(s)" else: if require_balanced: - return f"Unbalanced tags: {opening_count} opening vs " f"{closing_count} closing '{tag}' tags" + return f"Unbalanced tags: {opening_count} opening vs {closing_count} closing '{tag}' tags" else: - return f"Found '{tag}' tags (unbalanced: {opening_count} opening, " f"{closing_count} closing)" + return f"Found '{tag}' tags (unbalanced: {opening_count} opening, {closing_count} closing)" def _get_overall_reason( diff --git a/eval_protocol/stats/__init__.py b/eval_protocol/stats/__init__.py index c327d2ed..6cf5a888 100644 --- a/eval_protocol/stats/__init__.py +++ b/eval_protocol/stats/__init__.py @@ -1,5 +1,3 @@ """Statistical utilities for evaluation reporting (confidence intervals, etc.).""" from .confidence_intervals import compute_fixed_set_mu_ci # re-export - - diff --git a/eval_protocol/stats/confidence_intervals.py b/eval_protocol/stats/confidence_intervals.py index bf78934c..70a4dd2d 100644 --- a/eval_protocol/stats/confidence_intervals.py +++ b/eval_protocol/stats/confidence_intervals.py @@ -112,5 +112,3 @@ def compute_fixed_set_mu_ci( ci_high = min(1.0, mu_hat + margin) return float(mu_hat), float(ci_low), float(ci_high) - - diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index 46630cdf..e5e6e4a3 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -195,7 +195,6 @@ def _should_update_status(self, row: "EvaluationRow") -> bool: """Check if a row's status should be updated to 'stopped'.""" # Check if the row has running status and a PID if row.eval_metadata and row.eval_metadata.status == "running" and row.pid is not None: - # Check if the process is still running try: process = psutil.Process(row.pid) diff --git a/examples/adapters/README.md b/examples/adapters/README.md index 591bbbb2..4b8501ef 100644 --- a/examples/adapters/README.md +++ b/examples/adapters/README.md @@ -43,34 +43,6 @@ Loads datasets from HuggingFace Hub and converts them to EvaluationRow format. pip install 'eval-protocol[huggingface]' ``` -### 3. BigQuery Adapter (`bigquery_example.py`) - -Queries data from Google BigQuery tables and converts them to EvaluationRow format. - -**Features:** -- Execute custom SQL queries against BigQuery datasets -- Support for parameterized queries and batch processing -- Built-in convenience adapters for conversation and Q&A data -- Rich metadata preservation including query information -- Integration with Google Cloud authentication -- Schema introspection and dataset exploration - -**Prerequisites:** -```bash -pip install 'eval-protocol[bigquery]' -``` - -**Environment Variables:** -```bash -export GOOGLE_CLOUD_PROJECT="your-project-id" -export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service-account.json" # optional -``` - -**Alternative Authentication:** -```bash -gcloud auth application-default login -``` - ## Running the Examples ### Basic Usage @@ -82,9 +54,6 @@ python examples/adapters/langfuse_example.py # Run HuggingFace example python examples/adapters/huggingface_example.py -# Run BigQuery example -python examples/adapters/bigquery_example.py - # Run GSM8K replacement example python examples/adapters/gsm8k_replacement_example.py ``` @@ -97,11 +66,6 @@ export LANGFUSE_PUBLIC_KEY="pk_..." export LANGFUSE_SECRET_KEY="sk_..." python examples/adapters/langfuse_example.py -# Set up Google Cloud credentials for BigQuery -export GOOGLE_CLOUD_PROJECT="your-project-id" -export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service-account.json" # optional -python examples/adapters/bigquery_example.py - # HuggingFace works without credentials for public datasets python examples/adapters/huggingface_example.py ``` diff --git a/examples/adapters/gsm8k_replacement_example.py b/examples/adapters/gsm8k_replacement_example.py index a86de261..3c18775e 100644 --- a/examples/adapters/gsm8k_replacement_example.py +++ b/examples/adapters/gsm8k_replacement_example.py @@ -1,8 +1,8 @@ """ GSM8K Replacement Example -This example shows how to replace the static GSM8K JSONL file -(development/gsm8k_sample.jsonl) with the dynamic HuggingFace adapter +This example shows how to replace the static GSM8K JSONL file +(development/gsm8k_sample.jsonl) with the dynamic HuggingFace adapter to get fresh data from the GSM8K dataset. """ @@ -18,17 +18,17 @@ def load_original_gsm8k_sample() -> List[dict]: """Load the original GSM8K sample file for comparison.""" sample_file = Path("development/gsm8k_sample.jsonl") - + if not sample_file.exists(): print(f"โš ๏ธ Original sample file not found: {sample_file}") return [] - + data = [] - with open(sample_file, 'r') as f: + with open(sample_file, "r") as f: for line in f: if line.strip(): data.append(json.loads(line)) - + return data @@ -36,52 +36,52 @@ def demonstrate_old_vs_new_approach(): """Compare the old static file approach with the new adapter approach.""" print("๐Ÿ“Š Comparing Old vs New Approach") print("=" * 50) - + # OLD APPROACH: Static JSONL file print("๐Ÿ—‚๏ธ OLD APPROACH: Static JSONL File") print("-" * 35) - + original_data = load_original_gsm8k_sample() print(f"Loaded {len(original_data)} items from static file") - + if original_data: sample = original_data[0] print(f"Sample item fields: {list(sample.keys())}") print(f"Sample question: {sample.get('user_query', '')[:100]}...") print(f"Sample ground truth: {sample.get('ground_truth_for_eval', '')[:100]}...") - - print("\n" + "="*50 + "\n") - + + print("\n" + "=" * 50 + "\n") + # NEW APPROACH: HuggingFace Adapter print("๐Ÿค— NEW APPROACH: HuggingFace Adapter") print("-" * 38) - + try: # Create adapter adapter = create_gsm8k_adapter( system_prompt="You are a helpful assistant that solves math problems step by step." ) - + print("โœ… GSM8K adapter created successfully") - + # Get the same number of items as the original file num_items = len(original_data) if original_data else 6 rows = list(adapter.get_evaluation_rows(limit=num_items)) - + print(f"Retrieved {len(rows)} evaluation rows from HuggingFace") - + if rows: sample_row = rows[0] - print(f"Sample EvaluationRow fields: messages, tools, input_metadata, ground_truth") - + print("Sample EvaluationRow fields: messages, tools, input_metadata, ground_truth") + # Show the question from messages user_msg = next((msg for msg in sample_row.messages if msg.role == "user"), None) if user_msg: print(f"Sample question: {user_msg.content[:100]}...") - + if sample_row.ground_truth: print(f"Sample ground truth: {sample_row.ground_truth[:100]}...") - + except ImportError as e: print(f"โŒ Error: {e}") print("Install HuggingFace dependencies: pip install 'eval-protocol[huggingface]'") @@ -89,9 +89,9 @@ def demonstrate_old_vs_new_approach(): except Exception as e: print(f"โŒ Error with adapter: {e}") return - - print("\n" + "="*50 + "\n") - + + print("\n" + "=" * 50 + "\n") + # COMPARISON print("๐Ÿ” Key Differences") print("-" * 20) @@ -101,7 +101,7 @@ def demonstrate_old_vs_new_approach(): print(" โŒ Manual data preparation required") print(" โŒ Limited to pre-selected subset") print(" โŒ Requires manual format conversion") - + print("\nNEW APPROACH:") print(" โœ… Access to full GSM8K dataset (8,792 test problems)") print(" โœ… Automatic format conversion to EvaluationRow") @@ -115,10 +115,11 @@ def show_migration_example(): """Show how to migrate existing code from JSONL to adapter.""" print("\n๐Ÿ”„ Code Migration Example") print("=" * 30) - + print("OLD CODE:") print("-" * 10) - print(""" + print( + """ # Old way with static JSONL file input_dataset = ["development/gsm8k_sample.jsonl"] @@ -134,11 +135,13 @@ def show_migration_example(): ] ground_truth = item["ground_truth_for_eval"] # ... more manual processing -""") - +""" + ) + print("\nNEW CODE:") print("-" * 10) - print(""" + print( + """ # New way with HuggingFace adapter from eval_protocol.adapters.huggingface import create_gsm8k_adapter @@ -149,7 +152,7 @@ def show_migration_example(): # Get evaluation rows (already in correct format) evaluation_rows = list(adapter.get_evaluation_rows( - split="test", # or "train" + split="test", # or "train" limit=100, # Can get much more data than static file model_name="gpt-4", temperature=0.0, @@ -175,8 +178,9 @@ def custom_gsm8k_transform(row): config_name="main", transform_fn=custom_gsm8k_transform ) -""") - +""" + ) + print("\nโœ… Benefits of Migration:") print(" - More data available (6 โ†’ 8,792 problems)") print(" - Automatic format handling") @@ -189,30 +193,30 @@ def practical_migration_demo(): """Show a practical example of using the adapter in evaluation.""" print("\n๐Ÿงช Practical Evaluation Example") print("=" * 35) - + try: # Create adapter adapter = create_gsm8k_adapter() - + # Get a few problems for evaluation print("Loading GSM8K problems...") rows = list(adapter.get_evaluation_rows(limit=3)) print(f"โœ… Loaded {len(rows)} problems from GSM8K test set") - + # Simulate evaluation workflow for i, row in enumerate(rows): - print(f"\n๐Ÿ“ Problem {i+1}:") - + print(f"\n๐Ÿ“ Problem {i + 1}:") + # Show the problem user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" Question: {user_msg.content[:150]}...") - + # In a real scenario, you'd generate a response with your LLM # For this demo, we'll add a dummy response dummy_response = "Let me solve this step by step. After working through the math, the answer is 42." row.messages.append(Message(role="assistant", content=dummy_response)) - + # Evaluate with math reward function if row.ground_truth: try: @@ -222,7 +226,7 @@ def practical_migration_demo(): ) print(f" ๐Ÿ“Š Math evaluation score: {result.score:.2f}") print(f" ๐Ÿ’ญ Evaluation reason: {result.reason[:100]}...") - + # Show metadata if row.input_metadata: print(f" ๐Ÿท๏ธ Row ID: {row.input_metadata.row_id}") @@ -230,12 +234,12 @@ def practical_migration_demo(): dataset_info = row.input_metadata.dataset_info print(f" ๐Ÿ“š Dataset: {dataset_info.get('dataset_name', 'N/A')}") print(f" ๐Ÿ“ Row index: {dataset_info.get('row_index', 'N/A')}") - + except Exception as e: print(f" โŒ Evaluation error: {e}") - + print(f"\nโœ… Successfully processed {len(rows)} problems using the new adapter approach!") - + except Exception as e: print(f"โŒ Error in practical demo: {e}") @@ -244,9 +248,9 @@ def performance_comparison(): """Compare performance characteristics of both approaches.""" print("\nโšก Performance Considerations") print("=" * 35) - + import time - + # Time the old approach (if file exists) original_data = load_original_gsm8k_sample() if original_data: @@ -259,7 +263,7 @@ def performance_comparison(): print("๐Ÿ“ Static file not available for timing") old_time = 0 processed_old = 0 - + # Time the new approach try: start_time = time.time() @@ -267,9 +271,9 @@ def performance_comparison(): rows = list(adapter.get_evaluation_rows(split="test", limit=max(6, processed_old))) new_time = time.time() - start_time processed_new = len(rows) - + print(f"๐Ÿค— HuggingFace adapter: {processed_new} items in {new_time:.4f}s") - + if old_time > 0: if new_time > old_time: factor = new_time / old_time @@ -277,11 +281,11 @@ def performance_comparison(): else: factor = old_time / new_time print(f" ๐Ÿ“Š Adapter is {factor:.1f}x faster!") - - print(f"\n๐Ÿ’ก Trade-offs:") + + print("\n๐Ÿ’ก Trade-offs:") print(f" Static file: Fast ({old_time:.4f}s) but limited data ({processed_old} items)") print(f" Adapter: Slower ({new_time:.4f}s) but access to full dataset ({processed_new}+ items)") - + except Exception as e: print(f"โŒ Error timing adapter: {e}") @@ -293,16 +297,16 @@ def main(): print("This example shows how to replace the static GSM8K JSONL file") print("with the dynamic HuggingFace adapter for better data access.") print() - + # Run all demonstrations demonstrate_old_vs_new_approach() show_migration_example() practical_migration_demo() performance_comparison() - - print("\n" + "="*50) + + print("\n" + "=" * 50) print("๐ŸŽฏ MIGRATION SUMMARY") - print("="*50) + print("=" * 50) print("1. โœ… Replace static JSONL with HuggingFace adapter") print("2. โœ… Get access to full GSM8K dataset (8,792 test problems)") print("3. โœ… Automatic conversion to EvaluationRow format") @@ -318,4 +322,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/adapters/huggingface_example.py b/examples/adapters/huggingface_example.py index 2d79eae3..c4f1d907 100644 --- a/examples/adapters/huggingface_example.py +++ b/examples/adapters/huggingface_example.py @@ -9,10 +9,10 @@ from typing import List from eval_protocol.adapters.huggingface import ( - create_huggingface_adapter, + HuggingFaceAdapter, create_gsm8k_adapter, + create_huggingface_adapter, create_math_adapter, - HuggingFaceAdapter, ) from eval_protocol.models import EvaluationRow @@ -21,44 +21,47 @@ def gsm8k_example(): """Example using the GSM8K dataset.""" print("๐Ÿ“š Example 1: GSM8K Dataset") print("-" * 30) - + try: # Create GSM8K adapter using the convenience method adapter = create_gsm8k_adapter( - split="test", - system_prompt="You are a helpful assistant that solves math problems step by step." + split="test", system_prompt="You are a helpful assistant that solves math problems step by step." ) - + print("โœ… GSM8K adapter created successfully") print(f"๐Ÿ“Š Dataset info: {adapter.get_dataset_info()}") - + # Get a few evaluation rows - rows = list(adapter.get_evaluation_rows( - limit=3, - model_name="gpt-4", - temperature=0.0, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=3, + model_name="gpt-4", + temperature=0.0, + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows from GSM8K test set:") - + for i, row in enumerate(rows): - print(f"\n Row {i+1}:") + print(f"\n Row {i + 1}:") print(f" - ID: {row.input_metadata.row_id if row.input_metadata else 'N/A'}") print(f" - Messages: {len(row.messages)}") - + # Show the math problem user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - problem = user_message.content[:200] + "..." if len(user_message.content) > 200 else user_message.content + problem = ( + user_message.content[:200] + "..." if len(user_message.content) > 200 else user_message.content + ) print(f" - Problem: {problem}") - + # Show ground truth answer if row.ground_truth: answer_preview = row.ground_truth[:100] + "..." if len(row.ground_truth) > 100 else row.ground_truth print(f" - Ground truth: {answer_preview}") - + print() - + except ImportError as e: print(f"โŒ Error: {e}") print("Install HuggingFace dependencies: pip install 'eval-protocol[huggingface]'") @@ -70,42 +73,44 @@ def math_dataset_example(): """Example using the MATH competition dataset.""" print("๐Ÿงฎ Example 2: MATH Competition Dataset") print("-" * 40) - + try: # Create MATH dataset adapter - adapter = create_math_adapter( - system_prompt="You are an expert mathematician. Solve this step by step." - ) - + adapter = create_math_adapter(system_prompt="You are an expert mathematician. Solve this step by step.") + print("โœ… MATH dataset adapter created successfully") print(f"๐Ÿ“Š Dataset info: {adapter.get_dataset_info()}") - + # Get a few examples - rows = list(adapter.get_evaluation_rows( - limit=2, - model_name="gpt-4", - temperature=0.1, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=2, + model_name="gpt-4", + temperature=0.1, + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows from MATH test set:") - + for i, row in enumerate(rows): - print(f"\n Row {i+1}:") - + print(f"\n Row {i + 1}:") + # Show the problem user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - problem = user_message.content[:150] + "..." if len(user_message.content) > 150 else user_message.content + problem = ( + user_message.content[:150] + "..." if len(user_message.content) > 150 else user_message.content + ) print(f" - Problem: {problem}") - + # Show metadata if row.input_metadata and row.input_metadata.dataset_info: dataset_info = row.input_metadata.dataset_info - if 'original_type' in dataset_info: + if "original_type" in dataset_info: print(f" - Problem type: {dataset_info['original_type']}") - if 'original_level' in dataset_info: + if "original_level" in dataset_info: print(f" - Level: {dataset_info['original_level']}") - + except Exception as e: print(f"โŒ Error with MATH dataset: {e}") @@ -114,66 +119,70 @@ def custom_dataset_example(): """Example using a custom dataset with transformation function.""" print("๐Ÿ”ง Example 3: Custom Dataset with Transform Function") print("-" * 55) - + try: # Define transformation function for SQuAD dataset def squad_transform(row): """Transform SQuAD row to evaluation format.""" - context = row['context'] - question = row['question'] - answers = row['answers'] - + context = row["context"] + question = row["question"] + answers = row["answers"] + # Get first answer text - answer_text = answers['text'][0] if answers['text'] else "No answer provided" - + answer_text = answers["text"][0] if answers["text"] else "No answer provided" + return { - 'messages': [ - {'role': 'system', 'content': 'Answer the question based on the given context.'}, - {'role': 'user', 'content': f"Context: {context}\\n\\nQuestion: {question}"}, + "messages": [ + {"role": "system", "content": "Answer the question based on the given context."}, + {"role": "user", "content": f"Context: {context}\\n\\nQuestion: {question}"}, ], - 'ground_truth': answer_text, - 'metadata': { - 'dataset': 'squad', - 'context_length': len(context), - 'question_length': len(question), - 'num_possible_answers': len(answers['text']), - } + "ground_truth": answer_text, + "metadata": { + "dataset": "squad", + "context_length": len(context), + "question_length": len(question), + "num_possible_answers": len(answers["text"]), + }, } - + # Create adapter with transformation function adapter = create_huggingface_adapter( dataset_id="squad", transform_fn=squad_transform, ) - + print("โœ… Custom dataset adapter created successfully") - + # Get dataset info info = adapter.get_dataset_info() print(f"๐Ÿ“Š Dataset info: {info}") - + # Get a few examples - rows = list(adapter.get_evaluation_rows( - split="validation", # SQuAD has train/validation splits - limit=2, - model_name="gpt-3.5-turbo", - )) - + rows = list( + adapter.get_evaluation_rows( + split="validation", # SQuAD has train/validation splits + limit=2, + model_name="gpt-3.5-turbo", + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows:") - + for i, row in enumerate(rows): - print(f"\n Row {i+1}:") + print(f"\n Row {i + 1}:") print(f" - Messages: {len(row.messages)}") - + # Show question user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - question = user_message.content[:100] + "..." if len(user_message.content) > 100 else user_message.content + question = ( + user_message.content[:100] + "..." if len(user_message.content) > 100 else user_message.content + ) print(f" - Question: {question}") - + # SQuAD answers are complex, so just show if we have ground truth print(f" - Has ground truth: {'Yes' if row.ground_truth else 'No'}") - + except Exception as e: print(f"โŒ Error with custom dataset: {e}") @@ -182,93 +191,85 @@ def local_file_example(): """Example loading a local dataset file.""" print("๐Ÿ“ Example 4: Local Dataset File") print("-" * 35) - + # Create a sample JSONL file for demonstration sample_file = "/tmp/sample_qa.jsonl" sample_data = [ - { - "id": "q1", - "question": "What is the capital of France?", - "answer": "Paris", - "category": "geography" - }, - { - "id": "q2", - "question": "What is 2 + 2?", - "answer": "4", - "category": "math" - }, + {"id": "q1", "question": "What is the capital of France?", "answer": "Paris", "category": "geography"}, + {"id": "q2", "question": "What is 2 + 2?", "answer": "4", "category": "math"}, { "id": "q3", "question": "Who wrote Romeo and Juliet?", "answer": "William Shakespeare", - "category": "literature" - } + "category": "literature", + }, ] - + try: import json - + # Write sample data - with open(sample_file, 'w') as f: + with open(sample_file, "w") as f: for item in sample_data: - f.write(json.dumps(item) + '\n') - + f.write(json.dumps(item) + "\n") + print(f"๐Ÿ“ Created sample file: {sample_file}") - + # Define transformation function for local data def local_qa_transform(row): - """Transform local Q&A data to evaluation format.""" + """Transform local Q&A data to evaluation format.""" return { - 'messages': [ - {'role': 'system', 'content': 'You are a knowledgeable assistant.'}, - {'role': 'user', 'content': row['question']}, + "messages": [ + {"role": "system", "content": "You are a knowledgeable assistant."}, + {"role": "user", "content": row["question"]}, ], - 'ground_truth': row['answer'], - 'metadata': { - 'id': row.get('id'), - 'category': row.get('category'), - 'dataset': 'local_qa_sample', - } + "ground_truth": row["answer"], + "metadata": { + "id": row.get("id"), + "category": row.get("category"), + "dataset": "local_qa_sample", + }, } - + # Load with adapter adapter = HuggingFaceAdapter.from_local( path=sample_file, transform_fn=local_qa_transform, ) - + print("โœ… Local file adapter created successfully") - + # Get all rows - rows = list(adapter.get_evaluation_rows( - model_name="gpt-3.5-turbo", - temperature=0.0, - )) - + rows = list( + adapter.get_evaluation_rows( + model_name="gpt-3.5-turbo", + temperature=0.0, + ) + ) + print(f"\nLoaded {len(rows)} rows from local file:") - + for i, row in enumerate(rows): - print(f"\n Row {i+1}:") - + print(f"\n Row {i + 1}:") + # Show question and answer user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" - Question: {user_msg.content}") - + if row.ground_truth: print(f" - Answer: {row.ground_truth}") - + # Show original metadata if row.input_metadata and row.input_metadata.dataset_info: - original_data = {k: v for k, v in row.input_metadata.dataset_info.items() if k.startswith('original_')} + original_data = {k: v for k, v in row.input_metadata.dataset_info.items() if k.startswith("original_")} if original_data: print(f" - Original data: {original_data}") - + # Clean up os.remove(sample_file) - print(f"\n๐Ÿงน Cleaned up sample file") - + print("\n๐Ÿงน Cleaned up sample file") + except Exception as e: print(f"โŒ Error with local file: {e}") @@ -277,35 +278,34 @@ def evaluation_integration_example(): """Show how to integrate with evaluation functions.""" print("\n๐Ÿงช Example 5: Integration with Evaluation") print("-" * 45) - + try: # Import evaluation functions - from eval_protocol.rewards.math import math_reward from eval_protocol.rewards.accuracy import accuracy_reward - + from eval_protocol.rewards.math import math_reward + # Create GSM8K adapter adapter = create_gsm8k_adapter(split="test") - + # Get a few rows for evaluation rows = list(adapter.get_evaluation_rows(limit=2)) - + print(f"Running evaluation on {len(rows)} GSM8K problems:") - + for i, row in enumerate(rows): - print(f"\n Problem {i+1}:") - + print(f"\n Problem {i + 1}:") + # Show the problem user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" Question: {user_msg.content[:100]}...") - + # For this example, we'll simulate an assistant response # In practice, this would come from your LLM - row.messages.append({ - "role": "assistant", - "content": "Let me solve this step by step... The answer is 42." - }) - + row.messages.append( + {"role": "assistant", "content": "Let me solve this step by step... The answer is 42."} + ) + # Evaluate with math reward if row.ground_truth: try: @@ -315,17 +315,17 @@ def evaluation_integration_example(): ) print(f" Math score: {math_result.score:.2f}") print(f" Reason: {math_result.reason[:100]}...") - + # Also try accuracy reward acc_result = accuracy_reward( messages=row.messages, ground_truth=row.ground_truth, ) print(f" Accuracy score: {acc_result.score:.2f}") - + except Exception as e: print(f" โŒ Evaluation error: {e}") - + except ImportError: print("Evaluation functions not available") except Exception as e: @@ -336,24 +336,26 @@ def batch_processing_example(): """Show how to process datasets in batches.""" print("\n๐Ÿ“ฆ Example 6: Batch Processing") print("-" * 35) - + try: adapter = create_gsm8k_adapter(split="test") - + batch_size = 5 total_processed = 0 - + print(f"Processing GSM8K test set in batches of {batch_size}:") - + # Process in batches for batch_start in range(0, 20, batch_size): # Process first 20 items - batch_rows = list(adapter.get_evaluation_rows( - limit=batch_size, - offset=batch_start, - )) - - print(f" Batch {batch_start//batch_size + 1}: {len(batch_rows)} rows") - + batch_rows = list( + adapter.get_evaluation_rows( + limit=batch_size, + offset=batch_start, + ) + ) + + print(f" Batch {batch_start // batch_size + 1}: {len(batch_rows)} rows") + # Process each row in the batch for row in batch_rows: # Here you would typically: @@ -361,9 +363,9 @@ def batch_processing_example(): # 2. Evaluate the response # 3. Store results total_processed += 1 - + print(f"โœ… Processed {total_processed} rows total") - + except Exception as e: print(f"โŒ Error in batch processing: {e}") @@ -372,40 +374,40 @@ def main(): """Run all examples.""" print("๐Ÿค— HuggingFace Dataset Adapter Examples") print("=" * 50) - + # Run examples gsm8k_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + math_dataset_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + custom_dataset_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + local_file_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + evaluation_integration_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + batch_processing_example() if __name__ == "__main__": try: main() - + print("\nโœ… All examples completed!") print("\nNext steps:") print("1. Choose the dataset that fits your needs") - print("2. Customize the system prompts for your use case") + print("2. Customize the system prompts for your use case") print("3. Integrate with your evaluation pipeline") print("4. Scale up to process full datasets") print("5. Use the EvaluationRow data for training or evaluation") - + except ImportError as e: print(f"โŒ Missing dependencies: {e}") print("Install with: pip install 'eval-protocol[huggingface]'") except Exception as e: - print(f"โŒ Error running examples: {e}") \ No newline at end of file + print(f"โŒ Error running examples: {e}") diff --git a/examples/adapters/langfuse_example.py b/examples/adapters/langfuse_example.py index 78937c80..39fe31f4 100644 --- a/examples/adapters/langfuse_example.py +++ b/examples/adapters/langfuse_example.py @@ -15,16 +15,16 @@ def main(): """Example usage of the Langfuse adapter.""" - + # Configuration - you can set these as environment variables public_key = os.getenv("LANGFUSE_PUBLIC_KEY", "your_public_key_here") - secret_key = os.getenv("LANGFUSE_SECRET_KEY", "your_secret_key_here") + secret_key = os.getenv("LANGFUSE_SECRET_KEY", "your_secret_key_here") host = os.getenv("LANGFUSE_HOST", "https://langfuse-web-prod-zfdbl7ykrq-uc.a.run.app") project_id = os.getenv("LANGFUSE_PROJECT_ID", "cmdj5yxhk0006s6022cyi0prv") - + print(f"Connecting to Langfuse at: {host}") print(f"Project ID: {project_id}\n") - + # Create the adapter try: adapter = create_langfuse_adapter( @@ -41,91 +41,99 @@ def main(): except Exception as e: print(f"โŒ Failed to create adapter: {e}") return - + # Example 1: Get recent evaluation rows print("\n๐Ÿ“Š Example 1: Get recent evaluation rows") try: - rows = list(adapter.get_evaluation_rows( - limit=5, - from_timestamp=datetime.now() - timedelta(days=7), - include_tool_calls=True, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=5, + from_timestamp=datetime.now() - timedelta(days=7), + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} evaluation rows") for i, row in enumerate(rows): - print(f" Row {i+1}:") + print(f" Row {i + 1}:") print(f" - ID: {row.input_metadata.row_id if row.input_metadata else 'N/A'}") print(f" - Messages: {len(row.messages)}") print(f" - Has tools: {'Yes' if row.tools else 'No'}") print(f" - Ground truth: {'Yes' if row.ground_truth else 'No'}") - + # Show first message content (truncated) if row.messages: content = row.messages[0].content or "" preview = content[:100] + "..." if len(content) > 100 else content print(f" - First message: {preview}") print() - + except Exception as e: print(f"โŒ Error retrieving rows: {e}") - + # Example 2: Filter by specific criteria print("\n๐Ÿ” Example 2: Filter by specific criteria") try: - rows = list(adapter.get_evaluation_rows( - limit=3, - tags=["production"], # Filter by tags if available - include_tool_calls=True, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=3, + tags=["production"], # Filter by tags if available + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} rows with 'production' tag") - + except Exception as e: print(f"โŒ Error with filtered query: {e}") - + # Example 3: Get specific traces by ID print("\n๐ŸŽฏ Example 3: Get specific traces by ID") try: # Replace with actual trace IDs from your Langfuse deployment trace_ids = ["trace_id_1", "trace_id_2"] # These would be real IDs - - rows = list(adapter.get_evaluation_rows_by_ids( - trace_ids=trace_ids, - include_tool_calls=True, - )) - + + rows = list( + adapter.get_evaluation_rows_by_ids( + trace_ids=trace_ids, + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} rows by specific IDs") - + except Exception as e: print(f"โŒ Error retrieving specific traces: {e}") - + # Example 4: Extract different types of conversations print("\n๐Ÿ’ฌ Example 4: Analyze conversation types") try: rows = list(adapter.get_evaluation_rows(limit=10, include_tool_calls=True)) - + chat_only = [] tool_calling = [] - + for row in rows: - if row.tools and any(msg.tool_calls for msg in row.messages if hasattr(msg, 'tool_calls') and msg.tool_calls): + if row.tools and any( + msg.tool_calls for msg in row.messages if hasattr(msg, "tool_calls") and msg.tool_calls + ): tool_calling.append(row) else: chat_only.append(row) - + print(f"Chat-only conversations: {len(chat_only)}") print(f"Tool calling conversations: {len(tool_calling)}") - + # Show example of tool calling conversation if tool_calling: row = tool_calling[0] - print(f"\n๐Ÿ”ง Example tool calling conversation:") + print("\n๐Ÿ”ง Example tool calling conversation:") for i, msg in enumerate(row.messages): - print(f" {i+1}. {msg.role}: {msg.content[:50] if msg.content else '[No content]'}...") - if hasattr(msg, 'tool_calls') and msg.tool_calls: + print(f" {i + 1}. {msg.role}: {msg.content[:50] if msg.content else '[No content]'}...") + if hasattr(msg, "tool_calls") and msg.tool_calls: for tool_call in msg.tool_calls: print(f" ๐Ÿ›  Tool call: {tool_call}") - + except Exception as e: print(f"โŒ Error analyzing conversation types: {e}") @@ -133,11 +141,11 @@ def main(): def demonstrate_evaluation_integration(): """Show how to use Langfuse data with evaluation functions.""" print("\n๐Ÿงช Integration with Evaluation Functions") - + # This would typically be in a separate evaluation script try: from eval_protocol.rewards.math import math_reward - + # Create adapter (reuse configuration from main example) adapter = create_langfuse_adapter( public_key=os.getenv("LANGFUSE_PUBLIC_KEY", "your_public_key_here"), @@ -145,13 +153,13 @@ def demonstrate_evaluation_integration(): host=os.getenv("LANGFUSE_HOST", "https://langfuse-web-prod-zfdbl7ykrq-uc.a.run.app"), project_id=os.getenv("LANGFUSE_PROJECT_ID", "cmdj5yxhk0006s6022cyi0prv"), ) - + # Get data and evaluate rows = list(adapter.get_evaluation_rows(limit=3)) - + for i, row in enumerate(rows): - print(f"\nEvaluating row {i+1}:") - + print(f"\nEvaluating row {i + 1}:") + # Only evaluate if we have ground truth if row.ground_truth: try: @@ -164,8 +172,8 @@ def demonstrate_evaluation_integration(): except Exception as e: print(f" โŒ Evaluation failed: {e}") else: - print(f" โš ๏ธ No ground truth available for evaluation") - + print(" โš ๏ธ No ground truth available for evaluation") + except ImportError: print("Math reward function not available") except Exception as e: @@ -175,25 +183,27 @@ def demonstrate_evaluation_integration(): if __name__ == "__main__": print("๐Ÿš€ Langfuse Adapter Example") print("=" * 50) - + # Check if credentials are set - if not all([ - os.getenv("LANGFUSE_PUBLIC_KEY"), - os.getenv("LANGFUSE_SECRET_KEY"), - ]): + if not all( + [ + os.getenv("LANGFUSE_PUBLIC_KEY"), + os.getenv("LANGFUSE_SECRET_KEY"), + ] + ): print("โš ๏ธ To run this example with real data, set environment variables:") print(" export LANGFUSE_PUBLIC_KEY='your_public_key'") print(" export LANGFUSE_SECRET_KEY='your_secret_key'") print(" export LANGFUSE_HOST='your_langfuse_host' # optional") print(" export LANGFUSE_PROJECT_ID='your_project_id' # optional") print() - + main() demonstrate_evaluation_integration() - + print("\nโœ… Example completed!") print("\nNext steps:") print("1. Set up your Langfuse credentials") print("2. Modify the filters and parameters to match your data") print("3. Integrate with your evaluation pipeline") - print("4. Use the converted EvaluationRow data for training or evaluation") \ No newline at end of file + print("4. Use the converted EvaluationRow data for training or evaluation") diff --git a/examples/aime2025_chat_completion/README.md b/examples/aime2025_chat_completion/README.md index dbe79527..69a6ad6f 100644 --- a/examples/aime2025_chat_completion/README.md +++ b/examples/aime2025_chat_completion/README.md @@ -19,6 +19,3 @@ Environment variables expected: - `FIREWORKS_API_KEY` To scale up, adjust parameters in the decorator (e.g., `threshold_of_success`, `max_dataset_rows`). - - - diff --git a/examples/aime2025_chat_completion/__init__.py b/examples/aime2025_chat_completion/__init__.py index 8bcaacfb..470d6936 100644 --- a/examples/aime2025_chat_completion/__init__.py +++ b/examples/aime2025_chat_completion/__init__.py @@ -1,4 +1 @@ __all__ = ["main"] - - - diff --git a/examples/aime2025_chat_completion/main.py b/examples/aime2025_chat_completion/main.py index 92c6dd83..b6d12976 100644 --- a/examples/aime2025_chat_completion/main.py +++ b/examples/aime2025_chat_completion/main.py @@ -64,9 +64,7 @@ def evaluate( score=0.0, reason="No messages provided", is_score_valid=False, - metrics={ - "parse_status": MetricResult(score=0.0, is_score_valid=False, reason="empty messages") - }, + metrics={"parse_status": MetricResult(score=0.0, is_score_valid=False, reason="empty messages")}, ) last_msg = messages[-1] @@ -106,5 +104,3 @@ def evaluate( is_score_valid=is_valid, metrics=metrics, ) - - diff --git a/examples/blackjack_mcp/blackjack_adapter.py b/examples/blackjack_mcp/blackjack_adapter.py index 48c0b6b3..f88bc2cb 100644 --- a/examples/blackjack_mcp/blackjack_adapter.py +++ b/examples/blackjack_mcp/blackjack_adapter.py @@ -32,7 +32,7 @@ def create_environment(self, config: Optional[Dict[str, Any]] = None) -> Blackja natural = config.get("natural") if natural is None: natural = False - print(f"๐Ÿ” BlackjackAdapter.create_environment: natural is not set in the config, use False by default") + print("๐Ÿ” BlackjackAdapter.create_environment: natural is not set in the config, use False by default") if isinstance(natural, str): natural = natural.lower() == "true" print(f"๐Ÿ” BlackjackAdapter.create_environment: natural is a string, convert to boolean: {natural}") @@ -42,7 +42,7 @@ def create_environment(self, config: Optional[Dict[str, Any]] = None) -> Blackja sab = config.get("sab", False) if sab is None: sab = False - print(f"๐Ÿ” BlackjackAdapter.create_environment: sab is not set in the config, use False by default") + print("๐Ÿ” BlackjackAdapter.create_environment: sab is not set in the config, use False by default") if isinstance(sab, str): sab = sab.lower() == "true" print(f"๐Ÿ” BlackjackAdapter.create_environment: sab is a string, convert to boolean: {sab}") @@ -50,7 +50,7 @@ def create_environment(self, config: Optional[Dict[str, Any]] = None) -> Blackja sab = bool(sab) env = BlackjackEnv(render_mode="ansi", natural=natural, sab=sab) - print(f"๐Ÿ” BlackjackAdapter.create_environment: Created BlackjackEnv") + print("๐Ÿ” BlackjackAdapter.create_environment: Created BlackjackEnv") return env def create_environment_with_seed( diff --git a/examples/blackjack_mcp/blackjack_mcp.py b/examples/blackjack_mcp/blackjack_mcp.py index f8b0a877..0f40f6b1 100644 --- a/examples/blackjack_mcp/blackjack_mcp.py +++ b/examples/blackjack_mcp/blackjack_mcp.py @@ -71,7 +71,7 @@ def blackjack_move(action: str, ctx: Context) -> Dict[str, Any]: # Validate action if not action or not isinstance(action, str): raise ValueError( - f"Invalid action parameter: '{action}'. " f"Must be a non-empty string. Valid actions: STICK, HIT" + f"Invalid action parameter: '{action}'. Must be a non-empty string. Valid actions: STICK, HIT" ) action = action.strip().upper() diff --git a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py index 69552c53..8a724565 100644 --- a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py +++ b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py @@ -562,7 +562,6 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording # Start server for this test server = _create_test_server(9600) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = multi_env_recording_file @@ -661,7 +660,7 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] print("\n๐Ÿ Validating trajectory termination...") _validate_trajectory_termination(env_recordings, dataset) - print(f"โœ… Recording integrity validation completed") + print("โœ… Recording integrity validation completed") def _validate_no_repeated_initial_states(env_recordings: Dict, dataset: List[Dict]): @@ -746,9 +745,9 @@ def _validate_state_progression(env_recordings: Dict): try: response_data = json.loads(response) game_states.append(response_data) - print(f" Step {i+1}: Game state {response_data}") + print(f" Step {i + 1}: Game state {response_data}") except json.JSONDecodeError: - pytest.fail(f"โŒ Invalid JSON in tool response {i+1} for env {env_idx}: {response}") + pytest.fail(f"โŒ Invalid JSON in tool response {i + 1} for env {env_idx}: {response}") # Check that player_sum changes when HIT action is taken for i in range(len(game_states) - 1): @@ -763,21 +762,21 @@ def _validate_state_progression(env_recordings: Dict): if current_player_sum == next_player_sum: pytest.fail( f"โŒ STATE PROGRESSION BUG DETECTED in Env {env_idx}: " - f"After HIT action at step {i+1}, player_sum remained {current_player_sum}. " + f"After HIT action at step {i + 1}, player_sum remained {current_player_sum}. " f"When hitting, player should draw a card and player_sum should change. " f"Current state: {current_state}, Next state: {next_state}" ) else: print( - f" โœ… Step {i+1}: HIT action changed player_sum from {current_player_sum} to {next_player_sum}" + f" โœ… Step {i + 1}: HIT action changed player_sum from {current_player_sum} to {next_player_sum}" ) elif current_action == "STAND": # STAND action should not change player_sum (dealer's turn) print( - f" โ„น๏ธ Step {i+1}: STAND action - player_sum transition from {current_player_sum} to {next_player_sum}" + f" โ„น๏ธ Step {i + 1}: STAND action - player_sum transition from {current_player_sum} to {next_player_sum}" ) else: - print(f" โš ๏ธ Step {i+1}: Unknown action '{current_action}' - skipping validation") + print(f" โš ๏ธ Step {i + 1}: Unknown action '{current_action}' - skipping validation") print(f" โœ… Env {env_idx}: State progression validation completed successfully") @@ -833,7 +832,7 @@ def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): elif terminated_steps == 0: print(f" โš ๏ธ Warning: No terminated=True found in {total_steps} steps (may be expected for short runs)") else: - print(f" โœ… Found some termination signals - control plane appears to be working") + print(" โœ… Found some termination signals - control plane appears to be working") def _validate_no_tool_calls_after_termination(env_recordings: Dict, dataset: List[Dict]): @@ -934,7 +933,7 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): f"Expected: Substantial trajectories should end with terminated=True." ) elif last_terminated: - print(f" โœ… Trajectory properly terminated") + print(" โœ… Trajectory properly terminated") else: print(f" โ„น๏ธ Short trajectory ({total_steps} steps) - termination not required") @@ -1019,7 +1018,6 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_recording_file @@ -1144,7 +1142,6 @@ async def test_control_plane_state_querying(multi_env_dataset): # Start server for this test server = _create_test_server(9700) try: - # Create policy with shorter sequence for testing policy = create_blackjack_static_policy(action_sequence=["HIT", "STAND"]) diff --git a/examples/cliff_walking_mcp/cliff_walking_adapter.py b/examples/cliff_walking_mcp/cliff_walking_adapter.py index e8f79fb2..0445d35e 100644 --- a/examples/cliff_walking_mcp/cliff_walking_adapter.py +++ b/examples/cliff_walking_mcp/cliff_walking_adapter.py @@ -29,7 +29,7 @@ def create_environment(self, config: Optional[Dict[str, Any]] = None) -> CliffWa """ print(f"๐Ÿ” CliffWalkingAdapter.create_environment: config: {config}") env = CliffWalkingEnv(render_mode="ansi", is_slippery=False) - print(f"๐Ÿ” CliffWalkingAdapter.create_environment: Created CliffWalkingEnv") + print("๐Ÿ” CliffWalkingAdapter.create_environment: Created CliffWalkingEnv") return env def create_environment_with_seed( diff --git a/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py b/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py index fc327f62..9d5c6e23 100644 --- a/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py +++ b/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py @@ -565,7 +565,6 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording # Start server for this test server = _create_test_server(9600) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = multi_env_recording_file @@ -680,7 +679,7 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] print("\n๐Ÿ Validating trajectory termination...") _validate_trajectory_termination(env_recordings, dataset) - print(f"โœ… Recording integrity validation completed") + print("โœ… Recording integrity validation completed") def _validate_multi_seed_environments(env_recordings: Dict, dataset: List[Dict]): @@ -760,9 +759,9 @@ def _validate_state_progression(env_recordings: Dict, dataset: List[Dict]): position = response_data.get("position") if position is not None: positions.append(position) - print(f" Step {i+1}: Position {position}") + print(f" Step {i + 1}: Position {position}") except json.JSONDecodeError: - pytest.fail(f"โŒ Invalid JSON in tool response {i+1} for env {env_idx}: {response}") + pytest.fail(f"โŒ Invalid JSON in tool response {i + 1} for env {env_idx}: {response}") if len(positions) < 2: print(f" Env {env_idx}: Only {len(positions)} valid positions, skipping progression check") @@ -801,7 +800,7 @@ def _validate_state_progression(env_recordings: Dict, dataset: List[Dict]): f"Full position sequence: {positions}" ) else: - print(f" โœ… Valid state progression - all position changes follow Cliff Walking rules") + print(" โœ… Valid state progression - all position changes follow Cliff Walking rules") def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): @@ -855,7 +854,7 @@ def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): elif terminated_steps == 0: print(f" โš ๏ธ Warning: No terminated=True found in {total_steps} steps (may be expected for short runs)") else: - print(f" โœ… Found some termination signals - control plane appears to be working") + print(" โœ… Found some termination signals - control plane appears to be working") def _validate_no_tool_calls_after_termination(env_recordings: Dict, dataset: List[Dict]): @@ -960,7 +959,7 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): f"Expected: Substantial trajectories should end with terminated=True." ) elif last_terminated: - print(f" โœ… Trajectory properly terminated") + print(" โœ… Trajectory properly terminated") else: print(f" โ„น๏ธ Short trajectory ({total_steps} steps) - termination not required") @@ -1045,7 +1044,6 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_recording_file @@ -1173,7 +1171,6 @@ async def test_control_plane_state_querying(multi_env_dataset): # Start server for this test server = _create_test_server(9700) try: - # Create policy with shorter sequence for testing policy = create_cliff_walking_static_policy(action_sequence=["UP", "UP"]) diff --git a/examples/frozen_lake_mcp/frozen_lake_adapter.py b/examples/frozen_lake_mcp/frozen_lake_adapter.py index 4181e64b..f8f197df 100644 --- a/examples/frozen_lake_mcp/frozen_lake_adapter.py +++ b/examples/frozen_lake_mcp/frozen_lake_adapter.py @@ -46,12 +46,12 @@ def create_environment(self, config: Optional[Dict[str, Any]] = None) -> FrozenL desc = generate_random_map(size=grid_size, p=0.8, seed=seed) print(f"๐Ÿ” FrozenLakeAdapter.create_environment: Generated map desc: {desc}") else: - print(f"๐Ÿ” FrozenLakeAdapter.create_environment: Generating map without seed") + print("๐Ÿ” FrozenLakeAdapter.create_environment: Generating map without seed") desc = generate_random_map(size=grid_size, p=0.8) print(f"๐Ÿ” FrozenLakeAdapter.create_environment: Generated map desc: {desc}") env = FrozenLakeEnv(desc=desc, is_slippery=False, render_mode="ansi") - print(f"๐Ÿ” FrozenLakeAdapter.create_environment: Created FrozenLakeEnv") + print("๐Ÿ” FrozenLakeAdapter.create_environment: Created FrozenLakeEnv") return env def create_environment_with_seed( diff --git a/examples/frozen_lake_mcp/rollout_example.py b/examples/frozen_lake_mcp/rollout_example.py index 4f969f0c..31a471a9 100644 --- a/examples/frozen_lake_mcp/rollout_example.py +++ b/examples/frozen_lake_mcp/rollout_example.py @@ -90,7 +90,7 @@ async def rollout(self, envs: List[McpGym], policy: SimplePolicy, steps: int = 2 evaluation_rows = [] for i, env in enumerate(envs): - self.logger.info(f"Running rollout {i+1}/{len(envs)}") + self.logger.info(f"Running rollout {i + 1}/{len(envs)}") trajectory = { "environment": env.__class__.__name__, @@ -136,7 +136,7 @@ async def rollout(self, envs: List[McpGym], policy: SimplePolicy, steps: int = 2 break evaluation_rows.append(trajectory) - self.logger.info(f"Rollout {i+1} completed: {trajectory['total_reward']} total reward") + self.logger.info(f"Rollout {i + 1} completed: {trajectory['total_reward']} total reward") return evaluation_rows @@ -147,7 +147,7 @@ def print_trajectory_summary(self, evaluation_rows: List[Dict[str, Any]]): print("=" * 60) for i, traj in enumerate(evaluation_rows): - print(f"\nEvaluation {i+1}:") + print(f"\nEvaluation {i + 1}:") print(f" Environment: {traj['environment']}") print(f" Seed: {traj['seed']}") print(f" Steps: {len(traj['steps'])}") @@ -169,7 +169,7 @@ def print_trajectory_summary(self, evaluation_rows: List[Dict[str, Any]]): else 0 ) - print(f"\nOverall Statistics:") + print("\nOverall Statistics:") print(f" Total Environments: {len(evaluation_rows)}") print(f" Average Reward: {avg_reward:.2f}") print(f" Success Rate: {success_rate:.2%}") diff --git a/examples/frozen_lake_mcp/test_seed_logging.py b/examples/frozen_lake_mcp/test_seed_logging.py index edb1b272..d4dcd203 100644 --- a/examples/frozen_lake_mcp/test_seed_logging.py +++ b/examples/frozen_lake_mcp/test_seed_logging.py @@ -36,7 +36,7 @@ async def test_seed_logging(): # Reset environments to trigger session creation print("๐Ÿ”„ Resetting environments...") observations, tool_schemas, system_prompts = await envs.reset() - print(f"โœ… Reset complete") + print("โœ… Reset complete") print(f"๐Ÿ“Š Observations: {observations}") print(f"๐Ÿ› ๏ธ Tool schemas: {len(tool_schemas[0])} tools available") diff --git a/examples/frozen_lake_mcp/test_termination_fix.py b/examples/frozen_lake_mcp/test_termination_fix.py index cdebcd59..10823826 100755 --- a/examples/frozen_lake_mcp/test_termination_fix.py +++ b/examples/frozen_lake_mcp/test_termination_fix.py @@ -50,7 +50,7 @@ async def test_control_plane_separation(): successful_path = ["DOWN", "RIGHT", "RIGHT", "RIGHT", "DOWN", "DOWN"] for i, action in enumerate(successful_path): - print(f"\n--- Step {i+1}: {action} ---") + print(f"\n--- Step {i + 1}: {action} ---") # Execute tool call (data plane) tool_result = await session.call_tool("lake_move", {"action": action}) diff --git a/examples/frozen_lake_mcp/test_validation_logic.py b/examples/frozen_lake_mcp/test_validation_logic.py index 69477db6..8a476929 100644 --- a/examples/frozen_lake_mcp/test_validation_logic.py +++ b/examples/frozen_lake_mcp/test_validation_logic.py @@ -95,7 +95,7 @@ def test_validation_with_existing_data(): print("\nโœ… All validations passed - no bugs detected") return True else: - print(f"\nโŒ Validation caught bugs (as expected):") + print("\nโŒ Validation caught bugs (as expected):") print(f" - Repeated states bug: {'No' if repeated_states_ok else 'Yes'}") print(f" - Control plane sync bug: {'No' if control_plane_ok else 'Yes'}") print(f" - Trajectory termination bug: {'No' if trajectory_termination_ok else 'Yes'}") diff --git a/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py b/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py index e2c4c78e..bea2ad08 100644 --- a/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py +++ b/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py @@ -585,7 +585,6 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording # Start server for this test server = _create_test_server(9600) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = multi_env_recording_file @@ -737,9 +736,9 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] response_data = json.loads(response) position = response_data.get("position") positions.append(position) - print(f" Step {i+1}: Position {position}") + print(f" Step {i + 1}: Position {position}") except json.JSONDecodeError: - pytest.fail(f"โŒ Invalid JSON in tool response {i+1} for env {env_idx}: {response}") + pytest.fail(f"โŒ Invalid JSON in tool response {i + 1} for env {env_idx}: {response}") # TODO: come back to fix this later. # if len(positions) >= 2: @@ -774,7 +773,7 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] print("\n๐Ÿ Validating trajectory termination...") _validate_trajectory_termination(env_recordings, dataset) - print(f"โœ… Recording integrity validation completed") + print("โœ… Recording integrity validation completed") def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): @@ -841,7 +840,7 @@ def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): print( f"โš ๏ธ WARNING: Env {env_idx}: Position {longest_sequence[0]} repeated {longest_sequence[1]} times starting from step {longest_sequence[2]}." ) - print(f" This might indicate session state or control plane termination issues.") + print(" This might indicate session state or control plane termination issues.") print(f" All positions: {[pos for _, pos in positions]}") # For FireworksPolicy, log but don't fail the test as LLM behavior can vary # pytest.fail( @@ -907,7 +906,7 @@ def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): elif terminated_steps == 0: print(f" โš ๏ธ Warning: No terminated=True found in {total_steps} steps (may be expected for short runs)") else: - print(f" โœ… Found some termination signals - control plane appears to be working") + print(" โœ… Found some termination signals - control plane appears to be working") def _validate_no_tool_calls_after_termination(env_recordings: Dict, dataset: List[Dict]): @@ -1012,7 +1011,7 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): f"Expected: Substantial trajectories should end with terminated=True." ) elif last_terminated: - print(f" โœ… Trajectory properly terminated") + print(" โœ… Trajectory properly terminated") else: print(f" โ„น๏ธ Short trajectory ({total_steps} steps) - termination not required") @@ -1098,7 +1097,6 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_recording_file @@ -1227,7 +1225,6 @@ async def test_control_plane_state_querying(multi_env_dataset): # Start server for this test server = _create_test_server(9700) try: - # Create policy with shorter sequence for testing policy = create_frozen_lake_static_policy(action_sequence=["RIGHT", "DOWN"]) diff --git a/examples/lunar_lander_mcp/generate_sample_images.py b/examples/lunar_lander_mcp/generate_sample_images.py index 8057bb46..75f22387 100644 --- a/examples/lunar_lander_mcp/generate_sample_images.py +++ b/examples/lunar_lander_mcp/generate_sample_images.py @@ -121,7 +121,7 @@ def generate_sample_trajectory(): print(f"\n๐Ÿ“ Trajectory saved to {output_dir}") print(f" ๐Ÿ“Š {len(trajectory_data)} steps recorded") print(f" ๐Ÿ–ผ๏ธ {len(list(output_dir.glob('*.png')))} images saved") - print(f" ๐Ÿ“‹ Summary: trajectory_summary.json") + print(" ๐Ÿ“‹ Summary: trajectory_summary.json") env.close() return output_dir @@ -153,7 +153,7 @@ def save_frame(frame_data: str, output_path: Path, step: int, action: str): if __name__ == "__main__": try: output_dir = generate_sample_trajectory() - print(f"\nโœ… Sample trajectory generated successfully!") + print("\nโœ… Sample trajectory generated successfully!") print(f"๐Ÿ“ View images in: {output_dir.absolute()}") except Exception as e: diff --git a/examples/lunar_lander_mcp/simple_trajectory_test.py b/examples/lunar_lander_mcp/simple_trajectory_test.py index 18c42dad..802b87fe 100644 --- a/examples/lunar_lander_mcp/simple_trajectory_test.py +++ b/examples/lunar_lander_mcp/simple_trajectory_test.py @@ -90,7 +90,7 @@ async def test_lunar_lander_direct(): actions = ["NOTHING", "FIRE_MAIN", "FIRE_LEFT", "FIRE_RIGHT", "NOTHING"] for i, action in enumerate(actions): - print(f"๐ŸŽฎ Step {i+1}: {action}") + print(f"๐ŸŽฎ Step {i + 1}: {action}") # Call lander_action tool result = await session.call_tool("lander_action", {"action": action}) @@ -113,7 +113,7 @@ async def test_lunar_lander_direct(): "status": response_data.get("status", "Unknown"), } - with open(output_dir / f"step_{i+1:03d}_summary.json", "w") as f: + with open(output_dir / f"step_{i + 1:03d}_summary.json", "w") as f: json.dump(step_summary, f, indent=2) # Save rendered frame if available @@ -123,14 +123,14 @@ async def test_lunar_lander_direct(): image_data = frame_data.split(",")[1] image_bytes = base64.b64decode(image_data) - frame_path = output_dir / f"step_{i+1:03d}_{action.lower()}.png" + frame_path = output_dir / f"step_{i + 1:03d}_{action.lower()}.png" with open(frame_path, "wb") as f: f.write(image_bytes) print(f" ๐Ÿ’พ Saved frame: {frame_path}") else: - print(f" โš ๏ธ No rendered frame in response") + print(" โš ๏ธ No rendered frame in response") else: - print(f" โš ๏ธ No rendered_frame field in response") + print(" โš ๏ธ No rendered_frame field in response") except json.JSONDecodeError as e: print(f" โŒ Could not parse response as JSON: {e}") diff --git a/examples/lunar_lander_mcp/test_lunar_lander_conda.py b/examples/lunar_lander_mcp/test_lunar_lander_conda.py index 98d3c491..89f48597 100644 --- a/examples/lunar_lander_mcp/test_lunar_lander_conda.py +++ b/examples/lunar_lander_mcp/test_lunar_lander_conda.py @@ -225,14 +225,14 @@ async def __call__(self, tool_schemas, observations, system_prompts, user_prompt f" Step {step_idx}: control_plane_step is not a dict, type: {type(control_plane_step)}" ) else: - print(f" ๐Ÿ” No control plane messages found") + print(" ๐Ÿ” No control plane messages found") print(f" โœ… Episode {i} validation passed") print(f"๐Ÿ“ All evaluation data saved to {output_dir}") - print(f" - Episode summaries: episode_*_summary.json") - print(f" - Control plane debug data: episode_*_first_control_plane_debug.json") - print(f" - Rendered frames: episode_*_step_*.png (if available)") + print(" - Episode summaries: episode_*_summary.json") + print(" - Control plane debug data: episode_*_first_control_plane_debug.json") + print(" - Rendered frames: episode_*_step_*.png (if available)") print("๐ŸŽ‰ All tests passed! Conda isolation working correctly.") return True diff --git a/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py b/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py index 723b68bb..3bda23aa 100644 --- a/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py +++ b/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py @@ -618,7 +618,6 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording # Start server for this test server = _create_test_server(9600) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = multi_env_recording_file @@ -790,9 +789,9 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] "reward": response_data.get("reward", 0.0), } states.append(state_info) - print(f" Step {i+1}: {state_info}") + print(f" Step {i + 1}: {state_info}") except (json.JSONDecodeError, TypeError) as e: - pytest.fail(f"โŒ Invalid JSON in tool response {i+1} for env {env_idx}: {response}. Error: {e}") + pytest.fail(f"โŒ Invalid JSON in tool response {i + 1} for env {env_idx}: {response}. Error: {e}") # For lunar lander, we expect state to change between steps if len(states) >= 2: @@ -819,7 +818,7 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] print("\n๐Ÿ Validating trajectory termination...") _validate_trajectory_termination(env_recordings, dataset) - print(f"โœ… Recording integrity validation completed") + print("โœ… Recording integrity validation completed") def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): @@ -899,7 +898,7 @@ def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): print( f"โš ๏ธ WARNING: Env {env_idx}: Position {longest_sequence[0]} repeated {longest_sequence[1]} times starting from step {longest_sequence[2]}." ) - print(f" This might indicate session state or control plane termination issues.") + print(" This might indicate session state or control plane termination issues.") print(f" All positions: {[pos for _, pos in positions]}") else: print(f" โœ… Env {env_idx}: No repeated states detected - good state progression!") @@ -952,9 +951,9 @@ def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): # f"Expected: At least some episodes should terminate when lander crashes or lands successfully." # ) if terminated_steps == 0: - print(f" โš ๏ธ Warning: No terminated=True found in (may be expected for short runs)") + print(" โš ๏ธ Warning: No terminated=True found in (may be expected for short runs)") else: - print(f" โœ… Found some termination signals - control plane appears to be working") + print(" โœ… Found some termination signals - control plane appears to be working") def _validate_no_tool_calls_after_termination(env_recordings: Dict, dataset: List[Dict]): @@ -1042,11 +1041,11 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): if total_steps >= 8 and not last_terminated: print(f" โš ๏ธ Env {env_idx}: Trajectory has {total_steps} steps but final metadata shows terminated=False.") print( - f" This might indicate: 1) Episode still in progress, 2) Control plane sync issues, or 3) Lander hasn't crashed/landed yet" + " This might indicate: 1) Episode still in progress, 2) Control plane sync issues, or 3) Lander hasn't crashed/landed yet" ) print(f" Last metadata: {last_tool_metadata}") elif last_terminated: - print(f" โœ… Trajectory properly terminated") + print(" โœ… Trajectory properly terminated") else: print(f" โ„น๏ธ Short trajectory ({total_steps} steps) - termination not required") @@ -1103,7 +1102,6 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_recording_file @@ -1223,7 +1221,6 @@ async def test_control_plane_state_querying(multi_env_dataset): # Start server for this test server = _create_test_server(9700) try: - # Create policy with shorter sequence for testing policy = create_lunar_lander_static_policy(action_sequence=["FIRE_MAIN", "FIRE_LEFT"]) diff --git a/examples/mcp_agent_filesystem_rl/test_example.py b/examples/mcp_agent_filesystem_rl/test_example.py index ef7c5f73..ff848fc0 100644 --- a/examples/mcp_agent_filesystem_rl/test_example.py +++ b/examples/mcp_agent_filesystem_rl/test_example.py @@ -51,7 +51,6 @@ def test_reward_function_import(): print("Testing reward function import...") # Import EvaluateResult specifically within this function's scope - from eval_protocol.models import EvaluateResult from examples.mcp_agent_filesystem_rl import main as filesystem_rl_main assert hasattr(filesystem_rl_main, "evaluate"), "Reward function 'evaluate' not found in main.py" @@ -106,9 +105,9 @@ def test_reward_function_import(): assert isinstance(result_success, EvaluateResult), "evaluate function did not return an EvaluateResult" # Based on main.py logic, a perfect move should result in score 1.0 - assert ( - result_success.score == 1.0 - ), f"Expected score 1.0 for mock success, got {result_success.score}. Reason: {result_success.reason}" + assert result_success.score == 1.0, ( + f"Expected score 1.0 for mock success, got {result_success.score}. Reason: {result_success.reason}" + ) assert result_success.is_score_valid print("โœ“ Reward function 'evaluate' import and basic validation works") diff --git a/examples/rollout_control_plane_demo.py b/examples/rollout_control_plane_demo.py index 9c327e5e..9f6d2795 100644 --- a/examples/rollout_control_plane_demo.py +++ b/examples/rollout_control_plane_demo.py @@ -98,7 +98,6 @@ async def demonstrate_control_plane_rollout(): patch.object(GeneralMCPVectorEnv, "step") as mock_step, patch.object(GeneralMCPVectorEnv, "close") as mock_close, ): - # Setup mock vector environment mock_env = GeneralMCPVectorEnv(sessions, dataset_rows) mock_env.sessions = sessions @@ -273,14 +272,14 @@ def mock_step_side_effect(tool_calls): # Analyze the trajectory trajectory = trajectories[0] - print(f"Basic Trajectory Info:") + print("Basic Trajectory Info:") print(f" โ€ข Total Steps: {trajectory.steps}") print(f" โ€ข Total Reward: {trajectory.total_reward}") print(f" โ€ข Terminated: {trajectory.terminated}") print(f" โ€ข Duration: {trajectory.duration:.3f}s") print() - print(f"Data Plane Analysis (Observations):") + print("Data Plane Analysis (Observations):") print(f" โ€ข Observation Count: {len(trajectory.observations)}") for i, obs in enumerate(trajectory.observations): if i == 0: @@ -289,14 +288,14 @@ def mock_step_side_effect(tool_calls): print(f" Step {i}: {obs}") print() - print(f"Control Plane Analysis (Rewards/Termination):") + print("Control Plane Analysis (Rewards/Termination):") print(f" โ€ข Reward Count: {len(trajectory.rewards)}") print(f" โ€ข Rewards: {trajectory.rewards}") print(f" โ€ข Actions: {trajectory.actions}") print() # Validate control plane separation - print(f"Control Plane Separation Validation:") + print("Control Plane Separation Validation:") # Check data plane contains no rewards data_plane_clean = True diff --git a/examples/tau2_mcp/airplane_environment/airline_environment.py b/examples/tau2_mcp/airplane_environment/airline_environment.py index 0c1e2d14..f7c7a920 100644 --- a/examples/tau2_mcp/airplane_environment/airline_environment.py +++ b/examples/tau2_mcp/airplane_environment/airline_environment.py @@ -5,6 +5,7 @@ This module implements an AirlineEnvironment that integrates the ฯ„ยฒ-Bench simulation pattern (Agent/User/Environment communication) with the MCP-Gym framework. """ + import json import logging import os diff --git a/examples/tau2_mcp/mock_environment/mock_environment.py b/examples/tau2_mcp/mock_environment/mock_environment.py index 85ad84f5..fc255f5f 100644 --- a/examples/tau2_mcp/mock_environment/mock_environment.py +++ b/examples/tau2_mcp/mock_environment/mock_environment.py @@ -5,6 +5,7 @@ This module implements a MockEnvironment that integrates the ฯ„ยฒ-Bench simulation pattern (Agent/User/Environment communication) with the MCP-Gym framework. """ + import json import logging import os diff --git a/examples/tau2_mcp/retail_environment/retail_environment.py b/examples/tau2_mcp/retail_environment/retail_environment.py index 425ef785..d2163c0c 100644 --- a/examples/tau2_mcp/retail_environment/retail_environment.py +++ b/examples/tau2_mcp/retail_environment/retail_environment.py @@ -5,6 +5,7 @@ This module implements a RetailEnvironment that integrates the ฯ„ยฒ-Bench simulation pattern (Agent/User/Environment communication) with the MCP-Gym framework. """ + import json import logging import os diff --git a/examples/tau2_mcp/tests/test_tau2_e2e.py b/examples/tau2_mcp/tests/test_tau2_e2e.py index cb71fab7..ec7c3944 100644 --- a/examples/tau2_mcp/tests/test_tau2_e2e.py +++ b/examples/tau2_mcp/tests/test_tau2_e2e.py @@ -31,7 +31,6 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*class-based.*config.*") # Set environment variable to suppress pydantic warnings at runtime -import os os.environ["PYTHONWARNINGS"] = "ignore::UserWarning:pydantic,ignore::DeprecationWarning:pydantic" @@ -140,12 +139,12 @@ def start(self) -> None: try: with open(self._log_file_path, "r") as f: log_content = f.read() - print(f"โŒ Server failed to start!") + print("โŒ Server failed to start!") print(f"๐Ÿ“‹ Server log ({self._log_file_path}):") print("=" * 50) print(log_content) print("=" * 50) - raise RuntimeError(f"Server failed to start. Check log above for details.") + raise RuntimeError("Server failed to start. Check log above for details.") except Exception as e: stdout, stderr = self.process.communicate() raise RuntimeError(f"Server failed to start. stderr: {stderr}, log error: {e}") @@ -438,9 +437,9 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] } states.append(state_info) - print(f" Step {i+1}: {state_info}") + print(f" Step {i + 1}: {state_info}") except (json.JSONDecodeError, TypeError) as e: - pytest.fail(f"โŒ Invalid JSON in tool response {i+1} for env {env_idx}: {response}. Error: {e}") + pytest.fail(f"โŒ Invalid JSON in tool response {i + 1} for env {env_idx}: {response}. Error: {e}") # For airline, we expect state to remain consistent between steps (same reservation details) if len(states) >= 2: @@ -467,7 +466,7 @@ async def _validate_recording_integrity(recording_file: str, dataset: List[Dict] print("\n๐Ÿ Validating trajectory termination...") _validate_trajectory_termination(env_recordings, dataset) - print(f"โœ… Recording integrity validation completed") + print("โœ… Recording integrity validation completed") def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): @@ -554,7 +553,7 @@ def _validate_no_repeated_states(env_recordings: Dict, dataset: List[Dict]): print( f"โš ๏ธ WARNING: Env {env_idx}: State {longest_sequence[0]} repeated {longest_sequence[1]} times starting from step {longest_sequence[2]}." ) - print(f" This might indicate session state or control plane termination issues.") + print(" This might indicate session state or control plane termination issues.") print(f" All states: {[state for _, state in reservation_states]}") else: print(f" โœ… Env {env_idx}: No repeated states detected - good state progression!") @@ -598,9 +597,9 @@ def _validate_control_plane_sync(env_recordings: Dict, dataset: List[Dict]): print(f" โ„น๏ธ {missing_envs} environments not recorded (likely terminated immediately)") if terminated_steps == 0: - print(f" โš ๏ธ Warning: No terminated=True found in metadata (may be expected for short runs)") + print(" โš ๏ธ Warning: No terminated=True found in metadata (may be expected for short runs)") else: - print(f" โœ… Found some termination signals - control plane appears to be working") + print(" โœ… Found some termination signals - control plane appears to be working") def _validate_no_tool_calls_after_termination(env_recordings: Dict, dataset: List[Dict]): @@ -688,11 +687,11 @@ def _validate_trajectory_termination(env_recordings: Dict, dataset: List[Dict]): if total_steps >= 8 and not last_terminated: print(f" โš ๏ธ Env {env_idx}: Trajectory has {total_steps} steps but final metadata shows terminated=False.") print( - f" This might indicate: 1) Conversation still in progress, 2) Control plane sync issues, or 3) User still interacting" + " This might indicate: 1) Conversation still in progress, 2) Control plane sync issues, or 3) User still interacting" ) print(f" Last metadata: {last_tool_metadata}") elif last_terminated: - print(f" โœ… Trajectory properly terminated") + print(" โœ… Trajectory properly terminated") else: print(f" โ„น๏ธ Short trajectory ({total_steps} steps) - termination not required") @@ -914,7 +913,6 @@ async def test_fireworks_multi_airline_environment_sessions( # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_airline_recording_file @@ -942,9 +940,9 @@ async def test_fireworks_multi_airline_environment_sessions( duration = time.time() - start_time # Validate results - assert len(evaluation_rows) == len( - multi_env_airline_dataset - ), "Should have evaluation row for each environment" + assert len(evaluation_rows) == len(multi_env_airline_dataset), ( + "Should have evaluation row for each environment" + ) assert all(eval_row.get_steps() > 0 for eval_row in evaluation_rows), "All evaluation rows should have steps" print( @@ -1057,7 +1055,6 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ # Start server for this test server = _create_test_server(9700) try: - # Set up recording os.environ["EP_PLAYBACK_FILE"] = fireworks_multi_env_airline_recording_file @@ -1090,9 +1087,9 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ duration = time.time() - start_time # Validate results - assert len(evaluation_rows) == len( - multi_env_airline_full_dataset - ), "Should have evaluation row for each environment" + assert len(evaluation_rows) == len(multi_env_airline_full_dataset), ( + "Should have evaluation row for each environment" + ) assert all(eval_row.get_steps() > 0 for eval_row in evaluation_rows), "All evaluation rows should have steps" print( @@ -1211,7 +1208,7 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ all_results.append(result) # Summary Statistics - print(f"\n๐Ÿ“ˆ Summary Statistics:") + print("\n๐Ÿ“ˆ Summary Statistics:") avg_score = sum(r["score"] for r in all_results) / len(all_results) if all_results else 0 total_cost = sum(r["cost_info"]["total_cost"] for r in all_results) @@ -1219,7 +1216,7 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ f" {policy.model_id}: {avg_score:.2%} success rate ({sum(r['score'] for r in all_results)}/{len(all_results)}) - Cost: ${total_cost:.2f}" ) print(f"\n๐Ÿ’ฐ Total evaluation cost: ${total_cost:.2f}") - print(f"๐Ÿ“Š Cost calculation uses actual API usage data.") + print("๐Ÿ“Š Cost calculation uses actual API usage data.") def save_results_jsonl( evaluation_records: List[Dict], output_file: str = "evaluation_outputs/all_evaluations.jsonl" @@ -1307,7 +1304,7 @@ def save_evaluation_files(evaluation_records: List[Dict], output_dir: str = "eva print(f"\n๐Ÿ“ Saved evaluation files to: {output_path}") print(f" - {len(evaluation_records)} individual evaluation files") - print(f" - 1 evaluation summary file") + print(" - 1 evaluation summary file") return output_path @@ -1382,7 +1379,7 @@ def save_trajectory_files(trajectory_records: List[Dict], output_dir: str = "tra print(f"\n๐Ÿ“ Saved trajectory files to: {output_path}") print(f" - {len(trajectory_records)} individual trajectory files") - print(f" - 1 trajectory summary file") + print(" - 1 trajectory summary file") return output_path diff --git a/examples/taxi_mcp_complete/local_testing/test_north_star.py b/examples/taxi_mcp_complete/local_testing/test_north_star.py index 3721b4e7..e0a3c75d 100644 --- a/examples/taxi_mcp_complete/local_testing/test_north_star.py +++ b/examples/taxi_mcp_complete/local_testing/test_north_star.py @@ -73,8 +73,8 @@ async def test_north_star_interface(): if recording_mode: print(f"๐Ÿ“ Recorded to: {playback_file}") - print(f"๐Ÿ’ฌ OpenAI format: clean_openai_format.jsonl") - print(f"๐Ÿ”„ Run again to test playback mode!") + print("๐Ÿ’ฌ OpenAI format: clean_openai_format.jsonl") + print("๐Ÿ”„ Run again to test playback mode!") else: # Assume ~90s for recording time for speedup calculation (taxi is more complex) estimated_recording_time = 90.0 diff --git a/examples/taxi_mcp_complete/mcp_server/simulation_server.py b/examples/taxi_mcp_complete/mcp_server/simulation_server.py index 45c11aec..22c767d3 100644 --- a/examples/taxi_mcp_complete/mcp_server/simulation_server.py +++ b/examples/taxi_mcp_complete/mcp_server/simulation_server.py @@ -40,8 +40,7 @@ def taxi_move(self, action: str, *, ctx, session_state) -> Dict[str, Any]: # Validate action if not action or not isinstance(action, str): raise ValueError( - f"Invalid action parameter: '{action}'. " - f"Must be a non-empty string. Valid actions: {self.ACTION_NAMES}" + f"Invalid action parameter: '{action}'. Must be a non-empty string. Valid actions: {self.ACTION_NAMES}" ) action = action.strip().upper() @@ -133,7 +132,7 @@ def main(): args = parser.parse_args() - print(f"๐Ÿš€ Starting Taxi Simulation Server") + print("๐Ÿš€ Starting Taxi Simulation Server") print(f"๐ŸŒ Host: {args.host}") print(f"๐ŸŒ Port: {args.port}") print("๐ŸŽฏ Framework: Unified SimulationServerBase") diff --git a/examples/trl_integration/working_grpo_example.py b/examples/trl_integration/working_grpo_example.py index 387040cd..7b5412df 100644 --- a/examples/trl_integration/working_grpo_example.py +++ b/examples/trl_integration/working_grpo_example.py @@ -372,7 +372,6 @@ def make_conversation(example): def combine_rewards( reward_adapter_configs: List[Dict[str, Any]], # Each dict: {'adapter': callable, 'weight': float} ) -> Callable[[List[Any], List[str]], List[float]]: # Corrected return type hint - total_weight = sum(c["weight"] for c in reward_adapter_configs) if abs(total_weight - 1.0) > 1e-6: logger.warning(f"Sum of weights is {total_weight}, normalizing to 1.0.") diff --git a/local_evals/model_comparison_eval.ipynb b/local_evals/model_comparison_eval.ipynb index e36dbe6c..231e4d95 100644 --- a/local_evals/model_comparison_eval.ipynb +++ b/local_evals/model_comparison_eval.ipynb @@ -1,1075 +1,1065 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "raw" - } - }, - "source": [ - "# Model Comparison Eval Harness: Tau2-Bench Airline\n", - "\n", - "This notebook compares different models on airline customer service scenarios using tau2-bench natural language evaluation.\n", - "\n", - "**Models being compared:**\n", - "- Claude 4 Opus (AnthropicPolicy)\n", - "- GPT 4.1 (OpenAIPolicy)\n", - "- Kimi K2 (FireworksPolicy)\n", - "\n", - "**Evaluation Framework:** tau2-bench with natural language assertions\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Install required packages\n", - "!pip install eval-protocol anthropic fireworks-ai tau2-bench pytest-asyncio\n", - "!pip install firectl # For sharing results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… All imports successful!\n" - ] - }, - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import asyncio\n", - "import json\n", - "import os\n", - "import time\n", - "from datetime import datetime\n", - "from pathlib import Path\n", - "from typing import Dict, List, Any, Tuple\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import logging\n", - "from litellm import cost_per_token\n", - "from loguru import logger\n", - "\n", - "# Import eval protocol and tau2-bench\n", - "import eval_protocol as rk\n", - "from eval_protocol import reward_function, EvaluateResult\n", - "from eval_protocol.models import LLMUsageStats\n", - "\n", - "from examples.tau2_mcp.tests.test_tau2_e2e import MCPServerManager\n", - "\n", - "from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator\n", - "from vendor.tau2.data_model.message import (\n", - " SystemMessage,\n", - " AssistantMessage,\n", - " UserMessage,\n", - " ToolMessage,\n", - ")\n", - "\n", - "print(\"โœ… All imports successful!\")\n", - "\n", - "logging.basicConfig(level=logging.WARNING, force=True)\n", - "\n", - "logger.remove() # Remove default handler\n", - "logger.add(lambda _: None, level=\"ERROR\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "raw" - } - }, - "source": [ - "## 1. Set Up Evaluation Benchmark\n", - "\n", - "First, let's load the evaluation dataset we want to benchmark our models on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… Loaded airline dataset with 50 scenarios\n" - ] - } - ], - "source": [ - "with open(\"datasets/airline.json\", \"r\") as f:\n", - " tau2_eval_dataset = json.load(f)\n", - " # TODO: something here is broken\n", - "\n", - "print(f\"โœ… Loaded airline dataset with {len(tau2_eval_dataset)} scenarios\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "raw" - } - }, - "source": [ - "## 2. Evaluation Function: Tau2-Bench\n", - "\n", - "Now, let's implement the actual evaluation function (also called a reward function), based on Tau2-Bench. If you haven't heard of Tau2-Bench, it's a customer support benchmark from Sierra AI. Check out more information here: https://github.com/sierra-research/tau2-bench" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@reward_function\n", - "async def airline_eval(messages: List[Any], nl_assertions: List[str] = None, **kwargs) -> EvaluateResult:\n", - " \"\"\"\n", - " Evaluate airline conversation using tau2-bench NLAssertionsEvaluator.\n", - "\n", - " Args:\n", - " messages: Conversation between agent and customer\n", - " nl_assertions: List of natural language assertions to evaluate\n", - " **kwargs: Additional parameters\n", - "\n", - " Returns:\n", - " EvaluateResult with binary pass/fail and detailed assertion breakdown\n", - " \"\"\"\n", - " # Default assertions if none provided\n", - " if nl_assertions is None:\n", - " nl_assertions = [\"The agent handled the customer request appropriately according to airline policy\"]\n", - "\n", - " # Convert Message objects directly to tau2-bench message objects\n", - " trajectory_objects = []\n", - " for msg in messages:\n", - " role = msg.role\n", - " content = msg.content\n", - "\n", - " if role == \"system\":\n", - " trajectory_objects.append(SystemMessage(role=role, content=content))\n", - " elif role == \"assistant\":\n", - " trajectory_objects.append(AssistantMessage(role=role, content=content))\n", - " elif role == \"user\":\n", - " trajectory_objects.append(UserMessage(role=role, content=content))\n", - " elif role == \"tool\":\n", - " tool_id = msg.tool_call_id\n", - " trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))\n", - "\n", - " # Run the synchronous tau2-bench evaluation in a thread pool to avoid blocking\n", - " loop = asyncio.get_event_loop()\n", - " nl_assertions_checks = await loop.run_in_executor(\n", - " None, \n", - " NLAssertionsEvaluator.evaluate_nl_assertions,\n", - " trajectory_objects, \n", - " nl_assertions\n", - " )\n", - "\n", - " all_expectations_met = all(result.met for result in nl_assertions_checks)\n", - " reward = 1.0 if all_expectations_met else 0.0\n", - "\n", - " # Build reason string\n", - " if all_expectations_met:\n", - " reason = f\"All {len(nl_assertions)} natural language assertions passed\"\n", - " else:\n", - " failed_assertions = [nl_assertions[i] for i, result in enumerate(nl_assertions_checks) if not result.met]\n", - " reason = f\"Failed assertions: {failed_assertions}\"\n", - "\n", - " return EvaluateResult(\n", - " score=reward,\n", - " reason=reason,\n", - " metrics={},\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "raw" - } - }, - "source": [ - "## 3. Set Up Model Policies\n", - "\n", - "Configure the three models we want to compare: Claude 4 Opus, GPT-4.1, and Kimi K2.\n" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "# Model Comparison Eval Harness: Tau2-Bench Airline\n", + "\n", + "This notebook compares different models on airline customer service scenarios using tau2-bench natural language evaluation.\n", + "\n", + "**Models being compared:**\n", + "- Claude 4 Opus (AnthropicPolicy)\n", + "- GPT 4.1 (OpenAIPolicy)\n", + "- Kimi K2 (FireworksPolicy)\n", + "\n", + "**Evaluation Framework:** tau2-bench with natural language assertions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install eval-protocol anthropic fireworks-ai tau2-bench pytest-asyncio\n", + "!pip install firectl # For sharing results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… All required API keys are set\n" - ] - } - ], - "source": [ - "# Check for required API keys (set these as environment variables)\n", - "# Example: export ANTHROPIC_API_KEY=your-key-here\n", - "\n", - "required_keys = [\"ANTHROPIC_API_KEY\", \"OPENAI_API_KEY\", \"FIREWORKS_API_KEY\"]\n", - "missing_keys = [key for key in required_keys if not os.getenv(key)]\n", - "\n", - "if missing_keys:\n", - " print(f\"โš ๏ธ Missing API keys: {missing_keys}\")\n", - " print(\"Please set these environment variables:\")\n", - " for key in missing_keys:\n", - " print(f\" export {key}='your-key-here'\")\n", - "else:\n", - " print(\"โœ… All required API keys are set\")\n" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… All imports successful!\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… Model policies created:\n", - " - Claude 4 Sonnet (Anthropic)\n", - " - Kimi K2 (Fireworks)\n" - ] - } - ], - "source": [ - "# Create model policies\n", - "openai_policy = rk.OpenAIPolicy(\n", - " model_id=\"gpt-4.1\",\n", - " temperature=0.1,\n", - " max_tokens=4096,\n", - ")\n", - "\n", - "anthropic_policy = rk.AnthropicPolicy(\n", - " model_id=\"claude-sonnet-4-20250514\",\n", - " temperature=0.1,\n", - " max_tokens=4096,\n", - ")\n", - "\n", - "kimi_policy = rk.FireworksPolicy(\n", - " model_id=\"accounts/fireworks/models/kimi-k2-instruct\",\n", - " temperature=0.1,\n", - " max_tokens=4096,\n", - ")\n", - "\n", - "models_to_test = {\n", - " # \"gpt-4.1\": {\n", - " # \"policy\": openai_policy,\n", - " # \"name\": \"GPT-4.1\",\n", - " # \"provider\": \"OpenAI\"\n", - " # },\n", - " \"claude-sonnet-4\": {\n", - " \"policy\": anthropic_policy,\n", - " \"name\": \"Claude 4 Sonnet\",\n", - " \"provider\": \"Anthropic\"\n", - " },\n", - " \"kimi-k2\": {\n", - " \"policy\": kimi_policy,\n", - " \"name\": \"Kimi K2\", \n", - " \"provider\": \"Fireworks\"\n", - " }\n", - "}\n", - "\n", - "print(\"โœ… Model policies created:\")\n", - "for model_id, model_info in models_to_test.items():\n", - " print(f\" - {model_info['name']} ({model_info['provider']})\")\n" + "data": { + "text/plain": [ + "3" ] - }, + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import asyncio\n", + "import json\n", + "import logging\n", + "import os\n", + "import time\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "from typing import Any, Dict, List, Tuple\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from litellm import cost_per_token\n", + "from loguru import logger\n", + "\n", + "# Import eval protocol and tau2-bench\n", + "import eval_protocol as rk\n", + "from eval_protocol import EvaluateResult, reward_function\n", + "from eval_protocol.models import LLMUsageStats\n", + "from examples.tau2_mcp.tests.test_tau2_e2e import MCPServerManager\n", + "from vendor.tau2.data_model.message import (\n", + " AssistantMessage,\n", + " SystemMessage,\n", + " ToolMessage,\n", + " UserMessage,\n", + ")\n", + "from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator\n", + "\n", + "print(\"โœ… All imports successful!\")\n", + "\n", + "logging.basicConfig(level=logging.WARNING, force=True)\n", + "\n", + "logger.remove() # Remove default handler\n", + "logger.add(lambda _: None, level=\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## 1. Set Up Evaluation Benchmark\n", + "\n", + "First, let's load the evaluation dataset we want to benchmark our models on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "source": [ - "## 4. Run Evaluations\n", - "\n", - "Now we'll run the airline evaluation on both models and compare their performance.\n", - "\n", - "First, let's set up some code to manager our MCP server. We will run this server later on for our MCP tools to make calls to." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Loaded airline dataset with 50 scenarios\n" + ] + } + ], + "source": [ + "with open(\"datasets/airline.json\", \"r\") as f:\n", + " tau2_eval_dataset = json.load(f)\n", + " # TODO: something here is broken\n", + "\n", + "print(f\"โœ… Loaded airline dataset with {len(tau2_eval_dataset)} scenarios\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## 2. Evaluation Function: Tau2-Bench\n", + "\n", + "Now, let's implement the actual evaluation function (also called a reward function), based on Tau2-Bench. If you haven't heard of Tau2-Bench, it's a customer support benchmark from Sierra AI. Check out more information here: https://github.com/sierra-research/tau2-bench" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@reward_function\n", + "async def airline_eval(messages: List[Any], nl_assertions: List[str] = None, **kwargs) -> EvaluateResult:\n", + " \"\"\"\n", + " Evaluate airline conversation using tau2-bench NLAssertionsEvaluator.\n", + "\n", + " Args:\n", + " messages: Conversation between agent and customer\n", + " nl_assertions: List of natural language assertions to evaluate\n", + " **kwargs: Additional parameters\n", + "\n", + " Returns:\n", + " EvaluateResult with binary pass/fail and detailed assertion breakdown\n", + " \"\"\"\n", + " # Default assertions if none provided\n", + " if nl_assertions is None:\n", + " nl_assertions = [\"The agent handled the customer request appropriately according to airline policy\"]\n", + "\n", + " # Convert Message objects directly to tau2-bench message objects\n", + " trajectory_objects = []\n", + " for msg in messages:\n", + " role = msg.role\n", + " content = msg.content\n", + "\n", + " if role == \"system\":\n", + " trajectory_objects.append(SystemMessage(role=role, content=content))\n", + " elif role == \"assistant\":\n", + " trajectory_objects.append(AssistantMessage(role=role, content=content))\n", + " elif role == \"user\":\n", + " trajectory_objects.append(UserMessage(role=role, content=content))\n", + " elif role == \"tool\":\n", + " tool_id = msg.tool_call_id\n", + " trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))\n", + "\n", + " # Run the synchronous tau2-bench evaluation in a thread pool to avoid blocking\n", + " loop = asyncio.get_event_loop()\n", + " nl_assertions_checks = await loop.run_in_executor(\n", + " None, NLAssertionsEvaluator.evaluate_nl_assertions, trajectory_objects, nl_assertions\n", + " )\n", + "\n", + " all_expectations_met = all(result.met for result in nl_assertions_checks)\n", + " reward = 1.0 if all_expectations_met else 0.0\n", + "\n", + " # Build reason string\n", + " if all_expectations_met:\n", + " reason = f\"All {len(nl_assertions)} natural language assertions passed\"\n", + " else:\n", + " failed_assertions = [nl_assertions[i] for i, result in enumerate(nl_assertions_checks) if not result.met]\n", + " reason = f\"Failed assertions: {failed_assertions}\"\n", + "\n", + " return EvaluateResult(\n", + " score=reward,\n", + " reason=reason,\n", + " metrics={},\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## 3. Set Up Model Policies\n", + "\n", + "Configure the three models we want to compare: Claude 4 Opus, GPT-4.1, and Kimi K2.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Before we get into the main logic, we'd like to track quality and cost across the different models, so this is a bit of setup for tracking cost. For Kimi K2, we're using the official pricing from Firework's website, since litellm doesn't contain it." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… All required API keys are set\n" + ] + } + ], + "source": [ + "# Check for required API keys (set these as environment variables)\n", + "# Example: export ANTHROPIC_API_KEY=your-key-here\n", + "\n", + "required_keys = [\"ANTHROPIC_API_KEY\", \"OPENAI_API_KEY\", \"FIREWORKS_API_KEY\"]\n", + "missing_keys = [key for key in required_keys if not os.getenv(key)]\n", + "\n", + "if missing_keys:\n", + " print(f\"โš ๏ธ Missing API keys: {missing_keys}\")\n", + " print(\"Please set these environment variables:\")\n", + " for key in missing_keys:\n", + " print(f\" export {key}='your-key-here'\")\n", + "else:\n", + " print(\"โœ… All required API keys are set\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MANUAL_PRICING = {\n", - " \"accounts/fireworks/models/kimi-k2-instruct\": {\n", - " \"input_cost_per_1m\": 0.60, # Estimated based on Fireworks pricing\n", - " \"output_cost_per_1m\": 2.50, # Estimated - Fireworks often uses same price for input/output\n", - " }\n", - "}\n", - "\n", - "def calculate_evaluation_cost(model_id: str, llm_usage_summary: LLMUsageStats) -> Dict[str, Any]:\n", - " input_tokens = llm_usage_summary.prompt_tokens or 0\n", - " output_tokens = llm_usage_summary.completion_tokens or 0\n", - " total_tokens = llm_usage_summary.total_tokens or (input_tokens + output_tokens)\n", - " \n", - " if model_id in MANUAL_PRICING:\n", - " pricing = MANUAL_PRICING[model_id]\n", - " \n", - " input_cost = input_tokens * pricing[\"input_cost_per_1m\"] / 1000000\n", - " output_cost = output_tokens * pricing[\"output_cost_per_1m\"] / 1000000\n", - " total_cost = input_cost + output_cost\n", - " \n", - " cost_source = \"manual_pricing\"\n", - "\n", - " else:\n", - " input_cost, output_cost = cost_per_token(\n", - " model=model_id,\n", - " prompt_tokens=input_tokens,\n", - " completion_tokens=output_tokens\n", - " )\n", - " total_cost = input_cost + output_cost\n", - " \n", - " cost_source = \"litellm\"\n", - " \n", - " return {\n", - " \"total_cost\": total_cost,\n", - " \"input_cost\": input_cost,\n", - " \"output_cost\": output_cost,\n", - " \"total_tokens\": total_tokens,\n", - " \"input_tokens\": input_tokens,\n", - " \"output_tokens\": output_tokens,\n", - " \"model_id\": model_id,\n", - " \"cost_source\": cost_source,\n", - " }" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Model policies created:\n", + " - Claude 4 Sonnet (Anthropic)\n", + " - Kimi K2 (Fireworks)\n" + ] + } + ], + "source": [ + "# Create model policies\n", + "openai_policy = rk.OpenAIPolicy(\n", + " model_id=\"gpt-4.1\",\n", + " temperature=0.1,\n", + " max_tokens=4096,\n", + ")\n", + "\n", + "anthropic_policy = rk.AnthropicPolicy(\n", + " model_id=\"claude-sonnet-4-20250514\",\n", + " temperature=0.1,\n", + " max_tokens=4096,\n", + ")\n", + "\n", + "kimi_policy = rk.FireworksPolicy(\n", + " model_id=\"accounts/fireworks/models/kimi-k2-instruct\",\n", + " temperature=0.1,\n", + " max_tokens=4096,\n", + ")\n", + "\n", + "models_to_test = {\n", + " # \"gpt-4.1\": {\n", + " # \"policy\": openai_policy,\n", + " # \"name\": \"GPT-4.1\",\n", + " # \"provider\": \"OpenAI\"\n", + " # },\n", + " \"claude-sonnet-4\": {\"policy\": anthropic_policy, \"name\": \"Claude 4 Sonnet\", \"provider\": \"Anthropic\"},\n", + " \"kimi-k2\": {\"policy\": kimi_policy, \"name\": \"Kimi K2\", \"provider\": \"Fireworks\"},\n", + "}\n", + "\n", + "print(\"โœ… Model policies created:\")\n", + "for model_id, model_info in models_to_test.items():\n", + " print(f\" - {model_info['name']} ({model_info['provider']})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "## 4. Run Evaluations\n", + "\n", + "Now we'll run the airline evaluation on both models and compare their performance.\n", + "\n", + "First, let's set up some code to manager our MCP server. We will run this server later on for our MCP tools to make calls to." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we get into the main logic, we'd like to track quality and cost across the different models, so this is a bit of setup for tracking cost. For Kimi K2, we're using the official pricing from Firework's website, since litellm doesn't contain it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MANUAL_PRICING = {\n", + " \"accounts/fireworks/models/kimi-k2-instruct\": {\n", + " \"input_cost_per_1m\": 0.60, # Estimated based on Fireworks pricing\n", + " \"output_cost_per_1m\": 2.50, # Estimated - Fireworks often uses same price for input/output\n", + " }\n", + "}\n", + "\n", + "\n", + "def calculate_evaluation_cost(model_id: str, llm_usage_summary: LLMUsageStats) -> Dict[str, Any]:\n", + " input_tokens = llm_usage_summary.prompt_tokens or 0\n", + " output_tokens = llm_usage_summary.completion_tokens or 0\n", + " total_tokens = llm_usage_summary.total_tokens or (input_tokens + output_tokens)\n", + "\n", + " if model_id in MANUAL_PRICING:\n", + " pricing = MANUAL_PRICING[model_id]\n", + "\n", + " input_cost = input_tokens * pricing[\"input_cost_per_1m\"] / 1000000\n", + " output_cost = output_tokens * pricing[\"output_cost_per_1m\"] / 1000000\n", + " total_cost = input_cost + output_cost\n", + "\n", + " cost_source = \"manual_pricing\"\n", + "\n", + " else:\n", + " input_cost, output_cost = cost_per_token(\n", + " model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens\n", + " )\n", + " total_cost = input_cost + output_cost\n", + "\n", + " cost_source = \"litellm\"\n", + "\n", + " return {\n", + " \"total_cost\": total_cost,\n", + " \"input_cost\": input_cost,\n", + " \"output_cost\": output_cost,\n", + " \"total_tokens\": total_tokens,\n", + " \"input_tokens\": input_tokens,\n", + " \"output_tokens\": output_tokens,\n", + " \"model_id\": model_id,\n", + " \"cost_source\": cost_source,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below is our core logic for running the Tau2-bench eval for a single model. We use the eval protocol framework to do rk.make() and rk.rollout(), " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def run_model_evaluation(model_id: str, model_info: Dict, dataset: List[Dict]) -> Tuple[List[Dict], List[Dict]]:\n", + " \"\"\"\n", + " Run evaluation for a single model on the airline dataset.\n", + "\n", + " Returns:\n", + " Tuple of (evaluation_results, evaluation_records)\n", + " \"\"\"\n", + " print(f\"\\n๐Ÿงช Starting evaluation for {model_info['name']}...\")\n", + "\n", + " # Use context manager for automatic cleanup even on exceptions\n", + " with MCPServerManager(\"../examples/tau2_mcp/server.py\", port=8000, domain=\"airline\") as server:\n", + " policy = model_info[\"policy\"]\n", + "\n", + " envs = await rk.make(\n", + " \"http://localhost:8000/mcp/\",\n", + " dataset=dataset,\n", + " model_id=policy.model_id,\n", + " )\n", + "\n", + " print(f\"๐Ÿ“Š Created {len(envs.sessions)} environment sessions\")\n", + "\n", + " start_time = time.time()\n", + " evaluation_rows = await rk.rollout(envs, policy=policy, steps=30, max_concurrent_rollouts=8)\n", + " duration = time.time() - start_time\n", + "\n", + " print(f\"โœ… Completed {len(evaluation_rows)} evaluation rows in {duration:.2f}s\")\n", + "\n", + " # Create a helper function to process each evaluation row\n", + " async def process_evaluation_row(i: int, eval_row, dataset_item):\n", + " nl_assertions = dataset_item[\"assertions\"]\n", + "\n", + " # Run tau2-bench evaluation (now async and parallelizable!)\n", + " eval_result = await airline_eval(eval_row.messages, nl_assertions)\n", + "\n", + " # Calculate cost using existing LLMUsageStats and LiteLLM/manual pricing\n", + " llm_usage = eval_row.llm_usage_summary\n", + " print(f\" ๐Ÿ“Š LLM Usage for {dataset_item['id']}: {llm_usage}\") # Debug: show actual usage\n", + " cost_info = calculate_evaluation_cost(policy.model_id, llm_usage)\n", + "\n", + " num_assertions = len(nl_assertions)\n", + "\n", + " # Create evaluation result\n", + " result = {\n", + " \"scenario_id\": dataset_item[\"id\"],\n", + " \"model_id\": policy.model_id,\n", + " \"score\": eval_result.score,\n", + " \"num_assertions\": num_assertions,\n", + " \"cost_info\": cost_info, # Include cost information in results\n", + " }\n", + "\n", + " # Create comprehensive evaluation record\n", + " evaluation_record = {\n", + " \"model_id\": policy.model_id,\n", + " \"scenario_id\": dataset_item[\"id\"],\n", + " \"conversation_history\": eval_row.messages,\n", + " \"evaluation\": {\n", + " \"score\": eval_result.score,\n", + " \"num_assertions\": num_assertions,\n", + " \"reason\": eval_result.reason,\n", + " \"assertions\": [\n", + " {\n", + " \"assertion\": assertion,\n", + " \"passed\": eval_result.score > 0, # All pass or all fail for this simple implementation\n", + " }\n", + " for assertion in nl_assertions\n", + " ],\n", + " },\n", + " \"cost_info\": cost_info, # Add cost information to evaluation record\n", + " \"timestamp\": datetime.now().isoformat(),\n", + " }\n", + "\n", + " print(f\" ๐Ÿ“‹ {result['scenario_id']}: {result['score']:.1f}, total {result['num_assertions']} assertions)\")\n", + " return result, evaluation_record\n", + "\n", + " # Process all evaluation rows in parallel using asyncio.gather\n", + " print(f\"๐Ÿš€ Processing {len(evaluation_rows)} evaluation row evaluations in parallel...\")\n", + " eval_start_time = time.time()\n", + "\n", + " tasks = [process_evaluation_row(i, eval_row, dataset[i]) for i, eval_row in enumerate(evaluation_rows)]\n", + "\n", + " # Run all evaluations concurrently\n", + " results_and_records = await asyncio.gather(*tasks)\n", + "\n", + " eval_duration = time.time() - eval_start_time\n", + " print(f\"โœ… Completed parallel evaluations in {eval_duration:.2f}s\")\n", + "\n", + " # Separate results and evaluation records\n", + " results = []\n", + " evaluation_records = []\n", + " for result, evaluation_record in results_and_records:\n", + " results.append(result)\n", + " evaluation_records.append(evaluation_record)\n", + "\n", + " await envs.close()\n", + " # Server cleanup happens automatically via context manager\n", + "\n", + " return results, evaluation_records" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Below is our core logic for running the Tau2-bench eval for a single model. We use the eval protocol framework to do rk.make() and rk.rollout(), " - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "๐Ÿงช Starting evaluation for Claude 4 Sonnet...\n", + "โœ… Server started successfully on port 8000\n", + "๐Ÿ“Š Created 50 environment sessions\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def run_model_evaluation(model_id: str, model_info: Dict, dataset: List[Dict]) -> Tuple[List[Dict], List[Dict]]:\n", - " \"\"\"\n", - " Run evaluation for a single model on the airline dataset.\n", - " \n", - " Returns:\n", - " Tuple of (evaluation_results, evaluation_records)\n", - " \"\"\"\n", - " print(f\"\\n๐Ÿงช Starting evaluation for {model_info['name']}...\")\n", - "\n", - " # Use context manager for automatic cleanup even on exceptions\n", - " with MCPServerManager(\"../examples/tau2_mcp/server.py\", port=8000, domain=\"airline\") as server:\n", - " policy = model_info[\"policy\"]\n", - " \n", - " envs = await rk.make(\n", - " \"http://localhost:8000/mcp/\",\n", - " dataset=dataset, \n", - " model_id=policy.model_id,\n", - " )\n", - " \n", - " print(f\"๐Ÿ“Š Created {len(envs.sessions)} environment sessions\")\n", - " \n", - " start_time = time.time()\n", - " evaluation_rows = await rk.rollout(envs, policy=policy, steps=30, max_concurrent_rollouts=8)\n", - " duration = time.time() - start_time\n", - " \n", - " print(f\"โœ… Completed {len(evaluation_rows)} evaluation rows in {duration:.2f}s\")\n", - " \n", - " # Create a helper function to process each evaluation row\n", - " async def process_evaluation_row(i: int, eval_row, dataset_item):\n", - " nl_assertions = dataset_item[\"assertions\"]\n", - " \n", - " # Run tau2-bench evaluation (now async and parallelizable!)\n", - " eval_result = await airline_eval(eval_row.messages, nl_assertions)\n", - " \n", - " # Calculate cost using existing LLMUsageStats and LiteLLM/manual pricing\n", - " llm_usage = eval_row.llm_usage_summary\n", - " print(f\" ๐Ÿ“Š LLM Usage for {dataset_item['id']}: {llm_usage}\") # Debug: show actual usage\n", - " cost_info = calculate_evaluation_cost(policy.model_id, llm_usage)\n", - "\n", - " num_assertions = len(nl_assertions)\n", - "\n", - " # Create evaluation result\n", - " result = {\n", - " \"scenario_id\": dataset_item[\"id\"],\n", - " \"model_id\": policy.model_id,\n", - " \"score\": eval_result.score,\n", - " \"num_assertions\": num_assertions,\n", - " \"cost_info\": cost_info, # Include cost information in results\n", - " }\n", - " \n", - " # Create comprehensive evaluation record\n", - " evaluation_record = {\n", - " \"model_id\": policy.model_id,\n", - " \"scenario_id\": dataset_item[\"id\"],\n", - " \"conversation_history\": eval_row.messages,\n", - " \"evaluation\": {\n", - " \"score\": eval_result.score,\n", - " \"num_assertions\": num_assertions,\n", - " \"reason\": eval_result.reason,\n", - " \"assertions\": [\n", - " {\n", - " \"assertion\": assertion,\n", - " \"passed\": eval_result.score > 0 # All pass or all fail for this simple implementation\n", - " }\n", - " for assertion in nl_assertions\n", - " ]\n", - " },\n", - " \"cost_info\": cost_info, # Add cost information to evaluation record\n", - " \"timestamp\": datetime.now().isoformat(),\n", - " }\n", - " \n", - " print(f\" ๐Ÿ“‹ {result['scenario_id']}: {result['score']:.1f}, total {result['num_assertions']} assertions)\")\n", - " return result, evaluation_record\n", - " \n", - " # Process all evaluation rows in parallel using asyncio.gather\n", - " print(f\"๐Ÿš€ Processing {len(evaluation_rows)} evaluation row evaluations in parallel...\")\n", - " eval_start_time = time.time()\n", - " \n", - " tasks = [\n", - " process_evaluation_row(i, eval_row, dataset[i]) \n", - " for i, eval_row in enumerate(evaluation_rows)\n", - " ]\n", - " \n", - " # Run all evaluations concurrently\n", - " results_and_records = await asyncio.gather(*tasks)\n", - " \n", - " eval_duration = time.time() - eval_start_time\n", - " print(f\"โœ… Completed parallel evaluations in {eval_duration:.2f}s\")\n", - " \n", - " # Separate results and evaluation records\n", - " results = []\n", - " evaluation_records = []\n", - " for result, evaluation_record in results_and_records:\n", - " results.append(result)\n", - " evaluation_records.append(evaluation_record)\n", - " \n", - " await envs.close()\n", - " # Server cleanup happens automatically via context manager\n", - " \n", - " return results, evaluation_records" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:eval_protocol.mcp.client.connection:Session c581b1937dfd10fa4e177cc027a41035: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Flight HAT030 not available on date 2024-05-13. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session aa8e35d6b8cfee9df34e24b405b60f94: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 41dba4c12d152158564c1e49f986c220: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Certificate cannot be used to update reservation. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 105e5b441bcc0be055a231d0189ee750: Invalid JSON from book_reservation: Error executing tool book_reservation: Payment amount does not add up, total price is 290, but paid 304. Error: Expecting value: line 1 column 1 (char 0)\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "๐Ÿงช Starting evaluation for Claude 4 Sonnet...\n", - "โœ… Server started successfully on port 8000\n", - "๐Ÿ“Š Created 50 environment sessions\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:eval_protocol.mcp.client.connection:Session c581b1937dfd10fa4e177cc027a41035: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Flight HAT030 not available on date 2024-05-13. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session aa8e35d6b8cfee9df34e24b405b60f94: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 84de98d36b6446b307f60fe1e534a067: Invalid JSON from update_reservation_baggages: Error executing tool update_reservation_baggages: Gift card balance is not enough. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 41dba4c12d152158564c1e49f986c220: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Certificate cannot be used to update reservation. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 105e5b441bcc0be055a231d0189ee750: Invalid JSON from book_reservation: Error executing tool book_reservation: Payment amount does not add up, total price is 290, but paid 304. Error: Expecting value: line 1 column 1 (char 0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "๐Ÿงน Closing 50 MCP sessions...\n", - "โœ… All MCP sessions closed.\n", - "โœ… Completed 50 trajectories in 438.92s\n", - "๐Ÿš€ Processing 50 trajectory evaluations in parallel...\n", - " ๐Ÿ“Š LLM Usage for airline_task_6: {'prompt_tokens': 11809, 'completion_tokens': 439, 'total_tokens': 12248}\n", - " ๐Ÿ“‹ airline_task_6: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_1: {'prompt_tokens': 48521, 'completion_tokens': 465, 'total_tokens': 48986}\n", - " ๐Ÿ“‹ airline_task_1: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_0: {'prompt_tokens': 18067, 'completion_tokens': 255, 'total_tokens': 18322}\n", - " ๐Ÿ“‹ airline_task_0: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_10: {'prompt_tokens': 70113, 'completion_tokens': 1132, 'total_tokens': 71245}\n", - " ๐Ÿ“‹ airline_task_10: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_13: {'prompt_tokens': 73350, 'completion_tokens': 1136, 'total_tokens': 74486}\n", - " ๐Ÿ“‹ airline_task_13: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_5: {'prompt_tokens': 31643, 'completion_tokens': 416, 'total_tokens': 32059}\n", - " ๐Ÿ“‹ airline_task_5: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_3: {'prompt_tokens': 18289, 'completion_tokens': 281, 'total_tokens': 18570}\n", - " ๐Ÿ“‹ airline_task_3: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_4: {'prompt_tokens': 47856, 'completion_tokens': 838, 'total_tokens': 48694}\n", - " ๐Ÿ“‹ airline_task_4: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_12: {'prompt_tokens': 32974, 'completion_tokens': 545, 'total_tokens': 33519}\n", - " ๐Ÿ“‹ airline_task_12: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_19: {'prompt_tokens': 31917, 'completion_tokens': 452, 'total_tokens': 32369}\n", - " ๐Ÿ“‹ airline_task_19: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_16: {'prompt_tokens': 34448, 'completion_tokens': 748, 'total_tokens': 35196}\n", - " ๐Ÿ“‹ airline_task_16: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_9: {'prompt_tokens': 48912, 'completion_tokens': 825, 'total_tokens': 49737}\n", - " ๐Ÿ“‹ airline_task_9: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_15: {'prompt_tokens': 56489, 'completion_tokens': 949, 'total_tokens': 57438}\n", - " ๐Ÿ“‹ airline_task_15: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_11: {'prompt_tokens': 32715, 'completion_tokens': 395, 'total_tokens': 33110}\n", - " ๐Ÿ“‹ airline_task_11: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_2: {'prompt_tokens': 64625, 'completion_tokens': 925, 'total_tokens': 65550}\n", - " ๐Ÿ“‹ airline_task_2: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_14: {'prompt_tokens': 105792, 'completion_tokens': 1297, 'total_tokens': 107089}\n", - " ๐Ÿ“‹ airline_task_14: 0.0, total 5 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_7: {'prompt_tokens': 49692, 'completion_tokens': 518, 'total_tokens': 50210}\n", - " ๐Ÿ“‹ airline_task_7: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_26: {'prompt_tokens': 18258, 'completion_tokens': 498, 'total_tokens': 18756}\n", - " ๐Ÿ“‹ airline_task_26: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_8: {'prompt_tokens': 56129, 'completion_tokens': 939, 'total_tokens': 57068}\n", - " ๐Ÿ“‹ airline_task_8: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_17: {'prompt_tokens': 57622, 'completion_tokens': 710, 'total_tokens': 58332}\n", - " ๐Ÿ“‹ airline_task_17: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_28: {'prompt_tokens': 18305, 'completion_tokens': 519, 'total_tokens': 18824}\n", - " ๐Ÿ“‹ airline_task_28: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_22: {'prompt_tokens': 102233, 'completion_tokens': 1630, 'total_tokens': 103863}\n", - " ๐Ÿ“‹ airline_task_22: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_20: {'prompt_tokens': 75987, 'completion_tokens': 1169, 'total_tokens': 77156}\n", - " ๐Ÿ“‹ airline_task_20: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_31: {'prompt_tokens': 17973, 'completion_tokens': 508, 'total_tokens': 18481}\n", - " ๐Ÿ“‹ airline_task_31: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_18: {'prompt_tokens': 290410, 'completion_tokens': 3217, 'total_tokens': 293627}\n", - " ๐Ÿ“‹ airline_task_18: 0.0, total 6 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_27: {'prompt_tokens': 91697, 'completion_tokens': 900, 'total_tokens': 92597}\n", - " ๐Ÿ“‹ airline_task_27: 1.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_29: {'prompt_tokens': 45098, 'completion_tokens': 921, 'total_tokens': 46019}\n", - " ๐Ÿ“‹ airline_task_29: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_30: {'prompt_tokens': 34284, 'completion_tokens': 851, 'total_tokens': 35135}\n", - " ๐Ÿ“‹ airline_task_30: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_25: {'prompt_tokens': 34098, 'completion_tokens': 791, 'total_tokens': 34889}\n", - " ๐Ÿ“‹ airline_task_25: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_32: {'prompt_tokens': 25880, 'completion_tokens': 505, 'total_tokens': 26385}\n", - " ๐Ÿ“‹ airline_task_32: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_21: {'prompt_tokens': 91852, 'completion_tokens': 1140, 'total_tokens': 92992}\n", - " ๐Ÿ“‹ airline_task_21: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_34: {'prompt_tokens': 33951, 'completion_tokens': 1036, 'total_tokens': 34987}\n", - " ๐Ÿ“‹ airline_task_34: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_24: {'prompt_tokens': 64616, 'completion_tokens': 1547, 'total_tokens': 66163}\n", - " ๐Ÿ“‹ airline_task_24: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_36: {'prompt_tokens': 27295, 'completion_tokens': 347, 'total_tokens': 27642}\n", - " ๐Ÿ“‹ airline_task_36: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_41: {'prompt_tokens': 66964, 'completion_tokens': 577, 'total_tokens': 67541}\n", - " ๐Ÿ“‹ airline_task_41: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_42: {'prompt_tokens': 86379, 'completion_tokens': 951, 'total_tokens': 87330}\n", - " ๐Ÿ“‹ airline_task_42: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_40: {'prompt_tokens': 11763, 'completion_tokens': 229, 'total_tokens': 11992}\n", - " ๐Ÿ“‹ airline_task_40: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_33: {'prompt_tokens': 43097, 'completion_tokens': 814, 'total_tokens': 43911}\n", - " ๐Ÿ“‹ airline_task_33: 1.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_45: {'prompt_tokens': 5584, 'completion_tokens': 62, 'total_tokens': 5646}\n", - " ๐Ÿ“‹ airline_task_45: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_46: {'prompt_tokens': 12057, 'completion_tokens': 322, 'total_tokens': 12379}\n", - " ๐Ÿ“‹ airline_task_46: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_23: {'prompt_tokens': 134261, 'completion_tokens': 2398, 'total_tokens': 136659}\n", - " ๐Ÿ“‹ airline_task_23: 1.0, total 8 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_47: {'prompt_tokens': 17987, 'completion_tokens': 221, 'total_tokens': 18208}\n", - " ๐Ÿ“‹ airline_task_47: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_38: {'prompt_tokens': 40105, 'completion_tokens': 852, 'total_tokens': 40957}\n", - " ๐Ÿ“‹ airline_task_38: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_35: {'prompt_tokens': 56123, 'completion_tokens': 1211, 'total_tokens': 57334}\n", - " ๐Ÿ“‹ airline_task_35: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_49: {'prompt_tokens': 11756, 'completion_tokens': 332, 'total_tokens': 12088}\n", - " ๐Ÿ“‹ airline_task_49: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_43: {'prompt_tokens': 56223, 'completion_tokens': 554, 'total_tokens': 56777}\n", - " ๐Ÿ“‹ airline_task_43: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_48: {'prompt_tokens': 18950, 'completion_tokens': 592, 'total_tokens': 19542}\n", - " ๐Ÿ“‹ airline_task_48: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_37: {'prompt_tokens': 50314, 'completion_tokens': 870, 'total_tokens': 51184}\n", - " ๐Ÿ“‹ airline_task_37: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_39: {'prompt_tokens': 100087, 'completion_tokens': 1072, 'total_tokens': 101159}\n", - " ๐Ÿ“‹ airline_task_39: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_44: {'prompt_tokens': 146859, 'completion_tokens': 1910, 'total_tokens': 148769}\n", - " ๐Ÿ“‹ airline_task_44: 0.0, total 5 assertions)\n", - "โœ… Completed parallel evaluations in 16.19s\n", - "๐Ÿงน Closing 50 MCP sessions...\n", - "โœ… All MCP sessions closed.\n", - "๐Ÿ›‘ Stopping server on port 8000...\n", - "๐Ÿงน Cleaned up log file: /Users/derekxu/Documents/code/python-sdk/local_evals/server_output_airline_8000.log\n", - "\n", - "๐Ÿงช Starting evaluation for Kimi K2...\n", - "โœ… Server started successfully on port 8000\n", - "๐Ÿ“Š Created 50 environment sessions\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:eval_protocol.mcp.client.connection:Session 77f6847cd7f4eaa908955b68fc08b75e: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation L7X4P9 not found. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Control plane status endpoint timed out after 3.0s\n", - "WARNING:eval_protocol.mcp.client.connection:Session 0de216038acb0986989909c3b22b5373: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation 45698213 not found. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Session 0de216038acb0986989909c3b22b5373: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation Q7ZB34 not found. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Control plane reward endpoint timed out after 3.0s\n", - "WARNING:eval_protocol.mcp.client.connection:Session 20dd3f68f9165c4cc4bd81aec770c9d4: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Payment method not found. Error: Expecting value: line 1 column 1 (char 0)\n", - "WARNING:eval_protocol.mcp.client.connection:Control plane status endpoint timed out after 3.0s\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "๐Ÿงน Closing 50 MCP sessions...\n", - "โœ… All MCP sessions closed.\n", - "โœ… Completed 50 trajectories in 373.16s\n", - "๐Ÿš€ Processing 50 trajectory evaluations in parallel...\n", - " ๐Ÿ“Š LLM Usage for airline_task_0: {'prompt_tokens': 10394, 'completion_tokens': 348, 'total_tokens': 10742}\n", - " ๐Ÿ“‹ airline_task_0: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_1: {'prompt_tokens': 42103, 'completion_tokens': 192, 'total_tokens': 42295}\n", - " ๐Ÿ“‹ airline_task_1: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_6: {'prompt_tokens': 4932, 'completion_tokens': 38, 'total_tokens': 4970}\n", - " ๐Ÿ“‹ airline_task_6: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_13: {'prompt_tokens': 38788, 'completion_tokens': 663, 'total_tokens': 39451}\n", - " ๐Ÿ“‹ airline_task_13: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_10: {'prompt_tokens': 43693, 'completion_tokens': 366, 'total_tokens': 44059}\n", - " ๐Ÿ“‹ airline_task_10: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_5: {'prompt_tokens': 33828, 'completion_tokens': 479, 'total_tokens': 34307}\n", - " ๐Ÿ“‹ airline_task_5: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_3: {'prompt_tokens': 15963, 'completion_tokens': 230, 'total_tokens': 16193}\n", - " ๐Ÿ“‹ airline_task_3: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_12: {'prompt_tokens': 37553, 'completion_tokens': 577, 'total_tokens': 38130}\n", - " ๐Ÿ“‹ airline_task_12: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_4: {'prompt_tokens': 55540, 'completion_tokens': 610, 'total_tokens': 56150}\n", - " ๐Ÿ“‹ airline_task_4: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_19: {'prompt_tokens': 16020, 'completion_tokens': 118, 'total_tokens': 16138}\n", - " ๐Ÿ“‹ airline_task_19: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_2: {'prompt_tokens': 37009, 'completion_tokens': 243, 'total_tokens': 37252}\n", - " ๐Ÿ“‹ airline_task_2: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_11: {'prompt_tokens': 4766, 'completion_tokens': 146, 'total_tokens': 4912}\n", - " ๐Ÿ“‹ airline_task_11: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_15: {'prompt_tokens': 57885, 'completion_tokens': 533, 'total_tokens': 58418}\n", - " ๐Ÿ“‹ airline_task_15: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_16: {'prompt_tokens': 30439, 'completion_tokens': 429, 'total_tokens': 30868}\n", - " ๐Ÿ“‹ airline_task_16: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_9: {'prompt_tokens': 35496, 'completion_tokens': 376, 'total_tokens': 35872}\n", - " ๐Ÿ“‹ airline_task_9: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_17: {'prompt_tokens': 49670, 'completion_tokens': 610, 'total_tokens': 50280}\n", - " ๐Ÿ“‹ airline_task_17: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_8: {'prompt_tokens': 49402, 'completion_tokens': 510, 'total_tokens': 49912}\n", - " ๐Ÿ“‹ airline_task_8: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_7: {'prompt_tokens': 51445, 'completion_tokens': 275, 'total_tokens': 51720}\n", - " ๐Ÿ“‹ airline_task_7: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_26: {'prompt_tokens': 15790, 'completion_tokens': 426, 'total_tokens': 16216}\n", - " ๐Ÿ“‹ airline_task_26: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_14: {'prompt_tokens': 59997, 'completion_tokens': 1031, 'total_tokens': 61028}\n", - " ๐Ÿ“‹ airline_task_14: 0.0, total 5 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_21: {'prompt_tokens': 96182, 'completion_tokens': 603, 'total_tokens': 96785}\n", - " ๐Ÿ“‹ airline_task_21: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_22: {'prompt_tokens': 28531, 'completion_tokens': 143, 'total_tokens': 28674}\n", - " ๐Ÿ“‹ airline_task_22: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_25: {'prompt_tokens': 10238, 'completion_tokens': 138, 'total_tokens': 10376}\n", - " ๐Ÿ“‹ airline_task_25: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_18: {'prompt_tokens': 85908, 'completion_tokens': 689, 'total_tokens': 86597}\n", - " ๐Ÿ“‹ airline_task_18: 1.0, total 6 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_20: {'prompt_tokens': 42172, 'completion_tokens': 548, 'total_tokens': 42720}\n", - " ๐Ÿ“‹ airline_task_20: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_27: {'prompt_tokens': 80408, 'completion_tokens': 488, 'total_tokens': 80896}\n", - " ๐Ÿ“‹ airline_task_27: 1.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_30: {'prompt_tokens': 30393, 'completion_tokens': 396, 'total_tokens': 30789}\n", - " ๐Ÿ“‹ airline_task_30: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_24: {'prompt_tokens': 63523, 'completion_tokens': 731, 'total_tokens': 64254}\n", - " ๐Ÿ“‹ airline_task_24: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_34: {'prompt_tokens': 23049, 'completion_tokens': 485, 'total_tokens': 23534}\n", - " ๐Ÿ“‹ airline_task_34: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_28: {'prompt_tokens': 10066, 'completion_tokens': 360, 'total_tokens': 10426}\n", - " ๐Ÿ“‹ airline_task_28: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_31: {'prompt_tokens': 16464, 'completion_tokens': 226, 'total_tokens': 16690}\n", - " ๐Ÿ“‹ airline_task_31: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_36: {'prompt_tokens': 10772, 'completion_tokens': 202, 'total_tokens': 10974}\n", - " ๐Ÿ“‹ airline_task_36: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_40: {'prompt_tokens': 10333, 'completion_tokens': 114, 'total_tokens': 10447}\n", - " ๐Ÿ“‹ airline_task_40: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_33: {'prompt_tokens': 30350, 'completion_tokens': 393, 'total_tokens': 30743}\n", - " ๐Ÿ“‹ airline_task_33: 1.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_29: {'prompt_tokens': 38672, 'completion_tokens': 501, 'total_tokens': 39173}\n", - " ๐Ÿ“‹ airline_task_29: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_41: {'prompt_tokens': 49179, 'completion_tokens': 227, 'total_tokens': 49406}\n", - " ๐Ÿ“‹ airline_task_41: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_32: {'prompt_tokens': 22825, 'completion_tokens': 297, 'total_tokens': 23122}\n", - " ๐Ÿ“‹ airline_task_32: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_42: {'prompt_tokens': 84720, 'completion_tokens': 491, 'total_tokens': 85211}\n", - " ๐Ÿ“‹ airline_task_42: 1.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_37: {'prompt_tokens': 61432, 'completion_tokens': 572, 'total_tokens': 62004}\n", - " ๐Ÿ“‹ airline_task_37: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_46: {'prompt_tokens': 5083, 'completion_tokens': 83, 'total_tokens': 5166}\n", - " ๐Ÿ“‹ airline_task_46: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_47: {'prompt_tokens': 10303, 'completion_tokens': 74, 'total_tokens': 10377}\n", - " ๐Ÿ“‹ airline_task_47: 0.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_45: {'prompt_tokens': 10672, 'completion_tokens': 107, 'total_tokens': 10779}\n", - " ๐Ÿ“‹ airline_task_45: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_38: {'prompt_tokens': 35391, 'completion_tokens': 495, 'total_tokens': 35886}\n", - " ๐Ÿ“‹ airline_task_38: 0.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_39: {'prompt_tokens': 77165, 'completion_tokens': 331, 'total_tokens': 77496}\n", - " ๐Ÿ“‹ airline_task_39: 1.0, total 4 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_48: {'prompt_tokens': 10259, 'completion_tokens': 330, 'total_tokens': 10589}\n", - " ๐Ÿ“‹ airline_task_48: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_49: {'prompt_tokens': 10091, 'completion_tokens': 257, 'total_tokens': 10348}\n", - " ๐Ÿ“‹ airline_task_49: 1.0, total 1 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_35: {'prompt_tokens': 32788, 'completion_tokens': 406, 'total_tokens': 33194}\n", - " ๐Ÿ“‹ airline_task_35: 0.0, total 3 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_43: {'prompt_tokens': 43329, 'completion_tokens': 275, 'total_tokens': 43604}\n", - " ๐Ÿ“‹ airline_task_43: 0.0, total 2 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_23: {'prompt_tokens': 50198, 'completion_tokens': 921, 'total_tokens': 51119}\n", - " ๐Ÿ“‹ airline_task_23: 0.0, total 8 assertions)\n", - " ๐Ÿ“Š LLM Usage for airline_task_44: {'prompt_tokens': 41578, 'completion_tokens': 345, 'total_tokens': 41923}\n", - " ๐Ÿ“‹ airline_task_44: 0.0, total 5 assertions)\n", - "โœ… Completed parallel evaluations in 17.52s\n", - "๐Ÿงน Closing 50 MCP sessions...\n", - "โœ… All MCP sessions closed.\n", - "๐Ÿ›‘ Stopping server on port 8000...\n", - "๐Ÿงน Cleaned up log file: /Users/derekxu/Documents/code/python-sdk/local_evals/server_output_airline_8000.log\n", - "\n", - "โœ… Completed evaluations for 2 models\n", - "๐Ÿ“Š Total results: 100\n", - "๐Ÿ“Š Total trajectories: 100\n" - ] - } - ], - "source": [ - "all_results = []\n", - "all_evaluation_records = []\n", - "\n", - "for model_id, model_info in models_to_test.items():\n", - " model_results, evaluation_records = await run_model_evaluation(model_id, model_info, tau2_eval_dataset)\n", - " all_results.extend(model_results)\n", - " all_evaluation_records.extend(evaluation_records)\n", - "\n", - "print(f\"\\nโœ… Completed evaluations for {len(models_to_test)} models\")\n", - "print(f\"๐Ÿ“Š Total results: {len(all_results)}\")\n", - "print(f\"๐Ÿ“Š Total evaluation records: {len(all_evaluation_records)}\")" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿงน Closing 50 MCP sessions...\n", + "โœ… All MCP sessions closed.\n", + "โœ… Completed 50 trajectories in 438.92s\n", + "๐Ÿš€ Processing 50 trajectory evaluations in parallel...\n", + " ๐Ÿ“Š LLM Usage for airline_task_6: {'prompt_tokens': 11809, 'completion_tokens': 439, 'total_tokens': 12248}\n", + " ๐Ÿ“‹ airline_task_6: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_1: {'prompt_tokens': 48521, 'completion_tokens': 465, 'total_tokens': 48986}\n", + " ๐Ÿ“‹ airline_task_1: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_0: {'prompt_tokens': 18067, 'completion_tokens': 255, 'total_tokens': 18322}\n", + " ๐Ÿ“‹ airline_task_0: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_10: {'prompt_tokens': 70113, 'completion_tokens': 1132, 'total_tokens': 71245}\n", + " ๐Ÿ“‹ airline_task_10: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_13: {'prompt_tokens': 73350, 'completion_tokens': 1136, 'total_tokens': 74486}\n", + " ๐Ÿ“‹ airline_task_13: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_5: {'prompt_tokens': 31643, 'completion_tokens': 416, 'total_tokens': 32059}\n", + " ๐Ÿ“‹ airline_task_5: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_3: {'prompt_tokens': 18289, 'completion_tokens': 281, 'total_tokens': 18570}\n", + " ๐Ÿ“‹ airline_task_3: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_4: {'prompt_tokens': 47856, 'completion_tokens': 838, 'total_tokens': 48694}\n", + " ๐Ÿ“‹ airline_task_4: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_12: {'prompt_tokens': 32974, 'completion_tokens': 545, 'total_tokens': 33519}\n", + " ๐Ÿ“‹ airline_task_12: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_19: {'prompt_tokens': 31917, 'completion_tokens': 452, 'total_tokens': 32369}\n", + " ๐Ÿ“‹ airline_task_19: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_16: {'prompt_tokens': 34448, 'completion_tokens': 748, 'total_tokens': 35196}\n", + " ๐Ÿ“‹ airline_task_16: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_9: {'prompt_tokens': 48912, 'completion_tokens': 825, 'total_tokens': 49737}\n", + " ๐Ÿ“‹ airline_task_9: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_15: {'prompt_tokens': 56489, 'completion_tokens': 949, 'total_tokens': 57438}\n", + " ๐Ÿ“‹ airline_task_15: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_11: {'prompt_tokens': 32715, 'completion_tokens': 395, 'total_tokens': 33110}\n", + " ๐Ÿ“‹ airline_task_11: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_2: {'prompt_tokens': 64625, 'completion_tokens': 925, 'total_tokens': 65550}\n", + " ๐Ÿ“‹ airline_task_2: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_14: {'prompt_tokens': 105792, 'completion_tokens': 1297, 'total_tokens': 107089}\n", + " ๐Ÿ“‹ airline_task_14: 0.0, total 5 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_7: {'prompt_tokens': 49692, 'completion_tokens': 518, 'total_tokens': 50210}\n", + " ๐Ÿ“‹ airline_task_7: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_26: {'prompt_tokens': 18258, 'completion_tokens': 498, 'total_tokens': 18756}\n", + " ๐Ÿ“‹ airline_task_26: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_8: {'prompt_tokens': 56129, 'completion_tokens': 939, 'total_tokens': 57068}\n", + " ๐Ÿ“‹ airline_task_8: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_17: {'prompt_tokens': 57622, 'completion_tokens': 710, 'total_tokens': 58332}\n", + " ๐Ÿ“‹ airline_task_17: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_28: {'prompt_tokens': 18305, 'completion_tokens': 519, 'total_tokens': 18824}\n", + " ๐Ÿ“‹ airline_task_28: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_22: {'prompt_tokens': 102233, 'completion_tokens': 1630, 'total_tokens': 103863}\n", + " ๐Ÿ“‹ airline_task_22: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_20: {'prompt_tokens': 75987, 'completion_tokens': 1169, 'total_tokens': 77156}\n", + " ๐Ÿ“‹ airline_task_20: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_31: {'prompt_tokens': 17973, 'completion_tokens': 508, 'total_tokens': 18481}\n", + " ๐Ÿ“‹ airline_task_31: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_18: {'prompt_tokens': 290410, 'completion_tokens': 3217, 'total_tokens': 293627}\n", + " ๐Ÿ“‹ airline_task_18: 0.0, total 6 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_27: {'prompt_tokens': 91697, 'completion_tokens': 900, 'total_tokens': 92597}\n", + " ๐Ÿ“‹ airline_task_27: 1.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_29: {'prompt_tokens': 45098, 'completion_tokens': 921, 'total_tokens': 46019}\n", + " ๐Ÿ“‹ airline_task_29: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_30: {'prompt_tokens': 34284, 'completion_tokens': 851, 'total_tokens': 35135}\n", + " ๐Ÿ“‹ airline_task_30: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_25: {'prompt_tokens': 34098, 'completion_tokens': 791, 'total_tokens': 34889}\n", + " ๐Ÿ“‹ airline_task_25: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_32: {'prompt_tokens': 25880, 'completion_tokens': 505, 'total_tokens': 26385}\n", + " ๐Ÿ“‹ airline_task_32: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_21: {'prompt_tokens': 91852, 'completion_tokens': 1140, 'total_tokens': 92992}\n", + " ๐Ÿ“‹ airline_task_21: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_34: {'prompt_tokens': 33951, 'completion_tokens': 1036, 'total_tokens': 34987}\n", + " ๐Ÿ“‹ airline_task_34: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_24: {'prompt_tokens': 64616, 'completion_tokens': 1547, 'total_tokens': 66163}\n", + " ๐Ÿ“‹ airline_task_24: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_36: {'prompt_tokens': 27295, 'completion_tokens': 347, 'total_tokens': 27642}\n", + " ๐Ÿ“‹ airline_task_36: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_41: {'prompt_tokens': 66964, 'completion_tokens': 577, 'total_tokens': 67541}\n", + " ๐Ÿ“‹ airline_task_41: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_42: {'prompt_tokens': 86379, 'completion_tokens': 951, 'total_tokens': 87330}\n", + " ๐Ÿ“‹ airline_task_42: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_40: {'prompt_tokens': 11763, 'completion_tokens': 229, 'total_tokens': 11992}\n", + " ๐Ÿ“‹ airline_task_40: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_33: {'prompt_tokens': 43097, 'completion_tokens': 814, 'total_tokens': 43911}\n", + " ๐Ÿ“‹ airline_task_33: 1.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_45: {'prompt_tokens': 5584, 'completion_tokens': 62, 'total_tokens': 5646}\n", + " ๐Ÿ“‹ airline_task_45: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_46: {'prompt_tokens': 12057, 'completion_tokens': 322, 'total_tokens': 12379}\n", + " ๐Ÿ“‹ airline_task_46: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_23: {'prompt_tokens': 134261, 'completion_tokens': 2398, 'total_tokens': 136659}\n", + " ๐Ÿ“‹ airline_task_23: 1.0, total 8 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_47: {'prompt_tokens': 17987, 'completion_tokens': 221, 'total_tokens': 18208}\n", + " ๐Ÿ“‹ airline_task_47: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_38: {'prompt_tokens': 40105, 'completion_tokens': 852, 'total_tokens': 40957}\n", + " ๐Ÿ“‹ airline_task_38: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_35: {'prompt_tokens': 56123, 'completion_tokens': 1211, 'total_tokens': 57334}\n", + " ๐Ÿ“‹ airline_task_35: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_49: {'prompt_tokens': 11756, 'completion_tokens': 332, 'total_tokens': 12088}\n", + " ๐Ÿ“‹ airline_task_49: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_43: {'prompt_tokens': 56223, 'completion_tokens': 554, 'total_tokens': 56777}\n", + " ๐Ÿ“‹ airline_task_43: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_48: {'prompt_tokens': 18950, 'completion_tokens': 592, 'total_tokens': 19542}\n", + " ๐Ÿ“‹ airline_task_48: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_37: {'prompt_tokens': 50314, 'completion_tokens': 870, 'total_tokens': 51184}\n", + " ๐Ÿ“‹ airline_task_37: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_39: {'prompt_tokens': 100087, 'completion_tokens': 1072, 'total_tokens': 101159}\n", + " ๐Ÿ“‹ airline_task_39: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_44: {'prompt_tokens': 146859, 'completion_tokens': 1910, 'total_tokens': 148769}\n", + " ๐Ÿ“‹ airline_task_44: 0.0, total 5 assertions)\n", + "โœ… Completed parallel evaluations in 16.19s\n", + "๐Ÿงน Closing 50 MCP sessions...\n", + "โœ… All MCP sessions closed.\n", + "๐Ÿ›‘ Stopping server on port 8000...\n", + "๐Ÿงน Cleaned up log file: /Users/derekxu/Documents/code/python-sdk/local_evals/server_output_airline_8000.log\n", + "\n", + "๐Ÿงช Starting evaluation for Kimi K2...\n", + "โœ… Server started successfully on port 8000\n", + "๐Ÿ“Š Created 50 environment sessions\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "raw" - } - }, - "source": [ - "## 5. Analyze Results\n", - "\n", - "Let's analyze and visualize the comparison between Claude 4 Opus, GPT-4.1, and Kimi K2.\n" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:eval_protocol.mcp.client.connection:Session 77f6847cd7f4eaa908955b68fc08b75e: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation L7X4P9 not found. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Control plane status endpoint timed out after 3.0s\n", + "WARNING:eval_protocol.mcp.client.connection:Session 0de216038acb0986989909c3b22b5373: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation 45698213 not found. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Session 0de216038acb0986989909c3b22b5373: Invalid JSON from get_reservation_details: Error executing tool get_reservation_details: Reservation Q7ZB34 not found. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Control plane reward endpoint timed out after 3.0s\n", + "WARNING:eval_protocol.mcp.client.connection:Session 20dd3f68f9165c4cc4bd81aec770c9d4: Invalid JSON from update_reservation_flights: Error executing tool update_reservation_flights: Payment method not found. Error: Expecting value: line 1 column 1 (char 0)\n", + "WARNING:eval_protocol.mcp.client.connection:Control plane status endpoint timed out after 3.0s\n" + ] }, { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "๐Ÿ“ˆ Summary Statistics:\n", - " Claude 4 Sonnet: 54.00% success rate (27.0/50) - Cost: $8.79 (via litellm)\n", - " Kimi K2: 46.00% success rate (23.0/50) - Cost: $1.14 (via manual_pricing)\n", - "\n", - "๐Ÿ’ฐ Total evaluation cost: $9.93\n", - "๐Ÿ“Š Cost calculation uses actual API usage data from LLMUsageStats\n" - ] - } - ], - "source": [ - "model_id_to_config = {}\n", - "for config_key, model_info in models_to_test.items():\n", - " actual_model_id = model_info[\"policy\"].model_id\n", - " model_id_to_config[actual_model_id] = model_info\n", - "\n", - "print(f\"\\n๐Ÿ“ˆ Summary Statistics:\")\n", - "total_cost = 0.0\n", - "for actual_model_id, model_info in model_id_to_config.items():\n", - " model_results_subset = [r for r in all_results if r['model_id'] == actual_model_id]\n", - " avg_score = sum(r['score'] for r in model_results_subset) / len(model_results_subset) if model_results_subset else 0\n", - " \n", - " # Calculate total cost for this model\n", - " model_total_cost = sum(r['cost_info']['total_cost'] for r in model_results_subset if 'cost_info' in r)\n", - " total_cost += model_total_cost\n", - " \n", - " # Show cost source info\n", - " cost_sources = [r['cost_info'].get('cost_source', 'unknown') for r in model_results_subset if 'cost_info' in r]\n", - " cost_source_summary = f\" (via {cost_sources[0]})\" if cost_sources else \"\"\n", - " \n", - " print(f\" {model_info['name']}: {avg_score:.2%} success rate ({sum(r['score'] for r in model_results_subset)}/{len(model_results_subset)}) - Cost: ${model_total_cost:.2f}{cost_source_summary}\")\n", - "\n", - "print(f\"\\n๐Ÿ’ฐ Total evaluation cost: ${total_cost:.2f}\")\n", - "print(f\"๐Ÿ“Š Cost calculation uses actual API usage data from LLMUsageStats\")" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿงน Closing 50 MCP sessions...\n", + "โœ… All MCP sessions closed.\n", + "โœ… Completed 50 trajectories in 373.16s\n", + "๐Ÿš€ Processing 50 trajectory evaluations in parallel...\n", + " ๐Ÿ“Š LLM Usage for airline_task_0: {'prompt_tokens': 10394, 'completion_tokens': 348, 'total_tokens': 10742}\n", + " ๐Ÿ“‹ airline_task_0: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_1: {'prompt_tokens': 42103, 'completion_tokens': 192, 'total_tokens': 42295}\n", + " ๐Ÿ“‹ airline_task_1: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_6: {'prompt_tokens': 4932, 'completion_tokens': 38, 'total_tokens': 4970}\n", + " ๐Ÿ“‹ airline_task_6: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_13: {'prompt_tokens': 38788, 'completion_tokens': 663, 'total_tokens': 39451}\n", + " ๐Ÿ“‹ airline_task_13: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_10: {'prompt_tokens': 43693, 'completion_tokens': 366, 'total_tokens': 44059}\n", + " ๐Ÿ“‹ airline_task_10: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_5: {'prompt_tokens': 33828, 'completion_tokens': 479, 'total_tokens': 34307}\n", + " ๐Ÿ“‹ airline_task_5: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_3: {'prompt_tokens': 15963, 'completion_tokens': 230, 'total_tokens': 16193}\n", + " ๐Ÿ“‹ airline_task_3: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_12: {'prompt_tokens': 37553, 'completion_tokens': 577, 'total_tokens': 38130}\n", + " ๐Ÿ“‹ airline_task_12: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_4: {'prompt_tokens': 55540, 'completion_tokens': 610, 'total_tokens': 56150}\n", + " ๐Ÿ“‹ airline_task_4: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_19: {'prompt_tokens': 16020, 'completion_tokens': 118, 'total_tokens': 16138}\n", + " ๐Ÿ“‹ airline_task_19: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_2: {'prompt_tokens': 37009, 'completion_tokens': 243, 'total_tokens': 37252}\n", + " ๐Ÿ“‹ airline_task_2: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_11: {'prompt_tokens': 4766, 'completion_tokens': 146, 'total_tokens': 4912}\n", + " ๐Ÿ“‹ airline_task_11: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_15: {'prompt_tokens': 57885, 'completion_tokens': 533, 'total_tokens': 58418}\n", + " ๐Ÿ“‹ airline_task_15: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_16: {'prompt_tokens': 30439, 'completion_tokens': 429, 'total_tokens': 30868}\n", + " ๐Ÿ“‹ airline_task_16: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_9: {'prompt_tokens': 35496, 'completion_tokens': 376, 'total_tokens': 35872}\n", + " ๐Ÿ“‹ airline_task_9: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_17: {'prompt_tokens': 49670, 'completion_tokens': 610, 'total_tokens': 50280}\n", + " ๐Ÿ“‹ airline_task_17: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_8: {'prompt_tokens': 49402, 'completion_tokens': 510, 'total_tokens': 49912}\n", + " ๐Ÿ“‹ airline_task_8: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_7: {'prompt_tokens': 51445, 'completion_tokens': 275, 'total_tokens': 51720}\n", + " ๐Ÿ“‹ airline_task_7: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_26: {'prompt_tokens': 15790, 'completion_tokens': 426, 'total_tokens': 16216}\n", + " ๐Ÿ“‹ airline_task_26: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_14: {'prompt_tokens': 59997, 'completion_tokens': 1031, 'total_tokens': 61028}\n", + " ๐Ÿ“‹ airline_task_14: 0.0, total 5 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_21: {'prompt_tokens': 96182, 'completion_tokens': 603, 'total_tokens': 96785}\n", + " ๐Ÿ“‹ airline_task_21: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_22: {'prompt_tokens': 28531, 'completion_tokens': 143, 'total_tokens': 28674}\n", + " ๐Ÿ“‹ airline_task_22: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_25: {'prompt_tokens': 10238, 'completion_tokens': 138, 'total_tokens': 10376}\n", + " ๐Ÿ“‹ airline_task_25: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_18: {'prompt_tokens': 85908, 'completion_tokens': 689, 'total_tokens': 86597}\n", + " ๐Ÿ“‹ airline_task_18: 1.0, total 6 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_20: {'prompt_tokens': 42172, 'completion_tokens': 548, 'total_tokens': 42720}\n", + " ๐Ÿ“‹ airline_task_20: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_27: {'prompt_tokens': 80408, 'completion_tokens': 488, 'total_tokens': 80896}\n", + " ๐Ÿ“‹ airline_task_27: 1.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_30: {'prompt_tokens': 30393, 'completion_tokens': 396, 'total_tokens': 30789}\n", + " ๐Ÿ“‹ airline_task_30: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_24: {'prompt_tokens': 63523, 'completion_tokens': 731, 'total_tokens': 64254}\n", + " ๐Ÿ“‹ airline_task_24: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_34: {'prompt_tokens': 23049, 'completion_tokens': 485, 'total_tokens': 23534}\n", + " ๐Ÿ“‹ airline_task_34: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_28: {'prompt_tokens': 10066, 'completion_tokens': 360, 'total_tokens': 10426}\n", + " ๐Ÿ“‹ airline_task_28: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_31: {'prompt_tokens': 16464, 'completion_tokens': 226, 'total_tokens': 16690}\n", + " ๐Ÿ“‹ airline_task_31: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_36: {'prompt_tokens': 10772, 'completion_tokens': 202, 'total_tokens': 10974}\n", + " ๐Ÿ“‹ airline_task_36: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_40: {'prompt_tokens': 10333, 'completion_tokens': 114, 'total_tokens': 10447}\n", + " ๐Ÿ“‹ airline_task_40: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_33: {'prompt_tokens': 30350, 'completion_tokens': 393, 'total_tokens': 30743}\n", + " ๐Ÿ“‹ airline_task_33: 1.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_29: {'prompt_tokens': 38672, 'completion_tokens': 501, 'total_tokens': 39173}\n", + " ๐Ÿ“‹ airline_task_29: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_41: {'prompt_tokens': 49179, 'completion_tokens': 227, 'total_tokens': 49406}\n", + " ๐Ÿ“‹ airline_task_41: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_32: {'prompt_tokens': 22825, 'completion_tokens': 297, 'total_tokens': 23122}\n", + " ๐Ÿ“‹ airline_task_32: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_42: {'prompt_tokens': 84720, 'completion_tokens': 491, 'total_tokens': 85211}\n", + " ๐Ÿ“‹ airline_task_42: 1.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_37: {'prompt_tokens': 61432, 'completion_tokens': 572, 'total_tokens': 62004}\n", + " ๐Ÿ“‹ airline_task_37: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_46: {'prompt_tokens': 5083, 'completion_tokens': 83, 'total_tokens': 5166}\n", + " ๐Ÿ“‹ airline_task_46: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_47: {'prompt_tokens': 10303, 'completion_tokens': 74, 'total_tokens': 10377}\n", + " ๐Ÿ“‹ airline_task_47: 0.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_45: {'prompt_tokens': 10672, 'completion_tokens': 107, 'total_tokens': 10779}\n", + " ๐Ÿ“‹ airline_task_45: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_38: {'prompt_tokens': 35391, 'completion_tokens': 495, 'total_tokens': 35886}\n", + " ๐Ÿ“‹ airline_task_38: 0.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_39: {'prompt_tokens': 77165, 'completion_tokens': 331, 'total_tokens': 77496}\n", + " ๐Ÿ“‹ airline_task_39: 1.0, total 4 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_48: {'prompt_tokens': 10259, 'completion_tokens': 330, 'total_tokens': 10589}\n", + " ๐Ÿ“‹ airline_task_48: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_49: {'prompt_tokens': 10091, 'completion_tokens': 257, 'total_tokens': 10348}\n", + " ๐Ÿ“‹ airline_task_49: 1.0, total 1 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_35: {'prompt_tokens': 32788, 'completion_tokens': 406, 'total_tokens': 33194}\n", + " ๐Ÿ“‹ airline_task_35: 0.0, total 3 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_43: {'prompt_tokens': 43329, 'completion_tokens': 275, 'total_tokens': 43604}\n", + " ๐Ÿ“‹ airline_task_43: 0.0, total 2 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_23: {'prompt_tokens': 50198, 'completion_tokens': 921, 'total_tokens': 51119}\n", + " ๐Ÿ“‹ airline_task_23: 0.0, total 8 assertions)\n", + " ๐Ÿ“Š LLM Usage for airline_task_44: {'prompt_tokens': 41578, 'completion_tokens': 345, 'total_tokens': 41923}\n", + " ๐Ÿ“‹ airline_task_44: 0.0, total 5 assertions)\n", + "โœ… Completed parallel evaluations in 17.52s\n", + "๐Ÿงน Closing 50 MCP sessions...\n", + "โœ… All MCP sessions closed.\n", + "๐Ÿ›‘ Stopping server on port 8000...\n", + "๐Ÿงน Cleaned up log file: /Users/derekxu/Documents/code/python-sdk/local_evals/server_output_airline_8000.log\n", + "\n", + "โœ… Completed evaluations for 2 models\n", + "๐Ÿ“Š Total results: 100\n", + "๐Ÿ“Š Total trajectories: 100\n" + ] + } + ], + "source": [ + "all_results = []\n", + "all_evaluation_records = []\n", + "\n", + "for model_id, model_info in models_to_test.items():\n", + " model_results, evaluation_records = await run_model_evaluation(model_id, model_info, tau2_eval_dataset)\n", + " all_results.extend(model_results)\n", + " all_evaluation_records.extend(evaluation_records)\n", + "\n", + "print(f\"\\nโœ… Completed evaluations for {len(models_to_test)} models\")\n", + "print(f\"๐Ÿ“Š Total results: {len(all_results)}\")\n", + "print(f\"๐Ÿ“Š Total evaluation records: {len(all_evaluation_records)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## 5. Analyze Results\n", + "\n", + "Let's analyze and visualize the comparison between Claude 4 Opus, GPT-4.1, and Kimi K2.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "๐Ÿ“„ Saved JSONL file: trajectory_outputs/all_trajectories.jsonl\n" - ] - }, - { - "data": { - "text/plain": [ - "PosixPath('trajectory_outputs/all_trajectories.jsonl')" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def save_results_jsonl(evaluation_records: List[Dict], output_file: str = \"evaluation_outputs/all_evaluations.jsonl\"):\n", - " \"\"\"Save all evaluation records in JSONL format (one JSON object per line).\"\"\"\n", - " output_path = Path(output_file)\n", - " output_path.parent.mkdir(exist_ok=True)\n", - " \n", - " with open(output_path, 'w') as f:\n", - " for record in evaluation_records:\n", - " json.dump(record, f, default=str)\n", - " f.write('\\n')\n", - " \n", - " print(f\"๐Ÿ“„ Saved JSONL file: {output_path}\")\n", - " return output_path\n", - "\n", - "save_results_jsonl(all_evaluation_records)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "๐Ÿ“ˆ Summary Statistics:\n", + " Claude 4 Sonnet: 54.00% success rate (27.0/50) - Cost: $8.79 (via litellm)\n", + " Kimi K2: 46.00% success rate (23.0/50) - Cost: $1.14 (via manual_pricing)\n", + "\n", + "๐Ÿ’ฐ Total evaluation cost: $9.93\n", + "๐Ÿ“Š Cost calculation uses actual API usage data from LLMUsageStats\n" + ] + } + ], + "source": [ + "model_id_to_config = {}\n", + "for config_key, model_info in models_to_test.items():\n", + " actual_model_id = model_info[\"policy\"].model_id\n", + " model_id_to_config[actual_model_id] = model_info\n", + "\n", + "print(\"\\n๐Ÿ“ˆ Summary Statistics:\")\n", + "total_cost = 0.0\n", + "for actual_model_id, model_info in model_id_to_config.items():\n", + " model_results_subset = [r for r in all_results if r[\"model_id\"] == actual_model_id]\n", + " avg_score = (\n", + " sum(r[\"score\"] for r in model_results_subset) / len(model_results_subset) if model_results_subset else 0\n", + " )\n", + "\n", + " # Calculate total cost for this model\n", + " model_total_cost = sum(r[\"cost_info\"][\"total_cost\"] for r in model_results_subset if \"cost_info\" in r)\n", + " total_cost += model_total_cost\n", + "\n", + " # Show cost source info\n", + " cost_sources = [r[\"cost_info\"].get(\"cost_source\", \"unknown\") for r in model_results_subset if \"cost_info\" in r]\n", + " cost_source_summary = f\" (via {cost_sources[0]})\" if cost_sources else \"\"\n", + "\n", + " print(\n", + " f\" {model_info['name']}: {avg_score:.2%} success rate ({sum(r['score'] for r in model_results_subset)}/{len(model_results_subset)}) - Cost: ${model_total_cost:.2f}{cost_source_summary}\"\n", + " )\n", + "\n", + "print(f\"\\n๐Ÿ’ฐ Total evaluation cost: ${total_cost:.2f}\")\n", + "print(\"๐Ÿ“Š Cost calculation uses actual API usage data from LLMUsageStats\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "๐Ÿ“ Saved trajectory files to: trajectory_outputs\n", - " - 100 individual trajectory files\n", - " - 1 evaluation summary file\n" - ] - }, - { - "data": { - "text/plain": [ - "PosixPath('trajectory_outputs')" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def save_evaluation_files(evaluation_records: List[Dict], output_dir: str = \"evaluation_outputs\"):\n", - " \"\"\"Save evaluation records to individual files and create summary.\"\"\"\n", - " output_path = Path(output_dir)\n", - " output_path.mkdir(exist_ok=True)\n", - " \n", - " # Save individual evaluation files\n", - " for record in evaluation_records:\n", - " # Sanitize model_id for filename (replace slashes with underscores)\n", - " safe_model_id = record['model_id'].replace('/', '_').replace('\\\\', '_')\n", - " filename = f\"{safe_model_id}_{record['scenario_id']}_evaluation.json\"\n", - " filepath = output_path / filename\n", - " \n", - " with open(filepath, 'w') as f:\n", - " json.dump(record, f, indent=2, default=str)\n", - " \n", - " # Create summary file\n", - " summary = {\n", - " \"evaluation_summary\": {\n", - " \"total_evaluations\": len(evaluation_records),\n", - " \"models_evaluated\": list(set(r['model_id'] for r in evaluation_records)),\n", - " \"scenarios_evaluated\": list(set(r['scenario_id'] for r in evaluation_records)),\n", - " \"timestamp\": datetime.now().isoformat(),\n", - " },\n", - " \"model_performance\": {},\n", - " \"scenario_difficulty\": {}\n", - " }\n", - " \n", - " # Calculate model performance\n", - " for model_id in summary[\"evaluation_summary\"][\"models_evaluated\"]:\n", - " model_records = [r for r in evaluation_records if r['model_id'] == model_id]\n", - " total_score = sum(r['evaluation']['score'] for r in model_records)\n", - " avg_score = total_score / len(model_records) if model_records else 0\n", - " \n", - " # Calculate cost metrics\n", - " total_cost = sum(r.get('cost_info', {}).get('total_cost', 0) for r in model_records)\n", - " total_tokens = sum(r.get('cost_info', {}).get('total_tokens', 0) for r in model_records)\n", - " avg_cost_per_scenario = total_cost / len(model_records) if model_records else 0\n", - " \n", - " summary[\"model_performance\"][model_id] = {\n", - " \"total_scenarios\": len(model_records),\n", - " \"total_score\": total_score,\n", - " \"average_score\": avg_score,\n", - " \"pass_rate\": avg_score, # Since scores are 0 or 1\n", - " \"total_cost\": total_cost,\n", - " \"average_cost_per_scenario\": avg_cost_per_scenario,\n", - " \"total_tokens\": total_tokens,\n", - " \"cost_per_success\": total_cost / total_score if total_score > 0 else 0\n", - " }\n", - " \n", - " # Calculate scenario difficulty\n", - " for scenario_id in summary[\"evaluation_summary\"][\"scenarios_evaluated\"]:\n", - " scenario_records = [r for r in evaluation_records if r['scenario_id'] == scenario_id]\n", - " total_score = sum(r['evaluation']['score'] for r in scenario_records)\n", - " avg_score = total_score / len(scenario_records) if scenario_records else 0\n", - " \n", - " summary[\"scenario_difficulty\"][scenario_id] = {\n", - " \"models_tested\": len(scenario_records),\n", - " \"total_score\": total_score,\n", - " \"average_score\": avg_score,\n", - " \"difficulty\": \"easy\" if avg_score > 0.8 else \"medium\" if avg_score > 0.5 else \"hard\"\n", - " }\n", - " \n", - " # Save summary\n", - " summary_path = output_path / \"evaluation_summary.json\"\n", - " with open(summary_path, 'w') as f:\n", - " json.dump(summary, f, indent=2, default=str)\n", - " \n", - " print(f\"\\n๐Ÿ“ Saved evaluation files to: {output_path}\")\n", - " print(f\" - {len(evaluation_records)} individual evaluation files\")\n", - " print(f\" - 1 evaluation summary file\")\n", - " \n", - " return output_path\n", - "\n", - "save_evaluation_files(all_evaluation_records)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿ“„ Saved JSONL file: trajectory_outputs/all_trajectories.jsonl\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "source": [ - "## 7. Share Results with Firectl\n", - "\n", - "Finally, let's create a dataset with our evaluation results to share using `firectl create dataset`.\n" + "data": { + "text/plain": [ + "PosixPath('trajectory_outputs/all_trajectories.jsonl')" ] - }, + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def save_results_jsonl(evaluation_records: List[Dict], output_file: str = \"evaluation_outputs/all_evaluations.jsonl\"):\n", + " \"\"\"Save all evaluation records in JSONL format (one JSON object per line).\"\"\"\n", + " output_path = Path(output_file)\n", + " output_path.parent.mkdir(exist_ok=True)\n", + "\n", + " with open(output_path, \"w\") as f:\n", + " for record in evaluation_records:\n", + " json.dump(record, f, default=str)\n", + " f.write(\"\\n\")\n", + "\n", + " print(f\"๐Ÿ“„ Saved JSONL file: {output_path}\")\n", + " return output_path\n", + "\n", + "\n", + "save_results_jsonl(all_evaluation_records)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# TODO" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "๐Ÿ“ Saved trajectory files to: trajectory_outputs\n", + " - 100 individual trajectory files\n", + " - 1 evaluation summary file\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "source": [ - "## Summary\n", - "\n", - "This notebook provides a complete eval harness for comparing models using tau2-bench airline evaluation with proper dataset structure:\n", - "\n", - "1. **Dataset Structure**: Following tau2-bench pattern with separate JSON datasets and markdown system prompts\n", - "2. **Models**: Configured Claude 4 Sonnet (AnthropicPolicy) and Kimi K2 (FireworksPolicy)\n", - "3. **Evaluation**: Used tau2-bench NLAssertionsEvaluator for objective scoring with EvaluationRow format\n", - "4. **Analysis**: Compared performance across multiple dimensions\n", - "5. **Sharing**: Prepared results for sharing via `firectl create dataset`\n", - "\n", - "### Key Features:\n", - "- **Clean Dataset Structure**: Separate JSON data and markdown prompts like the tau2 examples\n", - "- **Natural Language Evaluation**: Uses human-readable assertions instead of code-based metrics\n", - "- **Multi-Model Comparison**: Easy to add more models for comparison\n", - "- **Comprehensive Analysis**: Performance, accuracy, and efficiency metrics with cost tracking\n", - "- **EvaluationRow Support**: Updated to work with the new EvaluationRow format from eval_protocol\n", - "- **Reproducible**: Results can be shared and reproduced via firectl\n", - "\n", - "### Next Steps:\n", - "1. Set your API keys as environment variables:\n", - " ```bash\n", - " export ANTHROPIC_API_KEY=\"your-anthropic-key-here\"\n", - " export OPENAI_API_KEY=\"your-openai-key-here\"\n", - " export FIREWORKS_API_KEY=\"your-fireworks-key-here\"\n", - " ```\n", - "2. Start the tau2 MCP server: `cd examples/tau2_mcp && python server.py --port 8000`\n", - "3. Run the evaluation cells\n", - "4. Share results with the community using the provided firectl command\n", - "\n", - "### Expected Results:\n", - "Based on the tau2-bench framework, we expect different models to show varying performance on natural language assertion evaluation, demonstrating their ability to adhere to airline policy compliance and customer service protocols.\n", - "\n", - "This structure uses the updated EvaluationRow format and provides comprehensive cost analysis across different model providers." + "data": { + "text/plain": [ + "PosixPath('trajectory_outputs')" ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" + ], + "source": [ + "def save_evaluation_files(evaluation_records: List[Dict], output_dir: str = \"evaluation_outputs\"):\n", + " \"\"\"Save evaluation records to individual files and create summary.\"\"\"\n", + " output_path = Path(output_dir)\n", + " output_path.mkdir(exist_ok=True)\n", + "\n", + " # Save individual evaluation files\n", + " for record in evaluation_records:\n", + " # Sanitize model_id for filename (replace slashes with underscores)\n", + " safe_model_id = record[\"model_id\"].replace(\"/\", \"_\").replace(\"\\\\\", \"_\")\n", + " filename = f\"{safe_model_id}_{record['scenario_id']}_evaluation.json\"\n", + " filepath = output_path / filename\n", + "\n", + " with open(filepath, \"w\") as f:\n", + " json.dump(record, f, indent=2, default=str)\n", + "\n", + " # Create summary file\n", + " summary = {\n", + " \"evaluation_summary\": {\n", + " \"total_evaluations\": len(evaluation_records),\n", + " \"models_evaluated\": list(set(r[\"model_id\"] for r in evaluation_records)),\n", + " \"scenarios_evaluated\": list(set(r[\"scenario_id\"] for r in evaluation_records)),\n", + " \"timestamp\": datetime.now().isoformat(),\n", + " },\n", + " \"model_performance\": {},\n", + " \"scenario_difficulty\": {},\n", + " }\n", + "\n", + " # Calculate model performance\n", + " for model_id in summary[\"evaluation_summary\"][\"models_evaluated\"]:\n", + " model_records = [r for r in evaluation_records if r[\"model_id\"] == model_id]\n", + " total_score = sum(r[\"evaluation\"][\"score\"] for r in model_records)\n", + " avg_score = total_score / len(model_records) if model_records else 0\n", + "\n", + " # Calculate cost metrics\n", + " total_cost = sum(r.get(\"cost_info\", {}).get(\"total_cost\", 0) for r in model_records)\n", + " total_tokens = sum(r.get(\"cost_info\", {}).get(\"total_tokens\", 0) for r in model_records)\n", + " avg_cost_per_scenario = total_cost / len(model_records) if model_records else 0\n", + "\n", + " summary[\"model_performance\"][model_id] = {\n", + " \"total_scenarios\": len(model_records),\n", + " \"total_score\": total_score,\n", + " \"average_score\": avg_score,\n", + " \"pass_rate\": avg_score, # Since scores are 0 or 1\n", + " \"total_cost\": total_cost,\n", + " \"average_cost_per_scenario\": avg_cost_per_scenario,\n", + " \"total_tokens\": total_tokens,\n", + " \"cost_per_success\": total_cost / total_score if total_score > 0 else 0,\n", + " }\n", + "\n", + " # Calculate scenario difficulty\n", + " for scenario_id in summary[\"evaluation_summary\"][\"scenarios_evaluated\"]:\n", + " scenario_records = [r for r in evaluation_records if r[\"scenario_id\"] == scenario_id]\n", + " total_score = sum(r[\"evaluation\"][\"score\"] for r in scenario_records)\n", + " avg_score = total_score / len(scenario_records) if scenario_records else 0\n", + "\n", + " summary[\"scenario_difficulty\"][scenario_id] = {\n", + " \"models_tested\": len(scenario_records),\n", + " \"total_score\": total_score,\n", + " \"average_score\": avg_score,\n", + " \"difficulty\": \"easy\" if avg_score > 0.8 else \"medium\" if avg_score > 0.5 else \"hard\",\n", + " }\n", + "\n", + " # Save summary\n", + " summary_path = output_path / \"evaluation_summary.json\"\n", + " with open(summary_path, \"w\") as f:\n", + " json.dump(summary, f, indent=2, default=str)\n", + "\n", + " print(f\"\\n๐Ÿ“ Saved evaluation files to: {output_path}\")\n", + " print(f\" - {len(evaluation_records)} individual evaluation files\")\n", + " print(\" - 1 evaluation summary file\")\n", + "\n", + " return output_path\n", + "\n", + "\n", + "save_evaluation_files(all_evaluation_records)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "## 7. Share Results with Firectl\n", + "\n", + "Finally, let's create a dataset with our evaluation results to share using `firectl create dataset`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" } + }, + "source": [ + "## Summary\n", + "\n", + "This notebook provides a complete eval harness for comparing models using tau2-bench airline evaluation with proper dataset structure:\n", + "\n", + "1. **Dataset Structure**: Following tau2-bench pattern with separate JSON datasets and markdown system prompts\n", + "2. **Models**: Configured Claude 4 Sonnet (AnthropicPolicy) and Kimi K2 (FireworksPolicy)\n", + "3. **Evaluation**: Used tau2-bench NLAssertionsEvaluator for objective scoring with EvaluationRow format\n", + "4. **Analysis**: Compared performance across multiple dimensions\n", + "5. **Sharing**: Prepared results for sharing via `firectl create dataset`\n", + "\n", + "### Key Features:\n", + "- **Clean Dataset Structure**: Separate JSON data and markdown prompts like the tau2 examples\n", + "- **Natural Language Evaluation**: Uses human-readable assertions instead of code-based metrics\n", + "- **Multi-Model Comparison**: Easy to add more models for comparison\n", + "- **Comprehensive Analysis**: Performance, accuracy, and efficiency metrics with cost tracking\n", + "- **EvaluationRow Support**: Updated to work with the new EvaluationRow format from eval_protocol\n", + "- **Reproducible**: Results can be shared and reproduced via firectl\n", + "\n", + "### Next Steps:\n", + "1. Set your API keys as environment variables:\n", + " ```bash\n", + " export ANTHROPIC_API_KEY=\"your-anthropic-key-here\"\n", + " export OPENAI_API_KEY=\"your-openai-key-here\"\n", + " export FIREWORKS_API_KEY=\"your-fireworks-key-here\"\n", + " ```\n", + "2. Start the tau2 MCP server: `cd examples/tau2_mcp && python server.py --port 8000`\n", + "3. Run the evaluation cells\n", + "4. Share results with the community using the provided firectl command\n", + "\n", + "### Expected Results:\n", + "Based on the tau2-bench framework, we expect different models to show varying performance on natural language assertion evaluation, demonstrating their ability to adhere to airline policy compliance and customer service protocols.\n", + "\n", + "This structure uses the updated EvaluationRow format and provides comprehensive cost analysis across different model providers." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 2 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 182ca82b..00000000 --- a/mypy.ini +++ /dev/null @@ -1,37 +0,0 @@ -[mypy] -# Global options -python_version = 3.10 -follow_imports = skip -explicit_package_bases = True -warn_return_any = False -warn_unused_configs = True -disallow_untyped_defs = False -disallow_incomplete_defs = False -check_untyped_defs = True -disallow_untyped_decorators = False -no_implicit_optional = True -strict_optional = True -ignore_missing_imports = True -disable_error_code = import-not-found, truthy-function, no-redef, assignment, union-attr, attr-defined, arg-type, method-assign, misc, return-value, var-annotated, operator, call-arg, index - -[mypy.plugins.pydantic.*] -follow_imports = skip - -# Specific package options -[mypy.eval_protocol.*] -# Be less strict when type checking Eval Protocol code -disallow_untyped_defs = False -disallow_incomplete_defs = False - -# Third-party packages -[mypy.numpy.*] -ignore_missing_imports = True - -[mypy.requests.*] -ignore_missing_imports = True - -[mypy.pytest.*] -ignore_missing_imports = True - -[mypy.docker.*] -ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 73105fd5..30380dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,11 +62,8 @@ dev = [ "pytest-asyncio", "pytest-httpserver", "werkzeug>=2.0.0", - "black>=21.5b2", - "isort>=5.0.0", - "mypy>=0.812", - "flake8>=3.9.2", - "autopep8>=1.5.0", + "ruff>=0.5.0", + "pyright>=1.1.365", "transformers>=4.0.0", "types-setuptools", "types-requests", @@ -109,18 +106,10 @@ huggingface = [ "datasets>=2.0.0", "transformers>=4.0.0", ] -bigquery = [ - "google-cloud-bigquery>=3.0.0", - "google-auth>=2.0.0", - "google-auth-oauthlib>=1.0.0", -] adapters = [ "langfuse>=2.0.0", "datasets>=2.0.0", "transformers>=4.0.0", - "google-cloud-bigquery>=3.0.0", - "google-auth>=2.0.0", - "google-auth-oauthlib>=1.0.0", ] svgbench = [ "selenium>=4.0.0", @@ -157,12 +146,7 @@ versionfile_build = "eval_protocol/_version.py" tag_prefix = "v" parentdir_prefix = "eval-protocol-" -[tool.black] -line-length = 119 - -[tool.isort] -profile = "black" -line_length = 119 +### Black and isort removed; Ruff provides formatting and imports [tool.uv.sources] tau2 = { git = "https://github.com/sierra-research/tau2-bench.git" } @@ -174,3 +158,64 @@ dev = [ "haikus==0.3.8", "pytest>=8.4.1", ] + +[tool.ruff] +line-length = 119 +target-version = "py310" +exclude = ["vite-app", "vendor", "local_evals"] + +[tool.ruff.lint] +# Relax: only enforce core errors (E/F); drop W/I for now +select = ["E", "F"] +# Preserve current flake8 ignore behavior to minimize churn +ignore = ["E203", "E402", "E501", "F401"] + +# Suppress noisy rules in tests/examples/vendor where readability patterns differ +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["F841", "E712", "E731", "F821"] +"examples/**/*.py" = ["F841", "E712", "E731", "F821"] +"eval_protocol/pytest/**/*.py" = ["F841", "E712", "E731", "F811"] +"vendor/**" = ["F841", "E712", "E731", "F811"] +"development/**/*.py" = ["F841", "E712", "E731"] +"eval_protocol/**/*.py" = ["F841", "F811", "E731", "E721"] + +[tool.ruff.lint.isort] +known-first-party = ["eval_protocol"] +combine-as-imports = true + +[tool.pyright] +typeCheckingMode = "basic" +pythonVersion = "3.10" +reportMissingImports = "none" +reportMissingTypeStubs = "none" +reportMissingModuleSource = "none" +include = ["eval_protocol", "examples", "tests"] +exclude = ["vite-app", "vendor"] +# Ignore diagnostics for vendored generator code +ignore = ["versioneer.py"] +# Relax noisy diagnostics commonly triggered in tests and dynamic libs +reportAttributeAccessIssue = "none" +reportCallIssue = "none" +reportUnknownMemberType = "none" +reportUnknownVariableType = "none" +reportPossiblyUnboundVariable = "none" +# Additional suppressions per request +reportOptionalMemberAccess = "none" +reportIndexIssue = "none" +reportReturnType = "none" +reportOptionalCall = "none" +reportGeneralTypeIssues = "none" +reportOperatorIssue = "none" +reportOptionalSubscript = "none" +reportUnsupportedDunderAll = "none" +reportOptionalContextManager = "none" +reportInvalidTypeForm = "none" +reportRedeclaration = "none" +reportUndefinedVariable = "none" +reportPrivateImportUsage = "none" +reportOptionalIterable = "none" +# Make incompatibilities and argument types warnings instead of errors for now +# and suppress warnings output entirely +reportIncompatibleVariableOverride = "none" +reportArgumentType = "none" +reportAssignmentType = "none" diff --git a/scripts/create_sample_gsm8k_jsonl.py b/scripts/create_sample_gsm8k_jsonl.py index 0bc78e46..8e561cb2 100644 --- a/scripts/create_sample_gsm8k_jsonl.py +++ b/scripts/create_sample_gsm8k_jsonl.py @@ -54,7 +54,7 @@ def create_sample_jsonl(): if samples_written > 0: print(f"Successfully wrote {samples_written} samples to {output_filepath}") else: - print(f"No samples were written. Check dataset loading and content.") + print("No samples were written. Check dataset loading and content.") if __name__ == "__main__": diff --git a/tests/cli_commands/test_deploy_cmd.py b/tests/cli_commands/test_deploy_cmd.py index 65e6471c..fbd38ae8 100644 --- a/tests/cli_commands/test_deploy_cmd.py +++ b/tests/cli_commands/test_deploy_cmd.py @@ -51,7 +51,6 @@ def mock_gcp_tools(): patch("eval_protocol.cli_commands.deploy.deploy_to_cloud_run") as mock_deploy_run, patch("eval_protocol.cli_commands.deploy.ensure_gcp_secret") as mock_ensure_gcp_secret, ): - mock_ensure_repo.return_value = True mock_gen_dockerfile.return_value = "DOCKERFILE CONTENT" mock_build_push.return_value = True @@ -67,7 +66,6 @@ def mock_gcp_tools(): class TestDeployCommandRemoteUrl: - @patch("eval_protocol.cli_commands.deploy.create_evaluation") def test_deploy_remote_url_success(self, mock_create_evaluation_call, mock_check_environment, capsys): """Test successful registration of a remote URL via create_evaluation.""" @@ -167,11 +165,10 @@ def test_deploy_remote_url_unexpected_error(self, mock_create_eval, mock_check_e captured = capsys.readouterr() # Updated error message to match common registration block - assert f"An unexpected error occurred during Fireworks AI registration: Something broke" in captured.out + assert "An unexpected error occurred during Fireworks AI registration: Something broke" in captured.out class TestDeployCommandLocalMode: # This class tests the "fireworks" target (packaging metrics) - @patch("eval_protocol.cli_commands.deploy.create_evaluation") def test_deploy_local_mode_success( # Renaming to reflect it tests "fireworks" target self, mock_create_eval, mock_check_environment, capsys @@ -276,7 +273,7 @@ def test_deploy_gcp_mode_success( captured = capsys.readouterr() # Check initial message from helper assert f"Starting GCP Cloud Run deployment for evaluator '{args.id}'..." in captured.out - assert f"Successfully built and pushed Docker image" in captured.out + assert "Successfully built and pushed Docker image" in captured.out assert ( f"Successfully deployed to Cloud Run. Service URL: {mock_gcp_tools['deploy_run'].return_value}" in captured.out diff --git a/tests/cli_commands/test_preview_cmd.py b/tests/cli_commands/test_preview_cmd.py index ec4ffa9e..20cf0416 100644 --- a/tests/cli_commands/test_preview_cmd.py +++ b/tests/cli_commands/test_preview_cmd.py @@ -53,7 +53,6 @@ def create_temp_jsonl(tmp_path: Path, samples_data: list) -> str: class TestPreviewCommandRemoteUrl: - @patch("requests.post") def test_preview_remote_url_success_with_file(self, mock_post, mock_check_environment, tmp_path, capsys): mock_response = MagicMock() diff --git a/tests/conftest.py b/tests/conftest.py index 6a3526a7..9c93cbf8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import sys from pathlib import Path + import pytest # Add the project root to the Python path diff --git a/tests/eval_protocol_tests/test_eval_protocol_simple.py b/tests/eval_protocol_tests/test_eval_protocol_simple.py index a8e127a6..4068cd3c 100644 --- a/tests/eval_protocol_tests/test_eval_protocol_simple.py +++ b/tests/eval_protocol_tests/test_eval_protocol_simple.py @@ -24,7 +24,6 @@ def test_basic_imports(): print(" โœ“ eval_protocol imported successfully") print(" Importing eval_protocol...") - import eval_protocol print(" โœ“ eval_protocol imported successfully") @@ -208,7 +207,7 @@ def main(): print(f"Test {test.__name__} crashed: {e}") failed += 1 - print(f"\n=== Results ===") + print("\n=== Results ===") print(f"Passed: {passed}") print(f"Failed: {failed}") print(f"Total: {passed + failed}") diff --git a/tests/eval_protocol_tests/test_minimal_structure.py b/tests/eval_protocol_tests/test_minimal_structure.py index aff78bff..c083e28f 100644 --- a/tests/eval_protocol_tests/test_minimal_structure.py +++ b/tests/eval_protocol_tests/test_minimal_structure.py @@ -267,7 +267,7 @@ def main(): print(f"Test {test.__name__} crashed: {e}") failed += 1 - print(f"\n=== Results ===") + print("\n=== Results ===") print(f"Passed: {passed}") print(f"Failed: {failed}") print(f"Total: {passed + failed}") diff --git a/tests/execution/test_pipeline.py b/tests/execution/test_pipeline.py index 9dba6b57..96ba8ef0 100644 --- a/tests/execution/test_pipeline.py +++ b/tests/execution/test_pipeline.py @@ -7,9 +7,9 @@ from eval_protocol.execution.pipeline import EvaluationPipeline from eval_protocol.generation.cache import ResponseCache -from eval_protocol.generation.clients import GenerationResult # Import GenerationResult from eval_protocol.generation.clients import ( # For type hinting and mocking FireworksModelClient, + GenerationResult, # Import GenerationResult ) from eval_protocol.models import EvaluateResult, Message, MetricResult diff --git a/tests/mcp_agent/orchestration/test_local_docker_client.py b/tests/mcp_agent/orchestration/test_local_docker_client.py index 1ecd3be7..13edefba 100644 --- a/tests/mcp_agent/orchestration/test_local_docker_client.py +++ b/tests/mcp_agent/orchestration/test_local_docker_client.py @@ -160,7 +160,6 @@ async def test_provision_deprovision_http_instance( "eval_protocol.mcp_agent.orchestration.local_docker_client.streamablehttp_client", new_callable=MagicMock, ) as mock_streamablehttp_client_func: - # Configure the mock Async Context Manager (ACM) that mock_streamablehttp_client_func will return mock_acm_instance = AsyncMock() # This object needs __aenter__ and __aexit__ @@ -327,7 +326,7 @@ async def test_provision_deprovision_stdio_instance( assert len(read_content_list) == 1 assert read_content_list[0].get("type") == "text" assert read_content_list[0].get("text") == test_file_content - logger.info(f"Successfully called 'read_file' via stdio, content matches template.") + logger.info("Successfully called 'read_file' via stdio, content matches template.") finally: if provisioned_instances: diff --git a/tests/mcp_agent/test_rl_filesystem_scenario.py b/tests/mcp_agent/test_rl_filesystem_scenario.py index cb9f70bb..4a394793 100644 --- a/tests/mcp_agent/test_rl_filesystem_scenario.py +++ b/tests/mcp_agent/test_rl_filesystem_scenario.py @@ -127,7 +127,7 @@ async def main(): if not fs_instance_id: raise ValueError(f"Instance ID not found for filesystem instance #{i}") - logger.info(f"\n--- Testing Filesystem Instance #{i+1} (ID: {fs_instance_id}) ---") + logger.info(f"\n--- Testing Filesystem Instance #{i + 1} (ID: {fs_instance_id}) ---") # --- Verify Initial State --- logger.info(f"[{fs_instance_id}] Verifying initial state...") diff --git a/tests/pytest/data/basic_coding_dataset.jsonl b/tests/pytest/data/basic_coding_dataset.jsonl index 27573c1b..fc25abcd 100644 --- a/tests/pytest/data/basic_coding_dataset.jsonl +++ b/tests/pytest/data/basic_coding_dataset.jsonl @@ -7,4 +7,4 @@ {"prompt": "Write a Python function `multiply_by_two` that takes an integer and returns the integer multiplied by 2.", "input": "10", "expected_output": "20"} {"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "[1, 2, 3]", "expected_output": "3"} {"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "[]", "expected_output": "0"} -{"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "['a', 'b', 'c', 'd']", "expected_output": "4"} \ No newline at end of file +{"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "['a', 'b', 'c', 'd']", "expected_output": "4"} diff --git a/tests/pytest/data/lunar_lander_dataset.jsonl b/tests/pytest/data/lunar_lander_dataset.jsonl index af396fc1..a3de90c6 100644 --- a/tests/pytest/data/lunar_lander_dataset.jsonl +++ b/tests/pytest/data/lunar_lander_dataset.jsonl @@ -1,3 +1,3 @@ {"id": "multi_env_test_001", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -10.0, "enable_wind": false, "seed": 42}} {"id": "multi_env_test_002", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -8.0, "enable_wind": false, "seed": 123}} -{"id": "multi_env_test_003", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -12.0, "enable_wind": false, "seed": 456}} \ No newline at end of file +{"id": "multi_env_test_003", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -12.0, "enable_wind": false, "seed": 456}} diff --git a/tests/pytest/helper/word_count_to_evaluation_row.py b/tests/pytest/helper/word_count_to_evaluation_row.py index f0517dd0..dbb05cc4 100644 --- a/tests/pytest/helper/word_count_to_evaluation_row.py +++ b/tests/pytest/helper/word_count_to_evaluation_row.py @@ -7,8 +7,7 @@ def word_count_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationR """Convert gsm8k dataset format to EvaluationRow for word_count evaluation.""" return [ EvaluationRow( - messages=[Message(role="user", content=row["user_query"])], - ground_truth=row["ground_truth_for_eval"] + messages=[Message(role="user", content=row["user_query"])], ground_truth=row["ground_truth_for_eval"] ) for row in data - ] \ No newline at end of file + ] diff --git a/tests/pytest/test_livesvgbench.py b/tests/pytest/test_livesvgbench.py index f105f8e0..44a8c8b8 100644 --- a/tests/pytest/test_livesvgbench.py +++ b/tests/pytest/test_livesvgbench.py @@ -47,10 +47,10 @@ def svgbench_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow for i, row in enumerate(data): # Format requirements as numbered list - requirements = "\n".join([f"{i+1}. {req}" for i, req in enumerate(row["requirements"])]) + requirements = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(row["requirements"])]) # Create the generation prompt following SVGBench format - prompt = f"""{row['prompt']} Wrap the SVG code in an SVG code block following the example below. + prompt = f"""{row["prompt"]} Wrap the SVG code in an SVG code block following the example below. Example: ```svg @@ -166,7 +166,7 @@ def render_svg_to_png(svg_code: str, output_path: str) -> bool: chrome_options.add_argument("--no-sandbox") chrome_options.add_argument("--disable-dev-shm-usage") chrome_options.add_argument("--disable-gpu") - chrome_options.add_argument(f"--window-size={width+40},{height+40}") + chrome_options.add_argument(f"--window-size={width + 40},{height + 40}") # Create temporary HTML file with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: @@ -209,7 +209,7 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st Dictionary with evaluation results """ # Format requirements for evaluation (exactly as in original) - requirements_text = "\n".join([f"{i+1}. {req}" for i, req in enumerate(requirements)]) + requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)]) # Create evaluation prompt with JSON response format evaluate_prompt = f"""Examine the generated image. How many of the following {len(requirements)} requirements were fulfilled? @@ -343,7 +343,7 @@ def evaluate_with_human_preference_rubrics( For example, colored circles arranged in Google colors should score very low for intent matching and recognizability. Original Requirements (for context): -{chr(10).join([f"{i+1}. {req}" for i, req in enumerate(requirements)])} +{chr(10).join([f"{i + 1}. {req}" for i, req in enumerate(requirements)])} Respond with JSON in this exact format: {{ @@ -493,27 +493,27 @@ def test_svg_combined_evaluation(row: EvaluationRow) -> EvaluationRow: === REQUIREMENTS EVALUATION (Listwise - Row-Specific) === Score: {requirements_score:.3f} -{requirements_result.get('reasoning', 'No reasoning provided')} +{requirements_result.get("reasoning", "No reasoning provided")} === HUMAN PREFERENCE EVALUATION (Pointwise - Universal Rubrics) === Score: {human_pref_score:.3f} -๐ŸŽฏ Intent Matching: {human_pref_result.get('intent_matching_score', 0.0):.2f}/1.0 -{human_pref_result.get('intent_reasoning', 'No reasoning provided')} +๐ŸŽฏ Intent Matching: {human_pref_result.get("intent_matching_score", 0.0):.2f}/1.0 +{human_pref_result.get("intent_reasoning", "No reasoning provided")} -๐Ÿ‘๏ธ Content Recognizability: {human_pref_result.get('content_recognizability_score', 0.0):.2f}/1.0 -{human_pref_result.get('content_reasoning', 'No reasoning provided')} +๐Ÿ‘๏ธ Content Recognizability: {human_pref_result.get("content_recognizability_score", 0.0):.2f}/1.0 +{human_pref_result.get("content_reasoning", "No reasoning provided")} -๐Ÿ“ Spatial Design Quality: {human_pref_result.get('spatial_design_score', 0.0):.2f}/1.0 -{human_pref_result.get('spatial_reasoning', 'No reasoning provided')} +๐Ÿ“ Spatial Design Quality: {human_pref_result.get("spatial_design_score", 0.0):.2f}/1.0 +{human_pref_result.get("spatial_reasoning", "No reasoning provided")} -๐Ÿ‘ค User Experience: {human_pref_result.get('user_experience_score', 0.0):.2f}/1.0 -{human_pref_result.get('ux_reasoning', 'No reasoning provided')} +๐Ÿ‘ค User Experience: {human_pref_result.get("user_experience_score", 0.0):.2f}/1.0 +{human_pref_result.get("ux_reasoning", "No reasoning provided")} -๐ŸŽจ Visual Coherence: {human_pref_result.get('visual_coherence_score', 0.0):.2f}/1.0 -{human_pref_result.get('coherence_reasoning', 'No reasoning provided')} +๐ŸŽจ Visual Coherence: {human_pref_result.get("visual_coherence_score", 0.0):.2f}/1.0 +{human_pref_result.get("coherence_reasoning", "No reasoning provided")} -{human_pref_result.get('overall_reasoning', 'No overall reasoning provided')} +{human_pref_result.get("overall_reasoning", "No overall reasoning provided")} === FINAL COMBINED SCORE === Requirements: {requirements_score:.3f} ร— 30% = {requirements_score * 0.3:.3f} diff --git a/tests/pytest/test_svgbench.py b/tests/pytest/test_svgbench.py index 7ff08642..90d2f8f0 100644 --- a/tests/pytest/test_svgbench.py +++ b/tests/pytest/test_svgbench.py @@ -46,10 +46,10 @@ def svgbench_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow for i, row in enumerate(data): # Format requirements as numbered list - requirements = "\n".join([f"{i+1}. {req}" for i, req in enumerate(row["requirements"])]) + requirements = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(row["requirements"])]) # Create the generation prompt following SVGBench format - prompt = f"""{row['prompt']} Wrap the SVG code in an SVG code block following the example below. + prompt = f"""{row["prompt"]} Wrap the SVG code in an SVG code block following the example below. Example: ```svg @@ -165,7 +165,7 @@ def render_svg_to_png(svg_code: str, output_path: str) -> bool: chrome_options.add_argument("--no-sandbox") chrome_options.add_argument("--disable-dev-shm-usage") chrome_options.add_argument("--disable-gpu") - chrome_options.add_argument(f"--window-size={width+40},{height+40}") + chrome_options.add_argument(f"--window-size={width + 40},{height + 40}") # Create temporary HTML file with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: @@ -208,7 +208,7 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st Dictionary with evaluation results """ # Format requirements for evaluation (exactly as in original) - requirements_text = "\n".join([f"{i+1}. {req}" for i, req in enumerate(requirements)]) + requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)]) # Create evaluation prompt with JSON response format evaluate_prompt = f"""Examine the generated image. How many of the following {len(requirements)} requirements were fulfilled? diff --git a/tests/test_adapters_e2e.py b/tests/test_adapters_e2e.py index 72449e8b..2c105315 100644 --- a/tests/test_adapters_e2e.py +++ b/tests/test_adapters_e2e.py @@ -554,7 +554,9 @@ def google_books_transform(row: Dict[str, Any]) -> Dict[str, Any]: frequency_desc = ( "high frequency" if term_frequency > 1000 - else "moderate frequency" if term_frequency > 100 else "low frequency" + else "moderate frequency" + if term_frequency > 100 + else "low frequency" ) document_desc = ( f"appears in {document_frequency} documents" if document_frequency > 0 else "rare occurrence" diff --git a/tests/test_agent_resources.py b/tests/test_agent_resources.py index 303f9c6b..56c23601 100644 --- a/tests/test_agent_resources.py +++ b/tests/test_agent_resources.py @@ -374,16 +374,16 @@ async def test_checkpoint_and_restore_docker(self, docker_resource: DockerResour await docker_resource.setup(config) create_file_command = "sh -c \"echo 'initial_data' > /data.txt\"" create_file_result = await docker_resource.step("exec_command", {"command": create_file_command}) - assert ( - create_file_result["exit_code"] == 0 - ), f"Failed to create /data.txt with '{create_file_command}': {create_file_result['output']}" + assert create_file_result["exit_code"] == 0, ( + f"Failed to create /data.txt with '{create_file_command}': {create_file_result['output']}" + ) # Optionally, verify file content immediately after creation in the source container verify_result = await docker_resource.step("exec_command", {"command": "cat /data.txt"}) assert verify_result["exit_code"] == 0, f"Failed to cat /data.txt after creation: {verify_result['output']}" - assert ( - "initial_data" in verify_result["output"] - ), f"/data.txt content mismatch after creation: {verify_result['output']}" + assert "initial_data" in verify_result["output"], ( + f"/data.txt content mismatch after creation: {verify_result['output']}" + ) checkpoint_info = await docker_resource.checkpoint() checkpoint_image_id = checkpoint_info["image_id"] diff --git a/tests/test_batch_evaluation.py b/tests/test_batch_evaluation.py index 9308f77f..772b8290 100644 --- a/tests/test_batch_evaluation.py +++ b/tests/test_batch_evaluation.py @@ -399,7 +399,6 @@ def smart_move_generator(**kwargs): patch.object(task_manager, "_start_resource_server", return_value=12345), patch.object(task_manager, "_wait_for_server_health", return_value=True), ): - # Execute the task with batch evaluation results = await task_manager.execute_tasks( task_ids=[task_id], @@ -413,9 +412,9 @@ def smart_move_generator(**kwargs): result = results[task_id] # Should not be an error result - assert not ( - isinstance(result, dict) and "error" in result - ), f"Task failed: {result.get('error', 'Unknown error')}" + assert not (isinstance(result, dict) and "error" in result), ( + f"Task failed: {result.get('error', 'Unknown error')}" + ) # Should be aggregated results assert isinstance(result, dict) @@ -563,7 +562,6 @@ async def test_batch_evaluation_task_manager_openai( patch.object(task_manager, "_start_resource_server", return_value=12346), patch.object(task_manager, "_wait_for_server_health", return_value=True), ): - # Execute the task with batch evaluation results = await task_manager.execute_tasks( task_ids=[task_id], @@ -577,9 +575,9 @@ async def test_batch_evaluation_task_manager_openai( result = results[task_id] # Should not be an error result - assert not ( - isinstance(result, dict) and "error" in result - ), f"Task failed: {result.get('error', 'Unknown error')}" + assert not (isinstance(result, dict) and "error" in result), ( + f"Task failed: {result.get('error', 'Unknown error')}" + ) # Should be aggregated results assert isinstance(result, dict) @@ -964,7 +962,6 @@ def smart_move_generator(**kwargs): patch.object(task_manager, "_start_resource_server", return_value=12347), patch.object(task_manager, "_wait_for_server_health", return_value=True), ): - # Execute with parallel enabled results = await task_manager.execute_tasks( task_ids=[task_id], @@ -1108,7 +1105,6 @@ def smart_move_generator(**kwargs): patch.object(task_manager, "_start_resource_server", return_value=12348), patch.object(task_manager, "_wait_for_server_health", return_value=True), ): - # Execute task results = await task_manager.execute_tasks(task_ids=[task_id], num_rollouts_override=2) diff --git a/tests/test_cli_agent.py b/tests/test_cli_agent.py index cc50376f..00763c3c 100644 --- a/tests/test_cli_agent.py +++ b/tests/test_cli_agent.py @@ -41,7 +41,7 @@ class TestAgentEvalCommand: def test_agent_eval_success_yaml(self, MockPath, MockTaskManager, caplog): # Configure caplog to capture logs from the agent_eval logger caplog.set_level(logging.INFO, logger="agent_eval") - + # Setup Path mock mock_path_instance = Mock() MockPath.return_value = mock_path_instance @@ -207,9 +207,7 @@ def test_agent_eval_orchestrator_execution_fails(self, MockPath, MockTaskManager mock_task_manager.register_task.return_value = "task1" # Make execute_tasks raise an exception - mock_task_manager.execute_tasks = AsyncMock( - side_effect=RuntimeError("Execution failed") - ) # type: ignore[attr-defined] + mock_task_manager.execute_tasks = AsyncMock(side_effect=RuntimeError("Execution failed")) # type: ignore[attr-defined] mock_task_manager.cleanup = AsyncMock() args = argparse.Namespace(task_def="dummy_task.yaml") diff --git a/tests/test_cli_args.py b/tests/test_cli_args.py index 2ecdbadb..21817879 100644 --- a/tests/test_cli_args.py +++ b/tests/test_cli_args.py @@ -7,7 +7,6 @@ class TestCliArgParsing: - # --- Tests for 'preview' command --- def test_preview_with_remote_url_and_samples(self): args_list = [ diff --git a/tests/test_code_execution.py b/tests/test_code_execution.py index 2290087d..714813b1 100644 --- a/tests/test_code_execution.py +++ b/tests/test_code_execution.py @@ -7,7 +7,6 @@ import pytest from eval_protocol.models import EvaluateResult, Message # Added for new tests -from eval_protocol.rewards.code_execution import fractional_code_reward # Added for new tests from eval_protocol.rewards.code_execution import ( _HAS_E2B, compare_outputs, @@ -16,6 +15,7 @@ execute_javascript_code, execute_python_code, extract_code_blocks, + fractional_code_reward, # Added for new tests local_code_execution_reward, string_similarity, ) @@ -464,9 +464,9 @@ def test_python_function_arg_parsing(self, test_input_str, expected_args_list, e and len(actual_test_run_details_list) > 0 ): actual_output_str = actual_test_run_details_list[0].get("actual_output") - assert actual_output_str == repr( - expected_return_val - ), f"Actual output '{actual_output_str}' did not match expected '{repr(expected_return_val)}' for input '{test_input_str}'" + assert actual_output_str == repr(expected_return_val), ( + f"Actual output '{actual_output_str}' did not match expected '{repr(expected_return_val)}' for input '{test_input_str}'" + ) except json.JSONDecodeError: # Catch specifically json.JSONDecodeError # Accessing reason from MetricResult object print( diff --git a/tests/test_data_driven_task_manager.py b/tests/test_data_driven_task_manager.py index f77051bc..a0697351 100644 --- a/tests/test_data_driven_task_manager.py +++ b/tests/test_data_driven_task_manager.py @@ -168,7 +168,6 @@ async def test_execute_data_driven_rollouts_basic(self): patch.object(self.task_manager, "_stop_resource_server"), patch("eval_protocol.agent.task_manager.Orchestrator") as mock_orchestrator_class, ): - # Set up mock orchestrator mock_orchestrator = AsyncMock() mock_orchestrator.setup_base_resource = AsyncMock() @@ -206,7 +205,6 @@ async def test_execute_multiple_rollouts_per_sample(self): patch.object(self.task_manager, "_stop_resource_server"), patch("eval_protocol.agent.task_manager.Orchestrator") as mock_orchestrator_class, ): - # Set up mock orchestrator to return different scores for each rollout mock_orchestrator = AsyncMock() mock_orchestrator.setup_base_resource = AsyncMock() @@ -250,7 +248,6 @@ async def test_execute_data_driven_with_failures(self): patch.object(self.task_manager, "_stop_resource_server"), patch("eval_protocol.agent.task_manager.Orchestrator") as mock_orchestrator_class, ): - # Set up mock orchestrator with one success and one failure mock_orchestrator = AsyncMock() mock_orchestrator.setup_base_resource = AsyncMock() @@ -307,7 +304,6 @@ async def mock_execute(*args, **kwargs): patch.object(self.task_manager, "_stop_resource_server"), patch("eval_protocol.agent.task_manager.Orchestrator") as mock_orchestrator_class, ): - mock_orchestrator = AsyncMock() mock_orchestrator.setup_base_resource = AsyncMock() mock_orchestrator.execute_task_poc = AsyncMock(side_effect=mock_execute) @@ -355,7 +351,6 @@ async def test_execute_tasks_data_driven_vs_traditional(self): patch.object(self.task_manager, "_execute_data_driven_rollouts") as mock_data_driven, patch.object(self.task_manager, "_execute_batch_rollouts") as mock_traditional, ): - mock_data_driven.return_value = [{"score": 1.0}] mock_traditional.return_value = [{"score": 0.5}] diff --git a/tests/test_deepeval_integration.py b/tests/test_deepeval_integration.py index 7115e740..16a0d86c 100644 --- a/tests/test_deepeval_integration.py +++ b/tests/test_deepeval_integration.py @@ -108,7 +108,11 @@ def is_successful(self) -> bool: class DummyGEval(BaseMetric): # type: ignore - evaluation_params = [LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT] if DEEPEVAL_AVAILABLE and hasattr(LLMTestCaseParams, "INPUT") else [] # type: ignore + evaluation_params = ( + [LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT] + if DEEPEVAL_AVAILABLE and hasattr(LLMTestCaseParams, "INPUT") + else [] + ) # type: ignore def __init__(self, threshold: float = 0.0) -> None: self.threshold = threshold @@ -199,7 +203,9 @@ def test_fireworks_geval_integration_with_mock(self) -> None: if parsed_fireworks_model_name not in valid_gpt_models: valid_gpt_models.append(parsed_fireworks_model_name) # type: ignore - actual_fireworks_model_for_geval = GPTModel(model=fireworks_model_name_for_api, _openai_api_key=fireworks_api_key) # type: ignore + actual_fireworks_model_for_geval = GPTModel( + model=fireworks_model_name_for_api, _openai_api_key=fireworks_api_key + ) # type: ignore actual_fireworks_model_for_geval.model_name = fireworks_model_name_for_api # type: ignore if fireworks_model_name_for_api not in model_pricing: @@ -340,7 +346,13 @@ async def mock_chat_completions_create(*args, messages: List[dict], model: str, "openai.resources.chat.completions.AsyncCompletions.create", new=mock_chat_completions_create, ): - geval_metric = GEval(name="Fireworks GEval Mocked", criteria="Evaluate the helpfulness and relevance of the actual output based on the input.", evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT], model=actual_fireworks_model_for_geval, strict_mode=False) # type: ignore + geval_metric = GEval( + name="Fireworks GEval Mocked", + criteria="Evaluate the helpfulness and relevance of the actual output based on the input.", + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT], + model=actual_fireworks_model_for_geval, + strict_mode=False, + ) # type: ignore wrapped_metric = adapt_metric(geval_metric) messages_data = [ {"role": "user", "content": "What is the capital of France?"}, @@ -353,7 +365,11 @@ async def mock_chat_completions_create(*args, messages: List[dict], model: str, self.assertIsNotNone(result.score, "GEval score should not be None") self.assertEqual(result.score, 1.0, f"GEval score {result.score} was not 1.0 with mock.") expected_metric_key = f"{geval_metric.name} ({geval_metric.__class__.__name__})" # type: ignore - self.assertIn(expected_metric_key, result.metrics, f"Constructed metric key '{expected_metric_key}' not found. Keys: {list(result.metrics.keys())}") # type: ignore + self.assertIn( + expected_metric_key, + result.metrics, + f"Constructed metric key '{expected_metric_key}' not found. Keys: {list(result.metrics.keys())}", + ) # type: ignore self.assertIsNotNone(result.metrics[expected_metric_key].reason) # type: ignore diff --git a/tests/test_deploy_integration.py b/tests/test_deploy_integration.py index 8a8b6d26..6b7db1b0 100644 --- a/tests/test_deploy_integration.py +++ b/tests/test_deploy_integration.py @@ -174,7 +174,6 @@ def test_deploy_gcp_with_inline_requirements( return_value={"name": evaluator_id}, ) as mock_create_eval, ): - # Configure mock_get_config to return a basic config mock_config_instance = RewardKitConfig( gcp_cloud_run=GCPCloudRunConfig( diff --git a/tests/test_e2b_integration.py b/tests/test_e2b_integration.py index eafaeff8..eba8b585 100755 --- a/tests/test_e2b_integration.py +++ b/tests/test_e2b_integration.py @@ -62,7 +62,7 @@ def add(a, b): pytest.skip(f"Skipping due to E2B connection issue: {error_msg}") # Also check for error in main result reason if result.reason and "Invalid API key" in result.reason: - pytest.skip(f"Skipping due to invalid E2B API key") + pytest.skip("Skipping due to invalid E2B API key") # Otherwise, it should be a successful result assert result.score == 1.0 diff --git a/tests/test_e2b_js_integration.py b/tests/test_e2b_js_integration.py index 566c9005..5c266645 100755 --- a/tests/test_e2b_js_integration.py +++ b/tests/test_e2b_js_integration.py @@ -68,7 +68,7 @@ def test_e2b_javascript_integration(): pytest.skip(f"Skipping due to E2B connection issue: {error_msg}") # Also check for error in main result reason if result.reason and "Invalid API key" in result.reason: - pytest.skip(f"Skipping due to invalid E2B API key") + pytest.skip("Skipping due to invalid E2B API key") # Otherwise, it should be a successful result assert result.score == 1.0 diff --git a/tests/test_eval_protocol_import.py b/tests/test_eval_protocol_import.py index c0e17a99..c16b3927 100644 --- a/tests/test_eval_protocol_import.py +++ b/tests/test_eval_protocol_import.py @@ -36,14 +36,16 @@ def test_all_exports_consistency(self): def test_core_classes_available(self): """Test that core classes are available through both imports.""" - from eval_protocol import EvaluateResult - from eval_protocol import EvaluateResult as RPEvaluateResult - from eval_protocol import Message - from eval_protocol import Message as RPMessage - from eval_protocol import MetricResult - from eval_protocol import MetricResult as RPMetricResult - from eval_protocol import RewardFunction - from eval_protocol import RewardFunction as RPRewardFunction + from eval_protocol import ( + EvaluateResult, + EvaluateResult as RPEvaluateResult, + Message, + Message as RPMessage, + MetricResult, + MetricResult as RPMetricResult, + RewardFunction, + RewardFunction as RPRewardFunction, + ) # Classes should be the same assert RewardFunction is RPRewardFunction @@ -53,16 +55,18 @@ def test_core_classes_available(self): def test_functions_available(self): """Test that core functions are available through both imports.""" - from eval_protocol import load_jsonl - from eval_protocol import load_jsonl as rp_load_jsonl - from eval_protocol import make - from eval_protocol import make as rp_make - from eval_protocol import reward_function - from eval_protocol import reward_function as rp_reward_function - from eval_protocol import rollout - from eval_protocol import rollout as rp_rollout - from eval_protocol import test_mcp - from eval_protocol import test_mcp as rp_test_mcp + from eval_protocol import ( + load_jsonl, + load_jsonl as rp_load_jsonl, + make, + make as rp_make, + reward_function, + reward_function as rp_reward_function, + rollout, + rollout as rp_rollout, + test_mcp, + test_mcp as rp_test_mcp, + ) # Functions should be the same assert reward_function is rp_reward_function @@ -110,9 +114,11 @@ def test_star_import_works(self): def test_reward_function_decorator_works(self): """Test that the @reward_function decorator works through both imports.""" - from eval_protocol import EvaluateResult - from eval_protocol import reward_function as rk_reward_function - from eval_protocol import reward_function as rp_reward_function + from eval_protocol import ( + EvaluateResult, + reward_function as rk_reward_function, + reward_function as rp_reward_function, + ) # Create a simple reward function using eval_protocol @rk_reward_function @@ -147,8 +153,7 @@ def test_reward_rp(response: str, **kwargs) -> EvaluateResult: def test_message_class_works(self): """Test that Message class works through both imports.""" - from eval_protocol import Message as RKMessage - from eval_protocol import Message as RPMessage + from eval_protocol import Message as RKMessage, Message as RPMessage # They should be the same class assert RKMessage is RPMessage @@ -196,29 +201,28 @@ def test_deep_import_consistency(self): """Test that deep imports work consistently.""" try: # Test importing from submodules - from eval_protocol.models import Message as RKMessage - from eval_protocol.models import Message as RPMessage + from eval_protocol.models import Message as RKMessage, Message as RPMessage # Should be the same class assert RKMessage is RPMessage except ImportError: # If submodule imports don't work, that's expected in some install scenarios # Just verify the star import works - from eval_protocol import Message as RKMessage - from eval_protocol import Message as RPMessage + from eval_protocol import Message as RKMessage, Message as RPMessage assert RKMessage is RPMessage try: # Test another submodule - use a function that actually exists - from eval_protocol.auth import get_fireworks_account_id - from eval_protocol.auth import get_fireworks_account_id as rp_get_fireworks_account_id + from eval_protocol.auth import ( + get_fireworks_account_id, + get_fireworks_account_id as rp_get_fireworks_account_id, + ) assert get_fireworks_account_id is rp_get_fireworks_account_id except ImportError: # If submodule imports don't work, verify through star import - from eval_protocol import auth as rk_auth - from eval_protocol import auth as rp_auth + from eval_protocol import auth as rk_auth, auth as rp_auth assert rk_auth is rp_auth diff --git a/tests/test_examples_end_to_end.py b/tests/test_examples_end_to_end.py index d4edac17..fdb35496 100644 --- a/tests/test_examples_end_to_end.py +++ b/tests/test_examples_end_to_end.py @@ -47,7 +47,6 @@ def mock_requests(): patch("requests.get") as mock_get, patch("requests.delete") as mock_delete, ): - # Configure mock_post for different use cases def post_side_effect(*args, **kwargs): url = args[0] diff --git a/tests/test_function_calling.py b/tests/test_function_calling.py index 2c4d8cb3..343a755a 100644 --- a/tests/test_function_calling.py +++ b/tests/test_function_calling.py @@ -4,8 +4,10 @@ import pytest -from eval_protocol.models import EvaluateResult # Changed -from eval_protocol.models import Message # Added import +from eval_protocol.models import ( + EvaluateResult, # Changed + Message, # Added import +) from eval_protocol.rewards.function_calling import ( calculate_jaccard_similarity, composite_function_call_reward, @@ -93,16 +95,14 @@ def test_wrong_function_name(self): assert result.metrics["function_name_match"].score == 0.0 assert ( result.metrics["function_name_match"].reason is not None - and "Function name does not match" - in result.metrics["function_name_match"].reason # type: ignore[operator] + and "Function name does not match" in result.metrics["function_name_match"].reason # type: ignore[operator] ) # Dictionary access assert result["score"] < 1.0 assert result["metrics"]["function_name_match"]["score"] == 0.0 assert ( result["metrics"]["function_name_match"]["reason"] is not None - and "Function name does not match" - in result["metrics"]["function_name_match"]["reason"] # type: ignore[operator] + and "Function name does not match" in result["metrics"]["function_name_match"]["reason"] # type: ignore[operator] ) def test_missing_required_argument(self): diff --git a/tests/test_gcp_tools.py b/tests/test_gcp_tools.py index 9b8804fa..d1f490eb 100644 --- a/tests/test_gcp_tools.py +++ b/tests/test_gcp_tools.py @@ -91,7 +91,6 @@ def test_run_gcloud_command_success_with_stderr(self, mock_subprocess_run): ) # os.path.exists mock removed as not directly used by SUT for this path @patch("builtins.open", new_callable=mock_open) def test_build_and_push_docker_image_success(self, mock_open_file, mock_os_remove, mock_run_gcloud, MockGCPPath): - mock_path_instance = MockGCPPath.return_value mock_dockerfile_path_obj = MagicMock(spec=Path) @@ -276,7 +275,6 @@ def test_build_and_push_docker_image_success_dockerfile_vanishes( @patch("eval_protocol.gcp_tools._run_gcloud_command") def test_deploy_to_cloud_run_success(self, mock_run_gcloud): - # Mock for deploy command mock_run_gcloud.side_effect = [ (True, "Deploy success", ""), # For initial deploy diff --git a/tests/test_generic_server.py b/tests/test_generic_server.py index ee21906a..94beb0aa 100644 --- a/tests/test_generic_server.py +++ b/tests/test_generic_server.py @@ -102,8 +102,7 @@ def test_load_failure_resets_globals(self): # --- Tests for FastAPI app endpoints --- from fastapi.testclient import TestClient -from eval_protocol.generic_server import EvaluationRequest -from eval_protocol.generic_server import app as generic_fastapi_app +from eval_protocol.generic_server import EvaluationRequest, app as generic_fastapi_app from eval_protocol.models import EvaluateResult, Message, MetricResult diff --git a/tests/test_math.py b/tests/test_math.py index 8b53a4c4..5aad9135 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -8,8 +8,11 @@ import pytest # Removed: from eval_protocol.rewards.advanced_math import advanced_math_reward -from eval_protocol.models import Message # Added Message import -from eval_protocol.models import EvaluateResult, MetricResult +from eval_protocol.models import ( + EvaluateResult, + Message, # Added Message import + MetricResult, +) from eval_protocol.rewards.math import compare_numbers, extract_numbers, math_reward diff --git a/tests/test_models_rl.py b/tests/test_models_rl.py index ea753f77..f1eed5d4 100644 --- a/tests/test_models_rl.py +++ b/tests/test_models_rl.py @@ -6,16 +6,13 @@ from eval_protocol.agent.models import StepData # Assuming these are the correct import paths based on our plan -from eval_protocol.models import EvaluateResult -from eval_protocol.models import Message as RewardKitMessage -from eval_protocol.models import StepOutput +from eval_protocol.models import EvaluateResult, Message as RewardKitMessage, StepOutput # Minimal Message for StepData if direct import from eval_protocol.models is problematic in tests # For now, assume RewardKitMessage from eval_protocol.models works. class TestRLDataStructures: - def test_step_output_creation_valid(self): """Test valid creation of StepOutput.""" so = StepOutput(step_index=0, base_reward=0.5, reason="Good step", metrics={"accuracy": 0.9}) diff --git a/tests/test_packaging.py b/tests/test_packaging.py index a019edcd..ee44eb7c 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -19,7 +19,6 @@ class TestPackaging(unittest.TestCase): - @classmethod def setUpClass(cls): # Create a unique dummy reward function file for testing to avoid race conditions diff --git a/tests/test_parallel_rollouts.py b/tests/test_parallel_rollouts.py index ef5c83a6..8da83d27 100644 --- a/tests/test_parallel_rollouts.py +++ b/tests/test_parallel_rollouts.py @@ -105,7 +105,7 @@ async def _test_seed_handling_and_type_compatibility_impl(): stdout, stderr = server_process.communicate() # This is a CI environment issue, not a code issue - run a simplified test instead - print(f"โš ๏ธ Server startup failed in CI environment, running simplified test...") + print("โš ๏ธ Server startup failed in CI environment, running simplified test...") print(f"Server stdout: {stdout.decode()[:200]}") print(f"Server stderr: {stderr.decode()[:200]}") @@ -376,4 +376,4 @@ async def test_mcp_resource_type_compatibility(): print(f" - Seed 42 map: {map1}") print(f" - Seed 123 map: {map2}") print(f" - Seed 999 map: {map3}") - print(f" - JSON serialization: โœ…") + print(" - JSON serialization: โœ…") diff --git a/tests/test_platform_api.py b/tests/test_platform_api.py index 84f1dc0d..13451779 100644 --- a/tests/test_platform_api.py +++ b/tests/test_platform_api.py @@ -15,7 +15,6 @@ class TestPlatformAPI(unittest.TestCase): - def setUp(self): # Patch auth functions for isolation self.mock_api_key = "test_api_key" diff --git a/tests/test_readiness.py b/tests/test_readiness.py index 5fef142a..8712aeb0 100644 --- a/tests/test_readiness.py +++ b/tests/test_readiness.py @@ -43,7 +43,6 @@ def mock_requests_post(): # --- End-to-End Script Tests for Math Example --- class TestMathExampleEndToEndScripts: - BASE_MATH_EXAMPLE_PATH = os.path.join(os.path.dirname(__file__), "../examples/math_example") def run_script( @@ -201,7 +200,6 @@ def run_script( # --- End-to-End Script Tests for Math Example (OpenR1) --- class TestMathExampleOpenR1EndToEndScripts: - BASE_MATH_EXAMPLE_OPENR1_PATH = os.path.join(os.path.dirname(__file__), "../examples/math_example_openr1") def run_script( diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index f00be4fc..8b55869f 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -109,7 +109,7 @@ def test_retry_mechanism_mock_verification(): # Get our mock tracker mock_tracker = shared_processor.mock_tracker - print(f"\n๐Ÿ”„ MOCK CALL ANALYSIS:") + print("\n๐Ÿ”„ MOCK CALL ANALYSIS:") print(f" Batch calls made: {mock_tracker.batch_call.call_count}") print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") @@ -125,7 +125,7 @@ def test_retry_mechanism_mock_verification(): call_counts = Counter(rollout_ids) print(f" Call counts per rollout_id: {dict(call_counts)}") - print(f" Individual calls:") + print(" Individual calls:") for i, call_arg in enumerate(call_args, 1): rollout_id = call_arg[0][0] attempt_num = rollout_ids[:i].count(rollout_id) @@ -133,9 +133,9 @@ def test_retry_mechanism_mock_verification(): # ASSERTIONS USING MOCK DATA # Should have exactly 6 total row processing calls (5 initial + 1 retry) - assert ( - mock_tracker.process_row_call.call_count == 6 - ), f"Expected 6 total calls, got {mock_tracker.process_row_call.call_count}" + assert mock_tracker.process_row_call.call_count == 6, ( + f"Expected 6 total calls, got {mock_tracker.process_row_call.call_count}" + ) # Should have exactly 2 batch calls (initial batch + retry batch) assert mock_tracker.batch_call.call_count == 2, f"Expected 2 batch calls, got {mock_tracker.batch_call.call_count}" @@ -147,11 +147,11 @@ def test_retry_mechanism_mock_verification(): # Exactly one rollout_id should be called twice, others called once call_count_values = list(call_counts.values()) - assert ( - call_count_values.count(2) == 1 - ), f"Expected exactly 1 rollout_id to be called twice, got counts: {dict(call_counts)}" - assert ( - call_count_values.count(1) == 4 - ), f"Expected exactly 4 rollout_ids to be called once, got counts: {dict(call_counts)}" + assert call_count_values.count(2) == 1, ( + f"Expected exactly 1 rollout_id to be called twice, got counts: {dict(call_counts)}" + ) + assert call_count_values.count(1) == 4, ( + f"Expected exactly 4 rollout_ids to be called once, got counts: {dict(call_counts)}" + ) print("โœ… All mock-based assertions passed! Retry mechanism is working correctly.") diff --git a/tests/test_reward_protocol_import.py b/tests/test_reward_protocol_import.py index 7466ed93..d643c483 100644 --- a/tests/test_reward_protocol_import.py +++ b/tests/test_reward_protocol_import.py @@ -36,14 +36,16 @@ def test_all_exports_consistency(self): def test_core_classes_available(self): """Test that core classes are available through both imports.""" - from eval_protocol import EvaluateResult - from eval_protocol import EvaluateResult as RPEvaluateResult - from eval_protocol import Message - from eval_protocol import Message as RPMessage - from eval_protocol import MetricResult - from eval_protocol import MetricResult as RPMetricResult - from eval_protocol import RewardFunction - from eval_protocol import RewardFunction as RPRewardFunction + from eval_protocol import ( + EvaluateResult, + EvaluateResult as RPEvaluateResult, + Message, + Message as RPMessage, + MetricResult, + MetricResult as RPMetricResult, + RewardFunction, + RewardFunction as RPRewardFunction, + ) # Classes should be the same assert RewardFunction is RPRewardFunction @@ -53,16 +55,18 @@ def test_core_classes_available(self): def test_functions_available(self): """Test that core functions are available through both imports.""" - from eval_protocol import load_jsonl - from eval_protocol import load_jsonl as rp_load_jsonl - from eval_protocol import make - from eval_protocol import make as rp_make - from eval_protocol import reward_function - from eval_protocol import reward_function as rp_reward_function - from eval_protocol import rollout - from eval_protocol import rollout as rp_rollout - from eval_protocol import test_mcp - from eval_protocol import test_mcp as rp_test_mcp + from eval_protocol import ( + load_jsonl, + load_jsonl as rp_load_jsonl, + make, + make as rp_make, + reward_function, + reward_function as rp_reward_function, + rollout, + rollout as rp_rollout, + test_mcp, + test_mcp as rp_test_mcp, + ) # Functions should be the same assert reward_function is rp_reward_function @@ -110,9 +114,11 @@ def test_star_import_works(self): def test_reward_function_decorator_works(self): """Test that the @reward_function decorator works through both imports.""" - from eval_protocol import EvaluateResult - from eval_protocol import reward_function as rk_reward_function - from eval_protocol import reward_function as rp_reward_function + from eval_protocol import ( + EvaluateResult, + reward_function as rk_reward_function, + reward_function as rp_reward_function, + ) # Create a simple reward function using eval_protocol @rk_reward_function @@ -147,8 +153,7 @@ def test_reward_rp(response: str, **kwargs) -> EvaluateResult: def test_message_class_works(self): """Test that Message class works through both imports.""" - from eval_protocol import Message as RKMessage - from eval_protocol import Message as RPMessage + from eval_protocol import Message as RKMessage, Message as RPMessage # They should be the same class assert RKMessage is RPMessage @@ -195,29 +200,28 @@ def test_deep_import_consistency(self): """Test that deep imports work consistently.""" try: # Test importing from submodules - from eval_protocol.models import Message as RKMessage - from eval_protocol.models import Message as RPMessage + from eval_protocol.models import Message as RKMessage, Message as RPMessage # Should be the same class assert RKMessage is RPMessage except ImportError: # If submodule imports don't work, that's expected in some install scenarios # Just verify the star import works - from eval_protocol import Message as RKMessage - from eval_protocol import Message as RPMessage + from eval_protocol import Message as RKMessage, Message as RPMessage assert RKMessage is RPMessage try: # Test another submodule - use a function that actually exists - from eval_protocol.auth import get_fireworks_account_id - from eval_protocol.auth import get_fireworks_account_id as rp_get_fireworks_account_id + from eval_protocol.auth import ( + get_fireworks_account_id, + get_fireworks_account_id as rp_get_fireworks_account_id, + ) assert get_fireworks_account_id is rp_get_fireworks_account_id except ImportError: # If submodule imports don't work, verify through star import - from eval_protocol import auth as rk_auth - from eval_protocol import auth as rp_auth + from eval_protocol import auth as rk_auth, auth as rp_auth assert rk_auth is rp_auth diff --git a/tests/test_rl_processing.py b/tests/test_rl_processing.py index 1086732c..93f9e7fc 100644 --- a/tests/test_rl_processing.py +++ b/tests/test_rl_processing.py @@ -9,14 +9,11 @@ import pytest from eval_protocol.agent.models import StepData -from eval_protocol.models import EvaluateResult -from eval_protocol.models import Message as RewardKitMessage -from eval_protocol.models import StepOutput +from eval_protocol.models import EvaluateResult, Message as RewardKitMessage, StepOutput from eval_protocol.rl_processing import RLDataAligner class TestRLDataAligner: - def create_mock_step_data( self, system_step_index: int, diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 8d176780..e97769c7 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -89,7 +89,7 @@ def add_tool_response( { "role": "tool", "content": response, - "tool_call_id": tool_call.tool_call_id or f"call_{len(conversation_history)-1}", + "tool_call_id": tool_call.tool_call_id or f"call_{len(conversation_history) - 1}", "control_plane_step": { "step": env_index, "reward": reward, @@ -149,7 +149,6 @@ async def test_rollout_with_control_plane_separation(self): patch.object(GeneralMCPVectorEnv, "close") as mock_close, patch.object(GeneralMCPVectorEnv, "format_user_prompt") as mock_format_user_prompt, ): - # Setup mock vector environment mock_env = GeneralMCPVectorEnv(sessions, dataset_rows) mock_env.sessions = sessions @@ -269,9 +268,9 @@ def mock_step_side_effect(env_index, tool_call): # Tool responses should only contain data plane information content = msg.content or "" # The content should not directly contain rewards or termination (they're in control_plane_step) - assert ( - "reward" not in content.lower() or "reward_source" in content.lower() - ), "Tool response should not directly contain reward" + assert "reward" not in content.lower() or "reward_source" in content.lower(), ( + "Tool response should not directly contain reward" + ) # Validate control plane information from messages rewards = [msg.control_plane_step["reward"] for msg in messages_with_control_plane] @@ -433,7 +432,6 @@ async def test_rollout_handles_control_plane_failure_gracefully(self): patch.object(GeneralMCPVectorEnv, "close") as mock_close, patch.object(GeneralMCPVectorEnv, "format_user_prompt") as mock_format_user_prompt, ): - mock_env = GeneralMCPVectorEnv(sessions, dataset_rows) mock_env.sessions = sessions mock_env.dataset_rows = dataset_rows diff --git a/tests/test_typed_interface_rl.py b/tests/test_typed_interface_rl.py index 3b1bf975..5455f9f9 100644 --- a/tests/test_typed_interface_rl.py +++ b/tests/test_typed_interface_rl.py @@ -69,7 +69,6 @@ def batch_invalid_output_func(rollouts_messages: List[List[Message]], ground_tru class TestTypedInterfaceRL: - def test_pointwise_rl_rewards_valid_input(self): """Test pointwise RL reward function with valid dict messages.""" raw_messages = [ diff --git a/tests/test_url_handling.py b/tests/test_url_handling.py index 542b0f05..6c5d9f4f 100644 --- a/tests/test_url_handling.py +++ b/tests/test_url_handling.py @@ -1,4 +1,5 @@ from unittest.mock import AsyncMock, MagicMock, patch + import httpx import pytest from werkzeug.wrappers import Response diff --git a/uv.lock b/uv.lock index 0c27bafe..e1e74645 100644 --- a/uv.lock +++ b/uv.lock @@ -338,19 +338,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/58/cc6a08053f822f98f334d38a27687b69c6655fb05cd74a7a5e70a2aeed95/authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e", size = 239299, upload-time = "2025-07-20T07:38:39.259Z" }, ] -[[package]] -name = "autopep8" -version = "2.3.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pycodestyle" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/50/d8/30873d2b7b57dee9263e53d142da044c4600a46f2d28374b3e38b023df16/autopep8-2.3.2.tar.gz", hash = "sha256:89440a4f969197b69a995e4ce0661b031f455a9f776d2c5ba3dbd83466931758", size = 92210, upload-time = "2025-01-14T14:46:18.454Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/43/53afb8ba17218f19b77c7834128566c5bbb100a0ad9ba2e8e89d089d7079/autopep8-2.3.2-py2.py3-none-any.whl", hash = "sha256:ce8ad498672c845a0c3de2629c15b635ec2b05ef8177a6e7c91c74f3e9b51128", size = 45807, upload-time = "2025-01-14T14:46:15.466Z" }, -] - [[package]] name = "babel" version = "2.17.0" @@ -420,40 +407,6 @@ compiler = [ { name = "ruff" }, ] -[[package]] -name = "black" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "mypy-extensions" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "platformdirs" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" }, - { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" }, - { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" }, - { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372, upload-time = "2025-01-29T05:37:11.71Z" }, - { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865, upload-time = "2025-01-29T05:37:14.309Z" }, - { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699, upload-time = "2025-01-29T04:18:17.688Z" }, - { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028, upload-time = "2025-01-29T04:18:51.711Z" }, - { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988, upload-time = "2025-01-29T05:37:16.707Z" }, - { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985, upload-time = "2025-01-29T05:37:18.273Z" }, - { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816, upload-time = "2025-01-29T04:18:33.823Z" }, - { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860, upload-time = "2025-01-29T04:19:12.944Z" }, - { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673, upload-time = "2025-01-29T05:37:20.574Z" }, - { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190, upload-time = "2025-01-29T05:37:22.106Z" }, - { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926, upload-time = "2025-01-29T04:18:58.564Z" }, - { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613, upload-time = "2025-01-29T04:19:27.63Z" }, - { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, -] - [[package]] name = "bleach" version = "6.2.0" @@ -585,15 +538,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/c2/80633736cd183ee4a62107413def345f7e6e3c01563dbca1417363cf957e/build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5", size = 22950, upload-time = "2024-10-06T17:22:23.299Z" }, ] -[[package]] -name = "cachetools" -version = "5.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, -] - [[package]] name = "certifi" version = "2025.7.14" @@ -1167,41 +1111,30 @@ dependencies = [ [package.optional-dependencies] adapters = [ { name = "datasets" }, - { name = "google-auth" }, - { name = "google-auth-oauthlib" }, - { name = "google-cloud-bigquery" }, { name = "langfuse" }, { name = "transformers" }, ] -bigquery = [ - { name = "google-auth" }, - { name = "google-auth-oauthlib" }, - { name = "google-cloud-bigquery" }, -] box2d = [ { name = "gymnasium", extra = ["box2d"] }, { name = "pillow" }, { name = "swig" }, ] dev = [ - { name = "autopep8" }, - { name = "black" }, { name = "build" }, { name = "docker" }, { name = "e2b" }, - { name = "flake8" }, { name = "haikus" }, { name = "ipykernel" }, - { name = "isort" }, { name = "jupyter" }, - { name = "mypy" }, { name = "openai" }, { name = "pip" }, { name = "pre-commit" }, + { name = "pyright" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-httpserver" }, { name = "pytest-xdist" }, + { name = "ruff" }, { name = "transformers" }, { name = "twine" }, { name = "types-docker" }, @@ -1250,8 +1183,6 @@ requires-dist = [ { name = "aiohttp" }, { name = "aiosqlite" }, { name = "anthropic", specifier = ">=0.59.0" }, - { name = "autopep8", marker = "extra == 'dev'", specifier = ">=1.5.0" }, - { name = "black", marker = "extra == 'dev'", specifier = ">=21.5b2" }, { name = "build", marker = "extra == 'dev'" }, { name = "dataclasses-json", specifier = ">=0.5.7" }, { name = "datasets" }, @@ -1263,14 +1194,7 @@ requires-dist = [ { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, { name = "fireworks-ai", marker = "extra == 'fireworks'", specifier = ">=0.19.12" }, - { name = "flake8", marker = "extra == 'dev'", specifier = ">=3.9.2" }, { name = "fsspec" }, - { name = "google-auth", marker = "extra == 'adapters'", specifier = ">=2.0.0" }, - { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, - { name = "google-auth-oauthlib", marker = "extra == 'adapters'", specifier = ">=1.0.0" }, - { name = "google-auth-oauthlib", marker = "extra == 'bigquery'", specifier = ">=1.0.0" }, - { name = "google-cloud-bigquery", marker = "extra == 'adapters'", specifier = ">=3.0.0" }, - { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", specifier = ">=0.29.0" }, { name = "gymnasium", extras = ["box2d"], marker = "extra == 'box2d'", specifier = ">=0.29.0" }, { name = "haikus", marker = "extra == 'dev'", specifier = "==0.3.8" }, @@ -1278,7 +1202,6 @@ requires-dist = [ { name = "hydra-core", specifier = ">=1.3.2" }, { name = "ipykernel", specifier = ">=6.30.0" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.30.0" }, - { name = "isort", marker = "extra == 'dev'", specifier = ">=5.0.0" }, { name = "jupyter", specifier = ">=1.1.1" }, { name = "jupyter", marker = "extra == 'dev'", specifier = ">=1.1.1" }, { name = "langfuse", marker = "extra == 'adapters'", specifier = ">=2.0.0" }, @@ -1286,7 +1209,6 @@ requires-dist = [ { name = "litellm", specifier = ">=1.0.0" }, { name = "loguru", specifier = ">=0.6.0" }, { name = "mcp", specifier = ">=1.9.2" }, - { name = "mypy", marker = "extra == 'dev'", specifier = ">=0.812" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "openai", specifier = "==1.78.1" }, { name = "openai", marker = "extra == 'dev'", specifier = "==1.78.1" }, @@ -1299,6 +1221,7 @@ requires-dist = [ { name = "pre-commit", marker = "extra == 'dev'" }, { name = "psutil", specifier = ">=5.8.0" }, { name = "pydantic", specifier = ">=2.0.0" }, + { name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.365" }, { name = "pytest", specifier = ">=6.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'dev'" }, @@ -1308,6 +1231,7 @@ requires-dist = [ { name = "pyyaml", specifier = ">=5.0" }, { name = "requests", specifier = ">=2.25.0" }, { name = "rich", specifier = ">=12.0.0" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.5.0" }, { name = "selenium", marker = "extra == 'svgbench'", specifier = ">=4.0.0" }, { name = "swig", marker = "extra == 'box2d'" }, { name = "toml", specifier = ">=0.10.0" }, @@ -1327,7 +1251,7 @@ requires-dist = [ { name = "websockets", specifier = ">=15.0.1" }, { name = "werkzeug", marker = "extra == 'dev'", specifier = ">=2.0.0" }, ] -provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "bigquery", "adapters", "svgbench"] +provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "adapters", "svgbench"] [package.metadata.requires-dev] dev = [ @@ -1504,20 +1428,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/83/1d44379a2d60009e31d157673998e46be144fc842c9c9389523e42cdfa65/fireworks_ai-0.19.12-py3-none-any.whl", hash = "sha256:3a50d807a7ca274a62e37be3a0b0d93ce76c412fb58e24b3d4cbadf8b9be0b74", size = 584665, upload-time = "2025-07-28T16:09:48.132Z" }, ] -[[package]] -name = "flake8" -version = "7.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mccabe" }, - { name = "pycodestyle" }, - { name = "pyflakes" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9b/af/fbfe3c4b5a657d79e5c47a2827a362f9e1b763336a52f926126aa6dc7123/flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872", size = 48326, upload-time = "2025-06-20T19:31:35.838Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/56/13ab06b4f93ca7cac71078fbe37fcea175d3216f31f85c3168a6bbd0bb9a/flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e", size = 57922, upload-time = "2025-06-20T19:31:34.425Z" }, -] - [[package]] name = "fqdn" version = "1.5.1" @@ -1635,133 +1545,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "google-api-core" -version = "2.25.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "googleapis-common-protos" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443, upload-time = "2025-06-12T20:52:20.439Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, -] - -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, - { name = "grpcio-status" }, -] - -[[package]] -name = "google-auth" -version = "2.40.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools" }, - { name = "pyasn1-modules" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, -] - -[[package]] -name = "google-auth-oauthlib" -version = "1.2.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "requests-oauthlib" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fb/87/e10bf24f7bcffc1421b84d6f9c3377c30ec305d082cd737ddaa6d8f77f7c/google_auth_oauthlib-1.2.2.tar.gz", hash = "sha256:11046fb8d3348b296302dd939ace8af0a724042e8029c1b872d87fabc9f41684", size = 20955, upload-time = "2025-04-22T16:40:29.172Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/84/40ee070be95771acd2f4418981edb834979424565c3eec3cd88b6aa09d24/google_auth_oauthlib-1.2.2-py3-none-any.whl", hash = "sha256:fd619506f4b3908b5df17b65f39ca8d66ea56986e5472eb5978fd8f3786f00a2", size = 19072, upload-time = "2025-04-22T16:40:28.174Z" }, -] - -[[package]] -name = "google-cloud-bigquery" -version = "3.35.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-resumable-media" }, - { name = "packaging" }, - { name = "python-dateutil" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/44/e4/9cf03fa81fefd1b9811a7cd6e398804ae0de3b6a4edef810e2acd45cabbc/google_cloud_bigquery-3.35.1.tar.gz", hash = "sha256:599f26cacf190acfe88000f6cc5f4bc9e6baac7899e4f406ca054f1906f71960", size = 496433, upload-time = "2025-07-24T15:09:04.108Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/50/96fe9bc5b83d3a421e91ed8edc2535de45957e9af398273e3ecb5c3a1094/google_cloud_bigquery-3.35.1-py3-none-any.whl", hash = "sha256:6739a6ba63c6d80735ca2b34b1df2090ff473b80c1a62354caa2debe6dbbd961", size = 256877, upload-time = "2025-07-24T15:09:02.443Z" }, -] - -[[package]] -name = "google-cloud-core" -version = "2.4.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d6/b8/2b53838d2acd6ec6168fd284a990c76695e84c65deee79c9f3a4276f6b4f/google_cloud_core-2.4.3.tar.gz", hash = "sha256:1fab62d7102844b278fe6dead3af32408b1df3eb06f5c7e8634cbd40edc4da53", size = 35861, upload-time = "2025-03-10T21:05:38.948Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/86/bda7241a8da2d28a754aad2ba0f6776e35b67e37c36ae0c45d49370f1014/google_cloud_core-2.4.3-py2.py3-none-any.whl", hash = "sha256:5130f9f4c14b4fafdff75c79448f9495cfade0d8775facf1b09c3bf67e027f6e", size = 29348, upload-time = "2025-03-10T21:05:37.785Z" }, -] - -[[package]] -name = "google-crc32c" -version = "1.7.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/69/b1b05cf415df0d86691d6a8b4b7e60ab3a6fb6efb783ee5cd3ed1382bfd3/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:b07d48faf8292b4db7c3d64ab86f950c2e94e93a11fd47271c28ba458e4a0d76", size = 30467, upload-time = "2025-03-26T14:31:11.92Z" }, - { url = "https://files.pythonhosted.org/packages/44/3d/92f8928ecd671bd5b071756596971c79d252d09b835cdca5a44177fa87aa/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:7cc81b3a2fbd932a4313eb53cc7d9dde424088ca3a0337160f35d91826880c1d", size = 30311, upload-time = "2025-03-26T14:53:14.161Z" }, - { url = "https://files.pythonhosted.org/packages/33/42/c2d15a73df79d45ed6b430b9e801d0bd8e28ac139a9012d7d58af50a385d/google_crc32c-1.7.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c67ca0a1f5b56162951a9dae987988679a7db682d6f97ce0f6381ebf0fbea4c", size = 37889, upload-time = "2025-03-26T14:41:27.83Z" }, - { url = "https://files.pythonhosted.org/packages/57/ea/ac59c86a3c694afd117bb669bde32aaf17d0de4305d01d706495f09cbf19/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc5319db92daa516b653600794d5b9f9439a9a121f3e162f94b0e1891c7933cb", size = 33028, upload-time = "2025-03-26T14:41:29.141Z" }, - { url = "https://files.pythonhosted.org/packages/60/44/87e77e8476767a4a93f6cf271157c6d948eacec63688c093580af13b04be/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcdf5a64adb747610140572ed18d011896e3b9ae5195f2514b7ff678c80f1603", size = 38026, upload-time = "2025-03-26T14:41:29.921Z" }, - { url = "https://files.pythonhosted.org/packages/c8/bf/21ac7bb305cd7c1a6de9c52f71db0868e104a5b573a4977cd9d0ff830f82/google_crc32c-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:754561c6c66e89d55754106739e22fdaa93fafa8da7221b29c8b8e8270c6ec8a", size = 33476, upload-time = "2025-03-26T14:29:09.086Z" }, - { url = "https://files.pythonhosted.org/packages/f7/94/220139ea87822b6fdfdab4fb9ba81b3fff7ea2c82e2af34adc726085bffc/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6fbab4b935989e2c3610371963ba1b86afb09537fd0c633049be82afe153ac06", size = 30468, upload-time = "2025-03-26T14:32:52.215Z" }, - { url = "https://files.pythonhosted.org/packages/94/97/789b23bdeeb9d15dc2904660463ad539d0318286d7633fe2760c10ed0c1c/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:ed66cbe1ed9cbaaad9392b5259b3eba4a9e565420d734e6238813c428c3336c9", size = 30313, upload-time = "2025-03-26T14:57:38.758Z" }, - { url = "https://files.pythonhosted.org/packages/81/b8/976a2b843610c211e7ccb3e248996a61e87dbb2c09b1499847e295080aec/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee6547b657621b6cbed3562ea7826c3e11cab01cd33b74e1f677690652883e77", size = 33048, upload-time = "2025-03-26T14:41:30.679Z" }, - { url = "https://files.pythonhosted.org/packages/c9/16/a3842c2cf591093b111d4a5e2bfb478ac6692d02f1b386d2a33283a19dc9/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d68e17bad8f7dd9a49181a1f5a8f4b251c6dbc8cc96fb79f1d321dfd57d66f53", size = 32669, upload-time = "2025-03-26T14:41:31.432Z" }, - { url = "https://files.pythonhosted.org/packages/04/17/ed9aba495916fcf5fe4ecb2267ceb851fc5f273c4e4625ae453350cfd564/google_crc32c-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:6335de12921f06e1f774d0dd1fbea6bf610abe0887a1638f64d694013138be5d", size = 33476, upload-time = "2025-03-26T14:29:10.211Z" }, - { url = "https://files.pythonhosted.org/packages/dd/b7/787e2453cf8639c94b3d06c9d61f512234a82e1d12d13d18584bd3049904/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2d73a68a653c57281401871dd4aeebbb6af3191dcac751a76ce430df4d403194", size = 30470, upload-time = "2025-03-26T14:34:31.655Z" }, - { url = "https://files.pythonhosted.org/packages/ed/b4/6042c2b0cbac3ec3a69bb4c49b28d2f517b7a0f4a0232603c42c58e22b44/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:22beacf83baaf59f9d3ab2bbb4db0fb018da8e5aebdce07ef9f09fce8220285e", size = 30315, upload-time = "2025-03-26T15:01:54.634Z" }, - { url = "https://files.pythonhosted.org/packages/29/ad/01e7a61a5d059bc57b702d9ff6a18b2585ad97f720bd0a0dbe215df1ab0e/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19eafa0e4af11b0a4eb3974483d55d2d77ad1911e6cf6f832e1574f6781fd337", size = 33180, upload-time = "2025-03-26T14:41:32.168Z" }, - { url = "https://files.pythonhosted.org/packages/3b/a5/7279055cf004561894ed3a7bfdf5bf90a53f28fadd01af7cd166e88ddf16/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d86616faaea68101195c6bdc40c494e4d76f41e07a37ffdef270879c15fb65", size = 32794, upload-time = "2025-03-26T14:41:33.264Z" }, - { url = "https://files.pythonhosted.org/packages/0f/d6/77060dbd140c624e42ae3ece3df53b9d811000729a5c821b9fd671ceaac6/google_crc32c-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:b7491bdc0c7564fcf48c0179d2048ab2f7c7ba36b84ccd3a3e1c3f7a72d3bba6", size = 33477, upload-time = "2025-03-26T14:29:10.94Z" }, - { url = "https://files.pythonhosted.org/packages/8b/72/b8d785e9184ba6297a8620c8a37cf6e39b81a8ca01bb0796d7cbb28b3386/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:df8b38bdaf1629d62d51be8bdd04888f37c451564c2042d36e5812da9eff3c35", size = 30467, upload-time = "2025-03-26T14:36:06.909Z" }, - { url = "https://files.pythonhosted.org/packages/34/25/5f18076968212067c4e8ea95bf3b69669f9fc698476e5f5eb97d5b37999f/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:e42e20a83a29aa2709a0cf271c7f8aefaa23b7ab52e53b322585297bb94d4638", size = 30309, upload-time = "2025-03-26T15:06:15.318Z" }, - { url = "https://files.pythonhosted.org/packages/92/83/9228fe65bf70e93e419f38bdf6c5ca5083fc6d32886ee79b450ceefd1dbd/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:905a385140bf492ac300026717af339790921f411c0dfd9aa5a9e69a08ed32eb", size = 33133, upload-time = "2025-03-26T14:41:34.388Z" }, - { url = "https://files.pythonhosted.org/packages/c3/ca/1ea2fd13ff9f8955b85e7956872fdb7050c4ace8a2306a6d177edb9cf7fe/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b211ddaf20f7ebeec5c333448582c224a7c90a9d98826fbab82c0ddc11348e6", size = 32773, upload-time = "2025-03-26T14:41:35.19Z" }, - { url = "https://files.pythonhosted.org/packages/89/32/a22a281806e3ef21b72db16f948cad22ec68e4bdd384139291e00ff82fe2/google_crc32c-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:0f99eaa09a9a7e642a61e06742856eec8b19fc0037832e03f941fe7cf0c8e4db", size = 33475, upload-time = "2025-03-26T14:29:11.771Z" }, - { url = "https://files.pythonhosted.org/packages/b8/c5/002975aff514e57fc084ba155697a049b3f9b52225ec3bc0f542871dd524/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32d1da0d74ec5634a05f53ef7df18fc646666a25efaaca9fc7dcfd4caf1d98c3", size = 33243, upload-time = "2025-03-26T14:41:35.975Z" }, - { url = "https://files.pythonhosted.org/packages/61/cb/c585282a03a0cea70fcaa1bf55d5d702d0f2351094d663ec3be1c6c67c52/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e10554d4abc5238823112c2ad7e4560f96c7bf3820b202660373d769d9e6e4c9", size = 32870, upload-time = "2025-03-26T14:41:37.08Z" }, - { url = "https://files.pythonhosted.org/packages/0b/43/31e57ce04530794917dfe25243860ec141de9fadf4aa9783dffe7dac7c39/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8e9afc74168b0b2232fb32dd202c93e46b7d5e4bf03e66ba5dc273bb3559589", size = 28242, upload-time = "2025-03-26T14:41:42.858Z" }, - { url = "https://files.pythonhosted.org/packages/eb/f3/8b84cd4e0ad111e63e30eb89453f8dd308e3ad36f42305cf8c202461cdf0/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa8136cc14dd27f34a3221c0f16fd42d8a40e4778273e61a3c19aedaa44daf6b", size = 28049, upload-time = "2025-03-26T14:41:44.651Z" }, - { url = "https://files.pythonhosted.org/packages/16/1b/1693372bf423ada422f80fd88260dbfd140754adb15cbc4d7e9a68b1cb8e/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85fef7fae11494e747c9fd1359a527e5970fc9603c90764843caabd3a16a0a48", size = 28241, upload-time = "2025-03-26T14:41:45.898Z" }, - { url = "https://files.pythonhosted.org/packages/fd/3c/2a19a60a473de48717b4efb19398c3f914795b64a96cf3fbe82588044f78/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6efb97eb4369d52593ad6f75e7e10d053cf00c48983f7a973105bc70b0ac4d82", size = 28048, upload-time = "2025-03-26T14:41:46.696Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.7.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099, upload-time = "2024-08-07T22:20:38.555Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251, upload-time = "2024-08-07T22:20:36.409Z" }, -] - [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -1873,20 +1656,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/80/de3eb55eb581815342d097214bed4c59e806b05f1b3110df03b2280d6dfd/grpcio-1.74.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd3c71aeee838299c5887230b8a1822795325ddfea635edd82954c1eaa831e24", size = 4489214, upload-time = "2025-07-24T18:53:59.771Z" }, ] -[[package]] -name = "grpcio-status" -version = "1.71.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "googleapis-common-protos" }, - { name = "grpcio" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fd/d1/b6e9877fedae3add1afdeae1f89d1927d296da9cf977eca0eb08fb8a460e/grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50", size = 13677, upload-time = "2025-06-28T04:24:05.426Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/67/58/317b0134129b556a93a3b0afe00ee675b5657f0155509e22fcb853bafe2d/grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3", size = 14424, upload-time = "2025-06-28T04:23:42.136Z" }, -] - [[package]] name = "grpclib" version = "0.4.8" @@ -2280,15 +2049,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/55/e5326141505c5d5e34c5e0935d2908a74e4561eca44108fbfb9c13d2911a/isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042", size = 11321, upload-time = "2020-11-01T10:59:58.02Z" }, ] -[[package]] -name = "isort" -version = "6.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955, upload-time = "2025-02-26T21:13:16.955Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186, upload-time = "2025-02-26T21:13:14.911Z" }, -] - [[package]] name = "jaconv" version = "0.4.0" @@ -2965,15 +2725,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] -[[package]] -name = "mccabe" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/ff/0ffefdcac38932a54d2b5eed4e0ba8a408f215002cd178ad1df0f2806ff8/mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", size = 9658, upload-time = "2022-01-24T01:14:51.113Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, -] - [[package]] name = "mcp" version = "1.12.2" @@ -3227,45 +2978,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] -[[package]] -name = "mypy" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mypy-extensions" }, - { name = "pathspec" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1e/e3/034322d5a779685218ed69286c32faa505247f1f096251ef66c8fd203b08/mypy-1.17.0.tar.gz", hash = "sha256:e5d7ccc08ba089c06e2f5629c660388ef1fee708444f1dee0b9203fa031dee03", size = 3352114, upload-time = "2025-07-14T20:34:30.181Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/31/e762baa3b73905c856d45ab77b4af850e8159dffffd86a52879539a08c6b/mypy-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8e08de6138043108b3b18f09d3f817a4783912e48828ab397ecf183135d84d6", size = 10998313, upload-time = "2025-07-14T20:33:24.519Z" }, - { url = "https://files.pythonhosted.org/packages/1c/c1/25b2f0d46fb7e0b5e2bee61ec3a47fe13eff9e3c2f2234f144858bbe6485/mypy-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce4a17920ec144647d448fc43725b5873548b1aae6c603225626747ededf582d", size = 10128922, upload-time = "2025-07-14T20:34:06.414Z" }, - { url = "https://files.pythonhosted.org/packages/02/78/6d646603a57aa8a2886df1b8881fe777ea60f28098790c1089230cd9c61d/mypy-1.17.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ff25d151cc057fdddb1cb1881ef36e9c41fa2a5e78d8dd71bee6e4dcd2bc05b", size = 11913524, upload-time = "2025-07-14T20:33:19.109Z" }, - { url = "https://files.pythonhosted.org/packages/4f/19/dae6c55e87ee426fb76980f7e78484450cad1c01c55a1dc4e91c930bea01/mypy-1.17.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93468cf29aa9a132bceb103bd8475f78cacde2b1b9a94fd978d50d4bdf616c9a", size = 12650527, upload-time = "2025-07-14T20:32:44.095Z" }, - { url = "https://files.pythonhosted.org/packages/86/e1/f916845a235235a6c1e4d4d065a3930113767001d491b8b2e1b61ca56647/mypy-1.17.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:98189382b310f16343151f65dd7e6867386d3e35f7878c45cfa11383d175d91f", size = 12897284, upload-time = "2025-07-14T20:33:38.168Z" }, - { url = "https://files.pythonhosted.org/packages/ae/dc/414760708a4ea1b096bd214d26a24e30ac5e917ef293bc33cdb6fe22d2da/mypy-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:c004135a300ab06a045c1c0d8e3f10215e71d7b4f5bb9a42ab80236364429937", size = 9506493, upload-time = "2025-07-14T20:34:01.093Z" }, - { url = "https://files.pythonhosted.org/packages/d4/24/82efb502b0b0f661c49aa21cfe3e1999ddf64bf5500fc03b5a1536a39d39/mypy-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9d4fe5c72fd262d9c2c91c1117d16aac555e05f5beb2bae6a755274c6eec42be", size = 10914150, upload-time = "2025-07-14T20:31:51.985Z" }, - { url = "https://files.pythonhosted.org/packages/03/96/8ef9a6ff8cedadff4400e2254689ca1dc4b420b92c55255b44573de10c54/mypy-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d96b196e5c16f41b4f7736840e8455958e832871990c7ba26bf58175e357ed61", size = 10039845, upload-time = "2025-07-14T20:32:30.527Z" }, - { url = "https://files.pythonhosted.org/packages/df/32/7ce359a56be779d38021d07941cfbb099b41411d72d827230a36203dbb81/mypy-1.17.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:73a0ff2dd10337ceb521c080d4147755ee302dcde6e1a913babd59473904615f", size = 11837246, upload-time = "2025-07-14T20:32:01.28Z" }, - { url = "https://files.pythonhosted.org/packages/82/16/b775047054de4d8dbd668df9137707e54b07fe18c7923839cd1e524bf756/mypy-1.17.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24cfcc1179c4447854e9e406d3af0f77736d631ec87d31c6281ecd5025df625d", size = 12571106, upload-time = "2025-07-14T20:34:26.942Z" }, - { url = "https://files.pythonhosted.org/packages/a1/cf/fa33eaf29a606102c8d9ffa45a386a04c2203d9ad18bf4eef3e20c43ebc8/mypy-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3c56f180ff6430e6373db7a1d569317675b0a451caf5fef6ce4ab365f5f2f6c3", size = 12759960, upload-time = "2025-07-14T20:33:42.882Z" }, - { url = "https://files.pythonhosted.org/packages/94/75/3f5a29209f27e739ca57e6350bc6b783a38c7621bdf9cac3ab8a08665801/mypy-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:eafaf8b9252734400f9b77df98b4eee3d2eecab16104680d51341c75702cad70", size = 9503888, upload-time = "2025-07-14T20:32:34.392Z" }, - { url = "https://files.pythonhosted.org/packages/12/e9/e6824ed620bbf51d3bf4d6cbbe4953e83eaf31a448d1b3cfb3620ccb641c/mypy-1.17.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f986f1cab8dbec39ba6e0eaa42d4d3ac6686516a5d3dccd64be095db05ebc6bb", size = 11086395, upload-time = "2025-07-14T20:34:11.452Z" }, - { url = "https://files.pythonhosted.org/packages/ba/51/a4afd1ae279707953be175d303f04a5a7bd7e28dc62463ad29c1c857927e/mypy-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:51e455a54d199dd6e931cd7ea987d061c2afbaf0960f7f66deef47c90d1b304d", size = 10120052, upload-time = "2025-07-14T20:33:09.897Z" }, - { url = "https://files.pythonhosted.org/packages/8a/71/19adfeac926ba8205f1d1466d0d360d07b46486bf64360c54cb5a2bd86a8/mypy-1.17.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3204d773bab5ff4ebbd1f8efa11b498027cd57017c003ae970f310e5b96be8d8", size = 11861806, upload-time = "2025-07-14T20:32:16.028Z" }, - { url = "https://files.pythonhosted.org/packages/0b/64/d6120eca3835baf7179e6797a0b61d6c47e0bc2324b1f6819d8428d5b9ba/mypy-1.17.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1051df7ec0886fa246a530ae917c473491e9a0ba6938cfd0ec2abc1076495c3e", size = 12744371, upload-time = "2025-07-14T20:33:33.503Z" }, - { url = "https://files.pythonhosted.org/packages/1f/dc/56f53b5255a166f5bd0f137eed960e5065f2744509dfe69474ff0ba772a5/mypy-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f773c6d14dcc108a5b141b4456b0871df638eb411a89cd1c0c001fc4a9d08fc8", size = 12914558, upload-time = "2025-07-14T20:33:56.961Z" }, - { url = "https://files.pythonhosted.org/packages/69/ac/070bad311171badc9add2910e7f89271695a25c136de24bbafc7eded56d5/mypy-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:1619a485fd0e9c959b943c7b519ed26b712de3002d7de43154a489a2d0fd817d", size = 9585447, upload-time = "2025-07-14T20:32:20.594Z" }, - { url = "https://files.pythonhosted.org/packages/be/7b/5f8ab461369b9e62157072156935cec9d272196556bdc7c2ff5f4c7c0f9b/mypy-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2c41aa59211e49d717d92b3bb1238c06d387c9325d3122085113c79118bebb06", size = 11070019, upload-time = "2025-07-14T20:32:07.99Z" }, - { url = "https://files.pythonhosted.org/packages/9c/f8/c49c9e5a2ac0badcc54beb24e774d2499748302c9568f7f09e8730e953fa/mypy-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e69db1fb65b3114f98c753e3930a00514f5b68794ba80590eb02090d54a5d4a", size = 10114457, upload-time = "2025-07-14T20:33:47.285Z" }, - { url = "https://files.pythonhosted.org/packages/89/0c/fb3f9c939ad9beed3e328008b3fb90b20fda2cddc0f7e4c20dbefefc3b33/mypy-1.17.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03ba330b76710f83d6ac500053f7727270b6b8553b0423348ffb3af6f2f7b889", size = 11857838, upload-time = "2025-07-14T20:33:14.462Z" }, - { url = "https://files.pythonhosted.org/packages/4c/66/85607ab5137d65e4f54d9797b77d5a038ef34f714929cf8ad30b03f628df/mypy-1.17.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:037bc0f0b124ce46bfde955c647f3e395c6174476a968c0f22c95a8d2f589bba", size = 12731358, upload-time = "2025-07-14T20:32:25.579Z" }, - { url = "https://files.pythonhosted.org/packages/73/d0/341dbbfb35ce53d01f8f2969facbb66486cee9804048bf6c01b048127501/mypy-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c38876106cb6132259683632b287238858bd58de267d80defb6f418e9ee50658", size = 12917480, upload-time = "2025-07-14T20:34:21.868Z" }, - { url = "https://files.pythonhosted.org/packages/64/63/70c8b7dbfc520089ac48d01367a97e8acd734f65bd07813081f508a8c94c/mypy-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:d30ba01c0f151998f367506fab31c2ac4527e6a7b2690107c7a7f9e3cb419a9c", size = 9589666, upload-time = "2025-07-14T20:34:16.841Z" }, - { url = "https://files.pythonhosted.org/packages/e3/fc/ee058cc4316f219078464555873e99d170bde1d9569abd833300dbeb484a/mypy-1.17.0-py3-none-any.whl", hash = "sha256:15d9d0018237ab058e5de3d8fce61b6fa72cc59cc78fd91f1b474bce12abf496", size = 2283195, upload-time = "2025-07-14T20:31:54.753Z" }, -] - [[package]] name = "mypy-extensions" version = "1.1.0" @@ -3719,15 +3431,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" }, ] -[[package]] -name = "oauthlib" -version = "3.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/19930f824ffeb0ad4372da4812c50edbd1434f678c90c2733e1188edfc63/oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9", size = 185918, upload-time = "2025-06-19T22:48:08.269Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, -] - [[package]] name = "omegaconf" version = "2.3.0" @@ -4072,15 +3775,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, ] -[[package]] -name = "pathspec" -version = "0.12.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, -] - [[package]] name = "peewee" version = "3.18.2" @@ -4376,18 +4070,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, ] -[[package]] -name = "proto-plus" -version = "1.26.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, -] - [[package]] name = "protobuf" version = "5.29.3" @@ -4478,27 +4160,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/4e/519c1bc1876625fe6b71e9a28287c43ec2f20f73c658b9ae1d485c0c206e/pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10", size = 26371006, upload-time = "2025-07-18T00:56:56.379Z" }, ] -[[package]] -name = "pyasn1" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, -] - -[[package]] -name = "pyasn1-modules" -version = "0.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, -] - [[package]] name = "pycares" version = "4.9.0" @@ -4570,15 +4231,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/f8/b4d4bf71ae92727a0b3a9b9092c2e722833c1ca50ebd0414824843cb84fd/pycares-4.9.0-cp313-cp313-win_arm64.whl", hash = "sha256:faa9de8e647ed06757a2c117b70a7645a755561def814da6aca0d766cf71a402", size = 115646, upload-time = "2025-06-13T00:37:33.251Z" }, ] -[[package]] -name = "pycodestyle" -version = "2.14.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/e0/abfd2a0d2efe47670df87f3e3a0e2edda42f055053c85361f19c0e2c1ca8/pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783", size = 39472, upload-time = "2025-06-20T18:49:48.75Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/27/a58ddaf8c588a3ef080db9d0b7e0b97215cee3a45df74f3a94dbbf5c893a/pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d", size = 31594, upload-time = "2025-06-20T18:49:47.491Z" }, -] - [[package]] name = "pycparser" version = "2.22" @@ -4709,15 +4361,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/58/f0/427018098906416f580e3cf1366d3b1abfb408a0652e9f31600c24a1903c/pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796", size = 45235, upload-time = "2025-06-24T13:26:45.485Z" }, ] -[[package]] -name = "pyflakes" -version = "3.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/45/dc/fd034dc20b4b264b3d015808458391acbf9df40b1e54750ef175d39180b1/pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58", size = 64669, upload-time = "2025-06-20T18:45:27.834Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/2f/81d580a0fb83baeb066698975cb14a618bdbed7720678566f1b046a95fe8/pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f", size = 63551, upload-time = "2025-06-20T18:45:26.937Z" }, -] - [[package]] name = "pygame" version = "2.6.1" @@ -4791,6 +4434,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216, upload-time = "2024-09-29T09:24:11.978Z" }, ] +[[package]] +name = "pyright" +version = "1.1.403" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fe/f6/35f885264ff08c960b23d1542038d8da86971c5d8c955cfab195a4f672d7/pyright-1.1.403.tar.gz", hash = "sha256:3ab69b9f41c67fb5bbb4d7a36243256f0d549ed3608678d381d5f51863921104", size = 3913526, upload-time = "2025-07-09T07:15:52.882Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/b6/b04e5c2f41a5ccad74a1a4759da41adb20b4bc9d59a5e08d29ba60084d07/pyright-1.1.403-py3-none-any.whl", hash = "sha256:c0eeca5aa76cbef3fcc271259bbd785753c7ad7bcac99a9162b4c4c7daed23b3", size = 5684504, upload-time = "2025-07-09T07:15:50.958Z" }, +] + [[package]] name = "pysocks" version = "1.7.1" @@ -5178,19 +4834,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" }, ] -[[package]] -name = "requests-oauthlib" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "oauthlib" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650, upload-time = "2024-03-22T20:32:29.939Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" }, -] - [[package]] name = "requests-toolbelt" version = "1.0.0" @@ -5535,18 +5178,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/ed/9de62c2150ca8e2e5858acf3f4f4d0d180a38feef9fdab4078bea63d8dba/rpds_py-0.26.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e99685fc95d386da368013e7fb4269dd39c30d99f812a8372d62f244f662709c", size = 555334, upload-time = "2025-07-01T15:56:51.703Z" }, ] -[[package]] -name = "rsa" -version = "4.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, -] - [[package]] name = "ruff" version = "0.9.10" diff --git a/vendor/tau2/__init__.py b/vendor/tau2/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/__init__.py +++ b/vendor/tau2/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/agent/README.md b/vendor/tau2/agent/README.md index fa201945..ee77cfb6 100644 --- a/vendor/tau2/agent/README.md +++ b/vendor/tau2/agent/README.md @@ -32,4 +32,4 @@ tau2 run \ --agent-llm \ --user-llm \ ... -``` \ No newline at end of file +``` diff --git a/vendor/tau2/agent/base.py b/vendor/tau2/agent/base.py index 7a345432..4d6d9dbd 100644 --- a/vendor/tau2/agent/base.py +++ b/vendor/tau2/agent/base.py @@ -73,9 +73,7 @@ def set_seed(self, seed: int): """ Set the seed for the agent. [Optional] """ - logger.warning( - f"Setting seed for agent is not implemented for class {self.__class__.__name__}" - ) + logger.warning(f"Setting seed for agent is not implemented for class {self.__class__.__name__}") class LocalAgent(BaseAgent[AgentState]): diff --git a/vendor/tau2/agent/llm_agent.py b/vendor/tau2/agent/llm_agent.py index b2fdee99..dffc106d 100644 --- a/vendor/tau2/agent/llm_agent.py +++ b/vendor/tau2/agent/llm_agent.py @@ -82,9 +82,9 @@ def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLM """ if message_history is None: message_history = [] - assert all( - is_valid_agent_history_message(m) for m in message_history - ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + assert all(is_valid_agent_history_message(m) for m in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, @@ -206,9 +206,9 @@ def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLM """ if message_history is None: message_history = [] - assert all( - is_valid_agent_history_message(m) for m in message_history - ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + assert all(is_valid_agent_history_message(m) for m in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, @@ -416,9 +416,9 @@ def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLM """ if message_history is None: message_history = [] - assert all( - is_valid_agent_history_message(m) for m in message_history - ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + assert all(is_valid_agent_history_message(m) for m in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, diff --git a/vendor/tau2/cli.py b/vendor/tau2/cli.py index 65b2d115..89056109 100644 --- a/vendor/tau2/cli.py +++ b/vendor/tau2/cli.py @@ -193,9 +193,7 @@ def main(): start_parser.set_defaults(func=lambda args: run_start_servers()) # Check data command - check_data_parser = subparsers.add_parser( - "check-data", help="Check if data directory is properly configured" - ) + check_data_parser = subparsers.add_parser("check-data", help="Check if data directory is properly configured") check_data_parser.set_defaults(func=lambda args: run_check_data()) args = parser.parse_args() diff --git a/vendor/tau2/data/user_simulator/simulation_guidelines.md b/vendor/tau2/data/user_simulator/simulation_guidelines.md index 8bf34059..f7a559fd 100644 --- a/vendor/tau2/data/user_simulator/simulation_guidelines.md +++ b/vendor/tau2/data/user_simulator/simulation_guidelines.md @@ -1,5 +1,5 @@ # User Simulation Guidelines -You are playing the role of a customer contacting a customer service representative. +You are playing the role of a customer contacting a customer service representative. Your goal is to simulate realistic customer interactions while following specific scenario instructions. ## Core Principles @@ -15,4 +15,4 @@ Your goal is to simulate realistic customer interactions while following specifi - If you are transferred to another agent, generate the '###TRANSFER###' token to indicate the transfer. - If you find yourself in a situation in which the scenario does not provide enough information for you to continue the conversation, generate the '###OUT-OF-SCOPE###' token to end the conversation. -Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. \ No newline at end of file +Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. diff --git a/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md b/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md index 09f85a50..33908510 100644 --- a/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md +++ b/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md @@ -1,6 +1,6 @@ # User Simulation Guidelines -You are playing the role of a customer contacting a customer service representative agent. +You are playing the role of a customer contacting a customer service representative agent. Your goal is to simulate realistic customer interactions while following specific scenario instructions. You have some tools to perform the actions on your end that might be requested by the agent to diagnose and resolve your issue. @@ -27,4 +27,4 @@ You have some tools to perform the actions on your end that might be requested b - If you have been transferred to another agent, generate the '###TRANSFER###' token to indicate the transfer. Only do this after the agent has clearly indicated that you are being transferred. - If you find yourself in a situation in which the scenario does not provide enough information for you to continue the conversation, generate the '###OUT-OF-SCOPE###' token to end the conversation. -Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. \ No newline at end of file +Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. diff --git a/vendor/tau2/data_model/__init__.py b/vendor/tau2/data_model/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/data_model/__init__.py +++ b/vendor/tau2/data_model/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/data_model/message.py b/vendor/tau2/data_model/message.py index ef5f1f7b..077c176d 100644 --- a/vendor/tau2/data_model/message.py +++ b/vendor/tau2/data_model/message.py @@ -18,15 +18,9 @@ class SystemMessage(BaseModel): """ role: SystemRole = Field(description="The role of the message sender.") - content: Optional[str] = Field( - description="The content of the message.", default=None - ) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + content: Optional[str] = Field(description="The content of the message.", default=None) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) def __str__(self) -> str: lines = [ @@ -87,35 +81,21 @@ class ParticipantMessageBase(BaseModel): role: str = Field(description="The role of the message sender.") - content: Optional[str] = Field( - description="The content of the message.", default=None - ) - tool_calls: Optional[list[ToolCall]] = Field( - description="The tool calls made in the message.", default=None - ) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + content: Optional[str] = Field(description="The content of the message.", default=None) + tool_calls: Optional[list[ToolCall]] = Field(description="The tool calls made in the message.", default=None) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) cost: Optional[float] = Field(description="The cost of the message.", default=None) - usage: Optional[dict] = Field( - description="The token usage of the message.", default=None - ) - raw_data: Optional[dict] = Field( - description="The raw data of the message.", default=None - ) + usage: Optional[dict] = Field(description="The token usage of the message.", default=None) + raw_data: Optional[dict] = Field(description="The raw data of the message.", default=None) def validate(self): # NOTE: It would be better to do this in the Pydantic model """ Validate the message. """ if not (self.has_text_content() or self.is_tool_call()): - raise ValueError( - f"AssistantMessage must have either content or tool calls. Got {self}" - ) + raise ValueError(f"AssistantMessage must have either content or tool calls. Got {self}") def has_text_content(self) -> bool: """ @@ -151,11 +131,7 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: if type(other) is not type(self): return False - return ( - self.role == other.role - and self.content == other.content - and self.tool_calls == other.tool_calls - ) + return self.role == other.role and self.content == other.content and self.tool_calls == other.tool_calls class AssistantMessage(ParticipantMessageBase): @@ -187,12 +163,8 @@ class ToolMessage(BaseModel): description="The requestor of the tool call.", ) error: bool = Field(description="Whether the tool call failed.", default=False) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) def __str__(self) -> str: lines = [f"ToolMessage (responding to {self.requestor})"] @@ -228,6 +200,4 @@ class MultiToolMessage(BaseModel): APICompatibleMessage = SystemMessage | AssistantMessage | UserMessage | ToolMessage -Message = ( - SystemMessage | AssistantMessage | UserMessage | ToolMessage | MultiToolMessage -) +Message = SystemMessage | AssistantMessage | UserMessage | ToolMessage | MultiToolMessage diff --git a/vendor/tau2/data_model/simulation.py b/vendor/tau2/data_model/simulation.py index 41de1d72..ee630cf7 100644 --- a/vendor/tau2/data_model/simulation.py +++ b/vendor/tau2/data_model/simulation.py @@ -216,9 +216,7 @@ class RewardInfo(BaseModel): """ reward: Annotated[float, Field(description="The reward received by the agent.")] - db_check: Annotated[ - Optional[DBCheck], Field(description="The database check.", default=None) - ] + db_check: Annotated[Optional[DBCheck], Field(description="The database check.", default=None)] env_assertions: Annotated[ Optional[list[EnvAssertionCheck]], Field(description="The environment assertions.", default=None), @@ -265,9 +263,7 @@ class AgentInfo(BaseModel): implementation: str = Field(description="The type of agent.") llm: Optional[str] = Field(description="The LLM used by the agent.", default=None) - llm_args: Optional[dict] = Field( - description="The arguments to pass to the LLM for the agent.", default=None - ) + llm_args: Optional[dict] = Field(description="The arguments to pass to the LLM for the agent.", default=None) class UserInfo(BaseModel): @@ -277,9 +273,7 @@ class UserInfo(BaseModel): implementation: str = Field(description="The type of user.") llm: Optional[str] = Field(description="The LLM used by the user.", default=None) - llm_args: Optional[dict] = Field( - description="The arguments to pass to the LLM for the user.", default=None - ) + llm_args: Optional[dict] = Field(description="The arguments to pass to the LLM for the user.", default=None) global_simulation_guidelines: Optional[str] = Field( description="The global simulation guidelines for the user.", default=None ) @@ -295,9 +289,7 @@ class Info(BaseModel): user_info: UserInfo = Field(description="User information.") agent_info: AgentInfo = Field(description="Agent information.") environment_info: EnvironmentInfo = Field(description="Environment information.") - seed: Optional[int] = Field( - description="The seed used for the simulation.", default=None - ) + seed: Optional[int] = Field(description="The seed used for the simulation.", default=None) class TerminationReason(str, Enum): @@ -314,31 +306,17 @@ class SimulationRun(BaseModel): id: str = Field(description="The unique identifier for the simulation run.") task_id: str = Field(description="The unique identifier for the task.") - timestamp: str = Field( - description="The timestamp of the simulation.", default_factory=get_now - ) + timestamp: str = Field(description="The timestamp of the simulation.", default_factory=get_now) start_time: str = Field(description="The start time of the simulation.") end_time: str = Field(description="The end time of the simulation.") duration: float = Field(description="The duration of the simulation.") - termination_reason: TerminationReason = Field( - description="The reason for the termination of the simulation." - ) - agent_cost: Optional[float] = Field( - description="The cost of the agent.", default=None - ) - user_cost: Optional[float] = Field( - description="The cost of the user.", default=None - ) - reward_info: Optional[RewardInfo] = Field( - description="The reward received by the agent.", default=None - ) - messages: list[Message] = Field( - description="The messages exchanged between the user, agent and environment." - ) + termination_reason: TerminationReason = Field(description="The reason for the termination of the simulation.") + agent_cost: Optional[float] = Field(description="The cost of the agent.", default=None) + user_cost: Optional[float] = Field(description="The cost of the user.", default=None) + reward_info: Optional[RewardInfo] = Field(description="The reward received by the agent.", default=None) + messages: list[Message] = Field(description="The messages exchanged between the user, agent and environment.") trial: Optional[int] = Field(description="Trial number", default=None) - seed: Optional[int] = Field( - description="Seed used for the simulation.", default=None - ) + seed: Optional[int] = Field(description="Seed used for the simulation.", default=None) class Results(BaseModel): @@ -346,9 +324,7 @@ class Results(BaseModel): Run results """ - timestamp: Optional[str] = Field( - description="The timestamp of the simulation.", default_factory=get_now - ) + timestamp: Optional[str] = Field(description="The timestamp of the simulation.", default_factory=get_now) info: Info = Field(description="Information.") tasks: list[Task] = Field(description="The list of tasks.") simulations: list[SimulationRun] = Field(description="The list of simulations.") @@ -387,14 +363,8 @@ def transfer_only(task: Task) -> bool: return False def get_task_metrics(task: Task) -> dict: - eval_metrics = ( - task.evaluation_criteria.info() - if task.evaluation_criteria is not None - else {} - ) - num_actions = ( - eval_metrics["num_agent_actions"] + eval_metrics["num_user_actions"] - ) + eval_metrics = task.evaluation_criteria.info() if task.evaluation_criteria is not None else {} + num_actions = eval_metrics["num_agent_actions"] + eval_metrics["num_user_actions"] if transfer_only(task): num_actions = -1 info = { diff --git a/vendor/tau2/data_model/tasks.py b/vendor/tau2/data_model/tasks.py index ef17dc3c..c3105557 100644 --- a/vendor/tau2/data_model/tasks.py +++ b/vendor/tau2/data_model/tasks.py @@ -18,9 +18,7 @@ class StructuredUserInstructions(BaseModel): """ domain: Annotated[str, Field(description="The domain of the task.")] - reason_for_call: Annotated[ - str, Field(description="The reason for the user to call the agent.") - ] + reason_for_call: Annotated[str, Field(description="The reason for the user to call the agent.")] known_info: Annotated[ Optional[str], Field(description="Known information about the user.", default=None), @@ -40,9 +38,7 @@ def __str__(self) -> str: lines.append(f"Known info:\n{textwrap.indent(self.known_info, tab)}") if self.unknown_info is not None: lines.append(f"Unknown info:\n{textwrap.indent(self.unknown_info, tab)}") - lines.append( - f"Task instructions:\n{textwrap.indent(self.task_instructions, tab)}" - ) + lines.append(f"Task instructions:\n{textwrap.indent(self.task_instructions, tab)}") return "\n".join(lines) @@ -128,18 +124,14 @@ class Action(BaseModel): If compare_args is None, will check all the arguments. """ - action_id: str = Field( - description="The unique identifier for the action within a scenario." - ) + action_id: str = Field(description="The unique identifier for the action within a scenario.") requestor: ToolRequestor = Field( description="The requestor of the action.", default="assistant", ) name: str = Field(description="The name of the action.") arguments: dict = Field(description="The arguments for the action.") - info: Optional[str] = Field( - description="Information about the action.", default=None - ) + info: Optional[str] = Field(description="Information about the action.", default=None) compare_args: Optional[list[str]] = Field( description="The arguments to check in tool call. If None, will check all the arguments.", default=None, @@ -159,9 +151,7 @@ def get_func_format(self) -> str: """ Get the function format of the action. """ - return ( - f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" - ) + return f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" def compare_with_tool_call(self, tool_call: ToolCall) -> bool: """ @@ -193,9 +183,7 @@ class EnvFunctionCall(BaseModel): Field(description="The type of environment to call the function on."), ] func_name: Annotated[str, Field(description="The name of the function to call.")] - arguments: Annotated[ - dict, Field(description="The arguments to pass to the function.") - ] + arguments: Annotated[dict, Field(description="The arguments to pass to the function.")] def __str__(self) -> str: lines = [] @@ -210,9 +198,7 @@ class EnvAssertion(EnvFunctionCall): An assertion on the agent or user environment. """ - assert_value: Annotated[ - bool, Field(default=True, description="The value to assert on.") - ] + assert_value: Annotated[bool, Field(default=True, description="The value to assert on.")] message: Annotated[ Optional[str], Field( @@ -279,27 +265,16 @@ def __str__(self) -> str: lines = [] if self.actions is not None: lines.append("Actions:") - lines.extend( - [textwrap.indent(str(action), "\t") for action in self.actions] - ) + lines.extend([textwrap.indent(str(action), "\t") for action in self.actions]) if self.env_assertions is not None: lines.append("Env Assertions:") - lines.extend( - [ - textwrap.indent(str(assertion), "\t") - for assertion in self.env_assertions - ] - ) + lines.extend([textwrap.indent(str(assertion), "\t") for assertion in self.env_assertions]) if self.communicate_info is not None: lines.append("Communicate Info:") - lines.extend( - [textwrap.indent(info, "\t") for info in self.communicate_info] - ) + lines.extend([textwrap.indent(info, "\t") for info in self.communicate_info]) if self.nl_assertions is not None: lines.append("NL Assertions:") - lines.extend( - [textwrap.indent(assertion, "\t") for assertion in self.nl_assertions] - ) + lines.extend([textwrap.indent(assertion, "\t") for assertion in self.nl_assertions]) return "\n".join(lines) def info(self) -> dict: @@ -309,16 +284,10 @@ def info(self) -> dict: else 0 ) num_user_actions = ( - len([action for action in self.actions if action.requestor == "user"]) - if self.actions is not None - else 0 - ) - num_env_assertions = ( - len(self.env_assertions) if self.env_assertions is not None else 0 - ) - num_nl_assertions = ( - len(self.nl_assertions) if self.nl_assertions is not None else 0 + len([action for action in self.actions if action.requestor == "user"]) if self.actions is not None else 0 ) + num_env_assertions = len(self.env_assertions) if self.env_assertions is not None else 0 + num_nl_assertions = len(self.nl_assertions) if self.nl_assertions is not None else 0 return { "num_agent_actions": num_agent_actions, "num_user_actions": num_user_actions, @@ -354,9 +323,7 @@ class InitialState(BaseModel): ] initialization_actions: Annotated[ Optional[list[EnvFunctionCall]], - Field( - description="Initial actions to be taken on the environment.", default=None - ), + Field(description="Initial actions to be taken on the environment.", default=None), ] message_history: Annotated[ Optional[list[Message]], @@ -370,29 +337,13 @@ def __str__(self) -> str: lines = [] if self.initialization_data is not None: lines.append("Initialization Data:") - lines.extend( - [ - textwrap.indent( - self.initialization_data.model_dump_json(indent=2), "\t" - ) - ] - ) + lines.extend([textwrap.indent(self.initialization_data.model_dump_json(indent=2), "\t")]) if self.initialization_actions is not None: lines.append("Initialization Actions:") - lines.extend( - [ - textwrap.indent(str(action), "\t") - for action in self.initialization_actions - ] - ) + lines.extend([textwrap.indent(str(action), "\t") for action in self.initialization_actions]) if self.message_history is not None: lines.append("Message History:") - lines.extend( - [ - textwrap.indent(str(message), "\t") - for message in self.message_history - ] - ) + lines.extend([textwrap.indent(str(message), "\t") for message in self.message_history]) return "\n".join(lines) @@ -411,9 +362,7 @@ class Task(BaseModel): ] user_scenario: Annotated[ UserScenario, - Field( - description="User scenario. This information will be sent to the user simulator." - ), + Field(description="User scenario. This information will be sent to the user simulator."), ] ticket: Annotated[ Optional[str], @@ -478,11 +427,7 @@ def make_task( if message_history is not None: # Patch to consider empty list of tool calls as None. for message in message_history: - if ( - message.role == "assistant" - and isinstance(message.tool_calls, list) - and len(message.tool_calls) == 0 - ): + if message.role == "assistant" and isinstance(message.tool_calls, list) and len(message.tool_calls) == 0: message.tool_calls = None initial_state = InitialState( diff --git a/vendor/tau2/domains/airline/data_model.py b/vendor/tau2/domains/airline/data_model.py index f2733727..c046d228 100644 --- a/vendor/tau2/domains/airline/data_model.py +++ b/vendor/tau2/domains/airline/data_model.py @@ -10,9 +10,7 @@ Insurance = Literal["yes", "no"] -MembershipLevel = Annotated[ - Literal["gold", "silver", "regular"], Field(description="Membership level") -] +MembershipLevel = Annotated[Literal["gold", "silver", "regular"], Field(description="Membership level")] class AirportCode(BaseModel): @@ -30,9 +28,7 @@ class Name(BaseModel): class Address(BaseModel): address1: str = Field(description="Primary address line") - address2: Optional[str] = Field( - None, description="Secondary address line (optional)" - ) + address2: Optional[str] = Field(None, description="Secondary address line (optional)") city: str = Field(description="City name") country: str = Field(description="Country name") state: str = Field(description="State or province name") @@ -51,25 +47,19 @@ class PaymentMethodBase(BaseModel): class CreditCard(PaymentMethodBase): - source: Literal["credit_card"] = Field( - description="Indicates this is a credit card payment method" - ) + source: Literal["credit_card"] = Field(description="Indicates this is a credit card payment method") brand: str = Field(description="Credit card brand (e.g., visa, mastercard)") last_four: str = Field(description="Last four digits of the credit card") class GiftCard(PaymentMethodBase): - source: Literal["gift_card"] = Field( - description="Indicates this is a gift card payment method" - ) + source: Literal["gift_card"] = Field(description="Indicates this is a gift card payment method") amount: float = Field(description="Gift card value amount") id: str = Field(description="Unique identifier for the gift card") class Certificate(PaymentMethodBase): - source: Literal["certificate"] = Field( - description="Indicates this is a certificate payment method" - ) + source: Literal["certificate"] = Field(description="Indicates this is a certificate payment method") amount: float = Field(description="Certificate value amount") @@ -82,9 +72,7 @@ class Passenger(BaseModel): dob: str = Field(description="Date of birth in YYYY-MM-DD format") -SeatPrices = Annotated[ - dict[CabinClass, int], Field(description="Prices for different cabin classes") -] +SeatPrices = Annotated[dict[CabinClass, int], Field(description="Prices for different cabin classes")] AvailableSeats = Annotated[ dict[CabinClass, int], Field(description="Available seats for different cabin classes"), @@ -92,9 +80,7 @@ class Passenger(BaseModel): class FlightDateStatusAvailable(BaseModel): - status: Literal["available"] = Field( - description="Indicates flight is available for booking" - ) + status: Literal["available"] = Field(description="Indicates flight is available for booking") available_seats: AvailableSeats = Field(description="Available seats by class") prices: SeatPrices = Field(description="Current prices by class") @@ -166,24 +152,18 @@ class Flight(FlightBase): scheduled_arrival_time_est: str = Field( description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00" ) - dates: Dict[str, FlightDateStatus] = Field( - description="Flight status by date (YYYY-MM-DD)" - ) + dates: Dict[str, FlightDateStatus] = Field(description="Flight status by date (YYYY-MM-DD)") class DirectFlight(FlightBase): - status: Literal["available"] = Field( - description="Indicates flight is available for booking" - ) + status: Literal["available"] = Field(description="Indicates flight is available for booking") scheduled_departure_time_est: str = Field( description="Scheduled departure time in EST in the format HH:MM:SS, e.g 06:00:00" ) scheduled_arrival_time_est: str = Field( description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00" ) - date: Optional[str] = Field( - description="Flight date in YYYY-MM-DD format", default=None - ) + date: Optional[str] = Field(description="Flight date in YYYY-MM-DD format", default=None) available_seats: AvailableSeats = Field(description="Available seats by class") prices: SeatPrices = Field(description="Current prices by class") @@ -195,9 +175,7 @@ class ReservationFlight(FlightBase): class FlightInfo(BaseModel): flight_number: str = Field(description="Flight number, such as 'HAT001'.") - date: str = Field( - description="The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'." - ) + date: str = Field(description="The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'.") class User(BaseModel): @@ -205,15 +183,9 @@ class User(BaseModel): name: Name = Field(description="User's full name") address: Address = Field(description="User's address information") email: str = Field(description="User's email address") - dob: str = Field( - description="User's date of birth in the format YYYY-MM-DD, e.g 1990-04-05" - ) - payment_methods: Dict[str, PaymentMethod] = Field( - description="User's saved payment methods" - ) - saved_passengers: List[Passenger] = Field( - description="User's saved passenger information" - ) + dob: str = Field(description="User's date of birth in the format YYYY-MM-DD, e.g 1990-04-05") + payment_methods: Dict[str, PaymentMethod] = Field(description="User's saved payment methods") + saved_passengers: List[Passenger] = Field(description="User's saved passenger information") membership: MembershipLevel = Field(description="User's membership level") reservations: List[str] = Field(description="List of user's reservation IDs") @@ -226,35 +198,21 @@ class Reservation(BaseModel): destination: str = Field(description="IATA code for trip destination") flight_type: FlightType = Field(description="Type of trip") cabin: CabinClass = Field(description="Selected cabin class") - flights: List[ReservationFlight] = Field( - description="List of flights in the reservation" - ) - passengers: List[Passenger] = Field( - description="List of passengers on the reservation" - ) - payment_history: List[Payment] = Field( - description="History of payments for this reservation" - ) - created_at: str = Field( - description="Timestamp when reservation was created in the format YYYY-MM-DDTHH:MM:SS" - ) + flights: List[ReservationFlight] = Field(description="List of flights in the reservation") + passengers: List[Passenger] = Field(description="List of passengers on the reservation") + payment_history: List[Payment] = Field(description="History of payments for this reservation") + created_at: str = Field(description="Timestamp when reservation was created in the format YYYY-MM-DDTHH:MM:SS") total_baggages: int = Field(description="Total number of bags in reservation") nonfree_baggages: int = Field(description="Number of paid bags in reservation") insurance: Insurance = Field(description="Whether travel insurance was purchased") - status: Optional[Literal["cancelled"]] = Field( - description="Status of the reservation", default=None - ) + status: Optional[Literal["cancelled"]] = Field(description="Status of the reservation", default=None) class FlightDB(DB): """Database of all flights, users, and reservations.""" - flights: Dict[str, Flight] = Field( - description="Dictionary of all flights indexed by flight number" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) + flights: Dict[str, Flight] = Field(description="Dictionary of all flights indexed by flight number") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") reservations: Dict[str, Reservation] = Field( description="Dictionary of all reservations indexed by reservation ID" ) @@ -262,9 +220,7 @@ class FlightDB(DB): def get_statistics(self) -> dict[str, Any]: """Get the statistics of the database.""" num_flights = len(self.flights) - num_flights_instances = sum( - len(flight.dates) for flight in self.flights.values() - ) + num_flights_instances = sum(len(flight.dates) for flight in self.flights.values()) num_users = len(self.users) num_reservations = len(self.reservations) return { diff --git a/vendor/tau2/domains/airline/tools.py b/vendor/tau2/domains/airline/tools.py index d4f45694..d854f725 100644 --- a/vendor/tau2/domains/airline/tools.py +++ b/vendor/tau2/domains/airline/tools.py @@ -62,15 +62,11 @@ def _get_flight_instance(self, flight_number: str, date: str) -> FlightDateStatu raise ValueError(f"Flight {flight_number} not found on date {date}") return flight.dates[date] - def _get_flights_from_flight_infos( - self, flight_infos: List[FlightInfo] - ) -> list[FlightDateStatus]: + def _get_flights_from_flight_infos(self, flight_infos: List[FlightInfo]) -> list[FlightDateStatus]: """Get the flight from the reservation.""" flights = [] for flight_info in flight_infos: - flights.append( - self._get_flight_instance(flight_info.flight_number, flight_info.date) - ) + flights.append(self._get_flight_instance(flight_info.flight_number, flight_info.date)) return flights def _get_new_reservation_id(self) -> str: @@ -123,10 +119,7 @@ def _search_direct_flight( and (destination is None or flight.destination == destination) and (date in flight.dates) and (flight.dates[date].status == "available") - and ( - leave_after is None - or flight.scheduled_departure_time_est >= leave_after - ) + and (leave_after is None or flight.scheduled_departure_time_est >= leave_after) ) if check: direct_flight = DirectFlight( @@ -142,9 +135,7 @@ def _search_direct_flight( results.append(direct_flight) return results - def _payment_for_update( - self, user: User, payment_id: str, total_price: int - ) -> Optional[Payment]: + def _payment_for_update(self, user: User, payment_id: str, total_price: int) -> Optional[Payment]: """ Process payment for update reservation @@ -165,9 +156,7 @@ def _payment_for_update( payment_method = user.payment_methods[payment_id] if payment_method.source == "certificate": raise ValueError("Certificate cannot be used to update reservation") - elif ( - payment_method.source == "gift_card" and payment_method.amount < total_price - ): + elif payment_method.source == "gift_card" and payment_method.amount < total_price: raise ValueError("Gift card balance is not enough") # Deduct payment @@ -219,9 +208,7 @@ def book_reservation( if all(isinstance(passenger, dict) for passenger in passengers): passengers = [Passenger(**passenger) for passenger in passengers] if all(isinstance(payment_method, dict) for payment_method in payment_methods): - payment_methods = [ - Payment(**payment_method) for payment_method in payment_methods - ] + payment_methods = [Payment(**payment_method) for payment_method in payment_methods] user = self._get_user(user_id) reservation_id = self._get_new_reservation_id() @@ -248,14 +235,10 @@ def book_reservation( for flight_info in flights: flight_number = flight_info.flight_number flight = self._get_flight(flight_number) - flight_date_data = self._get_flight_instance( - flight_number=flight_number, date=flight_info.date - ) + flight_date_data = self._get_flight_instance(flight_number=flight_number, date=flight_info.date) # Checking flight availability if not isinstance(flight_date_data, FlightDateStatusAvailable): - raise ValueError( - f"Flight {flight_number} not available on date {flight_info.date}" - ) + raise ValueError(f"Flight {flight_number} not available on date {flight_info.date}") # Checking seat availability if flight_date_data.available_seats[cabin] < len(passengers): raise ValueError(f"Not enough seats on flight {flight_number}") @@ -290,15 +273,11 @@ def book_reservation( user_payment_method = user.payment_methods[payment_id] if user_payment_method.source in {"gift_card", "certificate"}: if user_payment_method.amount < amount: - raise ValueError( - f"Not enough balance in payment method {payment_id}" - ) + raise ValueError(f"Not enough balance in payment method {payment_id}") total_payment = sum(payment.amount for payment in payment_methods) if total_payment != total_price: - raise ValueError( - f"Payment amount does not add up, total price is {total_price}, but paid {total_payment}" - ) + raise ValueError(f"Payment amount does not add up, total price is {total_price}, but paid {total_payment}") # if checks pass, deduct payment for payment_method in payment_methods: @@ -430,9 +409,7 @@ def list_all_airports(self) -> AirportInfo: # DONE ] @is_tool(ToolType.READ) - def search_direct_flight( - self, origin: str, destination: str, date: str - ) -> list[DirectFlight]: + def search_direct_flight(self, origin: str, destination: str, date: str) -> list[DirectFlight]: """ Search for direct flights between two cities on a specific date. @@ -444,9 +421,7 @@ def search_direct_flight( Returns: The direct flights between the two cities on the specific date. """ - return self._search_direct_flight( - date=date, origin=origin, destination=destination - ) + return self._search_direct_flight(date=date, origin=origin, destination=destination) @is_tool(ToolType.READ) def search_onestop_flight( @@ -464,15 +439,9 @@ def search_onestop_flight( A list of pairs of DirectFlight objects. """ results = [] - for result1 in self._search_direct_flight( - date=date, origin=origin, destination=None - ): + for result1 in self._search_direct_flight(date=date, origin=origin, destination=None): result1.date = date - date2 = ( - f"2024-05-{int(date[-2:]) + 1}" - if "+1" in result1.scheduled_arrival_time_est - else date - ) + date2 = f"2024-05-{int(date[-2:]) + 1}" if "+1" in result1.scheduled_arrival_time_est else date # TODO: flight1.scheduled_arrival_time_est could have a +1? for result2 in self._search_direct_flight( date=date2, @@ -637,9 +606,7 @@ def update_reservation_flights( None, ) if matching_reservation_flight: - total_price += matching_reservation_flight.price * len( - reservation.passengers - ) + total_price += matching_reservation_flight.price * len(reservation.passengers) reservation_flights.append(matching_reservation_flight) continue @@ -651,15 +618,11 @@ def update_reservation_flights( date=flight_info.date, ) if not isinstance(flight_date_data, FlightDateStatusAvailable): - raise ValueError( - f"Flight {flight_info.flight_number} not available on date {flight_info.date}" - ) + raise ValueError(f"Flight {flight_info.flight_number} not available on date {flight_info.date}") # Check seat availability if flight_date_data.available_seats[cabin] < len(reservation.passengers): - raise ValueError( - f"Not enough seats on flight {flight_info.flight_number}" - ) + raise ValueError(f"Not enough seats on flight {flight_info.flight_number}") # Calculate price and add to reservation reservation_flight = ReservationFlight( @@ -673,9 +636,7 @@ def update_reservation_flights( reservation_flights.append(reservation_flight) # Deduct amount already paid for reservation - total_price -= sum(flight.price for flight in reservation.flights) * len( - reservation.passengers - ) + total_price -= sum(flight.price for flight in reservation.flights) * len(reservation.passengers) # Create payment payment = self._payment_for_update(user, payment_id, total_price) @@ -690,9 +651,7 @@ def update_reservation_flights( return reservation @is_tool(ToolType.WRITE) - def update_reservation_passengers( - self, reservation_id: str, passengers: List[Passenger | dict] - ) -> Reservation: + def update_reservation_passengers(self, reservation_id: str, passengers: List[Passenger | dict]) -> Reservation: """ Update the passenger information of a reservation. diff --git a/vendor/tau2/domains/mock/data_model.py b/vendor/tau2/domains/mock/data_model.py index f643d3e0..bff026df 100644 --- a/vendor/tau2/domains/mock/data_model.py +++ b/vendor/tau2/domains/mock/data_model.py @@ -24,12 +24,8 @@ class User(BaseModel): class MockDB(DB): """Simple database with users and their tasks.""" - tasks: Dict[str, Task] = Field( - description="Dictionary of all tasks indexed by task ID" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) + tasks: Dict[str, Task] = Field(description="Dictionary of all tasks indexed by task ID") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") def get_db(): diff --git a/vendor/tau2/domains/mock/environment.py b/vendor/tau2/domains/mock/environment.py index d7063315..925fd297 100644 --- a/vendor/tau2/domains/mock/environment.py +++ b/vendor/tau2/domains/mock/environment.py @@ -13,9 +13,7 @@ from vendor.tau2.environment.environment import Environment -def get_environment( - db: Optional[MockDB] = None, solo_mode: bool = False -) -> Environment: +def get_environment(db: Optional[MockDB] = None, solo_mode: bool = False) -> Environment: if db is None: db = MockDB.load(MOCK_DB_PATH) tools = MockTools(db) diff --git a/vendor/tau2/domains/mock/tools.py b/vendor/tau2/domains/mock/tools.py index b36f46af..7c2ab361 100644 --- a/vendor/tau2/domains/mock/tools.py +++ b/vendor/tau2/domains/mock/tools.py @@ -30,9 +30,7 @@ def create_task(self, user_id: str, title: str, description: str = None) -> Task raise ValueError(f"User {user_id} not found") task_id = f"task_{len(self.db.tasks) + 1}" - task = Task( - task_id=task_id, title=title, description=description, status="pending" - ) + task = Task(task_id=task_id, title=title, description=description, status="pending") self.db.tasks[task_id] = task self.db.users[user_id].tasks.append(task_id) diff --git a/vendor/tau2/domains/retail/data_model.py b/vendor/tau2/domains/retail/data_model.py index ddb45e3d..d0415856 100644 --- a/vendor/tau2/domains/retail/data_model.py +++ b/vendor/tau2/domains/retail/data_model.py @@ -22,9 +22,7 @@ class Product(BaseModel): name: str = Field(description="Name of the product") product_id: str = Field(description="Unique identifier for the product") - variants: Dict[str, Variant] = Field( - description="Dictionary of variants indexed by variant ID" - ) + variants: Dict[str, Variant] = Field(description="Dictionary of variants indexed by variant ID") class UserName(BaseModel): @@ -51,23 +49,17 @@ class PaymentMethodBase(BaseModel): class CreditCard(PaymentMethodBase): - source: Literal["credit_card"] = Field( - description="Indicates this is a credit card payment method" - ) + source: Literal["credit_card"] = Field(description="Indicates this is a credit card payment method") brand: str = Field(description="Credit card brand (e.g., visa, mastercard)") last_four: str = Field(description="Last four digits of the credit card") class Paypal(PaymentMethodBase): - source: Literal["paypal"] = Field( - description="Indicates this is a paypal payment method" - ) + source: Literal["paypal"] = Field(description="Indicates this is a paypal payment method") class GiftCard(PaymentMethodBase): - source: Literal["gift_card"] = Field( - description="Indicates this is a gift card payment method" - ) + source: Literal["gift_card"] = Field(description="Indicates this is a gift card payment method") balance: float = Field(description="Gift card value amount") id: str = Field(description="Unique identifier for the gift card") @@ -92,9 +84,7 @@ class OrderFullfilment(BaseModel): """Represents the fulfillment details for items in an order""" tracking_id: list[str] = Field(description="List of tracking IDs for shipments") - item_ids: list[str] = Field( - description="List of item IDs included in this fulfillment" - ) + item_ids: list[str] = Field(description="List of item IDs included in this fulfillment") class OrderItem(BaseModel): @@ -113,9 +103,7 @@ class OrderItem(BaseModel): class OrderPayment(BaseModel): """Represents a payment or refund transaction for an order""" - transaction_type: OrderPaymentType = Field( - description="Type of transaction (payment or refund)" - ) + transaction_type: OrderPaymentType = Field(description="Type of transaction (payment or refund)") amount: float = Field(description="Amount of the transaction") payment_method_id: str = Field(description="ID of the payment method used") @@ -141,32 +129,18 @@ class BaseOrder(BaseModel): address: UserAddress = Field(description="Address of the user") items: List[OrderItem] = Field(description="Items in the order") status: OrderStatus = Field(description="Status of the order") - fulfillments: List[OrderFullfilment] = Field( - description="Fulfillments of the order" - ) + fulfillments: List[OrderFullfilment] = Field(description="Fulfillments of the order") payment_history: List[OrderPayment] = Field(description="Payments of the order") cancel_reason: Optional[CancelReason] = Field( description="Reason for cancelling the order. Can'no longer needed' or 'ordered by mistake'", default=None, ) - exchange_items: Optional[List[str]] = Field( - description="Items to be exchanged", default=None - ) - exchange_new_items: Optional[List[str]] = Field( - description="Items exchanged for", default=None - ) - exchange_payment_method_id: Optional[str] = Field( - description="Payment method ID for the exchange", default=None - ) - exchange_price_difference: Optional[float] = Field( - description="Price difference for the exchange", default=None - ) - return_items: Optional[List[str]] = Field( - description="Items to be returned", default=None - ) - return_payment_method_id: Optional[str] = Field( - description="Payment method ID for the return", default=None - ) + exchange_items: Optional[List[str]] = Field(description="Items to be exchanged", default=None) + exchange_new_items: Optional[List[str]] = Field(description="Items exchanged for", default=None) + exchange_payment_method_id: Optional[str] = Field(description="Payment method ID for the exchange", default=None) + exchange_price_difference: Optional[float] = Field(description="Price difference for the exchange", default=None) + return_items: Optional[List[str]] = Field(description="Items to be returned", default=None) + return_payment_method_id: Optional[str] = Field(description="Payment method ID for the return", default=None) class Order(BaseModel): @@ -177,55 +151,33 @@ class Order(BaseModel): address: UserAddress = Field(description="Address of the user") items: List[OrderItem] = Field(description="Items in the order") status: OrderStatus = Field(description="Status of the order") - fulfillments: List[OrderFullfilment] = Field( - description="Fulfillments of the order" - ) + fulfillments: List[OrderFullfilment] = Field(description="Fulfillments of the order") payment_history: List[OrderPayment] = Field(description="Payments of the order") cancel_reason: Optional[CancelReason] = Field( description="Reason for cancelling the order. Should be 'no longer needed' or 'ordered by mistake'", default=None, ) - exchange_items: Optional[List[str]] = Field( - description="Items to be exchanged", default=None - ) - exchange_new_items: Optional[List[str]] = Field( - description="Items exchanged for", default=None - ) - exchange_payment_method_id: Optional[str] = Field( - description="Payment method ID for the exchange", default=None - ) - exchange_price_difference: Optional[float] = Field( - description="Price difference for the exchange", default=None - ) - return_items: Optional[List[str]] = Field( - description="Items to be returned", default=None - ) - return_payment_method_id: Optional[str] = Field( - description="Payment method ID for the return", default=None - ) + exchange_items: Optional[List[str]] = Field(description="Items to be exchanged", default=None) + exchange_new_items: Optional[List[str]] = Field(description="Items exchanged for", default=None) + exchange_payment_method_id: Optional[str] = Field(description="Payment method ID for the exchange", default=None) + exchange_price_difference: Optional[float] = Field(description="Price difference for the exchange", default=None) + return_items: Optional[List[str]] = Field(description="Items to be returned", default=None) + return_payment_method_id: Optional[str] = Field(description="Payment method ID for the return", default=None) class RetailDB(DB): """Database containing all retail-related data including products, users and orders""" - products: Dict[str, Product] = Field( - description="Dictionary of all products indexed by product ID" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) - orders: Dict[str, Order] = Field( - description="Dictionary of all orders indexed by order ID" - ) + products: Dict[str, Product] = Field(description="Dictionary of all products indexed by product ID") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") + orders: Dict[str, Order] = Field(description="Dictionary of all orders indexed by order ID") def get_statistics(self) -> dict[str, Any]: """Get the statistics of the database.""" num_products = len(self.products) num_users = len(self.users) num_orders = len(self.orders) - total_num_items = sum( - len(product.variants) for product in self.products.values() - ) + total_num_items = sum(len(product.variants) for product in self.products.values()) return { "num_products": num_products, "num_users": num_users, diff --git a/vendor/tau2/domains/retail/tools.py b/vendor/tau2/domains/retail/tools.py index 944e206c..6fc91e15 100644 --- a/vendor/tau2/domains/retail/tools.py +++ b/vendor/tau2/domains/retail/tools.py @@ -92,9 +92,7 @@ def _get_variant(self, product_id: str, variant_id: str) -> Variant: raise ValueError("Variant not found") return product.variants[variant_id] - def _get_payment_method( - self, user_id: str, payment_method_id: str - ) -> PaymentMethod: + def _get_payment_method(self, user_id: str, payment_method_id: str) -> PaymentMethod: """Get the payment method from the database. Args: @@ -252,9 +250,7 @@ def exchange_delivered_order_items( payment_method = self._get_payment_method(order.user_id, payment_method_id) if isinstance(payment_method, GiftCard) and payment_method.balance < diff_price: - raise ValueError( - "Insufficient gift card balance to pay for the price difference" - ) + raise ValueError("Insufficient gift card balance to pay for the price difference") # modify the order order.status = "exchange requested" @@ -266,9 +262,7 @@ def exchange_delivered_order_items( return order @is_tool(ToolType.READ) - def find_user_id_by_name_zip( - self, first_name: str, last_name: str, zip: str - ) -> str: + def find_user_id_by_name_zip(self, first_name: str, last_name: str, zip: str) -> str: """Find user id by first name, last name, and zip code. If the user is not found, the function will return an error message. By default, find user id by email, and only call this function if the user is not found by email or cannot remember email. @@ -368,9 +362,7 @@ def list_all_product_types(self) -> str: Returns: str: A JSON string mapping product names to their product IDs, sorted alphabetically by name. """ - product_dict = { - product.name: product.product_id for product in self.db.products.values() - } + product_dict = {product.name: product.product_id for product in self.db.products.values()} return json.dumps(product_dict, sort_keys=True) @is_tool(ToolType.WRITE) @@ -461,9 +453,7 @@ def modify_pending_order_items( diff_price = 0 for item_id, new_item_id in zip(item_ids, new_item_ids): if item_id == new_item_id: - raise ValueError( - "The new item id should be different from the old item id" - ) + raise ValueError("The new item id should be different from the old item id") item = next((item for item in order.items if item.item_id == item_id), None) if item is None: raise ValueError(f"Item {item_id} not found") @@ -538,17 +528,12 @@ def modify_pending_order_payment( payment_method = self._get_payment_method(order.user_id, payment_method_id) # Check that the payment history should only have one payment - if ( - len(order.payment_history) != 1 - or order.payment_history[0].transaction_type != "payment" - ): + if len(order.payment_history) != 1 or order.payment_history[0].transaction_type != "payment": raise ValueError("There should be exactly one payment for a pending order") # Check that the payment method is different if order.payment_history[0].payment_method_id == payment_method_id: - raise ValueError( - "The new payment method should be different from the current one" - ) + raise ValueError("The new payment method should be different from the current one") amount = order.payment_history[0].amount @@ -578,9 +563,7 @@ def modify_pending_order_payment( payment_method.balance = round(payment_method.balance, 2) # If refund is made to a gift card, update the balance - old_payment_method = self._get_payment_method( - order.user_id, order.payment_history[0].payment_method_id - ) + old_payment_method = self._get_payment_method(order.user_id, order.payment_history[0].payment_method_id) if isinstance(old_payment_method, GiftCard): old_payment_method.balance += amount old_payment_method.balance = round(old_payment_method.balance, 2) diff --git a/vendor/tau2/domains/telecom/data_model.py b/vendor/tau2/domains/telecom/data_model.py index b5dea830..ffc8b1a0 100644 --- a/vendor/tau2/domains/telecom/data_model.py +++ b/vendor/tau2/domains/telecom/data_model.py @@ -23,9 +23,7 @@ class Plan(BaseModelNoExtra): name: str = Field(description="Display name of the plan") data_limit_gb: float = Field(description="Monthly data allowance in gigabytes (GB)") price_per_month: float = Field(description="Monthly price of the plan in USD") - data_refueling_price_per_gb: float = Field( - description="Price per gigabyte for data refueling" - ) + data_refueling_price_per_gb: float = Field(description="Price per gigabyte for data refueling") class DeviceType(str, Enum): @@ -40,15 +38,9 @@ class Device(BaseModelNoExtra): device_id: str = Field(description="Unique identifier for the device") device_type: DeviceType = Field(description="Type/category of the device") model: str = Field(description="Model name/number of the device") - imei: Optional[str] = Field( - None, description="International Mobile Equipment Identity number" - ) - is_esim_capable: bool = Field( - description="Whether the device supports eSIM technology" - ) - activated: bool = Field( - False, description="Whether the device has been activated on the network" - ) + imei: Optional[str] = Field(None, description="International Mobile Equipment Identity number") + is_esim_capable: bool = Field(description="Whether the device supports eSIM technology") + activated: bool = Field(False, description="Whether the device has been activated on the network") activation_date: Optional[datetime.datetime] = Field( None, description="Date and time when the device was activated (format: YYYY-MM-DDTHH:MM:SS, timezone: EST)", @@ -69,22 +61,12 @@ class LineStatus(str, Enum): class Line(BaseModelNoExtra): line_id: str = Field(description="Unique identifier for the line") phone_number: str = Field(description="Phone number associated with the line") - status: LineStatus = Field( - LineStatus.PENDING_ACTIVATION, description="Current status of the line" - ) + status: LineStatus = Field(LineStatus.PENDING_ACTIVATION, description="Current status of the line") plan_id: str = Field(description="Plan associated with this line") - device_id: Optional[str] = Field( - None, description="Device associated with this line" - ) - data_used_gb: float = Field( - 0.0, description="Data used in the current billing cycle in gigabytes (GB)" - ) - data_refueling_gb: float = Field( - 0.0, description="Data refueled in the current billing cycle in gigabytes (GB)" - ) - roaming_enabled: bool = Field( - False, description="Whether international roaming is enabled for this line" - ) + device_id: Optional[str] = Field(None, description="Device associated with this line") + data_used_gb: float = Field(0.0, description="Data used in the current billing cycle in gigabytes (GB)") + data_refueling_gb: float = Field(0.0, description="Data refueled in the current billing cycle in gigabytes (GB)") + roaming_enabled: bool = Field(False, description="Whether international roaming is enabled for this line") contract_end_date: Optional[datetime.date] = Field( None, description="End date of the current contract, if applicable (format: YYYY-MM-DD, timezone: EST)", @@ -105,15 +87,9 @@ class Line(BaseModelNoExtra): class LineItem(BaseModelNoExtra): description: str = Field(description="Descriptive text for the line item") - amount: float = Field( - description="Monetary amount in USD (positive for charges, negative for credits)" - ) - date: datetime.date = Field( - description="Date the line item was applied (format: YYYY-MM-DD, timezone: EST)" - ) - item_type: str = Field( - description="Category of the line item (e.g., Plan Charge, Overage, Fee, Credit, Payment)" - ) + amount: float = Field(description="Monetary amount in USD (positive for charges, negative for credits)") + date: datetime.date = Field(description="Date the line item was applied (format: YYYY-MM-DD, timezone: EST)") + item_type: str = Field(description="Category of the line item (e.g., Plan Charge, Overage, Fee, Credit, Payment)") class BillStatus(str, Enum): @@ -131,23 +107,17 @@ class Bill(BaseModelNoExtra): period_start: datetime.date = Field( description="Start date of the billing period (format: YYYY-MM-DD, timezone: EST)" ) - period_end: datetime.date = Field( - description="End date of the billing period (format: YYYY-MM-DD, timezone: EST)" - ) + period_end: datetime.date = Field(description="End date of the billing period (format: YYYY-MM-DD, timezone: EST)") issue_date: datetime.date = Field( description="Date the bill was issued/generated (format: YYYY-MM-DD, timezone: EST)" ) total_due: float = Field(description="Total amount due in USD") - due_date: datetime.date = Field( - description="Date by which payment is due (format: YYYY-MM-DD, timezone: EST)" - ) + due_date: datetime.date = Field(description="Date by which payment is due (format: YYYY-MM-DD, timezone: EST)") line_items: List[LineItem] = Field( default_factory=list, description="Individual charges, credits, and payments on this bill", ) - status: BillStatus = Field( - BillStatus.DRAFT, description="Current status of the bill" - ) + status: BillStatus = Field(BillStatus.DRAFT, description="Current status of the bill") class AccountStatus(str, Enum): @@ -165,20 +135,14 @@ class PaymentMethodType(str, Enum): class PaymentMethod(BaseModelNoExtra): method_type: PaymentMethodType = Field(description="Type of payment method") - account_number_last_4: str = Field( - description="Last 4 digits of the account number" - ) - expiration_date: str = Field( - description="The expiration date of the payment method in the format MM/YYYY" - ) + account_number_last_4: str = Field(description="Last 4 digits of the account number") + expiration_date: str = Field(description="The expiration date of the payment method in the format MM/YYYY") class Customer(BaseModelNoExtra): customer_id: str = Field(description="Unique identifier for the customer") full_name: str = Field(description="Customer's full name") - date_of_birth: str = Field( - description="Customer's date of birth for identity verification (format: YYYY-MM-DD)" - ) + date_of_birth: str = Field(description="Customer's date of birth for identity verification (format: YYYY-MM-DD)") email: str = Field(description="Customer's email address") phone_number: str = Field(description="Customer's primary contact phone number") address: Address = Field(description="Customer's billing address") @@ -189,12 +153,8 @@ class Customer(BaseModelNoExtra): payment_methods: List[PaymentMethod] = Field( default_factory=list, description="Stored payment methods for this customer" ) - line_ids: List[str] = Field( - default_factory=list, description="Phone/data lines owned by this customer" - ) - bill_ids: List[str] = Field( - default_factory=list, description="Bills associated with this customer" - ) + line_ids: List[str] = Field(default_factory=list, description="Phone/data lines owned by this customer") + bill_ids: List[str] = Field(default_factory=list, description="Bills associated with this customer") created_at: datetime.datetime = Field( DEFAULT_START_DATE, description="Date and time when the customer account was created (format: YYYY-MM-DDTHH:MM:SS, timezone: EST)", @@ -211,21 +171,11 @@ class Customer(BaseModelNoExtra): class TelecomDB(DB): """Database interface for telecom domain.""" - plans: List[Plan] = Field( - default_factory=list, description="Available service plans" - ) - customers: List[Customer] = Field( - default_factory=list, description="All customers in the system" - ) - lines: List[Line] = Field( - default_factory=list, description="All lines in the system" - ) - bills: List[Bill] = Field( - default_factory=list, description="All bills in the system" - ) - devices: List[Device] = Field( - default_factory=list, description="All devices in the system" - ) + plans: List[Plan] = Field(default_factory=list, description="Available service plans") + customers: List[Customer] = Field(default_factory=list, description="All customers in the system") + lines: List[Line] = Field(default_factory=list, description="All lines in the system") + bills: List[Bill] = Field(default_factory=list, description="All bills in the system") + devices: List[Device] = Field(default_factory=list, description="All devices in the system") def get_statistics(self) -> Dict[str, Any]: """Get the statistics of the database.""" @@ -234,9 +184,7 @@ def get_statistics(self) -> Dict[str, Any]: num_lines = len(self.lines) num_bills = len(self.bills) num_devices = len(self.devices) - num_payment_methods = sum( - len(customer.payment_methods) for customer in self.customers - ) + num_payment_methods = sum(len(customer.payment_methods) for customer in self.customers) return { "num_plans": num_plans, diff --git a/vendor/tau2/domains/telecom/environment.py b/vendor/tau2/domains/telecom/environment.py index 00ab1b9e..5dcbaa79 100644 --- a/vendor/tau2/domains/telecom/environment.py +++ b/vendor/tau2/domains/telecom/environment.py @@ -47,9 +47,7 @@ def sync_tools(self): phone_number = self.user_tools.db.surroundings.phone_number line = self.tools._get_line_by_phone(phone_number) if line is None: - raise ValueError( - f"Wrong scenario, line not found for phone number: {phone_number}" - ) + raise ValueError(f"Wrong scenario, line not found for phone number: {phone_number}") # Check if the line is active if line.status == LineStatus.ACTIVE: self.user_tools.db.surroundings.line_active = True @@ -65,9 +63,7 @@ def sync_tools(self): # Check if the user has exceeded their data usage limit plan = self.tools._get_plan_by_id(line.plan_id) if plan is None: - raise ValueError( - f"Wrong scenario, invalid plan id ({line.plan_id}) for the phone number {phone_number}" - ) + raise ValueError(f"Wrong scenario, invalid plan id ({line.plan_id}) for the phone number {phone_number}") if line.data_used_gb >= plan.data_limit_gb + line.data_refueling_gb: self.user_tools.db.surroundings.mobile_data_usage_exceeded = True else: @@ -82,9 +78,7 @@ def sync_tools(self): # Check if the user has a payment request current_payment_request = self.user_tools.db.surroundings.payment_request - if ( - current_payment_request is None - ): # If there already is a payment request, do nothing + if current_payment_request is None: # If there already is a payment request, do nothing customer = self.tools.get_customer_by_phone(phone_number) bills = self.tools._get_bills_awaiting_payment(customer) if len(bills) != 0: diff --git a/vendor/tau2/domains/telecom/tasks/const.py b/vendor/tau2/domains/telecom/tasks/const.py index 8c41d7dc..555d935d 100644 --- a/vendor/tau2/domains/telecom/tasks/const.py +++ b/vendor/tau2/domains/telecom/tasks/const.py @@ -1,6 +1,6 @@ TOOL_CALL_INFO_CHECK = "If the tool call does not return updated status information, you might need to perform another tool call to get the updated status." TOOL_CALL_GROUNDING = """ -Whenever the agent asks you about your device, always ground your responses on the results of tool calls. +Whenever the agent asks you about your device, always ground your responses on the results of tool calls. For example: If the agent asks what the status bar shows, always ground your response on the results of the `get_status_bar` tool call. If the agent asks if you are able to send an MMS message, always ground your response on the results of the `can_send_mms` tool call. Never make up the results of tool calls, always ground your responses on the results of tool calls. If you are unsure about whether an action is necessary, always ask the agent for clarification. diff --git a/vendor/tau2/domains/telecom/tasks/create_tasks.py b/vendor/tau2/domains/telecom/tasks/create_tasks.py index aeaf6f40..e6a297f8 100644 --- a/vendor/tau2/domains/telecom/tasks/create_tasks.py +++ b/vendor/tau2/domains/telecom/tasks/create_tasks.py @@ -27,7 +27,7 @@ def create_tasks(save_tasks: bool = True, max_count_per_bin: int = 3) -> list[Ta print(f"Number of tasks: {len(tasks)}") - file = DATA_DIR / "domains" / "telecom" / f"tasks_full.json" + file = DATA_DIR / "domains" / "telecom" / "tasks_full.json" if save_tasks: with open(file, "w") as f: json.dump([t.model_dump() for t in tasks], f, indent=2) @@ -50,14 +50,14 @@ def create_tasks(save_tasks: bool = True, max_count_per_bin: int = 3) -> list[Ta } ) - file_small = DATA_DIR / "domains" / "telecom" / f"tasks_small.json" + file_small = DATA_DIR / "domains" / "telecom" / "tasks_small.json" small_tasks = [t["task"] for t in tasks_with_attrs if t["num_subtasks"] == 1] print(f"Number of tasks in small set: {len(small_tasks)}") if save_tasks: with open(file_small, "w") as f: json.dump([t.model_dump() for t in small_tasks], f, indent=2) - file_sampled = DATA_DIR / "domains" / "telecom" / f"tasks.json" + file_sampled = DATA_DIR / "domains" / "telecom" / "tasks.json" tasks_by_bins = defaultdict(list) for task in tasks_with_attrs: if task["num_subtasks"] < 2: # We only keep tasks with at least 2 subtasks diff --git a/vendor/tau2/domains/telecom/tasks/mms_issues.py b/vendor/tau2/domains/telecom/tasks/mms_issues.py index 13319476..313ce272 100644 --- a/vendor/tau2/domains/telecom/tasks/mms_issues.py +++ b/vendor/tau2/domains/telecom/tasks/mms_issues.py @@ -92,9 +92,7 @@ def break_apn_mms_setting(*args, **kwargs) -> list[EnvFunctionCall]: ] -def _get_remove_app_permission_actions( - app_name: str = "messaging", permission: str = "sms" -): +def _get_remove_app_permission_actions(app_name: str = "messaging", permission: str = "sms"): """ Get the remove app permission actions for the mms issue task. """ @@ -116,9 +114,7 @@ def break_app_storage_permission(*args, **kwargs) -> list[EnvFunctionCall]: """ Break the app storage permission for the mms issue task. """ - return [ - _get_remove_app_permission_actions(app_name="messaging", permission="storage") - ] + return [_get_remove_app_permission_actions(app_name="messaging", permission="storage")] def break_app_both_permissions(*args, **kwargs) -> list[EnvFunctionCall]: @@ -163,9 +159,7 @@ def fix_break_apn_mms_setting(*args, **kwargs) -> list[ToolCall]: ] -def _get_grant_app_permission_actions( - app_name: str = "messaging", permission: str = "sms" -) -> ToolCall: +def _get_grant_app_permission_actions(app_name: str = "messaging", permission: str = "sms") -> ToolCall: """ Get the grant app permission actions for the mms issue task. """ @@ -187,9 +181,7 @@ def fix_break_app_storage_permission(*args, **kwargs) -> list[ToolCall]: """ Fix the break app storage permission issue. """ - return [ - _get_grant_app_permission_actions(app_name="messaging", permission="storage") - ] + return [_get_grant_app_permission_actions(app_name="messaging", permission="storage")] def fix_break_app_both_permissions(*args, **kwargs) -> list[ToolCall]: @@ -277,11 +269,7 @@ def fix_break_app_both_permissions(*args, **kwargs) -> list[ToolCall]: app_permission_issues, # Step3.5 ] -selection_sets = ( - service_issues_sample_sets - + mobile_data_issues_sample_sets - + mms_issues_selection_sets -) +selection_sets = service_issues_sample_sets + mobile_data_issues_sample_sets + mms_issues_selection_sets def task_validator(tasks: list[Optional[BaseTask]]): @@ -304,9 +292,7 @@ def task_validator(tasks: list[Optional[BaseTask]]): num_tasks_mms_issues = len( [ task - for task in tasks[ - len(service_issues_sample_sets) + len(mobile_data_issues_sample_sets) : - ] + for task in tasks[len(service_issues_sample_sets) + len(mobile_data_issues_sample_sets) :] if task is not None ] ) diff --git a/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py b/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py index 8e1caa70..b5405d44 100644 --- a/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py +++ b/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py @@ -485,9 +485,7 @@ def assert_data_refueling_amount(env: TelecomEnvironment) -> list[EnvAssertion]: # Path 2.2: Slow Mobile Data # Requires workflow Step 2.2.1 -data_usage_exceeded_issues = SelectionSet( - tasks=[data_usage_exceeded_task, data_usage_exceeded_no_refuel_task] -) +data_usage_exceeded_issues = SelectionSet(tasks=[data_usage_exceeded_task, data_usage_exceeded_no_refuel_task]) # Requires workflow Step 2.2.2 data_saver_mode_issues = SelectionSet(tasks=[data_saver_mode_on_task]) @@ -518,9 +516,7 @@ def task_validator(tasks: list[Optional[BaseTask]]): # num_tasks_service_issues = len( # [task for task in tasks[: len(service_issues_sample_sets)] if task is not None] # ) - num_tasks_mobile_data_issues = len( - [task for task in tasks[len(service_issues_sample_sets) :] if task is not None] - ) + num_tasks_mobile_data_issues = len([task for task in tasks[len(service_issues_sample_sets) :] if task is not None]) return num_tasks_mobile_data_issues > 0 diff --git a/vendor/tau2/domains/telecom/tasks/utils.py b/vendor/tau2/domains/telecom/tasks/utils.py index 73c22b37..360660e5 100644 --- a/vendor/tau2/domains/telecom/tasks/utils.py +++ b/vendor/tau2/domains/telecom/tasks/utils.py @@ -66,9 +66,7 @@ def compose_tasks( Return all the combinations of selecting 0 or more tasks from the selection sets """ - product_tasks = list( - product(*[selection_set.tasks + [None] for selection_set in selection_sets]) - ) + product_tasks = list(product(*[selection_set.tasks + [None] for selection_set in selection_sets])) composed_tasks = [] for tasks in product_tasks: if task_validator is not None: diff --git a/vendor/tau2/domains/telecom/tools.py b/vendor/tau2/domains/telecom/tools.py index 0a7fd05d..c092d35e 100644 --- a/vendor/tau2/domains/telecom/tools.py +++ b/vendor/tau2/domains/telecom/tools.py @@ -102,10 +102,7 @@ def get_customer_by_name(self, full_name: str, dob: str) -> List[Customer]: matching_customers = [] for customer in self.db.customers: - if ( - customer.full_name.lower() == full_name.lower() - and customer.date_of_birth == dob - ): + if customer.full_name.lower() == full_name.lower() and customer.date_of_birth == dob: matching_customers.append(customer) return matching_customers @@ -259,9 +256,7 @@ def get_details_by_id(self, id: str) -> Dict[str, Any]: raise ValueError(f"Unknown ID format or type: {id}") @is_tool(ToolType.WRITE) - def suspend_line( - self, customer_id: str, line_id: str, reason: str - ) -> Dict[str, Any]: + def suspend_line(self, customer_id: str, line_id: str, reason: str) -> Dict[str, Any]: """ Suspends a specific line (max 6 months). Checks: Line status must be Active. @@ -411,9 +406,7 @@ def _set_bill_to_paid(self, bill_id: str) -> None: bill.status = BillStatus.PAID return f"Bill {bill_id} set to paid" - def _apply_one_time_charge( - self, customer_id: str, amount: float, description: str - ) -> None: + def _apply_one_time_charge(self, customer_id: str, amount: float, description: str) -> None: """ Internal function to add a specific charge LineItem to the customer's next bill. Creates a pending bill if none exists. @@ -453,11 +446,7 @@ def _apply_one_time_charge( period_start=next_month, period_end=next_month.replace( month=next_month.month + 1 if next_month.month < 12 else 1, - year=( - next_month.year - if next_month.month < 12 - else next_month.year + 1 - ), + year=(next_month.year if next_month.month < 12 else next_month.year + 1), ) - timedelta(days=1), issue_date=next_month, @@ -501,9 +490,7 @@ def get_data_usage(self, customer_id: str, line_id: str) -> Dict[str, Any]: plan = self._get_plan_by_id(target_line.plan_id) today = get_today() - cycle_end_date = date( - today.year, today.month + 1 if today.month < 12 else 1, 1 - ) - timedelta(days=1) + cycle_end_date = date(today.year, today.month + 1 if today.month < 12 else 1, 1) - timedelta(days=1) return { "line_id": line_id, @@ -513,9 +500,7 @@ def get_data_usage(self, customer_id: str, line_id: str) -> Dict[str, Any]: "cycle_end_date": cycle_end_date, } - def set_data_usage( - self, customer_id: str, line_id: str, data_used_gb: float - ) -> str: + def set_data_usage(self, customer_id: str, line_id: str, data_used_gb: float) -> str: """ Sets the data usage for a line. Note: This method is not decorated as a tool but follows similar error handling. @@ -605,9 +590,7 @@ def transfer_to_human_agents(self, summary: str) -> str: return "Transfer successful" @is_tool(ToolType.WRITE) - def refuel_data( - self, customer_id: str, line_id: str, gb_amount: float - ) -> Dict[str, Any]: + def refuel_data(self, customer_id: str, line_id: str, gb_amount: float) -> Dict[str, Any]: """ Refuels data for a specific line, adding to the customer's bill. Checks: Line status must be Active, Customer owns the line. @@ -646,9 +629,7 @@ def refuel_data( f"Data refueling: {gb_amount} GB at ${plan.data_refueling_price_per_gb}/GB", ) - logger.info( - f"Data refueled for line {line_id}: {gb_amount} GB added, charge: ${charge_amount:.2f}" - ) + logger.info(f"Data refueled for line {line_id}: {gb_amount} GB added, charge: ${charge_amount:.2f}") return { "message": f"Successfully added {gb_amount} GB of data for line {line_id} for ${charge_amount:.2f}", @@ -721,27 +702,21 @@ def suspend_line_for_overdue_bill( return f"Line {line_id} suspended for unpaid bill {new_bill_id}. Contract ended: {contract_ended}" ### Assertions - def assert_data_refueling_amount( - self, customer_id: str, line_id: str, expected_amount: float - ) -> bool: + def assert_data_refueling_amount(self, customer_id: str, line_id: str, expected_amount: float) -> bool: """ Assert that the data refueling amount is as expected. """ target_line = self._get_target_line(customer_id, line_id) return abs(target_line.data_refueling_gb - expected_amount) < 1e-6 - def assert_line_status( - self, customer_id: str, line_id: str, expected_status: LineStatus - ) -> bool: + def assert_line_status(self, customer_id: str, line_id: str, expected_status: LineStatus) -> bool: """ Assert that the line status is as expected. """ target_line = self._get_target_line(customer_id, line_id) return target_line.status == expected_status - def assert_overdue_bill_exists( - self, customer_id: str, overdue_bill_id: str - ) -> bool: + def assert_overdue_bill_exists(self, customer_id: str, overdue_bill_id: str) -> bool: """ Assert that the overdue bill exists. """ diff --git a/vendor/tau2/domains/telecom/user_data_model.py b/vendor/tau2/domains/telecom/user_data_model.py index 77d838f9..218bbf93 100644 --- a/vendor/tau2/domains/telecom/user_data_model.py +++ b/vendor/tau2/domains/telecom/user_data_model.py @@ -99,12 +99,8 @@ def is_mms_basic_configured(self) -> bool: class VpnDetails(BaseModelNoExtra): """Holds details about the VPN connection if active.""" - server_address: Optional[str] = Field( - None, description="Address of the connected VPN server." - ) - protocol: Optional[str] = Field( - None, description="VPN protocol being used (e.g., WireGuard, OpenVPN)." - ) + server_address: Optional[str] = Field(None, description="Address of the connected VPN server.") + protocol: Optional[str] = Field(None, description="VPN protocol being used (e.g., WireGuard, OpenVPN).") server_performance: PerformanceLevel = Field( default=PerformanceLevel.UNKNOWN, validate_default=True, @@ -118,9 +114,7 @@ class AppPermissions(BaseModelNoExtra): sms: bool = Field(False, description="Permission to send/read SMS/MMS.") storage: bool = Field(False, description="Permission to access device storage.") phone: bool = Field(False, description="Permission to make/manage phone calls.") - network: bool = Field( - False, description="Permission to access network state/internet." - ) + network: bool = Field(False, description="Permission to access network state/internet.") class AppStatus(BaseModelNoExtra): @@ -146,22 +140,14 @@ class StatusBar(BaseModelNoExtra): validate_default=True, description="The network technology (2G, 3G, 4G, etc.) shown in the status bar.", ) - wifi_connected: bool = Field( - False, description="Whether WiFi is connected and shown in the status bar." - ) - airplane_mode: bool = Field( - False, description="Whether airplane mode is on and shown in the status bar." - ) - vpn_active: bool = Field( - False, description="Whether a VPN is active and shown in the status bar." - ) + wifi_connected: bool = Field(False, description="Whether WiFi is connected and shown in the status bar.") + airplane_mode: bool = Field(False, description="Whether airplane mode is on and shown in the status bar.") + vpn_active: bool = Field(False, description="Whether a VPN is active and shown in the status bar.") data_saver_active: bool = Field( False, description="Whether data saver mode is active and shown in the status bar.", ) - battery_level: int = Field( - 100, description="The battery level (0-100) shown in the status bar." - ) + battery_level: int = Field(100, description="The battery level (0-100) shown in the status bar.") # --- Main Device State Model --- @@ -201,9 +187,7 @@ class MockPhoneAttributes(BaseModelNoExtra): ) # --- Battery --- - battery_level: int = Field( - 80, description="The current battery level, from 0 to 100 percent." - ) + battery_level: int = Field(80, description="The current battery level, from 0 to 100 percent.") # --- Mobile Data --- data_enabled: bool = Field( @@ -230,9 +214,7 @@ class MockPhoneAttributes(BaseModelNoExtra): False, description="Whether the device is currently connected to a Wi-Fi network.", ) - wifi_ssid: Optional[str] = Field( - None, description="The name (SSID) of the connected Wi-Fi network, if any." - ) + wifi_ssid: Optional[str] = Field(None, description="The name (SSID) of the connected Wi-Fi network, if any.") wifi_signal_strength: SignalStrength = Field( default=SignalStrength.NONE, validate_default=True, @@ -240,9 +222,7 @@ class MockPhoneAttributes(BaseModelNoExtra): ) # --- Calling Features --- - wifi_calling_enabled: bool = Field( - False, description="Whether the Wi-Fi Calling feature is enabled." - ) + wifi_calling_enabled: bool = Field(False, description="Whether the Wi-Fi Calling feature is enabled.") wifi_calling_mms_over_wifi: bool = Field( False, description="Preference/capability to send/receive MMS over Wi-Fi (depends on carrier and device support).", @@ -259,9 +239,7 @@ class MockPhoneAttributes(BaseModelNoExtra): False, description="Whether a VPN profile is configured and potentially set to be 'always on' or manually enabled in settings.", ) - vpn_connected: bool = Field( - False, description="Whether there currently is an active VPN connection tunnel." - ) + vpn_connected: bool = Field(False, description="Whether there currently is an active VPN connection tunnel.") vpn_details: Optional[VpnDetails] = Field( None, description="Details about the active VPN connection, if connected." ) @@ -321,13 +299,9 @@ class UserSurroundings(BaseModelNoExtra): """Represents the physical surroundings of the user.""" name: Optional[str] = Field(None, description="The name of the user.") - phone_number: Optional[str] = Field( - None, description="The phone number of the user." - ) + phone_number: Optional[str] = Field(None, description="The phone number of the user.") is_abroad: bool = Field(False, description="Whether the user is currently abroad.") - roaming_allowed: bool = Field( - False, description="Whether the user is allowed to roam." - ) + roaming_allowed: bool = Field(False, description="Whether the user is allowed to roam.") signal_strength: dict[NetworkTechnology, SignalStrength] = Field( default_factory=lambda: { NetworkTechnology.TWO_G: SignalStrength.POOR, @@ -341,17 +315,13 @@ class UserSurroundings(BaseModelNoExtra): False, description="Whether the user has exceeded their data usage limit." ) line_active: bool = Field(True, description="Whether the user has an active line.") - payment_request: Optional[PaymentRequest] = Field( - None, description="The payment that the agent has requested." - ) + payment_request: Optional[PaymentRequest] = Field(None, description="The payment that the agent has requested.") class TelecomUserDB(DB): """Database interface for telecom domain.""" - device: MockPhoneAttributes = Field( - default_factory=MockPhoneAttributes, description="Mock phone device" - ) + device: MockPhoneAttributes = Field(default_factory=MockPhoneAttributes, description="Mock phone device") surroundings: UserSurroundings = Field( default_factory=UserSurroundings, description="User's physical surroundings" ) @@ -381,24 +351,16 @@ def main(): print("\n--- State after enabling Airplane Mode ---") print(f"Airplane Mode: {db.device.airplane_mode}") print(f"Network Status: {db.device.network_connection_status}") - print( - f"Helper - Potentially Online Mobile: {db.device.is_potentially_online_mobile()}" - ) + print(f"Helper - Potentially Online Mobile: {db.device.is_potentially_online_mobile()}") # 3. Simulate another problem: User disables Mobile Data and has wrong APN MMS URL # Start from default state again for clarity db = TelecomUserDB() update_2 = { "data_enabled": False, - "active_apn_settings": { # Update nested model - "mmsc_url": None # Simulate missing MMS config - }, + "active_apn_settings": {"mmsc_url": None}, # Update nested model # Simulate missing MMS config "app_statuses": { # Update nested dictionary/model - "messaging": { - "permissions": { - "storage": False - } # Update nested AppPermissions model field - } + "messaging": {"permissions": {"storage": False}} # Update nested AppPermissions model field }, } db.update_device(update_2) diff --git a/vendor/tau2/domains/telecom/user_tools.py b/vendor/tau2/domains/telecom/user_tools.py index aedf9fc4..aca28745 100644 --- a/vendor/tau2/domains/telecom/user_tools.py +++ b/vendor/tau2/domains/telecom/user_tools.py @@ -28,9 +28,7 @@ class TelecomUserTools(ToolKitBase): db: TelecomUserDB - network_mode_preference: NetworkModePreference = ( - NetworkModePreference.FOUR_G_5G_PREFERRED - ) + network_mode_preference: NetworkModePreference = NetworkModePreference.FOUR_G_5G_PREFERRED default_vpn_details: VpnDetails = VpnDetails( server_address="192.168.1.1", @@ -100,19 +98,14 @@ def _check_status_bar(self) -> str: SignalStrength.GOOD: "๐Ÿ“ถยณ Good", SignalStrength.EXCELLENT: "๐Ÿ“ถโด Excellent", } - indicators.append( - signal_map.get(device.network_signal_strength, "๐Ÿ“ต No Signal") - ) + indicators.append(signal_map.get(device.network_signal_strength, "๐Ÿ“ต No Signal")) # Network technology if device.network_technology_connected != NetworkTechnology.NONE: indicators.append(device.network_technology_connected.value) # Data enabled indicator - if ( - device.data_enabled - and device.network_technology_connected != NetworkTechnology.NONE - ): + if device.data_enabled and device.network_technology_connected != NetworkTechnology.NONE: indicators.append("๐Ÿ“ฑ Data Enabled") if device.data_saver_mode: indicators.append("๐Ÿ”ฝ Data Saver") @@ -186,9 +179,7 @@ def _check_network_mode_preference(self) -> NetworkModePreference: return self.device.network_mode_preference @is_tool(ToolType.WRITE) - def set_network_mode_preference( - self, mode: Union[NetworkModePreference, str] - ) -> str: + def set_network_mode_preference(self, mode: Union[NetworkModePreference, str]) -> str: """Changes the type of cellular network your phone prefers to connect to (e.g., 5G, LTE/4G, 3G). Higher-speed networks (LTE/5G) provide faster data but may use more battery.""" valid_mode = self._set_network_mode_preference(mode) if valid_mode is None: @@ -196,9 +187,7 @@ def set_network_mode_preference( status_update = f"Preferred Network Mode set to: {valid_mode.value}" return f"{status_update}\nStatus Bar: {self._check_status_bar()}" - def _set_network_mode_preference( - self, mode: Union[NetworkModePreference, str] - ) -> Optional[NetworkModePreference]: + def _set_network_mode_preference(self, mode: Union[NetworkModePreference, str]) -> Optional[NetworkModePreference]: """Sets the preferred network mode. This will trigger a network search. """ @@ -222,10 +211,7 @@ def _get_mobile_data_working(self) -> bool: - Data is not enabled - Data usage is exceeded """ - if ( - self.device.airplane_mode - or self.device.network_signal_strength == SignalStrength.NONE - ): + if self.device.airplane_mode or self.device.network_signal_strength == SignalStrength.NONE: return False if self.device.network_connection_status == NetworkStatus.NO_SERVICE: @@ -255,9 +241,7 @@ def run_speed_test(self) -> str: if description == "Very Poor": advice = "Connection is very slow. Basic web browsing might be difficult." elif description == "Poor": - advice = ( - "Connection is slow. Web browsing may be sluggish, streaming difficult." - ) + advice = "Connection is slow. Web browsing may be sluggish, streaming difficult." elif description == "Fair": advice = "Connection is okay for web browsing and some standard definition streaming." elif description == "Good": @@ -328,9 +312,7 @@ def _run_speed_test(self) -> Tuple[Optional[float], Optional[str]]: NetworkTechnology.FIVE_G: (50.0, 500.0), NetworkTechnology.NONE: (0.0, 0.0), } - min_speed, max_speed = tech_speed_map.get( - self.device.network_technology_connected, (0.0, 0.0) - ) + min_speed, max_speed = tech_speed_map.get(self.device.network_technology_connected, (0.0, 0.0)) # Adjust speed based on signal strength signal_factor_map = { @@ -343,9 +325,7 @@ def _run_speed_test(self) -> Tuple[Optional[float], Optional[str]]: signal_factor = signal_factor_map.get(self.device.network_signal_strength, 0.0) # Calculate simulated speed - simulated_speed = ( - (min_speed + max_speed) / 2.0 * signal_factor * base_speed_factor - ) + simulated_speed = (min_speed + max_speed) / 2.0 * signal_factor * base_speed_factor simulated_speed = round(simulated_speed, 2) # Determine description @@ -611,7 +591,7 @@ def reset_apn_settings(self) -> str: def _reset_apn_settings(self): """Resets your APN settings to the default settings. This will be applied at the next reboot.""" self.device.active_apn_settings.reset_at_reboot = True - return f"APN settings will reset at reboot." + return "APN settings will reset at reboot." def break_apn_settings(self) -> str: """Breaks the APN settings. This is fixed by calling reset_apn_settings().""" @@ -634,7 +614,9 @@ def check_wifi_status(self) -> str: if not status["enabled"]: return "Wi-Fi is turned OFF." if status["connected"]: - return f"Wi-Fi is ON and connected to '{status['ssid']}'. Signal strength: {status['signal_strength'].value}." + return ( + f"Wi-Fi is ON and connected to '{status['ssid']}'. Signal strength: {status['signal_strength'].value}." + ) else: return "Wi-Fi is ON but not connected to any network." @@ -702,9 +684,7 @@ def _toggle_wifi_calling(self) -> bool: self.device.wifi_calling_enabled = new_state return new_state - def set_wifi_calling( - self, enabled: bool, mms_over_wifi: Optional[bool] = None - ) -> str: + def set_wifi_calling(self, enabled: bool, mms_over_wifi: Optional[bool] = None) -> str: """Set the Wi-Fi Calling setting. Set MMS over WIFI accordingly if provided.""" if self.device.wifi_calling_enabled != enabled: self._toggle_wifi_calling() @@ -736,9 +716,7 @@ def _check_vpn_status(self) -> Dict[str, Any]: "enabled_setting": self.device.vpn_enabled_setting, "connected": self.device.vpn_connected, "details": ( - self.device.vpn_details.model_dump() - if self.device.vpn_details and self.device.vpn_connected - else None + self.device.vpn_details.model_dump() if self.device.vpn_details and self.device.vpn_connected else None ), } @@ -748,11 +726,7 @@ def connect_vpn(self) -> str: connected = self._connect_vpn() if connected is None: return "VPN already connected." - status_update = ( - "VPN connected successfully." - if connected - else "No VPN connection to connect." - ) + status_update = "VPN connected successfully." if connected else "No VPN connection to connect." return f"{status_update}\nStatus Bar: {self._check_status_bar()}" def _connect_vpn(self) -> Optional[bool]: @@ -769,11 +743,7 @@ def _connect_vpn(self) -> Optional[bool]: def disconnect_vpn(self) -> str: """Disconnects any active VPN (Virtual Private Network) connection. Stops routing your internet traffic through a VPN server, which might affect connection speed or access to content.""" disconnected = self._disconnect_vpn() - status_update = ( - "VPN disconnected successfully." - if disconnected - else "No active VPN connection to disconnect." - ) + status_update = "VPN disconnected successfully." if disconnected else "No active VPN connection to disconnect." return f"{status_update}\nStatus Bar: {self._check_status_bar()}" def _disconnect_vpn(self) -> bool: @@ -975,46 +945,34 @@ def simulate_network_search(self): self.device.network_connection_status = NetworkStatus.CONNECTED pref = self.device.network_mode_preference if pref == NetworkModePreference.FOUR_G_5G_PREFERRED: - five_g_signal = self.surroundings.signal_strength.get( - NetworkTechnology.FIVE_G, SignalStrength.NONE - ) + five_g_signal = self.surroundings.signal_strength.get(NetworkTechnology.FIVE_G, SignalStrength.NONE) if five_g_signal == SignalStrength.NONE: self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) else: self.device.network_technology_connected = NetworkTechnology.FIVE_G self.device.network_signal_strength = five_g_signal elif pref == NetworkModePreference.FOUR_G_ONLY: self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) elif pref == NetworkModePreference.THREE_G_ONLY: self.device.network_technology_connected = NetworkTechnology.THREE_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.THREE_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.THREE_G, SignalStrength.NONE ) elif pref == NetworkModePreference.TWO_G_ONLY: self.device.network_technology_connected = NetworkTechnology.TWO_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.TWO_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.TWO_G, SignalStrength.NONE ) else: # Default fallback self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) elif sim_status in [SimStatus.MISSING]: @@ -1120,9 +1078,7 @@ def assert_mobile_data_saver_mode_status(self, expected_status: bool) -> bool: """ return self.device.data_saver_mode == expected_status - def assert_internet_speed( - self, expected_speed: float, expected_desc: Optional[str] = None - ) -> bool: + def assert_internet_speed(self, expected_speed: float, expected_desc: Optional[str] = None) -> bool: """ Assert that the internet speed is as expected. """ diff --git a/vendor/tau2/environment/environment.py b/vendor/tau2/environment/environment.py index 07799bb7..2a69c176 100644 --- a/vendor/tau2/environment/environment.py +++ b/vendor/tau2/environment/environment.py @@ -258,9 +258,9 @@ def set_state( Set the state of the environment given initialization data and a list of messages. """ if self.solo_mode: - assert all( - [not isinstance(message, UserMessage) for message in message_history] - ), "User messages are not allowed in solo mode" + assert all([not isinstance(message, UserMessage) for message in message_history]), ( + "User messages are not allowed in solo mode" + ) def get_actions_from_messages( messages: list[Message], diff --git a/vendor/tau2/environment/server.py b/vendor/tau2/environment/server.py index 8faddd13..cce3fb2b 100644 --- a/vendor/tau2/environment/server.py +++ b/vendor/tau2/environment/server.py @@ -85,7 +85,8 @@ def _format_description(self, policy: str) -> str: description.append(content) # Add the tools section - description.append(""" + description.append( + """ ## Tools @@ -98,7 +99,8 @@ def _format_description(self, policy: str) -> str: ### Response Format All successful responses will return the tool's output directly. Errors will return a 400 status code with an error message. -""") +""" + ) return "\n".join(description) @@ -161,20 +163,14 @@ async def tool_endpoint( ) -> Any: try: if route_prefix == "user_tools": - result = self.environment.use_user_tool( - tool_name=tool_name, **request.model_dump() - ) + result = self.environment.use_user_tool(tool_name=tool_name, **request.model_dump()) else: - result = self.environment.use_tool( - tool_name=tool_name, **request.model_dump() - ) + result = self.environment.use_tool(tool_name=tool_name, **request.model_dump()) return result except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - def _format_tool_description( - self, doc: str, returns: Optional[dict] = None, is_user_tool: bool = False - ) -> str: + def _format_tool_description(self, doc: str, returns: Optional[dict] = None, is_user_tool: bool = False) -> str: """Format tool documentation for better ReDoc rendering""" import re diff --git a/vendor/tau2/environment/tool.py b/vendor/tau2/environment/tool.py index cf55463c..fa63bdcb 100644 --- a/vendor/tau2/environment/tool.py +++ b/vendor/tau2/environment/tool.py @@ -49,9 +49,7 @@ class Tool(BaseTool): """The parameters of the Tool.""" returns: type[BaseModel] = Field(..., description="The return of the Tool") """The return of the Tool.""" - raises: List[Dict[str, Optional[str]]] = Field( - [], description="The exceptions raised by the Tool" - ) + raises: List[Dict[str, Optional[str]]] = Field([], description="The exceptions raised by the Tool") """The exceptions raised by the Tool.""" examples: List[str] = Field([], description="The examples of the Tool") """The examples of the Tool.""" @@ -79,9 +77,7 @@ def __init__(self, func: Callable, use_short_desc: bool = False, **predefined: A self.__doc__ = doc # overwrite the doc string @classmethod - def parse_data( - cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any] - ) -> Dict[str, Any]: + def parse_data(cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any]) -> Dict[str, Any]: """Parse data from the signature and docstring of a function.""" doc = parse(docstring or "") data: Dict[str, Any] = { @@ -127,9 +123,7 @@ def parse_data( data["returns"] = create_model("returns", returns=(anno, default)) # build raises - data["raises"] = [ - {"type": exc.type_name, "desc": exc.description} for exc in doc.raises - ] + data["raises"] = [{"type": exc.type_name, "desc": exc.description} for exc in doc.raises] # build examples data["examples"] = doc.examples diff --git a/vendor/tau2/environment/toolkit.py b/vendor/tau2/environment/toolkit.py index fff04d60..d21923e2 100644 --- a/vendor/tau2/environment/toolkit.py +++ b/vendor/tau2/environment/toolkit.py @@ -102,18 +102,10 @@ def tool_type(self, tool_name: str) -> ToolType: def get_statistics(self) -> dict[str, Any]: """Get the statistics of the ToolKit.""" num_tools = len(self.tools) - num_read_tools = sum( - self.tool_type(name) == ToolType.READ for name in self.tools - ) - num_write_tools = sum( - self.tool_type(name) == ToolType.WRITE for name in self.tools - ) - num_think_tools = sum( - self.tool_type(name) == ToolType.THINK for name in self.tools - ) - num_generic_tools = sum( - self.tool_type(name) == ToolType.GENERIC for name in self.tools - ) + num_read_tools = sum(self.tool_type(name) == ToolType.READ for name in self.tools) + num_write_tools = sum(self.tool_type(name) == ToolType.WRITE for name in self.tools) + num_think_tools = sum(self.tool_type(name) == ToolType.THINK for name in self.tools) + num_generic_tools = sum(self.tool_type(name) == ToolType.GENERIC for name in self.tools) return { "num_tools": num_tools, "num_read_tools": num_read_tools, diff --git a/vendor/tau2/environment/utils/interface_agent.py b/vendor/tau2/environment/utils/interface_agent.py index 015b9cd2..773bbfb9 100644 --- a/vendor/tau2/environment/utils/interface_agent.py +++ b/vendor/tau2/environment/utils/interface_agent.py @@ -216,15 +216,11 @@ def get_prompt_text() -> str: if message == ":n": console.print("[info]Starting new session...[/]") - interface_agent, message_history = init_session( - current_domain - ) + interface_agent, message_history = init_session(current_domain) continue with console.status("[info]Processing query...[/]"): - response, message_history = interface_agent.respond( - message, message_history - ) + response, message_history = interface_agent.respond(message, message_history) # Try to parse response as markdown for better formatting try: @@ -232,9 +228,7 @@ def get_prompt_text() -> str: console.print("\n[bold]Response:[/]") console.print(md) except Exception as e: - console.print( - f"\n[error]Error parsing response:[/] {str(e)}" - ) + console.print(f"\n[error]Error parsing response:[/] {str(e)}") console.print("\n[bold]Response:[/]", response.content) except KeyboardInterrupt: @@ -244,9 +238,7 @@ def get_prompt_text() -> str: console.print(f"\n[error]Error processing message:[/] {str(e)}") except Exception as e: - console.print( - f"\n[error]Error initializing domain '{current_domain}':[/] {str(e)}" - ) + console.print(f"\n[error]Error initializing domain '{current_domain}':[/] {str(e)}") new_domain = change_domain(console) if new_domain is None: return diff --git a/vendor/tau2/evaluator/__init__.py b/vendor/tau2/evaluator/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/evaluator/__init__.py +++ b/vendor/tau2/evaluator/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/evaluator/evaluator.py b/vendor/tau2/evaluator/evaluator.py index fa3b4791..f206c3dc 100644 --- a/vendor/tau2/evaluator/evaluator.py +++ b/vendor/tau2/evaluator/evaluator.py @@ -33,9 +33,7 @@ def evaluate_simulation( }: return RewardInfo( reward=0.0, - info={ - "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}" - }, + info={"note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}"}, ) if task.evaluation_criteria is None: return RewardInfo( diff --git a/vendor/tau2/evaluator/evaluator_action.py b/vendor/tau2/evaluator/evaluator_action.py index 518475ba..3576ce2e 100644 --- a/vendor/tau2/evaluator/evaluator_action.py +++ b/vendor/tau2/evaluator/evaluator_action.py @@ -59,10 +59,7 @@ def evaluate_actions( predicted_tool_calls: list[ToolCall] = [] for message in full_trajectory: - if ( - isinstance(message, AssistantMessage) - or isinstance(message, UserMessage) - ) and message.is_tool_call(): + if (isinstance(message, AssistantMessage) or isinstance(message, UserMessage)) and message.is_tool_call(): predicted_tool_calls.extend(message.tool_calls) # Check if all the gold actions are in the predicted actions diff --git a/vendor/tau2/evaluator/evaluator_communicate.py b/vendor/tau2/evaluator/evaluator_communicate.py index 43eecebf..8eada207 100644 --- a/vendor/tau2/evaluator/evaluator_communicate.py +++ b/vendor/tau2/evaluator/evaluator_communicate.py @@ -32,9 +32,7 @@ def calculate_reward( reward_breakdown={RewardType.COMMUNICATE: 1.0}, ) - communicate_info_checks = cls.evaluate_communicate_info( - full_trajectory, communicate_info - ) + communicate_info_checks = cls.evaluate_communicate_info(full_trajectory, communicate_info) # Calculate reward: 1 if all expectations are met, 0 otherwise all_expectations_met = all(result.met for result in communicate_info_checks) @@ -66,9 +64,7 @@ def evaluate_communicate_info( continue if not message.has_text_content(): continue - if info_str.lower() in message.content.lower().replace( - ",", "" - ): # TODO: This could be improved! + if info_str.lower() in message.content.lower().replace(",", ""): # TODO: This could be improved! found = True break if found: diff --git a/vendor/tau2/evaluator/evaluator_env.py b/vendor/tau2/evaluator/evaluator_env.py index 46c3e296..d2b724a4 100644 --- a/vendor/tau2/evaluator/evaluator_env.py +++ b/vendor/tau2/evaluator/evaluator_env.py @@ -49,24 +49,15 @@ def calculate_reward( ) initialization_data = None - if ( - task.initial_state is not None - and task.initial_state.initialization_data is not None - ): + if task.initial_state is not None and task.initial_state.initialization_data is not None: initialization_data = task.initial_state.initialization_data initialization_actions = None - if ( - task.initial_state is not None - and task.initial_state.initialization_actions is not None - ): + if task.initial_state is not None and task.initial_state.initialization_actions is not None: initialization_actions = task.initial_state.initialization_actions message_history = [] - if ( - task.initial_state is not None - and task.initial_state.message_history is not None - ): + if task.initial_state is not None and task.initial_state.message_history is not None: message_history = task.initial_state.message_history predicted_environment = environment_constructor(solo_mode=solo_mode) @@ -77,10 +68,7 @@ def calculate_reward( ) predicted_tool_calls: list[ToolCall] = [] for message in full_trajectory: - if ( - isinstance(message, AssistantMessage) - or isinstance(message, UserMessage) - ) and message.is_tool_call(): + if (isinstance(message, AssistantMessage) or isinstance(message, UserMessage)) and message.is_tool_call(): predicted_tool_calls.extend(message.tool_calls) # Setting up gold environment @@ -99,9 +87,7 @@ def calculate_reward( **action.arguments, ) except Exception as e: - logger.warning( - f"Error in golden actions {action.name}({action.arguments}): {e}" - ) + logger.warning(f"Error in golden actions {action.name}({action.arguments}): {e}") # Comparing the environments agent_db_hash = gold_environment.get_db_hash() diff --git a/vendor/tau2/metrics/agent_metrics.py b/vendor/tau2/metrics/agent_metrics.py index 3192e698..8f8c30b8 100644 --- a/vendor/tau2/metrics/agent_metrics.py +++ b/vendor/tau2/metrics/agent_metrics.py @@ -55,9 +55,7 @@ def get_metrics_df(results: Results) -> tuple[pd.DataFrame, int]: df = results.to_df() df["success"] = df.reward.apply(is_successful) if len(df.info_num_trials.unique()) > 1: - logger.warning( - f"All simulations must have the same number of trials. Found {df.info_num_trials.unique()}" - ) + logger.warning(f"All simulations must have the same number of trials. Found {df.info_num_trials.unique()}") max_k = df.info_num_trials.max() task_ids_counts = [(tid, count) for tid, count in df.task_id.value_counts().items()] @@ -78,9 +76,7 @@ def get_tasks_pass_hat_k(results: Results) -> pd.DataFrame: df, max_k = get_metrics_df(results) dfs = [] for k in range(1, max_k + 1): - res = df.groupby("task_id")["success"].apply( - lambda df: pass_hat_k(len(df), df.sum(), k) - ) + res = df.groupby("task_id")["success"].apply(lambda df: pass_hat_k(len(df), df.sum(), k)) res.name = f"pass^{k}" dfs.append(res) df_pass_hat_k = pd.concat(dfs, axis=1) diff --git a/vendor/tau2/metrics/break_down_metrics.py b/vendor/tau2/metrics/break_down_metrics.py index 3b6e9571..3ecf4e35 100644 --- a/vendor/tau2/metrics/break_down_metrics.py +++ b/vendor/tau2/metrics/break_down_metrics.py @@ -24,9 +24,7 @@ def get_write_tools(domain): return set(agent_write_tools), set(user_write_tools) -def analyze_reward( - reward_info: RewardInfo, agent_write_tools: set[str], user_write_tools: set[str] -): +def analyze_reward(reward_info: RewardInfo, agent_write_tools: set[str], user_write_tools: set[str]): """ Analyze the reward breakdown. """ @@ -34,26 +32,18 @@ def analyze_reward( try: if RewardType.COMMUNICATE in reward_info.reward_basis: communicate_success = ( - is_successful(reward_breakdown[RewardType.COMMUNICATE]) - if reward_breakdown is not None - else 0 + is_successful(reward_breakdown[RewardType.COMMUNICATE]) if reward_breakdown is not None else 0 ) else: communicate_success = None if RewardType.ENV_ASSERTION in reward_info.reward_basis: env_success = ( - is_successful(reward_breakdown[RewardType.ENV_ASSERTION]) - if reward_breakdown is not None - else 0 + is_successful(reward_breakdown[RewardType.ENV_ASSERTION]) if reward_breakdown is not None else 0 ) else: env_success = None if RewardType.DB in reward_info.reward_basis: - db_success = ( - is_successful(reward_breakdown[RewardType.DB]) - if reward_breakdown is not None - else 0 - ) + db_success = is_successful(reward_breakdown[RewardType.DB]) if reward_breakdown is not None else 0 else: db_success = None except Exception as e: @@ -110,13 +100,9 @@ def result_reward_analysis(results: Results): Analyze the reward breakdown. """ rows = [] - agent_write_tools, user_write_tools = get_write_tools( - results.info.environment_info.domain_name - ) + agent_write_tools, user_write_tools = get_write_tools(results.info.environment_info.domain_name) for simulation in results.simulations: - reward_analysis = analyze_reward( - simulation.reward_info, agent_write_tools, user_write_tools - ) + reward_analysis = analyze_reward(simulation.reward_info, agent_write_tools, user_write_tools) reward_analysis["task_id"] = simulation.task_id reward_analysis["trial"] = simulation.trial rows.append(reward_analysis) diff --git a/vendor/tau2/orchestrator/environment_manager.py b/vendor/tau2/orchestrator/environment_manager.py index 1b5b7bc9..f34d0758 100644 --- a/vendor/tau2/orchestrator/environment_manager.py +++ b/vendor/tau2/orchestrator/environment_manager.py @@ -145,9 +145,7 @@ async def status(): @self.app.post("/start_environment") async def start_env(request: StartEnvironmentRequest) -> EnvironmentResponse: - env_id = self.start_environment( - domain=request.domain, env_id=request.env_id - ) + env_id = self.start_environment(domain=request.domain, env_id=request.env_id) return EnvironmentResponse(env_id=env_id) @self.app.post("/{env_id}/set_state") @@ -169,9 +167,7 @@ async def get_info(env_id: str) -> EnvironmentInfo: return self.get_environment_info(env_id) @self.app.post("/{env_id}/tools/{tool_name}") - async def execute_tool( - env_id: str, tool_name: str, request: ToolCall - ) -> ToolMessage: + async def execute_tool(env_id: str, tool_name: str, request: ToolCall) -> ToolMessage: return self.execute_tool(env_id=env_id, tool_call=request) def get_environment_id(self) -> str: @@ -210,12 +206,8 @@ def set_environment_state( Set the state of an environment. """ - self.environments[env_id].set_state( - initialization_data, initialization_actions, message_history - ) - self.trajectories[env_id] = [ - msg for msg in message_history if is_valid_environment_message(msg) - ] + self.environments[env_id].set_state(initialization_data, initialization_actions, message_history) + self.trajectories[env_id] = [msg for msg in message_history if is_valid_environment_message(msg)] def stop_environment(self, env_id: str): """ @@ -225,9 +217,7 @@ def stop_environment(self, env_id: str): # Get the router instance router = self.app.router # Filter out the routes we want to remove - router.routes = [ - route for route in router.routes if route not in self.routes[env_id] - ] + router.routes = [route for route in router.routes if route not in self.routes[env_id]] del self.routes[env_id] if env_id in self.environments: diff --git a/vendor/tau2/orchestrator/orchestrator.py b/vendor/tau2/orchestrator/orchestrator.py index 172519c1..ebd4252a 100644 --- a/vendor/tau2/orchestrator/orchestrator.py +++ b/vendor/tau2/orchestrator/orchestrator.py @@ -31,9 +31,7 @@ class Role(str, Enum): ENV = "env" -DEFAULT_FIRST_AGENT_MESSAGE = AssistantMessage( - role="assistant", content="Hi! How can I help you today?", cost=0.0 -) +DEFAULT_FIRST_AGENT_MESSAGE = AssistantMessage(role="assistant", content="Hi! How can I help you today?", cost=0.0) class Orchestrator: @@ -82,12 +80,8 @@ def initialize(self): - Send the first message (default message from the agent to the user). """ initial_state = self.task.initial_state - initialization_data = ( - initial_state.initialization_data if initial_state is not None else None - ) - initialization_actions = ( - initial_state.initialization_actions if initial_state is not None else None - ) + initialization_data = initial_state.initialization_data if initial_state is not None else None + initialization_actions = initial_state.initialization_actions if initial_state is not None else None message_history = ( deepcopy(initial_state.message_history) if initial_state is not None and initial_state.message_history is not None @@ -101,12 +95,8 @@ def initialize(self): if self.solo_mode: assert self.environment.solo_mode, "Environment should be in solo mode" - assert isinstance(self.agent, LLMSoloAgent), ( - "Agent must be a LLMSoloAgent in solo mode" - ) - assert isinstance(self.user, DummyUser), ( - "User must be a DummyUser in solo mode" - ) + assert isinstance(self.agent, LLMSoloAgent), "Agent must be a LLMSoloAgent in solo mode" + assert isinstance(self.user, DummyUser), "User must be a DummyUser in solo mode" # Initialize Environment state self._initialize_environment( @@ -133,18 +123,10 @@ def initialize(self): else: # Last message is for the environment self.to_role = Role.ENV self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_user_history_message(msg)] ) self.message = last_message if self.agent.is_stop(last_message): @@ -158,18 +140,10 @@ def initialize(self): else: # Last message is for the environment self.to_role = Role.ENV self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_user_history_message(msg)] ) self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_agent_history_message(msg)] ) self.message = last_message self.done = UserSimulator.is_stop(last_message) @@ -181,34 +155,18 @@ def initialize(self): if last_message.requestor == "assistant": self.to_role = Role.AGENT self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_user_history_message(msg)] ) else: self.to_role = Role.USER self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_user_history_message(msg)] ) self.message = last_message else: @@ -228,9 +186,7 @@ def initialize(self): self.from_role = Role.AGENT self.to_role = Role.USER else: - first_message, agent_state = self.agent.generate_next_message( - None, self.agent_state - ) + first_message, agent_state = self.agent.generate_next_message(None, self.agent_state) self.trajectory = [first_message] self.message = first_message self.from_role = Role.AGENT @@ -290,17 +246,13 @@ def step(self): """ if self.done: raise ValueError("Simulation is done") - logger.debug( - f"Step {self.step_count}. Sending message from {self.from_role} to {self.to_role}" - ) + logger.debug(f"Step {self.step_count}. Sending message from {self.from_role} to {self.to_role}") logger.debug( f"Step {self.step_count}.\nFrom role: {self.from_role}\nTo role: {self.to_role}\nMessage: {self.message}" ) # AGENT/ENV -> USER if self.from_role in [Role.AGENT, Role.ENV] and self.to_role == Role.USER: - user_msg, self.user_state = self.user.generate_next_message( - self.message, self.user_state - ) + user_msg, self.user_state = self.user.generate_next_message(self.message, self.user_state) user_msg.validate() if UserSimulator.is_stop(user_msg): self.done = True @@ -313,12 +265,8 @@ def step(self): else: self.to_role = Role.AGENT # USER/ENV -> AGENT - elif ( - self.from_role == Role.USER or self.from_role == Role.ENV - ) and self.to_role == Role.AGENT: - agent_msg, self.agent_state = self.agent.generate_next_message( - self.message, self.agent_state - ) + elif (self.from_role == Role.USER or self.from_role == Role.ENV) and self.to_role == Role.AGENT: + agent_msg, self.agent_state = self.agent.generate_next_message(self.message, self.agent_state) agent_msg.validate() if self.agent.is_stop(agent_msg): self.done = True @@ -342,9 +290,7 @@ def step(self): "Number of tool calls and tool messages should be the same" ) self.trajectory.extend(tool_msgs) - if ( - len(tool_msgs) > 1 - ): # Packaging multiple tool messages into a MultiToolMessage + if len(tool_msgs) > 1: # Packaging multiple tool messages into a MultiToolMessage self.message = MultiToolMessage( role="tool", tool_messages=tool_msgs, @@ -354,9 +300,7 @@ def step(self): self.to_role = self.from_role self.from_role = Role.ENV else: - raise ValueError( - f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}" - ) + raise ValueError(f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}") self.step_count += 1 self.environment.sync_tools() @@ -403,9 +347,7 @@ def validate_message_history(cls, message_history: list[Message]): if num_expected_tool_messages == 0 or requestor is None: raise ValueError("No tool messages expected.") if requestor != msg.requestor: - raise ValueError( - f"Got tool message from {msg.requestor}, expected {requestor}." - ) + raise ValueError(f"Got tool message from {msg.requestor}, expected {requestor}.") num_expected_tool_messages -= 1 else: raise ValueError(f"Invalid message type: {type(msg)}") @@ -435,13 +377,9 @@ def _count_errors(self, message_history: list[Message]) -> int: """ Count the number of errors in the message history. """ - return sum( - 1 for msg in message_history if isinstance(msg, ToolMessage) and msg.error - ) + return sum(1 for msg in message_history if isinstance(msg, ToolMessage) and msg.error) - def _add_timestamps( - self, message_history: list[Message] - ) -> list[tuple[str, Message]]: + def _add_timestamps(self, message_history: list[Message]) -> list[tuple[str, Message]]: """ Add timestamps to the message history. This is used to sort the messages by timestamp. diff --git a/vendor/tau2/orchestrator/utils.py b/vendor/tau2/orchestrator/utils.py index 5b119813..d15b6435 100644 --- a/vendor/tau2/orchestrator/utils.py +++ b/vendor/tau2/orchestrator/utils.py @@ -5,6 +5,4 @@ def is_valid_environment_message(msg: Message) -> bool: """ Check if the message is valid to the environment. """ - return isinstance(msg, ToolMessage) or ( - isinstance(msg, AssistantMessage) and msg.is_tool_call() - ) + return isinstance(msg, ToolMessage) or (isinstance(msg, AssistantMessage) and msg.is_tool_call()) diff --git a/vendor/tau2/registry.py b/vendor/tau2/registry.py index 764b917c..fbdf45c0 100644 --- a/vendor/tau2/registry.py +++ b/vendor/tau2/registry.py @@ -7,29 +7,25 @@ from vendor.tau2.agent.base import BaseAgent from vendor.tau2.agent.llm_agent import LLMAgent, LLMGTAgent, LLMSoloAgent from vendor.tau2.data_model.tasks import Task -from vendor.tau2.domains.airline.environment import \ - get_environment as airline_domain_get_environment -from vendor.tau2.domains.airline.environment import \ - get_tasks as airline_domain_get_tasks -from vendor.tau2.domains.mock.environment import \ - get_environment as mock_domain_get_environment -from vendor.tau2.domains.mock.environment import get_tasks as mock_domain_get_tasks -from vendor.tau2.domains.retail.environment import \ - get_environment as retail_domain_get_environment -from vendor.tau2.domains.retail.environment import \ - get_tasks as retail_domain_get_tasks -from vendor.tau2.domains.telecom.environment import \ - get_environment_manual_policy as \ - telecom_domain_get_environment_manual_policy -from vendor.tau2.domains.telecom.environment import \ - get_environment_workflow_policy as \ - telecom_domain_get_environment_workflow_policy -from vendor.tau2.domains.telecom.environment import \ - get_tasks as telecom_domain_get_tasks -from vendor.tau2.domains.telecom.environment import \ - get_tasks_full as telecom_domain_get_tasks_full -from vendor.tau2.domains.telecom.environment import \ - get_tasks_small as telecom_domain_get_tasks_small +from vendor.tau2.domains.airline.environment import ( + get_environment as airline_domain_get_environment, + get_tasks as airline_domain_get_tasks, +) +from vendor.tau2.domains.mock.environment import ( + get_environment as mock_domain_get_environment, + get_tasks as mock_domain_get_tasks, +) +from vendor.tau2.domains.retail.environment import ( + get_environment as retail_domain_get_environment, + get_tasks as retail_domain_get_tasks, +) +from vendor.tau2.domains.telecom.environment import ( + get_environment_manual_policy as telecom_domain_get_environment_manual_policy, + get_environment_workflow_policy as telecom_domain_get_environment_workflow_policy, + get_tasks as telecom_domain_get_tasks, + get_tasks_full as telecom_domain_get_tasks_full, + get_tasks_small as telecom_domain_get_tasks_small, +) from vendor.tau2.environment.environment import Environment from vendor.tau2.user.base import BaseUser from vendor.tau2.user.user_simulator import DummyUser, UserSimulator @@ -184,13 +180,13 @@ def get_info(self) -> RegistryInfo: registry.register_domain(retail_domain_get_environment, "retail") registry.register_tasks(retail_domain_get_tasks, "retail") registry.register_domain(telecom_domain_get_environment_manual_policy, "telecom") - registry.register_domain( - telecom_domain_get_environment_workflow_policy, "telecom-workflow" - ) + registry.register_domain(telecom_domain_get_environment_workflow_policy, "telecom-workflow") registry.register_tasks(telecom_domain_get_tasks_full, "telecom_full") registry.register_tasks(telecom_domain_get_tasks_small, "telecom_small") registry.register_tasks(telecom_domain_get_tasks, "telecom") registry.register_tasks(telecom_domain_get_tasks, "telecom-workflow") - logger.debug(f"Default components registered successfully. Registry info: {json.dumps(registry.get_info().model_dump(), indent=2)}") + logger.debug( + f"Default components registered successfully. Registry info: {json.dumps(registry.get_info().model_dump(), indent=2)}" + ) except Exception as e: logger.error(f"Error initializing registry: {str(e)}") diff --git a/vendor/tau2/run.py b/vendor/tau2/run.py index c2813d34..d49521ab 100644 --- a/vendor/tau2/run.py +++ b/vendor/tau2/run.py @@ -8,8 +8,7 @@ from loguru import logger from vendor.tau2.agent.llm_agent import LLMAgent, LLMGTAgent, LLMSoloAgent -from vendor.tau2.data_model.simulation import (AgentInfo, Info, Results, RunConfig, - SimulationRun, UserInfo) +from vendor.tau2.data_model.simulation import AgentInfo, Info, Results, RunConfig, SimulationRun, UserInfo from vendor.tau2.data_model.tasks import Task from vendor.tau2.environment.environment import Environment, EnvironmentInfo from vendor.tau2.evaluator.evaluator import EvaluationType, evaluate_simulation @@ -29,9 +28,7 @@ def get_options() -> RegistryInfo: return registry.get_info() -def get_environment_info( - domain_name: str, include_tool_info: bool = False -) -> EnvironmentInfo: +def get_environment_info(domain_name: str, include_tool_info: bool = False) -> EnvironmentInfo: """Get information about the environment for a registered Domain""" global registry env_constructor = registry.get_env_constructor(domain_name) @@ -59,14 +56,10 @@ def get_tasks( if task_ids is None: tasks = load_tasks(task_set_name=task_set_name) else: - tasks = [ - task for task in load_tasks(task_set_name=task_set_name) if task.id in task_ids - ] + tasks = [task for task in load_tasks(task_set_name=task_set_name) if task.id in task_ids] if task_ids is not None and len(tasks) != len(task_ids): missing_tasks = set(task_ids) - set([task.id for task in tasks]) - raise ValueError( - f"Not all tasks were found for task set {task_set_name}: {missing_tasks}" - ) + raise ValueError(f"Not all tasks were found for task set {task_set_name}: {missing_tasks}") if num_tasks is not None: tasks = tasks[:num_tasks] return tasks @@ -100,13 +93,17 @@ def run_domain(config: RunConfig) -> Results: total_num_tasks = len(tasks) tasks = [task for task in tasks if LLMGTAgent.check_valid_task(task)] num_tasks = len(tasks) - console_text = Text(text=f"Running {num_tasks} out of {total_num_tasks} tasks for GT agent.", style="bold green") + console_text = Text( + text=f"Running {num_tasks} out of {total_num_tasks} tasks for GT agent.", style="bold green" + ) ConsoleDisplay.console.print(console_text) if "solo" in config.agent: total_num_tasks = len(tasks) tasks = [task for task in tasks if LLMSoloAgent.check_valid_task(task)] num_tasks = len(tasks) - console_text = Text(text=f"Running {num_tasks} out of {total_num_tasks} tasks for solo agent.", style="bold green") + console_text = Text( + text=f"Running {num_tasks} out of {total_num_tasks} tasks for solo agent.", style="bold green" + ) ConsoleDisplay.console.print(console_text) num_trials = config.num_trials @@ -244,9 +241,7 @@ def run_tasks( with open(save_to, "r") as fp: prev_simulation_results = Results.model_validate_json(fp.read()) # Check if the run config has changed - if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash( - simulation_results.info - ): + if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash(simulation_results.info): diff = show_dict_diff( prev_simulation_results.info.model_dump(), simulation_results.info.model_dump(), @@ -279,14 +274,12 @@ def run_tasks( "The task set has changed. Please delete the existing file or use a different save_to name." ) # Check which of the runs have already been done - done_runs = set( - [ - (sim.trial, sim.task_id, sim.seed) - for sim in prev_simulation_results.simulations - ] - ) + done_runs = set([(sim.trial, sim.task_id, sim.seed) for sim in prev_simulation_results.simulations]) simulation_results = prev_simulation_results - console_text = Text(text=f"Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.", style="bold yellow") + console_text = Text( + text=f"Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.", + style="bold yellow", + ) ConsoleDisplay.console.print(console_text) # Create new save file else: @@ -338,7 +331,10 @@ def _run(task: Task, trial: int, seed: int, progress_str: str) -> SimulationRun: for trial in range(num_trials): for i, task in enumerate(tasks): if (trial, task.id, seeds[trial]) in done_runs: - console_text = Text(text=f"Skipping task {task.id}, trial {trial} because it has already been run.", style="bold yellow") + console_text = Text( + text=f"Skipping task {task.id}, trial {trial} because it has already been run.", + style="bold yellow", + ) ConsoleDisplay.console.print(console_text) continue progress_str = f"{i}/{len(tasks)} (trial {trial + 1}/{num_trials})" @@ -394,9 +390,7 @@ def run_task( if max_errors <= 0: raise ValueError("Max errors must be greater than 0") global registry - logger.info( - f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}" - ) + logger.info(f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}") environment_constructor = registry.get_env_constructor(domain) environment = environment_constructor() AgentConstructor = registry.get_agent_constructor(agent) @@ -429,9 +423,7 @@ def run_task( task=task, ) else: - raise ValueError( - f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent" - ) + raise ValueError(f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent") try: user_tools = environment.get_user_tools() except Exception: @@ -439,9 +431,7 @@ def run_task( UserConstructor = registry.get_user_constructor(user) if issubclass(UserConstructor, DummyUser): - assert isinstance(agent, LLMSoloAgent), ( - "Dummy user can only be used with solo agent" - ) + assert isinstance(agent, LLMSoloAgent), "Dummy user can only be used with solo agent" user = UserConstructor( tools=user_tools, diff --git a/vendor/tau2/scripts/show_domain_doc.py b/vendor/tau2/scripts/show_domain_doc.py index a7d55cdd..b1b4c8ee 100755 --- a/vendor/tau2/scripts/show_domain_doc.py +++ b/vendor/tau2/scripts/show_domain_doc.py @@ -59,9 +59,7 @@ def main(domain: str): except KeyError: available_domains = registry.get_domains() - logger.error( - f"Domain '{domain}' not found. Available domains: {available_domains}" - ) + logger.error(f"Domain '{domain}' not found. Available domains: {available_domains}") exit(1) except Exception as e: logger.error(f"Failed to start domain documentation server: {str(e)}") diff --git a/vendor/tau2/scripts/start_servers.py b/vendor/tau2/scripts/start_servers.py index 7a7596be..ecbec88c 100755 --- a/vendor/tau2/scripts/start_servers.py +++ b/vendor/tau2/scripts/start_servers.py @@ -18,9 +18,7 @@ def kill_process_on_port(port): connections = proc.net_connections() for conn in connections: if hasattr(conn, "laddr") and conn.laddr.port == port: - logger.warning( - f"Killing existing process {proc.pid} on port {port}" - ) + logger.warning(f"Killing existing process {proc.pid} on port {port}") proc.terminate() time.sleep(0.5) # Give it a moment to terminate if proc.is_running(): # If still running @@ -82,9 +80,7 @@ def signal_handler(signum, frame): try: with ThreadPoolExecutor(max_workers=len(servers)) as executor: # Start each server in a separate thread - futures = [ - executor.submit(run_server, command, port) for command, port in servers - ] + futures = [executor.submit(run_server, command, port) for command, port in servers] # Wait for all servers to complete for future in futures: diff --git a/vendor/tau2/scripts/view_simulations.py b/vendor/tau2/scripts/view_simulations.py index b8fad2be..e3357477 100644 --- a/vendor/tau2/scripts/view_simulations.py +++ b/vendor/tau2/scripts/view_simulations.py @@ -23,9 +23,7 @@ def get_available_simulations(): return sorted([f for f in sim_dir.glob("*.json")]) -def display_simulation_list( - results: Results, only_show_failed: bool = False, only_show_all_failed: bool = False -): +def display_simulation_list(results: Results, only_show_failed: bool = False, only_show_all_failed: bool = False): """Display a numbered list of simulations with basic info.""" ConsoleDisplay.console.print("\n[bold blue]Available Simulations:[/]") @@ -74,9 +72,7 @@ def display_available_files(files): ConsoleDisplay.console.print(f"[cyan]{i}.[/] {file.name}") -def display_simulation_with_task( - simulation, task, results_file: str, sim_index: int, show_details: bool = True -): +def display_simulation_with_task(simulation, task, results_file: str, sim_index: int, show_details: bool = True): """Display a simulation along with its associated task.""" ConsoleDisplay.console.print("\n" + "=" * 80) # Separator ConsoleDisplay.console.print("[bold blue]Task Details:[/]") @@ -113,18 +109,12 @@ def find_task_by_id(tasks, task_id): def find_simulation_by_task_id_and_trial(results, task_id, trial): """Get a simulation by its task ID and trial number.""" return next( - ( - sim - for sim in results.simulations - if sim.task_id == task_id and sim.trial == trial - ), + (sim for sim in results.simulations if sim.task_id == task_id and sim.trial == trial), None, ) -def save_simulation_note( - simulation, task, note: str, results_file: str, sim_index: int -): +def save_simulation_note(simulation, task, note: str, results_file: str, sim_index: int): """Save a note about a simulation to a CSV file.""" notes_file = Path(DATA_DIR) / "simulations" / "simulation_notes.csv" file_exists = notes_file.exists() @@ -137,9 +127,11 @@ def save_simulation_note( "trial": simulation.trial, "duration": simulation.duration, "reward": simulation.reward_info.reward if simulation.reward_info else None, - "db_match": simulation.reward_info.db_check.db_match - if simulation.reward_info and simulation.reward_info.db_check - else None, + "db_match": ( + simulation.reward_info.db_check.db_match + if simulation.reward_info and simulation.reward_info.db_check + else None + ), "results_file": results_file, "sim_index": sim_index, "note": note, @@ -165,9 +157,7 @@ def main( sim_files = [Path(sim_file)] if not sim_files: - ConsoleDisplay.console.print( - "[red]No simulation files found in data/simulations/[/]" - ) + ConsoleDisplay.console.print("[red]No simulation files found in data/simulations/[/]") return results = None @@ -176,20 +166,14 @@ def main( # Show main menu ConsoleDisplay.console.print("\n[bold yellow]Main Menu:[/]") ConsoleDisplay.console.print("1. Select simulation file") - ConsoleDisplay.console.print( - " [dim]Choose a simulation results file to load and analyze[/]" - ) + ConsoleDisplay.console.print(" [dim]Choose a simulation results file to load and analyze[/]") if results: ConsoleDisplay.console.print("2. View agent performance metrics") ConsoleDisplay.console.print(" [dim]Display agent performance metrics[/]") ConsoleDisplay.console.print("3. View simulation") - ConsoleDisplay.console.print( - " [dim]Examine a specific simulation in detail with all its data[/]" - ) + ConsoleDisplay.console.print(" [dim]Examine a specific simulation in detail with all its data[/]") ConsoleDisplay.console.print("4. View task details") - ConsoleDisplay.console.print( - " [dim]Look at the configuration and parameters of a specific task[/]" - ) + ConsoleDisplay.console.print(" [dim]Look at the configuration and parameters of a specific task[/]") ConsoleDisplay.console.print("5. Exit") ConsoleDisplay.console.print(" [dim]Close the simulation viewer[/]") choices = ["1", "2", "3", "4", "5"] @@ -200,17 +184,13 @@ def main( choices = ["1", "2"] default_choice = "1" - choice = Prompt.ask( - "\nWhat would you like to do?", choices=choices, default=default_choice - ) + choice = Prompt.ask("\nWhat would you like to do?", choices=choices, default=default_choice) if choice == "1": # Show available files and get selection display_available_files(sim_files) # default to view the last file - file_num = IntPrompt.ask( - f"\nSelect file number (1-{len(sim_files)})", default=len(sim_files) - ) + file_num = IntPrompt.ask(f"\nSelect file number (1-{len(sim_files)})", default=len(sim_files)) if 1 <= file_num <= len(sim_files): try: @@ -219,13 +199,9 @@ def main( ConsoleDisplay.console.print( f"\n[bold green]Loaded {len(results.simulations)} simulations from {current_file}[/]" ) - results.simulations = sorted( - results.simulations, key=lambda x: (x.task_id, x.trial) - ) + results.simulations = sorted(results.simulations, key=lambda x: (x.task_id, x.trial)) except Exception as e: - ConsoleDisplay.console.print( - f"[red]Error loading results:[/] {str(e)}" - ) + ConsoleDisplay.console.print(f"[red]Error loading results:[/] {str(e)}") else: ConsoleDisplay.console.print("[red]Invalid file number[/]") @@ -245,21 +221,15 @@ def main( # Get simulation selection by index sim_count = len(results.simulations) - sim_index = IntPrompt.ask( - f"\nEnter simulation number (1-{sim_count})", default=1 - ) + sim_index = IntPrompt.ask(f"\nEnter simulation number (1-{sim_count})", default=1) if 1 <= sim_index <= sim_count: sim = results.simulations[sim_index - 1] task = find_task_by_id(results.tasks, sim.task_id) if task: - display_simulation_with_task( - sim, task, current_file, sim_index, show_details=True - ) + display_simulation_with_task(sim, task, current_file, sim_index, show_details=True) else: - ConsoleDisplay.console.print( - f"[red]Warning: Could not find task for simulation {sim.id}[/]" - ) + ConsoleDisplay.console.print(f"[red]Warning: Could not find task for simulation {sim.id}[/]") ConsoleDisplay.display_simulation(sim, show_details=True) continue else: diff --git a/vendor/tau2/user/base.py b/vendor/tau2/user/base.py index b47bdb9a..1fe018b6 100644 --- a/vendor/tau2/user/base.py +++ b/vendor/tau2/user/base.py @@ -63,9 +63,7 @@ def flip_roles(self) -> list[APICompatibleMessage]: ) ) else: - raise ValueError( - f"Tool calls are not supported in the flipped messages: {message}" - ) + raise ValueError(f"Tool calls are not supported in the flipped messages: {message}") elif isinstance(message, ToolMessage): if message.requestor == "user": # Only add tool messages for the user @@ -77,9 +75,7 @@ def flip_roles(self) -> list[APICompatibleMessage]: ) ) else: - raise ValueError( - f"Tool messages should be sent to the user in this message history: {message}" - ) + raise ValueError(f"Tool messages should be sent to the user in this message history: {message}") else: print(message, type(message)) raise ValueError(f"Unknown message role: {message.role}") @@ -100,9 +96,7 @@ def __init__( self.instructions = instructions @abstractmethod - async def get_init_state( - self, message_history: Optional[list[Message]] = None - ) -> UserState: + async def get_init_state(self, message_history: Optional[list[Message]] = None) -> UserState: """Get the initial state of the user simulator. Args: diff --git a/vendor/tau2/user/user_simulator.py b/vendor/tau2/user/user_simulator.py index d5508409..31581bdc 100644 --- a/vendor/tau2/user/user_simulator.py +++ b/vendor/tau2/user/user_simulator.py @@ -100,9 +100,9 @@ def get_init_state(self, message_history: Optional[list[Message]] = None) -> Use """ if message_history is None: message_history = [] - assert all( - is_valid_user_history_message(m) for m in message_history - ), "Invalid user message history. User messages must be of type UserMessage, AssistantMessage, or ToolMessage to User." + assert all(is_valid_user_history_message(m) for m in message_history), ( + "Invalid user message history. User messages must be of type UserMessage, AssistantMessage, or ToolMessage to User." + ) user_state = UserState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], diff --git a/vendor/tau2/utils/display.py b/vendor/tau2/utils/display.py index 674efb92..8e83e2cd 100644 --- a/vendor/tau2/utils/display.py +++ b/vendor/tau2/utils/display.py @@ -95,9 +95,7 @@ def display_task(cls, task: Task): if task.description.purpose: content_parts.append(f"[white]Purpose:[/] {task.description.purpose}") if task.description.relevant_policies: - content_parts.append( - f"[white]Relevant Policies:[/] {task.description.relevant_policies}" - ) + content_parts.append(f"[white]Relevant Policies:[/] {task.description.relevant_policies}") if task.description.notes: content_parts.append(f"[white]Notes:[/] {task.description.notes}") @@ -108,14 +106,10 @@ def display_task(cls, task: Task): scenario_parts.append(f"[white]Persona:[/] {task.user_scenario.persona}") # User Instruction - scenario_parts.append( - f"[white]Task Instructions:[/] {task.user_scenario.instructions}" - ) + scenario_parts.append(f"[white]Task Instructions:[/] {task.user_scenario.instructions}") if scenario_parts: - content_parts.append( - "[bold cyan]User Scenario:[/]\n" + "\n".join(scenario_parts) - ) + content_parts.append("[bold cyan]User Scenario:[/]\n" + "\n".join(scenario_parts)) # Initial State section if task.initial_state: @@ -134,9 +128,7 @@ def display_task(cls, task: Task): ) if initial_state_parts: - content_parts.append( - "[bold cyan]Initial State:[/]\n" + "\n".join(initial_state_parts) - ) + content_parts.append("[bold cyan]Initial State:[/]\n" + "\n".join(initial_state_parts)) # Evaluation Criteria section if task.evaluation_criteria: @@ -154,15 +146,11 @@ def display_task(cls, task: Task): f"[white]Information to Communicate:[/]\n{json.dumps(task.evaluation_criteria.communicate_info, indent=2)}" ) if eval_parts: - content_parts.append( - "[bold cyan]Evaluation Criteria:[/]\n" + "\n".join(eval_parts) - ) + content_parts.append("[bold cyan]Evaluation Criteria:[/]\n" + "\n".join(eval_parts)) content = "\n\n".join(content_parts) # Create and display panel - task_panel = Panel( - content, title="[bold blue]Task Details", border_style="blue", expand=True - ) + task_panel = Panel(content, title="[bold blue]Task Details", border_style="blue", expand=True) cls.console.print(task_panel) @@ -203,18 +191,11 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True marker = "โœ…" if is_successful(simulation.reward_info.reward) else "โŒ" sim_info.append("Reward: ", style="bold cyan") if simulation.reward_info.reward_breakdown: - breakdown = sorted( - [ - f"{k.value}: {v:.1f}" - for k, v in simulation.reward_info.reward_breakdown.items() - ] - ) + breakdown = sorted([f"{k.value}: {v:.1f}" for k, v in simulation.reward_info.reward_breakdown.items()]) else: breakdown = [] - sim_info.append( - f"{marker} {simulation.reward_info.reward:.4f} ({', '.join(breakdown)})\n" - ) + sim_info.append(f"{marker} {simulation.reward_info.reward:.4f} ({', '.join(breakdown)})\n") # Add DB check info if present if simulation.reward_info.db_check: @@ -243,9 +224,7 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True if simulation.reward_info.communicate_checks: sim_info.append("\nCommunicate Checks:\n", style="bold magenta") for i, check in enumerate(simulation.reward_info.communicate_checks): - sim_info.append( - f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'}\n" - ) + sim_info.append(f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'}\n") # Add NL assertions if present if simulation.reward_info.nl_assertions: @@ -261,9 +240,7 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True for key, value in simulation.reward_info.info.items(): sim_info.append(f"{key}: {value}\n") - cls.console.print( - Panel(sim_info, title="Simulation Overview", border_style="blue") - ) + cls.console.print(Panel(sim_info, title="Simulation Overview", border_style="blue")) # Create messages table if simulation.messages: @@ -390,15 +367,8 @@ def display_simulation(cls, sim: SimulationRun) -> str: # Add reward info if present if sim.reward_info: - breakdown = sorted( - [ - f"{k.value}: {v:.1f}" - for k, v in sim.reward_info.reward_breakdown.items() - ] - ) - output.append( - f"**Reward**: {sim.reward_info.reward:.4f} ({', '.join(breakdown)})\n" - ) + breakdown = sorted([f"{k.value}: {v:.1f}" for k, v in sim.reward_info.reward_breakdown.items()]) + output.append(f"**Reward**: {sim.reward_info.reward:.4f} ({', '.join(breakdown)})\n") output.append(f"**Reward**: {sim.reward_info.reward:.4f}") # Add DB check info if present @@ -428,9 +398,7 @@ def display_simulation(cls, sim: SimulationRun) -> str: if sim.reward_info.communicate_checks: output.append("\n**Communicate Checks**") for i, check in enumerate(sim.reward_info.communicate_checks): - output.append( - f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'} {check.justification}" - ) + output.append(f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'} {check.justification}") # Add NL assertions if present if sim.reward_info.nl_assertions: diff --git a/vendor/tau2/utils/llm_utils.py b/vendor/tau2/utils/llm_utils.py index 750c75fe..895d3f56 100644 --- a/vendor/tau2/utils/llm_utils.py +++ b/vendor/tau2/utils/llm_utils.py @@ -1,4 +1,6 @@ import json +import logging +import os import re from typing import Any, Optional @@ -7,8 +9,6 @@ from litellm.caching.caching import Cache from litellm.main import ModelResponse, Usage from loguru import logger -import logging -import os from vendor.tau2.config import ( DEFAULT_LLM_CACHE_TYPE, diff --git a/vendor/tau2/utils/pydantic_utils.py b/vendor/tau2/utils/pydantic_utils.py index 5c34acd8..e46454eb 100644 --- a/vendor/tau2/utils/pydantic_utils.py +++ b/vendor/tau2/utils/pydantic_utils.py @@ -21,9 +21,7 @@ def get_pydantic_hash(obj: BaseModel) -> str: return get_dict_hash(hash_dict) -def update_pydantic_model_with_dict( - model_instance: T, update_data: Dict[str, Any] -) -> T: +def update_pydantic_model_with_dict(model_instance: T, update_data: Dict[str, Any]) -> T: """ Return an updated BaseModel instance based on the update_data. """ diff --git a/vendor/tau2/utils/utils.py b/vendor/tau2/utils/utils.py index c1103fe9..33c9b511 100644 --- a/vendor/tau2/utils/utils.py +++ b/vendor/tau2/utils/utils.py @@ -29,9 +29,7 @@ # Check if data directory exists and is accessible if not DATA_DIR.exists(): logger.warning(f"Data directory does not exist: {DATA_DIR}") - logger.warning( - "Set TAU2_DATA_DIR environment variable to point to your data directory" - ) + logger.warning("Set TAU2_DATA_DIR environment variable to point to your data directory") logger.warning("Or ensure the data directory exists in the expected location") @@ -72,11 +70,7 @@ def get_commit_hash() -> str: Get the commit hash of the current directory. """ try: - commit_hash = ( - subprocess.check_output(["git", "rev-parse", "HEAD"], text=True) - .strip() - .split("\n")[0] - ) + commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip().split("\n")[0] except Exception as e: logger.error(f"Failed to get git hash: {e}") commit_hash = "unknown" diff --git a/versioneer.py b/versioneer.py index f8092765..4611fc7d 100644 --- a/versioneer.py +++ b/versioneer.py @@ -503,9 +503,7 @@ def run_command( return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = r''' +LONG_VERSION_PY["git"] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build