Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
run: |
export PD_API_ENDPOINT=https://app.paperdebugger.com
export BETA_BUILD=false
export GRAFANA_API_KEY=${{ secrets.GRAFANA_API_KEY }}
cd webapp/_webapp
npm install
npm run build
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/google/wire v0.7.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2
github.com/joho/godotenv v1.5.1
github.com/openai/openai-go/v2 v2.1.1
github.com/openai/openai-go/v2 v2.7.1
github.com/samber/lo v1.51.0
github.com/stretchr/testify v1.10.0
go.mongodb.org/mongo-driver/v2 v2.3.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/openai/openai-go/v2 v2.1.1 h1:/RMA/V3D+yF/Cc4jHXFt6lkqSOWRf5roRi+DvZaDYQI=
github.com/openai/openai-go/v2 v2.1.1/go.mod h1:sIUkR+Cu/PMUVkSKhkk742PRURkQOCFhiwJ7eRSBqmk=
github.com/openai/openai-go/v2 v2.7.1 h1:/tfvTJhfv7hTSL8mWwc5VL4WLLSDL5yn9VqVykdu9r8=
github.com/openai/openai-go/v2 v2.7.1/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
34 changes: 22 additions & 12 deletions internal/api/chat/create_conversation_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,20 @@ func (s *ChatServer) appendConversationMessage(

// 如果 conversationId 是 "", 就创建新对话,否则就追加消息到对话
// conversationType 可以在一次 conversation 中多次切换
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, languageModel models.LanguageModel, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, error) {
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, languageModel models.LanguageModel, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return ctx, nil, err
return ctx, nil, nil, err
}

project, err := s.projectService.GetProject(ctx, actor.ID, projectId)
if err != nil && err != mongo.ErrNoDocuments {
return ctx, nil, err
return ctx, nil, nil, err
}

userInstructions, err := s.userService.GetUserInstructions(ctx, actor.ID)
if err != nil {
return ctx, nil, err
return ctx, nil, nil, err
}

var latexFullSource string
Expand All @@ -202,12 +202,12 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation
latexFullSource = "latex_full_source is not available in debug mode"
default:
if project == nil || project.IsOutOfDate() {
return ctx, nil, shared.ErrProjectOutOfDate("project is out of date")
return ctx, nil, nil, shared.ErrProjectOutOfDate("project is out of date")
}

latexFullSource, err = project.GetFullContent()
if err != nil {
return ctx, nil, err
return ctx, nil, nil, err
}
}

Expand Down Expand Up @@ -238,34 +238,44 @@ func (s *ChatServer) prepare(ctx context.Context, projectId string, conversation
}

if err != nil {
return ctx, nil, err
return ctx, nil, nil, err
}

ctx = contextutil.SetProjectID(ctx, conversation.ProjectID)
ctx = contextutil.SetConversationID(ctx, conversation.ID.Hex())

return ctx, conversation, nil
settings, err := s.userService.GetUserSettings(ctx, actor.ID)
if err != nil {
return ctx, conversation, nil, err
}

return ctx, conversation, settings, nil
}

// Deprecated: Use CreateConversationMessageStream instead.
func (s *ChatServer) CreateConversationMessage(
ctx context.Context,
req *chatv1.CreateConversationMessageRequest,
) (*chatv1.CreateConversationMessageResponse, error) {
ctx, conversation, err := s.prepare(
languageModel := models.LanguageModel(req.GetLanguageModel())
ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
models.LanguageModel(req.GetLanguageModel()),
languageModel,
req.GetConversationType(),
)
if err != nil {
return nil, err
}

openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletion(ctx, conversation.LanguageModel, conversation.OpenaiChatHistory)
llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
APIKey: settings.OpenAIAPIKey,
}
openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletion(ctx, languageModel, conversation.OpenaiChatHistory, llmProvider)
if err != nil {
return nil, err
}
Expand All @@ -290,7 +300,7 @@ func (s *ChatServer) CreateConversationMessage(
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessage(bsonMsg)
}
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages)
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages, llmProvider)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
Expand Down
15 changes: 11 additions & 4 deletions internal/api/chat/create_conversation_message_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,28 @@ func (s *ChatServer) CreateConversationMessageStream(
stream chatv1.ChatService_CreateConversationMessageStreamServer,
) error {
ctx := stream.Context()
ctx, conversation, err := s.prepare(

languageModel := models.LanguageModel(req.GetLanguageModel())
ctx, conversation, settings, err := s.prepare(
ctx,
req.GetProjectId(),
req.GetConversationId(),
req.GetUserMessage(),
req.GetUserSelectedText(),
models.LanguageModel(req.GetLanguageModel()),
languageModel,
req.GetConversationType(),
)
if err != nil {
return s.sendStreamError(stream, err)
}

// 用法跟 ChatCompletion 一样,只是传递了 stream 参数
openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), conversation.LanguageModel, conversation.OpenaiChatHistory)
llmProvider := &models.LLMProviderConfig{
Endpoint: s.cfg.OpenAIBaseURL,
APIKey: settings.OpenAIAPIKey,
}

