Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2
github.com/joho/godotenv v1.5.1
github.com/openai/openai-go/v2 v2.7.1
github.com/openai/openai-go/v3 v3.12.0
github.com/samber/lo v1.51.0
github.com/stretchr/testify v1.10.0
go.mongodb.org/mongo-driver/v2 v2.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/openai/openai-go/v2 v2.7.1 h1:/tfvTJhfv7hTSL8mWwc5VL4WLLSDL5yn9VqVykdu9r8=
github.com/openai/openai-go/v2 v2.7.1/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE=
github.com/openai/openai-go/v3 v3.12.0 h1:NkrImaglFQeDycc/n/fEmpFV8kKr8snl9/8X2x4eHOg=
github.com/openai/openai-go/v3 v3.12.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
74 changes: 4 additions & 70 deletions internal/api/chat/create_conversation_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package chat
import (
"context"

"paperdebugger/internal/api/mapper"
"paperdebugger/internal/libs/contextutil"
"paperdebugger/internal/libs/shared"
"paperdebugger/internal/models"
Expand Down Expand Up @@ -115,7 +114,7 @@ func (s *ChatServer) createConversation(
userInstructions string,
userMessage string,
userSelectedText string,
languageModel models.LanguageModel,
modelSlug string,
conversationType chatv1.ConversationType,
) (*models.Conversation, error) {
systemPrompt, err := s.chatService.GetSystemPrompt(ctx, latexFullSource, projectInstructions, userInstructions, conversationType)
Expand All @@ -135,7 +134,7 @@ func (s *ChatServer) createConversation(
}

return s.chatService.InsertConversationToDB(
ctx, userId, projectId, languageModel, messages, oaiHistory.OfInputItemList,
ctx, userId, projectId, modelSlug, messages, oaiHistory.OfInputItemList,
)
}

Expand Down Expand Up @@ -180,7 +179,7 @@ func (s *ChatServer) appendConversationMessage(

// 如果 conversationId 是 "", 就创建新对话,否则就追加消息到对话
// conversationType 可以在一次 conversation 中多次切换
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, languageModel models.LanguageModel, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, modelSlug string, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return ctx, nil, nil, err
Expand Down Expand Up @@ -223,7 +222,7 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation
userInstructions,
userMessage,
userSelectedText,
languageModel,
modelSlug,
conversationType,
)
} else {
Expand Down Expand Up @@ -251,68 +250,3 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation

return ctx, conversation, settings, nil
}

// Deprecated: Use CreateConversationMessageStream instead.
func (s *ChatServer) CreateConversationMessage(
ctx context.Context,
req *chatv1.CreateConversationMessageRequest,
) (*chatv1.CreateConversationMessageResponse, error) {
languageModel := models.LanguageModel(req.GetLanguageModel())
ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
languageModel,
req.GetConversationType(),
)
if err != nil {
return nil, err
}

llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
APIKey: settings.OpenAIAPIKey,
}
openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletion(ctx, languageModel, conversation.OpenaiChatHistory, llmProvider)
if err != nil {
return nil, err
}

bsonMessages := make([]bson.M, len(inappChatHistory))
for i := range inappChatHistory {
bsonMsg, err := convertToBSON(&inappChatHistory[i])
if err != nil {
return nil, err
}
bsonMessages[i] = bsonMsg
}
conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMessages...)
conversation.OpenaiChatHistory = openaiChatHistory

if err := s.chatService.UpdateConversation(conversation); err != nil {
return nil, err
}

go func() {
protoMessages := make([]*chatv1.Message, len(conversation.InappChatHistory))
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessage(bsonMsg)
}
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages, llmProvider)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
}
conversation.Title = title
if err := s.chatService.UpdateConversation(conversation); err != nil {
s.logger.Error("Failed to update conversation with new title", "error", err, "conversationID", conversation.ID.Hex())
return
}
}()

return &chatv1.CreateConversationMessageResponse{
Conversation: mapper.MapModelConversationToProto(conversation),
}, nil
}
13 changes: 8 additions & 5 deletions internal/api/chat/create_conversation_message_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ func (s *ChatServer) CreateConversationMessageStream(
) error {
ctx := stream.Context()

languageModel := models.LanguageModel(req.GetLanguageModel())
modelSlug := req.GetModelSlug()
if modelSlug == "" {
modelSlug = models.LanguageModel(req.GetLanguageModel()).Name()
}

ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
languageModel,
modelSlug,
req.GetConversationType(),
)
if err != nil {
Expand All @@ -41,11 +45,10 @@ func (s *ChatServer) CreateConversationMessageStream(

// 用法跟 ChatCompletion 一样,只是传递了 stream 参数
llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
APIKey: settings.OpenAIAPIKey,
APIKey: settings.OpenAIAPIKey,
}

openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), languageModel, conversation.OpenaiChatHistory, llmProvider)
openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistory, llmProvider)
if err != nil {
return s.sendStreamError(stream, err)
}
Expand Down
15 changes: 13 additions & 2 deletions internal/api/mapper/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,21 @@ func MapModelConversationToProto(conversation *models.Conversation) *chatv1.Conv
return msg.GetPayload().GetMessageType() != &chatv1.MessagePayload_System{}
})

