Skip to content

Commit 7f444a0

Browse files
committed
feat(convo): add new fields to conversation api (dapr#902)
* feat(convo): add new fields to conversation api Signed-off-by: Samantha Coyle <sam@diagrid.io> * fix: update proto/grpc code generator and add more tests Signed-off-by: Samantha Coyle <sam@diagrid.io> * style: appease linter Signed-off-by: Samantha Coyle <sam@diagrid.io> * style: tox -e type fixes Signed-off-by: Samantha Coyle <sam@diagrid.io> --------- Signed-off-by: Samantha Coyle <sam@diagrid.io>
1 parent 652b9a1 commit 7f444a0

4 files changed

Lines changed: 287 additions & 1 deletion

File tree

dapr/clients/grpc/client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
import grpc # type: ignore
2727
from google.protobuf.any_pb2 import Any as GrpcAny
28+
from google.protobuf.duration_pb2 import Duration as GrpcDuration
2829
from google.protobuf.empty_pb2 import Empty as GrpcEmpty
2930
from google.protobuf.message import Message as GrpcMessage
31+
from google.protobuf.struct_pb2 import Struct as GrpcStruct
3032
from grpc import ( # type: ignore
3133
RpcError,
3234
StatusCode,
@@ -1880,6 +1882,8 @@ def converse_alpha2(
18801882
temperature: Optional[float] = None,
18811883
tools: Optional[List[conversation.ConversationTools]] = None,
18821884
tool_choice: Optional[str] = None,
1885+
response_format: Optional[GrpcStruct] = None,
1886+
prompt_cache_retention: Optional[GrpcDuration] = None,
18831887
) -> conversation.ConversationResponseAlpha2:
18841888
"""Invoke an LLM using the conversation API (Alpha2) with tool calling support.
18851889
@@ -1893,6 +1897,8 @@ def converse_alpha2(
18931897
temperature: Optional temperature setting for the LLM to optimize for creativity or predictability
18941898
tools: Optional list of tools available for the LLM to call
18951899
tool_choice: Optional control over which tools can be called ('none', 'auto', 'required', or specific tool name)
1900+
response_format: Optional response format (google.protobuf.struct_pb2.Struct, ex: json_schema for structured output)
1901+
prompt_cache_retention: Optional retention for prompt cache (google.protobuf.duration_pb2.Duration)
18961902
18971903
Returns:
18981904
ConversationResponseAlpha2 containing the conversation results with choices and tool calls
@@ -1949,6 +1955,10 @@ def converse_alpha2(
19491955
request.temperature = temperature
19501956
if tool_choice is not None:
19511957
request.tool_choice = tool_choice
1958+
if response_format is not None and hasattr(request, 'response_format'):
1959+
request.response_format.CopyFrom(response_format)
1960+
if prompt_cache_retention is not None and hasattr(request, 'prompt_cache_retention'):
1961+
request.prompt_cache_retention.CopyFrom(prompt_cache_retention)
19521962

19531963
try:
19541964
response, call = self.retry_policy.run_rpc(self._stub.ConverseAlpha2.with_call, request)

dapr/clients/grpc/conversation.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,46 @@ class ConversationResultAlpha2Choices:
338338
message: ConversationResultAlpha2Message
339339

340340

341+
@dataclass
342+
class ConversationResultAlpha2CompletionUsageCompletionTokensDetails:
343+
"""Breakdown of tokens used in the completion."""
344+
345+
accepted_prediction_tokens: int = 0
346+
audio_tokens: int = 0
347+
reasoning_tokens: int = 0
348+
rejected_prediction_tokens: int = 0
349+
350+
351+
@dataclass
352+
class ConversationResultAlpha2CompletionUsagePromptTokensDetails:
353+
"""Breakdown of tokens used in the prompt."""
354+
355+
audio_tokens: int = 0
356+
cached_tokens: int = 0
357+
358+
359+
@dataclass
360+
class ConversationResultAlpha2CompletionUsage:
361+
"""Token usage for one Alpha2 conversation result."""
362+
363+
completion_tokens: int = 0
364+
prompt_tokens: int = 0
365+
total_tokens: int = 0
366+
completion_tokens_details: Optional[
367+
ConversationResultAlpha2CompletionUsageCompletionTokensDetails
368+
] = None
369+
prompt_tokens_details: Optional[ConversationResultAlpha2CompletionUsagePromptTokensDetails] = (
370+
None
371+
)
372+
373+
341374
@dataclass
342375
class ConversationResultAlpha2:
343376
"""One of the outputs in Alpha2 response from conversation input."""
344377

345378
choices: List[ConversationResultAlpha2Choices] = field(default_factory=list)
379+
model: Optional[str] = None
380+
usage: Optional[ConversationResultAlpha2CompletionUsage] = None
346381

347382

348383
@dataclass
@@ -657,5 +692,38 @@ def _get_outputs_from_grpc_response(
657692
)
658693
)
659694

660-
outputs.append(ConversationResultAlpha2(choices=choices))
695+
model: Optional[str] = None
696+
usage: Optional[ConversationResultAlpha2CompletionUsage] = None
697+
if hasattr(output, 'model') and getattr(output, 'model', None):
698+
model = output.model
699+
if hasattr(output, 'usage') and output.usage:
700+
u = output.usage
701+
completion_details: Optional[
702+
ConversationResultAlpha2CompletionUsageCompletionTokensDetails
703+
] = None
704+
prompt_details: Optional[ConversationResultAlpha2CompletionUsagePromptTokensDetails] = (
705+
None
706+
)
707+
if hasattr(u, 'completion_tokens_details') and u.completion_tokens_details:
708+
cd = u.completion_tokens_details
709+
completion_details = ConversationResultAlpha2CompletionUsageCompletionTokensDetails(
710+
accepted_prediction_tokens=getattr(cd, 'accepted_prediction_tokens', 0) or 0,
711+
audio_tokens=getattr(cd, 'audio_tokens', 0) or 0,
712+
reasoning_tokens=getattr(cd, 'reasoning_tokens', 0) or 0,
713+
rejected_prediction_tokens=getattr(cd, 'rejected_prediction_tokens', 0) or 0,
714+
)
715+
if hasattr(u, 'prompt_tokens_details') and u.prompt_tokens_details:
716+
pd = u.prompt_tokens_details
717+
prompt_details = ConversationResultAlpha2CompletionUsagePromptTokensDetails(
718+
audio_tokens=getattr(pd, 'audio_tokens', 0) or 0,
719+
cached_tokens=getattr(pd, 'cached_tokens', 0) or 0,
720+
)
721+
usage = ConversationResultAlpha2CompletionUsage(
722+
completion_tokens=getattr(u, 'completion_tokens', 0) or 0,
723+
prompt_tokens=getattr(u, 'prompt_tokens', 0) or 0,
724+
total_tokens=getattr(u, 'total_tokens', 0) or 0,
725+
completion_tokens_details=completion_details,
726+
prompt_tokens_details=prompt_details,
727+
)
728+
outputs.append(ConversationResultAlpha2(choices=choices, model=model, usage=usage))
661729
return outputs

tests/clients/fake_dapr_server.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,20 @@ def ConverseAlpha2(self, request, context):
636636

637637
# Create result for this input
638638
result = api_v1.ConversationResultAlpha2(choices=choices)
639+
if hasattr(result, 'model'):
640+
result.model = 'test-llm'
641+
if hasattr(result, 'usage'):
642+
try:
643+
usage_cls = getattr(api_v1, 'ConversationResultAlpha2CompletionUsage', None)
644+
if usage_cls is not None:
645+
u = usage_cls(
646+
completion_tokens=10,
647+
prompt_tokens=5,
648+
total_tokens=15,
649+
)
650+
result.usage.CopyFrom(u)
651+
except Exception:
652+
pass
639653
outputs.append(result)
640654

641655
return api_v1.ConversationResponseAlpha2(

tests/clients/test_conversation.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import json
1818
import unittest
1919
import uuid
20+
from unittest.mock import Mock, patch
2021

22+
from google.protobuf.struct_pb2 import Struct
2123
from google.rpc import code_pb2, status_pb2
2224

2325
from dapr.aio.clients import DaprClient as AsyncDaprClient
@@ -37,12 +39,16 @@
3739
ConversationResponseAlpha2,
3840
ConversationResultAlpha2,
3941
ConversationResultAlpha2Choices,
42+
ConversationResultAlpha2CompletionUsage,
43+
ConversationResultAlpha2CompletionUsageCompletionTokensDetails,
44+
ConversationResultAlpha2CompletionUsagePromptTokensDetails,
4045
ConversationResultAlpha2Message,
4146
ConversationToolCalls,
4247
ConversationToolCallsOfFunction,
4348
ConversationTools,
4449
ConversationToolsFunction,
4550
FunctionBackend,
51+
_get_outputs_from_grpc_response,
4652
create_assistant_message,
4753
create_system_message,
4854
create_tool_message,
@@ -248,6 +254,14 @@ def test_basic_conversation_alpha2(self):
248254
self.assertEqual(choice.finish_reason, 'stop')
249255
self.assertIn('Hello Alpha2!', choice.message.content)
250256

257+
out = response.outputs[0]
258+
if out.model is not None:
259+
self.assertEqual(out.model, 'test-llm')
260+
if out.usage is not None:
261+
self.assertGreaterEqual(out.usage.total_tokens, 15)
262+
self.assertGreaterEqual(out.usage.prompt_tokens, 5)
263+
self.assertGreaterEqual(out.usage.completion_tokens, 10)
264+
251265
def test_conversation_alpha2_with_system_message(self):
252266
"""Test Alpha2 conversation with system message."""
253267
system_message = create_system_message('You are a helpful assistant.')
@@ -1107,6 +1121,186 @@ def test_empty_and_none_outputs(self):
11071121
self.assertEqual(response_none.to_assistant_messages(), [])
11081122

11091123

1124+
class TestConversationResultAlpha2ModelAndUsage(unittest.TestCase):
1125+
"""Tests for model and usage fields on ConversationResultAlpha2 and related types."""
1126+
1127+
def test_result_alpha2_has_model_and_usage_attributes(self):
1128+
"""ConversationResultAlpha2 accepts and exposes model and usage."""
1129+
msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[])
1130+
choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg)
1131+
usage = ConversationResultAlpha2CompletionUsage(
1132+
completion_tokens=10,
1133+
prompt_tokens=5,
1134+
total_tokens=15,
1135+
)
1136+
result = ConversationResultAlpha2(
1137+
choices=[choice],
1138+
model='test-model-1',
1139+
usage=usage,
1140+
)
1141+
self.assertEqual(result.model, 'test-model-1')
1142+
self.assertIsNotNone(result.usage)
1143+
self.assertEqual(result.usage.completion_tokens, 10)
1144+
self.assertEqual(result.usage.prompt_tokens, 5)
1145+
self.assertEqual(result.usage.total_tokens, 15)
1146+
1147+
def test_result_alpha2_model_and_usage_default_none(self):
1148+
"""ConversationResultAlpha2 optional fields default to None when not provided.
1149+
1150+
When the API returns a response, model and usage are set from the conversation
1151+
component. This test only checks that the dataclass defaults are None when
1152+
constructing with choices only.
1153+
"""
1154+
msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[])
1155+
choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg)
1156+
result = ConversationResultAlpha2(choices=[choice])
1157+
self.assertIsNone(result.model)
1158+
self.assertIsNone(result.usage)
1159+
1160+
def test_usage_completion_and_prompt_details(self):
1161+
"""ConversationResultAlpha2CompletionUsage supports details."""
1162+
completion_details = ConversationResultAlpha2CompletionUsageCompletionTokensDetails(
1163+
accepted_prediction_tokens=1,
1164+
audio_tokens=2,
1165+
reasoning_tokens=3,
1166+
rejected_prediction_tokens=0,
1167+
)
1168+
prompt_details = ConversationResultAlpha2CompletionUsagePromptTokensDetails(
1169+
audio_tokens=0,
1170+
cached_tokens=4,
1171+
)
1172+
usage = ConversationResultAlpha2CompletionUsage(
1173+
completion_tokens=10,
1174+
prompt_tokens=5,
1175+
total_tokens=15,
1176+
completion_tokens_details=completion_details,
1177+
prompt_tokens_details=prompt_details,
1178+
)
1179+
self.assertEqual(usage.completion_tokens_details.accepted_prediction_tokens, 1)
1180+
self.assertEqual(usage.completion_tokens_details.audio_tokens, 2)
1181+
self.assertEqual(usage.completion_tokens_details.reasoning_tokens, 3)
1182+
self.assertEqual(usage.completion_tokens_details.rejected_prediction_tokens, 0)
1183+
self.assertEqual(usage.prompt_tokens_details.audio_tokens, 0)
1184+
self.assertEqual(usage.prompt_tokens_details.cached_tokens, 4)
1185+
self.assertEqual(usage.total_tokens, 15)
1186+
self.assertEqual(usage.completion_tokens, 10)
1187+
self.assertEqual(usage.prompt_tokens, 5)
1188+
1189+
def test_get_outputs_from_grpc_response_populates_model_and_usage(self):
1190+
"""_get_outputs_from_grpc_response sets model and usage when present on proto."""
1191+
from unittest import mock
1192+
1193+
# Build a mock proto response with one output that has model and usage
1194+
mock_usage = mock.Mock()
1195+
mock_usage.completion_tokens = 20
1196+
mock_usage.prompt_tokens = 8
1197+
mock_usage.total_tokens = 28
1198+
mock_usage.completion_tokens_details = None
1199+
mock_usage.prompt_tokens_details = None
1200+
1201+
mock_choice_msg = mock.Mock()
1202+
mock_choice_msg.content = 'Hello'
1203+
mock_choice_msg.tool_calls = []
1204+
1205+
mock_choice = mock.Mock()
1206+
mock_choice.finish_reason = 'stop'
1207+
mock_choice.index = 0
1208+
mock_choice.message = mock_choice_msg
1209+
1210+
mock_output = mock.Mock()
1211+
mock_output.model = 'gpt-4o-mini'
1212+
mock_output.usage = mock_usage
1213+
mock_output.choices = [mock_choice]
1214+
1215+
mock_response = mock.Mock()
1216+
mock_response.outputs = [mock_output]
1217+
1218+
outputs = _get_outputs_from_grpc_response(mock_response)
1219+
self.assertEqual(len(outputs), 1)
1220+
out = outputs[0]
1221+
self.assertEqual(out.model, 'gpt-4o-mini')
1222+
self.assertIsNotNone(out.usage)
1223+
self.assertEqual(out.usage.completion_tokens, 20)
1224+
self.assertEqual(out.usage.prompt_tokens, 8)
1225+
self.assertEqual(out.usage.total_tokens, 28)
1226+
self.assertEqual(len(out.choices), 1)
1227+
self.assertEqual(out.choices[0].message.content, 'Hello')
1228+
1229+
def test_get_outputs_from_grpc_response_without_model_usage(self):
1230+
"""_get_outputs_from_grpc_response leaves model and usage None when absent."""
1231+
from unittest import mock
1232+
1233+
mock_choice_msg = mock.Mock()
1234+
mock_choice_msg.content = 'Echo'
1235+
mock_choice_msg.tool_calls = []
1236+
1237+
mock_choice = mock.Mock()
1238+
mock_choice.finish_reason = 'stop'
1239+
mock_choice.index = 0
1240+
mock_choice.message = mock_choice_msg
1241+
1242+
mock_output = mock.Mock(spec=['choices'])
1243+
mock_output.choices = [mock_choice]
1244+
# No model or usage attributes
1245+
1246+
mock_response = mock.Mock()
1247+
mock_response.outputs = [mock_output]
1248+
1249+
outputs = _get_outputs_from_grpc_response(mock_response)
1250+
self.assertEqual(len(outputs), 1)
1251+
out = outputs[0]
1252+
self.assertIsNone(out.model)
1253+
self.assertIsNone(out.usage)
1254+
self.assertEqual(out.choices[0].message.content, 'Echo')
1255+
1256+
1257+
class ConverseAlpha2ResponseFormatTests(unittest.TestCase):
1258+
"""Unit tests for converse_alpha2 response_format parameter."""
1259+
1260+
def test_converse_alpha2_passes_response_format_on_request(self):
1261+
"""converse_alpha2 sets response_format on the gRPC request when provided."""
1262+
user_message = create_user_message('Structured output please')
1263+
input_alpha2 = ConversationInputAlpha2(messages=[user_message])
1264+
response_format = Struct()
1265+
response_format.update(
1266+
{'type': 'json_schema', 'json_schema': {'name': 'test', 'schema': {}}}
1267+
)
1268+
1269+
captured_requests = []
1270+
mock_choice_msg = Mock()
1271+
mock_choice_msg.content = 'ok'
1272+
mock_choice_msg.tool_calls = []
1273+
mock_choice = Mock()
1274+
mock_choice.finish_reason = 'stop'
1275+
mock_choice.index = 0
1276+
mock_choice.message = mock_choice_msg
1277+
mock_output = Mock()
1278+
mock_output.choices = [mock_choice]
1279+
mock_response = Mock()
1280+
mock_response.outputs = [mock_output]
1281+
mock_response.context_id = ''
1282+
mock_call = Mock()
1283+
1284+
def capture_run_rpc(rpc, request, *args, **kwargs):
1285+
captured_requests.append(request)
1286+
return (mock_response, mock_call)
1287+
1288+
with patch('dapr.clients.health.DaprHealth.wait_for_sidecar'):
1289+
client = DaprClient('localhost:50011')
1290+
with patch.object(client.retry_policy, 'run_rpc', side_effect=capture_run_rpc):
1291+
client.converse_alpha2(
1292+
name='test-llm',
1293+
inputs=[input_alpha2],
1294+
response_format=response_format,
1295+
)
1296+
1297+
self.assertEqual(len(captured_requests), 1)
1298+
req = captured_requests[0]
1299+
self.assertTrue(hasattr(req, 'response_format'))
1300+
self.assertEqual(req.response_format['type'], 'json_schema')
1301+
self.assertEqual(req.response_format['json_schema']['name'], 'test')
1302+
1303+
11101304
class ExecuteRegisteredToolSyncTests(unittest.TestCase):
11111305
def tearDown(self):
11121306
# Cleanup all tools we may have registered by name prefix

0 commit comments

Comments
 (0)