|
17 | 17 | import json |
18 | 18 | import unittest |
19 | 19 | import uuid |
| 20 | +from unittest.mock import Mock, patch |
20 | 21 |
|
| 22 | +from google.protobuf.struct_pb2 import Struct |
21 | 23 | from google.rpc import code_pb2, status_pb2 |
22 | 24 |
|
23 | 25 | from dapr.aio.clients import DaprClient as AsyncDaprClient |
|
37 | 39 | ConversationResponseAlpha2, |
38 | 40 | ConversationResultAlpha2, |
39 | 41 | ConversationResultAlpha2Choices, |
| 42 | + ConversationResultAlpha2CompletionUsage, |
| 43 | + ConversationResultAlpha2CompletionUsageCompletionTokensDetails, |
| 44 | + ConversationResultAlpha2CompletionUsagePromptTokensDetails, |
40 | 45 | ConversationResultAlpha2Message, |
41 | 46 | ConversationToolCalls, |
42 | 47 | ConversationToolCallsOfFunction, |
43 | 48 | ConversationTools, |
44 | 49 | ConversationToolsFunction, |
45 | 50 | FunctionBackend, |
| 51 | + _get_outputs_from_grpc_response, |
46 | 52 | create_assistant_message, |
47 | 53 | create_system_message, |
48 | 54 | create_tool_message, |
@@ -248,6 +254,14 @@ def test_basic_conversation_alpha2(self): |
248 | 254 | self.assertEqual(choice.finish_reason, 'stop') |
249 | 255 | self.assertIn('Hello Alpha2!', choice.message.content) |
250 | 256 |
|
| 257 | + out = response.outputs[0] |
| 258 | + if out.model is not None: |
| 259 | + self.assertEqual(out.model, 'test-llm') |
| 260 | + if out.usage is not None: |
| 261 | + self.assertGreaterEqual(out.usage.total_tokens, 15) |
| 262 | + self.assertGreaterEqual(out.usage.prompt_tokens, 5) |
| 263 | + self.assertGreaterEqual(out.usage.completion_tokens, 10) |
| 264 | + |
251 | 265 | def test_conversation_alpha2_with_system_message(self): |
252 | 266 | """Test Alpha2 conversation with system message.""" |
253 | 267 | system_message = create_system_message('You are a helpful assistant.') |
@@ -1107,6 +1121,186 @@ def test_empty_and_none_outputs(self): |
1107 | 1121 | self.assertEqual(response_none.to_assistant_messages(), []) |
1108 | 1122 |
|
1109 | 1123 |
|
| 1124 | +class TestConversationResultAlpha2ModelAndUsage(unittest.TestCase): |
| 1125 | + """Tests for model and usage fields on ConversationResultAlpha2 and related types.""" |
| 1126 | + |
| 1127 | + def test_result_alpha2_has_model_and_usage_attributes(self): |
| 1128 | + """ConversationResultAlpha2 accepts and exposes model and usage.""" |
| 1129 | + msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[]) |
| 1130 | + choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg) |
| 1131 | + usage = ConversationResultAlpha2CompletionUsage( |
| 1132 | + completion_tokens=10, |
| 1133 | + prompt_tokens=5, |
| 1134 | + total_tokens=15, |
| 1135 | + ) |
| 1136 | + result = ConversationResultAlpha2( |
| 1137 | + choices=[choice], |
| 1138 | + model='test-model-1', |
| 1139 | + usage=usage, |
| 1140 | + ) |
| 1141 | + self.assertEqual(result.model, 'test-model-1') |
| 1142 | + self.assertIsNotNone(result.usage) |
| 1143 | + self.assertEqual(result.usage.completion_tokens, 10) |
| 1144 | + self.assertEqual(result.usage.prompt_tokens, 5) |
| 1145 | + self.assertEqual(result.usage.total_tokens, 15) |
| 1146 | + |
| 1147 | + def test_result_alpha2_model_and_usage_default_none(self): |
| 1148 | + """ConversationResultAlpha2 optional fields default to None when not provided. |
| 1149 | +
|
| 1150 | + When the API returns a response, model and usage are set from the conversation |
| 1151 | + component. This test only checks that the dataclass defaults are None when |
| 1152 | + constructing with choices only. |
| 1153 | + """ |
| 1154 | + msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[]) |
| 1155 | + choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg) |
| 1156 | + result = ConversationResultAlpha2(choices=[choice]) |
| 1157 | + self.assertIsNone(result.model) |
| 1158 | + self.assertIsNone(result.usage) |
| 1159 | + |
| 1160 | + def test_usage_completion_and_prompt_details(self): |
| 1161 | + """ConversationResultAlpha2CompletionUsage supports details.""" |
| 1162 | + completion_details = ConversationResultAlpha2CompletionUsageCompletionTokensDetails( |
| 1163 | + accepted_prediction_tokens=1, |
| 1164 | + audio_tokens=2, |
| 1165 | + reasoning_tokens=3, |
| 1166 | + rejected_prediction_tokens=0, |
| 1167 | + ) |
| 1168 | + prompt_details = ConversationResultAlpha2CompletionUsagePromptTokensDetails( |
| 1169 | + audio_tokens=0, |
| 1170 | + cached_tokens=4, |
| 1171 | + ) |
| 1172 | + usage = ConversationResultAlpha2CompletionUsage( |
| 1173 | + completion_tokens=10, |
| 1174 | + prompt_tokens=5, |
| 1175 | + total_tokens=15, |
| 1176 | + completion_tokens_details=completion_details, |
| 1177 | + prompt_tokens_details=prompt_details, |
| 1178 | + ) |
| 1179 | + self.assertEqual(usage.completion_tokens_details.accepted_prediction_tokens, 1) |
| 1180 | + self.assertEqual(usage.completion_tokens_details.audio_tokens, 2) |
| 1181 | + self.assertEqual(usage.completion_tokens_details.reasoning_tokens, 3) |
| 1182 | + self.assertEqual(usage.completion_tokens_details.rejected_prediction_tokens, 0) |
| 1183 | + self.assertEqual(usage.prompt_tokens_details.audio_tokens, 0) |
| 1184 | + self.assertEqual(usage.prompt_tokens_details.cached_tokens, 4) |
| 1185 | + self.assertEqual(usage.total_tokens, 15) |
| 1186 | + self.assertEqual(usage.completion_tokens, 10) |
| 1187 | + self.assertEqual(usage.prompt_tokens, 5) |
| 1188 | + |
| 1189 | + def test_get_outputs_from_grpc_response_populates_model_and_usage(self): |
| 1190 | + """_get_outputs_from_grpc_response sets model and usage when present on proto.""" |
| 1191 | + from unittest import mock |
| 1192 | + |
| 1193 | + # Build a mock proto response with one output that has model and usage |
| 1194 | + mock_usage = mock.Mock() |
| 1195 | + mock_usage.completion_tokens = 20 |
| 1196 | + mock_usage.prompt_tokens = 8 |
| 1197 | + mock_usage.total_tokens = 28 |
| 1198 | + mock_usage.completion_tokens_details = None |
| 1199 | + mock_usage.prompt_tokens_details = None |
| 1200 | + |
| 1201 | + mock_choice_msg = mock.Mock() |
| 1202 | + mock_choice_msg.content = 'Hello' |
| 1203 | + mock_choice_msg.tool_calls = [] |
| 1204 | + |
| 1205 | + mock_choice = mock.Mock() |
| 1206 | + mock_choice.finish_reason = 'stop' |
| 1207 | + mock_choice.index = 0 |
| 1208 | + mock_choice.message = mock_choice_msg |
| 1209 | + |
| 1210 | + mock_output = mock.Mock() |
| 1211 | + mock_output.model = 'gpt-4o-mini' |
| 1212 | + mock_output.usage = mock_usage |
| 1213 | + mock_output.choices = [mock_choice] |
| 1214 | + |
| 1215 | + mock_response = mock.Mock() |
| 1216 | + mock_response.outputs = [mock_output] |
| 1217 | + |
| 1218 | + outputs = _get_outputs_from_grpc_response(mock_response) |
| 1219 | + self.assertEqual(len(outputs), 1) |
| 1220 | + out = outputs[0] |
| 1221 | + self.assertEqual(out.model, 'gpt-4o-mini') |
| 1222 | + self.assertIsNotNone(out.usage) |
| 1223 | + self.assertEqual(out.usage.completion_tokens, 20) |
| 1224 | + self.assertEqual(out.usage.prompt_tokens, 8) |
| 1225 | + self.assertEqual(out.usage.total_tokens, 28) |
| 1226 | + self.assertEqual(len(out.choices), 1) |
| 1227 | + self.assertEqual(out.choices[0].message.content, 'Hello') |
| 1228 | + |
| 1229 | + def test_get_outputs_from_grpc_response_without_model_usage(self): |
| 1230 | + """_get_outputs_from_grpc_response leaves model and usage None when absent.""" |
| 1231 | + from unittest import mock |
| 1232 | + |
| 1233 | + mock_choice_msg = mock.Mock() |
| 1234 | + mock_choice_msg.content = 'Echo' |
| 1235 | + mock_choice_msg.tool_calls = [] |
| 1236 | + |
| 1237 | + mock_choice = mock.Mock() |
| 1238 | + mock_choice.finish_reason = 'stop' |
| 1239 | + mock_choice.index = 0 |
| 1240 | + mock_choice.message = mock_choice_msg |
| 1241 | + |
| 1242 | + mock_output = mock.Mock(spec=['choices']) |
| 1243 | + mock_output.choices = [mock_choice] |
| 1244 | + # No model or usage attributes |
| 1245 | + |
| 1246 | + mock_response = mock.Mock() |
| 1247 | + mock_response.outputs = [mock_output] |
| 1248 | + |
| 1249 | + outputs = _get_outputs_from_grpc_response(mock_response) |
| 1250 | + self.assertEqual(len(outputs), 1) |
| 1251 | + out = outputs[0] |
| 1252 | + self.assertIsNone(out.model) |
| 1253 | + self.assertIsNone(out.usage) |
| 1254 | + self.assertEqual(out.choices[0].message.content, 'Echo') |
| 1255 | + |
| 1256 | + |
| 1257 | +class ConverseAlpha2ResponseFormatTests(unittest.TestCase): |
| 1258 | + """Unit tests for converse_alpha2 response_format parameter.""" |
| 1259 | + |
| 1260 | + def test_converse_alpha2_passes_response_format_on_request(self): |
| 1261 | + """converse_alpha2 sets response_format on the gRPC request when provided.""" |
| 1262 | + user_message = create_user_message('Structured output please') |
| 1263 | + input_alpha2 = ConversationInputAlpha2(messages=[user_message]) |
| 1264 | + response_format = Struct() |
| 1265 | + response_format.update( |
| 1266 | + {'type': 'json_schema', 'json_schema': {'name': 'test', 'schema': {}}} |
| 1267 | + ) |
| 1268 | + |
| 1269 | + captured_requests = [] |
| 1270 | + mock_choice_msg = Mock() |
| 1271 | + mock_choice_msg.content = 'ok' |
| 1272 | + mock_choice_msg.tool_calls = [] |
| 1273 | + mock_choice = Mock() |
| 1274 | + mock_choice.finish_reason = 'stop' |
| 1275 | + mock_choice.index = 0 |
| 1276 | + mock_choice.message = mock_choice_msg |
| 1277 | + mock_output = Mock() |
| 1278 | + mock_output.choices = [mock_choice] |
| 1279 | + mock_response = Mock() |
| 1280 | + mock_response.outputs = [mock_output] |
| 1281 | + mock_response.context_id = '' |
| 1282 | + mock_call = Mock() |
| 1283 | + |
| 1284 | + def capture_run_rpc(rpc, request, *args, **kwargs): |
| 1285 | + captured_requests.append(request) |
| 1286 | + return (mock_response, mock_call) |
| 1287 | + |
| 1288 | + with patch('dapr.clients.health.DaprHealth.wait_for_sidecar'): |
| 1289 | + client = DaprClient('localhost:50011') |
| 1290 | + with patch.object(client.retry_policy, 'run_rpc', side_effect=capture_run_rpc): |
| 1291 | + client.converse_alpha2( |
| 1292 | + name='test-llm', |
| 1293 | + inputs=[input_alpha2], |
| 1294 | + response_format=response_format, |
| 1295 | + ) |
| 1296 | + |
| 1297 | + self.assertEqual(len(captured_requests), 1) |
| 1298 | + req = captured_requests[0] |
| 1299 | + self.assertTrue(hasattr(req, 'response_format')) |
| 1300 | + self.assertEqual(req.response_format['type'], 'json_schema') |
| 1301 | + self.assertEqual(req.response_format['json_schema']['name'], 'test') |
| 1302 | + |
| 1303 | + |
1110 | 1304 | class ExecuteRegisteredToolSyncTests(unittest.TestCase): |
1111 | 1305 | def tearDown(self): |
1112 | 1306 | # Cleanup all tools we may have registered by name prefix |
|
0 commit comments