modelSlug := conversation.ModelSlug
if modelSlug == "" {
modelSlug = models.SlugFromLanguageModel(models.LanguageModel(conversation.LanguageModel))
}

languageModel := chatv1.LanguageModel(conversation.LanguageModel)
if languageModel == chatv1.LanguageModel_LANGUAGE_MODEL_UNSPECIFIED {
languageModel = chatv1.LanguageModel(models.LanguageModelFromSlug(modelSlug))
}

return &chatv1.Conversation{
Id: conversation.ID.Hex(),
Title: conversation.Title,
LanguageModel: chatv1.LanguageModel(conversation.LanguageModel),
Messages: filteredMessages,
LanguageModel: languageModel, // backward compatibility
// ModelSlug: modelSlug,
Messages: filteredMessages,
}
}
1 change: 1 addition & 0 deletions internal/models/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Conversation struct {
ProjectID string `bson:"project_id"`
Title string `bson:"title"`
LanguageModel LanguageModel `bson:"language_model"`
ModelSlug string `bson:"model_slug"`
InappChatHistory []bson.M `bson:"inapp_chat_history"` // Store as raw BSON to avoid protobuf decoding issues

OpenaiChatHistory responses.ResponseInputParam `bson:"openai_chat_history"` // 实际上发给 GPT 的聊天历史
Expand Down
66 changes: 66 additions & 0 deletions internal/models/language_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,69 @@ func (x LanguageModel) Name() string {
return openai.ChatModelGPT5
}
}

func LanguageModelFromSlug(slug string) LanguageModel {
switch slug {
case "gpt-4o":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT4O)
case "gpt-4.1":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT41)
case "gpt-4.1-mini":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT41_MINI)
case "gpt-5":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5)
case "gpt-5-mini":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_MINI)
case "gpt-5-nano":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_NANO)
case "gpt-5-chat-latest":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_CHAT_LATEST)
case "o1":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1)
case "o1-mini":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1_MINI)
case "o3":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3)
case "o3-mini":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3_MINI)
case "o4-mini":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O4_MINI)
case "codex-mini-latest":
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_CODEX_MINI_LATEST)
default:
return LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_UNSPECIFIED)
}
}

func SlugFromLanguageModel(languageModel LanguageModel) string {
switch languageModel {
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT4O):
return "gpt-4o"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT41):
return "gpt-4.1"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT41_MINI):
return "gpt-4.1-mini"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5):
return "gpt-5"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_MINI):
return "gpt-5-mini"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_NANO):
return "gpt-5-nano"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_CHAT_LATEST):
return "gpt-5-chat-latest"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1):
return "o1"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1_MINI):
return "o1-mini"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3):
return "o3"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3_MINI):
return "o3-mini"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O4_MINI):
return "o4-mini"
case LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_CODEX_MINI_LATEST):
return "codex-mini-latest"
default:
return "unknown"
}
}
4 changes: 2 additions & 2 deletions internal/services/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (s *ChatService) GetPrompt(ctx context.Context, content string, selectedTex
return strings.TrimSpace(userPromptBuffer.String()), nil
}

