Skip to content

Commit ed722e3

Browse files
committed
Refactor TgptUpdateProcessor to switch from ChatCompletion API to Responses API. Add support for web search, improve content and token handling logic.
1 parent 6e2f4d6 commit ed722e3

4 files changed

Lines changed: 157 additions & 163 deletions

File tree

tgpt/src/main/kotlin/io/heapy/tgpt/bot/TgptUpdateProcessor.kt

Lines changed: 144 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package io.heapy.tgpt.bot
22

33
import com.openai.models.ChatModel
4-
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam
5-
import com.openai.models.chat.completions.ChatCompletionContentPart
6-
import com.openai.models.chat.completions.ChatCompletionContentPartImage
7-
import com.openai.models.chat.completions.ChatCompletionContentPartText
8-
import com.openai.models.chat.completions.ChatCompletionCreateParams
9-
import com.openai.models.chat.completions.ChatCompletionMessageParam
10-
import com.openai.models.chat.completions.ChatCompletionSystemMessageParam
11-
import com.openai.models.chat.completions.ChatCompletionUserMessageParam
4+
import com.openai.models.responses.EasyInputMessage
5+
import com.openai.models.responses.Response
6+
import com.openai.models.responses.ResponseCreateParams
7+
import com.openai.models.responses.ResponseInputContent
8+
import com.openai.models.responses.ResponseInputFile
9+
import com.openai.models.responses.ResponseInputImage
10+
import com.openai.models.responses.ResponseInputItem
11+
import com.openai.models.responses.ResponseInputText
12+
import com.openai.models.responses.WebSearchTool
1213
import io.heapy.komok.tech.logging.Logger
1314
import io.heapy.kotbot.bot.Kotbot
1415
import io.heapy.kotbot.bot.TelegramApiError
@@ -148,18 +149,28 @@ class TgptUpdateProcessor(
148149
threadMessageDao.getThreadMessages(threadId)
149150
}
150151

151-
val openAiMessages = prependTelegramMarkdownPrompt(
152-
buildOpenAiMessages(threadMessages),
153-
)
152+
val inputItems = buildResponseInput(threadMessages)
153+
val instructions = TELEGRAM_MARKDOWN_SYSTEM_PROMPT + "\n\n" + openAiService.systemPrompt
154154

155-
val params = ChatCompletionCreateParams.builder()
155+
val paramsBuilder = ResponseCreateParams.builder()
156156
.model(ChatModel.of(openAiService.model()))
157-
.messages(openAiMessages)
158-
.maxCompletionTokens(openAiService.maxTokens().toLong())
159-
.build()
157+
.inputOfResponse(inputItems)
158+
.instructions(instructions)
159+
.maxOutputTokens(openAiService.maxTokens().toLong())
160+
161+
if (openAiService.webSearchEnabled) {
162+
paramsBuilder.addTool(
163+
WebSearchTool.builder()
164+
.type(WebSearchTool.Type.WEB_SEARCH)
165+
.searchContextSize(WebSearchTool.SearchContextSize.MEDIUM)
166+
.build()
167+
)
168+
}
160169