openaiChatHistory, inappChatHistory, err := s.aiClient.ChatCompletionStream(ctx, stream, conversation.ID.Hex(), languageModel, conversation.OpenaiChatHistory, llmProvider)
if err != nil {
return s.sendStreamError(stream, err)
}
Expand All @@ -64,7 +71,7 @@ func (s *ChatServer) CreateConversationMessageStream(
for i, bsonMsg := range conversation.InappChatHistory {
protoMessages[i] = mapper.BSONToChatMessage(bsonMsg)
}
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages)
title, err := s.aiClient.GetConversationTitle(ctx, protoMessages, llmProvider)
if err != nil {
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
return
Expand Down
104 changes: 104 additions & 0 deletions internal/api/chat/list_supported_models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package chat

import (
"context"
"strings"

"paperdebugger/internal/libs/contextutil"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"

"github.com/openai/openai-go/v2"
)

func (s *ChatServer) ListSupportedModels(
ctx context.Context,
req *chatv1.ListSupportedModelsRequest,
) (*chatv1.ListSupportedModelsResponse, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return nil, err
}

settings, err := s.userService.GetUserSettings(ctx, actor.ID)
if err != nil {
return nil, err
}

var models []*chatv1.SupportedModel
if strings.TrimSpace(settings.OpenAIAPIKey) == "" {
models = []*chatv1.SupportedModel{
{

Name: "GPT-4o",
Slug: openai.ChatModelGPT4o,
},
{
Name: "GPT-4.1",
Slug: openai.ChatModelGPT4_1,
},
{
Name: "GPT-4.1-mini",
Slug: openai.ChatModelGPT4_1Mini,
},
}
} else {
models = []*chatv1.SupportedModel{
{
Name: "GPT 4o",
Slug: openai.ChatModelGPT4o,
},
{
Name: "GPT 4.1",
Slug: openai.ChatModelGPT4_1,
},
{
Name: "GPT 4.1 mini",
Slug: openai.ChatModelGPT4_1Mini,
},
{
Name: "GPT 5",
Slug: openai.ChatModelGPT5,
},
{
Name: "GPT 5 mini",
Slug: openai.ChatModelGPT5Mini,
},
{
Name: "GPT 5 nano",
Slug: openai.ChatModelGPT5Nano,
},
{
Name: "GPT 5 Chat Latest",
Slug: openai.ChatModelGPT5ChatLatest,
},
{
Name: "o1",
Slug: openai.ChatModelO1,
},
{
Name: "o1 mini",
Slug: openai.ChatModelO1Mini,
},
{
Name: "o3",
Slug: openai.ChatModelO3,
},
{
Name: "o3 mini",
Slug: openai.ChatModelO3Mini,
},
{
Name: "o4 mini",
Slug: openai.ChatModelO4Mini,
},
{
Name: "Codex Mini Latest",
Slug: openai.ChatModelCodexMiniLatest,
},
}
}

return &chatv1.ListSupportedModelsResponse{
Models: models,
}, nil
}
2 changes: 2 additions & 0 deletions internal/api/mapper/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ func MapProtoSettingsToModel(settings *userv1.Settings) *models.Settings {
EnableCompletion: settings.EnableCompletion,
FullDocumentRag: settings.FullDocumentRag,
ShowedOnboarding: settings.ShowedOnboarding,
OpenAIAPIKey: settings.OpenaiApiKey,
}
}