func (s *ChatService) InsertConversationToDB(ctx context.Context, userID bson.ObjectID, projectID string, languageModel models.LanguageModel, inappChatHistory []*chatv1.Message, openaiChatHistory responses.ResponseInputParam) (*models.Conversation, error) {
func (s *ChatService) InsertConversationToDB(ctx context.Context, userID bson.ObjectID, projectID string, modelSlug string, inappChatHistory []*chatv1.Message, openaiChatHistory responses.ResponseInputParam) (*models.Conversation, error) {
// Convert protobuf messages to BSON
bsonMessages := make([]bson.M, len(inappChatHistory))
for i := range inappChatHistory {
Expand All @@ -116,7 +116,7 @@ func (s *ChatService) InsertConversationToDB(ctx context.Context, userID bson.Ob
UserID: userID,
ProjectID: projectID,
Title: DefaultConversationTitle,
LanguageModel: languageModel,
ModelSlug: modelSlug,
InappChatHistory: bsonMessages,
OpenaiChatHistory: openaiChatHistory,
}
Expand Down
12 changes: 6 additions & 6 deletions internal/services/toolkit/client/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import (
// Parameters:
//
// ctx: The context for controlling cancellation and deadlines.
// languageModel: The language model to use for completion (e.g., GPT-3.5, GPT-4).
// modelSlug: The language model to use for completion (e.g., GPT-3.5, GPT-4).
// messages: The full chat history (as input) to send to the language model.
//
// Returns:
// 1. The full chat history sent to the language model (including any tool call results).
// 2. The incremental chat history visible to the user (including tool call results and assistant responses).
// 3. An error, if any occurred during the process.
func (a *AIClient) ChatCompletion(ctx context.Context, languageModel models.LanguageModel, messages responses.ResponseInputParam, llmProvider *models.LLMProviderConfig) (responses.ResponseInputParam, []chatv1.Message, error) {
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStream(ctx, nil, "", languageModel, messages, llmProvider)
func (a *AIClient) ChatCompletion(ctx context.Context, modelSlug string, messages responses.ResponseInputParam, llmProvider *models.LLMProviderConfig) (responses.ResponseInputParam, []chatv1.Message, error) {
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStream(ctx, nil, "", modelSlug, messages, llmProvider)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -50,19 +50,19 @@ func (a *AIClient) ChatCompletion(ctx context.Context, languageModel models.Lang
// - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop.
// - If no tool calls are needed, it appends the assistant's response and exits the loop.
// - Finally, it returns the updated chat histories and any error encountered.
func (a *AIClient) ChatCompletionStream(ctx context.Context, callbackStream chatv1.ChatService_CreateConversationMessageStreamServer, conversationId string, languageModel models.LanguageModel, messages responses.ResponseInputParam, llmProvider *models.LLMProviderConfig) (responses.ResponseInputParam, []chatv1.Message, error) {
func (a *AIClient) ChatCompletionStream(ctx context.Context, callbackStream chatv1.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages responses.ResponseInputParam, llmProvider *models.LLMProviderConfig) (responses.ResponseInputParam, []chatv1.Message, error) {
openaiChatHistory := responses.ResponseNewParamsInputUnion{OfInputItemList: messages}
inappChatHistory := []chatv1.Message{}

streamHandler := handler.NewStreamHandler(callbackStream, conversationId, languageModel)
streamHandler := handler.NewStreamHandler(callbackStream, conversationId, modelSlug)

streamHandler.SendInitialization()
defer func() {
streamHandler.SendFinalization()
}()

oaiClient := a.GetOpenAIClient(llmProvider)
params := getDefaultParams(languageModel, openaiChatHistory, a.toolCallHandler.Registry)
params := getDefaultParams(modelSlug, openaiChatHistory, a.toolCallHandler.Registry)

for {
params.Input = openaiChatHistory
Expand Down
2 changes: 1 addition & 1 deletion internal/services/toolkit/client/get_conversation_title.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (a *AIClient) GetConversationTitle(ctx context.Context, inappChatHistory []
message := strings.Join(messages, "\n")
message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message)

_, resp, err := a.ChatCompletion(ctx, models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT41_MINI), responses.ResponseInputParam{
_, resp, err := a.ChatCompletion(ctx, "gpt-4.1-mini", responses.ResponseInputParam{
{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "system",
Expand Down
32 changes: 18 additions & 14 deletions internal/services/toolkit/client/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ This file contains utility functions for the client package. (Mainly miscellaneo
It is used to append assistant responses to both OpenAI and in-app chat histories, and to create response items for chat interactions.
*/
import (
"paperdebugger/internal/models"
"paperdebugger/internal/services/toolkit/registry"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"

"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/responses"
"github.com/samber/lo"
)

// appendAssistantTextResponse appends the assistant's response to both OpenAI and in-app chat histories.
Expand Down Expand Up @@ -43,26 +43,30 @@ func appendAssistantTextResponse(openaiChatHistory *responses.ResponseNewParamsI
// getDefaultParams constructs the default parameters for a chat completion request.
// The tool registry is managed centrally by the registry package.
// The chat history is constructed manually, so Store must be set to false.
func getDefaultParams(languageModel models.LanguageModel, chatHistory responses.ResponseNewParamsInputUnion, toolRegistry *registry.ToolRegistry) responses.ResponseNewParams {
if languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_MINI) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_NANO) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_CHAT_LATEST) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O4_MINI) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3_MINI) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1_MINI) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1) ||
languageModel == models.LanguageModel(chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_CODEX_MINI_LATEST) {
func getDefaultParams(modelSlug string, chatHistory responses.ResponseNewParamsInputUnion, toolRegistry *registry.ToolRegistry) responses.ResponseNewParams {
var reasoningModels = []string{
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"gpt-5-chat-latest",
"o4-mini",
"o3-mini",
"o3",
"o1-mini",
"o1",
"codex-mini-latest",
}
if lo.Contains(reasoningModels, modelSlug) {
return responses.ResponseNewParams{
Model: languageModel.Name(),
Model: modelSlug,
Tools: toolRegistry.GetTools(),
Input: chatHistory,
Store: openai.Bool(false),
}
}

return responses.ResponseNewParams{
Model: languageModel.Name(),
Model: modelSlug,
Temperature: openai.Float(0.7),
MaxOutputTokens: openai.Int(4000), // DEBUG POINT: change this to test the frontend handler
Tools: toolRegistry.GetTools(), // 工具注册由 registry 统一管理
Expand Down
Loading