161-
val completion = try {
162-
openAiService.chatCompletion(params)
170+
val params = paramsBuilder.build()
171+
172+
val response = try {
173+
openAiService.createResponse(params)
163174
} catch (e: Exception) {
164175
log.error("OpenAI API call failed", e)
165176
replyToMessage(
@@ -169,8 +180,7 @@ class TgptUpdateProcessor(
169180
return
170181
}
171182

172-
val responseText = completion.choices().firstOrNull()
173-
?.message()?.content()?.orElse(null)
183+
val responseText = extractOutputText(response)
174184
?: "No response from the model."
175185

176186
// Send reply to Telegram
@@ -191,24 +201,24 @@ class TgptUpdateProcessor(
191201
}
192202

193203
// Record API call with cost
194-
val usage = completion.usage()
204+
val usage = response.usage()
195205
if (usage.isPresent) {
196206
val usageVal = usage.get()
197-
val promptTokens = usageVal.promptTokens().toInt()
198-
val completionTokens = usageVal.completionTokens().toInt()
207+
val inputTokens = usageVal.inputTokens().toInt()
208+
val outputTokens = usageVal.outputTokens().toInt()
199209
val totalTokens = usageVal.totalTokens().toInt()
200210
val cost = CostCalculator.estimateCost(
201211
model = openAiService.model(),
202-
promptTokens = promptTokens,
203-
completionTokens = completionTokens,
212+
inputTokens = inputTokens,
213+
outputTokens = outputTokens,
204214
)
205215
transactionProvider.transaction {
206216
apiCallDao.recordApiCall(
207217
threadId = threadId,
208218
telegramUserId = userId,
209219
model = openAiService.model(),
210-
promptTokens = promptTokens,
211-
completionTokens = completionTokens,
220+
promptTokens = inputTokens,
221+
completionTokens = outputTokens,
212222
totalTokens = totalTokens,
213223
estimatedCostUsd = cost,
214224
)
@@ -325,27 +335,15 @@ class TgptUpdateProcessor(
325335
return
326336
}
327337

328-
val params = ChatCompletionCreateParams.builder()
338+
val params = ResponseCreateParams.builder()
329339
.model(ChatModel.of(openAiService.model()))
330-
.messages(
331-
listOf(
332-
ChatCompletionMessageParam.ofSystem(
333-
ChatCompletionSystemMessageParam.builder()
334-
.content(CHECKLIST_SYSTEM_PROMPT)
335-
.build(),
336-
),
337-
ChatCompletionMessageParam.ofUser(
338-
ChatCompletionUserMessageParam.builder()
339-
.content(sourceText)
340-
.build(),
341-
),
342-
),
343-
)
344-
.maxCompletionTokens(openAiService.maxTokens().toLong())
340+
.instructions(CHECKLIST_SYSTEM_PROMPT)
341+
.input(sourceText)
342+
.maxOutputTokens(openAiService.maxTokens().toLong())
345343
.build()
346344

347-
val completion = try {
348-
openAiService.chatCompletion(params)
345+
val response = try {
346+
openAiService.createResponse(params)
349347
} catch (e: Exception) {
350348
log.error("OpenAI API call failed for /checklist", e)
351349
replyToMessage(
@@ -354,8 +352,7 @@ class TgptUpdateProcessor(
354352
)
355353
return
356354
}
357-
val checklistText = completion.choices().firstOrNull()
358-
?.message()?.content()?.orElse(null)
355+
val checklistText = extractOutputText(response)
359356
?.trim()
360357
?.takeIf { it.isNotBlank() }
361358
?: "Could not generate checklist."
@@ -552,17 +549,6 @@ class TgptUpdateProcessor(
552549
return description.contains("can't parse entities", ignoreCase = true)
553550
}
554551

555-
private fun prependTelegramMarkdownPrompt(
556-
messages: List<ChatCompletionMessageParam>,
557-
): List<ChatCompletionMessageParam> {
558-
val markdownPrompt = ChatCompletionMessageParam.ofSystem(
559-
ChatCompletionSystemMessageParam.builder()
560-
.content(TELEGRAM_MARKDOWN_SYSTEM_PROMPT)
561-
.build(),
562-
)
563-
return listOf(markdownPrompt) + messages
564-
}
565-
566552
private fun parseChecklistTasks(text: String): List<String> {
567553
return text
568554
.lineSequence()
@@ -648,129 +634,135 @@ class TgptUpdateProcessor(
648634
return null
649635
}
650636

651-
private suspend fun buildOpenAiMessages(
652-
threadMessages: List<ThreadMessage>,
653-
): List<ChatCompletionMessageParam> {
654-
return threadMessages.map { msg ->
655-
when (msg.role) {
656-
MessageRole.system -> ChatCompletionMessageParam.ofSystem(
657-
ChatCompletionSystemMessageParam.builder()
658-
.content(msg.content)
659-
.build(),
660-
)
637+
private fun extractOutputText(response: Response): String? {
638+
return response.output()
639+
.filter { it.isMessage() }
640+
.flatMap { it.asMessage().content() }
641+
.filter { it.isOutputText() }
642+
.joinToString("") { it.asOutputText().text() }
643+
.ifBlank { null }
644+
}
661645

662-
MessageRole.user -> when (msg.contentType) {
663-
ContentType.image_url -> {
664-
val payload = deserializeImageContentPayload(msg.content)
665-
val parts = mutableListOf<ChatCompletionContentPart>()
666-
667-
payload.caption
668-
?.takeIf { it.isNotBlank() }
669-
?.let {
670-
parts += ChatCompletionContentPart.ofText(
671-
ChatCompletionContentPartText.builder()
672-
.text(it)
646+
private suspend fun buildResponseInput(
647+
threadMessages: List<ThreadMessage>,
648+
): List<ResponseInputItem> {
649+
return threadMessages
650+
.filter { it.role != MessageRole.system }
651+
.map { msg ->
652+
when (msg.role) {
653+
MessageRole.system -> error("system messages should be filtered out")
654+
655+
MessageRole.user -> when (msg.contentType) {
656+
ContentType.image_url -> {
657+
val payload = deserializeImageContentPayload(msg.content)
658+
val contentParts = mutableListOf<ResponseInputContent>()
659+
660+
payload.caption
661+
?.takeIf { it.isNotBlank() }
662+
?.let {
663+
contentParts += ResponseInputContent.ofInputText(
664+
ResponseInputText.builder()
665+
.text(it)
666+
.build(),
667+
)
668+
}
669+
670+
payload.fileIds.forEach { fileId ->
671+
val imageBytes = telegramFileService.downloadFile(fileId)
672+
val mimeType = detectImageMimeType(imageBytes)
673+
val base64 = Base64.encode(imageBytes)
674+
val imageUrl = "data:$mimeType;base64,$base64"
675+
676+
contentParts += ResponseInputContent.ofInputImage(
677+
ResponseInputImage.builder()
678+
.imageUrl(imageUrl)
679+
.detail(ResponseInputImage.Detail.AUTO)
673680
.build(),
674681
)
675682
}
676683

677-
payload.fileIds.forEach { fileId ->
678-
val imageBytes = telegramFileService.downloadFile(fileId)
679-
val mimeType = detectImageMimeType(imageBytes)
680-
val base64 = Base64.encode(imageBytes)
681-
val imageUrl = "data:$mimeType;base64,$base64"
682-
683-
parts += ChatCompletionContentPart.ofImageUrl(
684-
ChatCompletionContentPartImage.builder()
685-
.imageUrl(
686-
ChatCompletionContentPartImage.ImageUrl.builder()
687-
.url(imageUrl)
684+
val safeParts = if (contentParts.isEmpty()) {
685+
listOf(
686+
ResponseInputContent.ofInputText(
687+
ResponseInputText.builder()
688+
.text("[Image message]")
688689
.build(),
689-
)
690+
),
691+
)
692+
} else {
693+
contentParts
694+
}
695+
696+
ResponseInputItem.ofEasyInputMessage(
697+
EasyInputMessage.builder()
698+
.role(EasyInputMessage.Role.USER)
699+
.contentOfResponseInputMessageContentList(safeParts)
690700
.build(),
691701
)
692702
}
693703

694-
val safeParts = if (parts.isEmpty()) {
695-
listOf(
696-
ChatCompletionContentPart.ofText(
697-
ChatCompletionContentPartText.builder()
698-
.text("[Image message]")
699-
.build(),
700-
),
704+
ContentType.file -> {
705+
val payload = deserializeFileContentPayload(msg.content)
706+
val contentParts = mutableListOf<ResponseInputContent>()
707+
708+
payload.caption
709+
?.takeIf { it.isNotBlank() }
710+
?.let {
711+
contentParts += ResponseInputContent.ofInputText(
712+
ResponseInputText.builder()
713+
.text(it)
714+
.build(),
715+
)
716+
}
717+
718+
val fileBytes = telegramFileService.downloadFile(payload.fileId)
719+
val base64 = Base64.encode(fileBytes)
720+
721+
contentParts += ResponseInputContent.ofInputFile(
722+
ResponseInputFile.builder()
723+
.fileData("data:${payload.mimeType};base64,$base64")
724+
.filename(payload.fileName ?: defaultFilenameForMimeType(payload.mimeType))
725+
.build(),
726+
)
727+
728+
ResponseInputItem.ofEasyInputMessage(
729+
EasyInputMessage.builder()
730+
.role(EasyInputMessage.Role.USER)
731+
.contentOfResponseInputMessageContentList(contentParts)
732+
.build(),
701733
)
702-
} else {
703-
parts
704734
}
705735

706-
ChatCompletionMessageParam.ofUser(
707-
ChatCompletionUserMessageParam.builder()
708-
.contentOfArrayOfContentParts(safeParts)
736+
ContentType.transcription -> ResponseInputItem.ofEasyInputMessage(
737+
EasyInputMessage.builder()
738+
.role(EasyInputMessage.Role.USER)
739+
.content("[Voice message]: ${msg.content}")
709740
.build(),
710741
)
711-
}
712-
713-
ContentType.file -> {
714-
val payload = deserializeFileContentPayload(msg.content)
715-
val parts = mutableListOf<ChatCompletionContentPart>()
716-
717-
payload.caption
718-
?.takeIf { it.isNotBlank() }
719-
?.let {
720-
parts += ChatCompletionContentPart.ofText(
721-
ChatCompletionContentPartText.builder()
722-
.text(it)
723-
.build(),
724-
)
725-
}
726742

727-
val fileBytes = telegramFileService.downloadFile(payload.fileId)
728-
val base64 = Base64.encode(fileBytes)
729-
730-
parts += ChatCompletionContentPart.ofFile(
731-
ChatCompletionContentPart.File.builder()
732-
.file(
733-
ChatCompletionContentPart.File.FileObject.builder()
734-
.fileData("data:${payload.mimeType};base64,$base64")
735-
.filename(payload.fileName ?: defaultFilenameForMimeType(payload.mimeType))
736-
.build(),
737-
)
743+
ContentType.text -> ResponseInputItem.ofEasyInputMessage(
744+
EasyInputMessage.builder()
745+
.role(EasyInputMessage.Role.USER)
746+
.content(msg.content)
738747
.build(),
739748
)
740749

741-
ChatCompletionMessageParam.ofUser(
742-
ChatCompletionUserMessageParam.builder()
743-
.contentOfArrayOfContentParts(parts)
750+
null -> ResponseInputItem.ofEasyInputMessage(
751+
EasyInputMessage.builder()
752+
.role(EasyInputMessage.Role.USER)
753+
.content(msg.content)
744754
.build(),
745755
)
746756
}
747757

748-
ContentType.transcription -> ChatCompletionMessageParam.ofUser(
749-
ChatCompletionUserMessageParam.builder()
750-
.content("[Voice message]: ${msg.content}")
751-
.build(),
752-
)
753-
754-
ContentType.text -> ChatCompletionMessageParam.ofUser(
755-
ChatCompletionUserMessageParam.builder()
756-
.content(msg.content)
757-
.build(),
758-
)
759-
760-
null -> ChatCompletionMessageParam.ofUser(
761-
ChatCompletionUserMessageParam.builder()
758+
MessageRole.assistant -> ResponseInputItem.ofEasyInputMessage(
759+
EasyInputMessage.builder()
760+
.role(EasyInputMessage.Role.ASSISTANT)
762761
.content(msg.content)
763762
.build(),
764763
)
765764
}
766-
767-
MessageRole.assistant -> ChatCompletionMessageParam.ofAssistant(
768-
ChatCompletionAssistantMessageParam.builder()
769-
.content(msg.content)
770-
.build(),
771-
)
772765
}
773-
}
774766
}
775767

776768
private fun serializeImageContentPayload(

0 commit comments

Comments
 (0)