Expand All @@ -22,5 +23,6 @@ func MapModelSettingsToProto(settings *models.Settings) *userv1.Settings {
EnableCompletion: settings.EnableCompletion,
FullDocumentRag: settings.FullDocumentRag,
ShowedOnboarding: settings.ShowedOnboarding,
OpenaiApiKey: settings.OpenAIAPIKey,
}
}
14 changes: 14 additions & 0 deletions internal/models/language_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ func (x LanguageModel) Name() string {
return openai.ChatModelGPT5Mini
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_NANO:
return openai.ChatModelGPT5Nano
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_GPT5_CHAT_LATEST:
return openai.ChatModelGPT5ChatLatest
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1:
return openai.ChatModelO1
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O1_MINI:
return openai.ChatModelO1Mini
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3:
return openai.ChatModelO3
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O3_MINI:
return openai.ChatModelO3Mini
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_O4_MINI:
return openai.ChatModelO4Mini
case chatv1.LanguageModel_LANGUAGE_MODEL_OPENAI_CODEX_MINI_LATEST:
return openai.ChatModelCodexMiniLatest
default:
return openai.ChatModelGPT5
}
Expand Down
14 changes: 14 additions & 0 deletions internal/models/llm_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package models

// LLMProviderConfig holds the configuration for LLM API calls.
// If both Endpoint and APIKey are empty, the system default will be used.
type LLMProviderConfig struct {
Endpoint string
APIKey string
ModelName string
}

// IsCustom returns true if the user has configured custom LLM provider settings.
func (c *LLMProviderConfig) IsCustom() bool {
return c != nil && c.APIKey != ""
}
11 changes: 6 additions & 5 deletions internal/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package models
import "go.mongodb.org/mongo-driver/v2/bson"

type Settings struct {
ShowShortcutsAfterSelection bool `bson:"show_shortcuts_after_selection"`
FullWidthPaperDebuggerButton bool `bson:"full_width_paper_debugger_button"`
EnableCompletion bool `bson:"enable_completion"`
FullDocumentRag bool `bson:"full_document_rag"`
ShowedOnboarding bool `bson:"showed_onboarding"`
ShowShortcutsAfterSelection bool `bson:"show_shortcuts_after_selection"`
FullWidthPaperDebuggerButton bool `bson:"full_width_paper_debugger_button"`
EnableCompletion bool `bson:"enable_completion"`
FullDocumentRag bool `bson:"full_document_rag"`
ShowedOnboarding bool `bson:"showed_onboarding"`
OpenAIAPIKey string `bson:"openai_api_key"`
}

type User struct {
Expand Down
25 changes: 23 additions & 2 deletions internal/services/toolkit/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
)

type AIClient struct {
openaiClient *openai.Client
toolCallHandler *handler.ToolCallHandler

db *mongo.Database
Expand All @@ -29,6 +28,29 @@ type AIClient struct {
logger *logger.Logger
}

// SetOpenAIClient sets the appropriate OpenAI client based on the LLM provider config.
// If the config specifies a custom endpoint and API key, a new client is created for that endpoint.
func (a *AIClient) GetOpenAIClient(llmConfig *models.LLMProviderConfig) *openai.Client {
var Endpoint string = llmConfig.Endpoint
var APIKey string = llmConfig.APIKey

if Endpoint == "" {
Endpoint = a.cfg.OpenAIBaseURL
}

if APIKey == "" {
APIKey = a.cfg.OpenAIAPIKey
}

opts := []option.RequestOption{
option.WithAPIKey(APIKey),
option.WithBaseURL(Endpoint),
}

client := openai.NewClient(opts...)
return &client
}

func NewAIClient(
db *db.DB,

Expand Down Expand Up @@ -73,7 +95,6 @@ func NewAIClient(

toolCallHandler := handler.NewToolCallHandler(toolRegistry)
client := &AIClient{
openaiClient: &oaiClient,
toolCallHandler: toolCallHandler,

db: database,
Expand Down
Loading