11import asyncio
2+ from typing import Any , Dict , List , Optional
23
34import pytest
45
5- from eval_protocol .integrations .fireworks_v1_completions_client import FireworksV1CompletionsClient
6+ from eval_protocol .integrations .fireworks_v1_completions_client import (
7+ FireworksV1CompletionsClient ,
8+ ParsedToolCall ,
9+ to_openai_tool_calls ,
10+ strip_chat_special_tokens ,
11+ )
612
713
8- def test_plaintext_fallback_disabled_raises_on_non_json ():
9- client = FireworksV1CompletionsClient (
10- model_id = "accounts/fireworks/models/qwen3-0p6b" ,
11- tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
12- allow_plaintext_action_fallback = False ,
13- )
14- with pytest .raises (ValueError ):
15- client ._parse_tool_call_with_optional_fallback ("move RIGHT next" )
16- asyncio .run (client .close ())
14+ def test_parsed_tool_call_to_openai_format ():
15+ tc = ParsedToolCall (tool_call_id = "call_1" , name = "lake_move" , arguments = {"action" : "RIGHT" })
16+ payload = to_openai_tool_calls (tc )
17+ assert len (payload ) == 1
18+ assert payload [0 ]["function" ]["name" ] == "lake_move"
19+ assert '"action":"RIGHT"' in payload [0 ]["function" ]["arguments" ]
1720
1821
19- def test_plaintext_fallback_extracts_action_when_enabled ():
20- client = FireworksV1CompletionsClient (
21- model_id = "accounts/fireworks/models/qwen3-0p6b" ,
22- tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
23- allow_plaintext_action_fallback = True ,
24- )
25- parsed = client ._parse_tool_call_with_optional_fallback ("The best move is RIGHT." )
26- assert parsed .arguments ["action" ] == "RIGHT"
27- asyncio .run (client .close ())
22+ def test_strip_chat_special_tokens ():
23+ assert strip_chat_special_tokens ("<|im_start|>assistant\n hello<|im_end|>" ) == "assistant\n hello"
24+ assert strip_chat_special_tokens ("" ) == ""
25+ assert strip_chat_special_tokens (None ) == ""
26+
27+
28+ def test_tool_call_parser_is_invoked ():
29+ """When a tool_call_parser is provided, create_completion_from_prompt_ids uses it."""
2830
31+ def fake_parser (
32+ text : str , ids : List [int ], tools : Optional [List [Dict [str , Any ]]]
33+ ) -> Dict [str , Any ]:
34+ return {
35+ "parsed_tool_call" : ParsedToolCall (
36+ tool_call_id = "call_0" , name = "test_tool" , arguments = {"x" : 1 }
37+ ),
38+ "assistant_content" : "thought" ,
39+ "parser" : "fake" ,
40+ }
2941
30- def test_plaintext_fallback_raises_when_no_action_found ():
3142 client = FireworksV1CompletionsClient (
32- model_id = "accounts/fireworks/models/qwen3-0p6b " ,
43+ model_id = "test-model " ,
3344 tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
34- allow_plaintext_action_fallback = True ,
45+ tool_call_parser = fake_parser ,
3546 )
36- with pytest .raises (ValueError ):
37- client ._parse_tool_call_with_optional_fallback ("I cannot decide from this state." )
47+
48+ result = fake_parser ("some text" , [1 , 2 ], None )
49+ assert result ["parsed_tool_call" ].name == "test_tool"
50+ assert result ["assistant_content" ] == "thought"
3851 asyncio .run (client .close ())
3952
4053
41- def test_parse_assistant_output_preserves_non_tool_content (monkeypatch ):
54+ def test_no_parser_returns_raw_content ():
55+ """When no tool_call_parser is provided, message contains raw content."""
4256 client = FireworksV1CompletionsClient (
43- model_id = "accounts/fireworks/models/qwen3-0p6b " ,
57+ model_id = "test-model " ,
4458 tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
4559 )
46- monkeypatch .setattr (client , "_parse_tool_call_with_vllm_parser" , lambda ** kwargs : None )
47- parsed = client ._parse_assistant_output (
48- completion_text = '<think>\n \n </think>\n {"tool_calls":[{"name":"lake_move","arguments":{"action":"RIGHT"}}]}' ,
49- completion_token_ids = [1 , 2 , 3 ],
50- tools = [{"type" : "function" , "function" : {"name" : "lake_move" }}],
51- )
52- assert parsed ["parsed_tool_call" ].arguments == {"action" : "RIGHT" }
53- assert parsed ["assistant_content" ] == "<think>\n \n </think>"
54- assert parsed ["non_tool_content" ] == "<think>\n \n </think>"
55- assert parsed ["parser" ] == "json_schema"
60+ assert client .tool_call_parser is None
5661 asyncio .run (client .close ())
5762
5863
59- def test_parse_assistant_output_uses_vllm_parser_when_available (monkeypatch ):
64+ def test_default_tools_not_used_when_tools_is_empty_list ():
65+ """Passing tools=[] should not fall back to default_tools."""
6066 client = FireworksV1CompletionsClient (
61- model_id = "accounts/fireworks/models/qwen3-0p6b " ,
67+ model_id = "test-model " ,
6268 tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
69+ default_tools = [{"type" : "function" , "function" : {"name" : "my_tool" }}],
6370 )
64-
65- class _Parsed :
66- arguments = {"action" : "DOWN" }
67-
68- monkeypatch .setattr (
69- client ,
70- "_parse_tool_call_with_vllm_parser" ,
71- lambda ** kwargs : {"parsed_tool_call" : _Parsed (), "assistant_content" : "thought" , "parser" : "vllm:qwen3xml" },
72- )
73- parsed = client ._parse_assistant_output (
74- completion_text = '{"tool_calls":[{"name":"lake_move","arguments":{"action":"DOWN"}}]}' ,
75- completion_token_ids = [1 , 2 , 3 ],
76- tools = [{"type" : "function" , "function" : {"name" : "lake_move" }}],
77- )
78- assert parsed ["assistant_content" ] == "thought"
79- assert parsed ["non_tool_content" ] == "thought"
80- assert parsed ["parser" ] == "vllm:qwen3xml"
81- assert parsed ["parsed_tool_call" ].arguments == {"action" : "DOWN" }
71+ assert client .default_tools == [{"type" : "function" , "function" : {"name" : "my_tool" }}]
8272 asyncio .run (client .close ())
8373
8474
@@ -98,7 +88,7 @@ def apply_chat_template(self, messages, **kwargs):
9888 raise RuntimeError ("tools unsupported" )
9989 return [11 , 22 , 33 ]
10090
101- def encode (self , text , add_special_tokens = False ): # pragma: no cover
91+ def encode (self , text , add_special_tokens = False ):
10292 return [99 ]
10393
10494 fake_tokenizer = FakeTokenizer ()
@@ -122,7 +112,7 @@ class FakeTokenizer:
122112 def apply_chat_template (self , messages , ** kwargs ):
123113 return {"input_ids" : [[101 , 102 , 103 ]]}
124114
125- def encode (self , text , add_special_tokens = False ): # pragma: no cover
115+ def encode (self , text , add_special_tokens = False ):
126116 return [99 ]
127117
128118 monkeypatch .setattr (client , "_get_tokenizer" , lambda : FakeTokenizer ())
@@ -132,3 +122,25 @@ def encode(self, text, add_special_tokens=False): # pragma: no cover
132122 )
133123 assert token_ids == [101 , 102 , 103 ]
134124 asyncio .run (client .close ())
125+
126+
127+ def test_thinking_kwargs_respects_enable_thinking ():
128+ client_none = FireworksV1CompletionsClient (
129+ model_id = "test" , tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
130+ )
131+ assert client_none ._thinking_kwargs () == {}
132+
133+ client_false = FireworksV1CompletionsClient (
134+ model_id = "test" , tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
135+ enable_thinking = False ,
136+ )
137+ assert client_false ._thinking_kwargs () == {"enable_thinking" : False }
138+
139+ client_true = FireworksV1CompletionsClient (
140+ model_id = "test" , tokenizer_name_or_path = "Qwen/Qwen3-0.6B" ,
141+ enable_thinking = True ,
142+ )
143+ assert client_true ._thinking_kwargs () == {"enable_thinking" : True }
144+ asyncio .run (client_none .close ())
145+ asyncio .run (client_false .close ())
146+ asyncio .run (client_true .close ())
0 commit comments