diff --git a/cmd/slim/main.go b/cmd/slim/main.go new file mode 100644 index 0000000000..137785c460 --- /dev/null +++ b/cmd/slim/main.go @@ -0,0 +1,93 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "io" + "log/slog" + "os" + + "github.com/langgenius/dify-plugin-daemon/pkg/slim" +) + +func main() { + slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil))) + id := flag.String("id", "", "plugin unique identifier") + action := flag.String("action", "", "plugin access action") + args := flag.String("args", "", "plugin invocation parameters (JSON); if omitted, read from stdin") + configFile := flag.String("config", "", "path to JSON config file (replaces env vars)") + flag.Parse() + + if *id == "" || *action == "" { + fatal(slim.NewError(slim.ErrInvalidInput, "usage: slim -id -action [-args ''] [-config ]")) + } + + argsJSON := *args + if argsJSON == "" { + b, err := io.ReadAll(os.Stdin) + if err != nil { + fatal(slim.NewError(slim.ErrInvalidInput, "failed to read stdin: "+err.Error())) + } + if len(bytes.TrimSpace(b)) == 0 { + fatal(slim.NewError(slim.ErrInvalidInput, "no -args flag and no JSON on stdin")) + } + argsJSON = string(b) + } + + ctx, err := slim.NewInvokeContext(*id, *action, argsJSON) + if err != nil { + fatal(err) + } + + var cfg *slim.SlimConfig + if *configFile != "" { + cfg, err = slim.LoadConfigFromFile(*configFile) + } else { + cfg, err = slim.LoadConfig() + } + if err != nil { + fatal(err) + } + + out := slim.NewOutputWriter(os.Stdout) + + switch cfg.Mode { + case slim.ModeLocal: + err = slim.RunLocal(ctx, &cfg.Local, out) + case slim.ModeRemote: + err = slim.RunRemote(ctx, &cfg.Remote, out) + default: + err = slim.NewError(slim.ErrUnknownMode, cfg.Mode) + } + + if err != nil { + fatal(err) + } +} + +func fatal(err error) { + exitCode := slim.ExitPluginError + var errorToMarshal *slim.SlimError + + if se, ok := err.(*slim.SlimError); ok { + errorToMarshal = se + exitCode = se.ExitCode() + } else { + // Wrap non-SlimError types to ensure they are marshalled to JSON correctly. + errorToMarshal = slim.NewError(slim.ErrPluginExec, err.Error()) + } + + b, marshalErr := json.Marshal(errorToMarshal) + if marshalErr != nil { + // This should be practically impossible since SlimError is designed for JSON. + // As a last resort, print a hardcoded error message. + os.Stderr.Write([]byte(`{"code":"INTERNAL_ERROR","message":"failed to marshal error to JSON"}`)) + os.Stderr.Write([]byte("\n")) + os.Exit(slim.ExitPluginError) + } + + os.Stderr.Write(b) + os.Stderr.Write([]byte("\n")) + os.Exit(exitCode) +} diff --git a/internal/core/local_runtime/constructor_slim.go b/internal/core/local_runtime/constructor_slim.go new file mode 100644 index 0000000000..2caab77fa0 --- /dev/null +++ b/internal/core/local_runtime/constructor_slim.go @@ -0,0 +1,42 @@ +package local_runtime + +import ( + "sync" + + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_runtime" + "github.com/langgenius/dify-plugin-daemon/internal/types/app" + "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/pkg/plugin_packager/decoder" +) + +// NewLocalPluginRuntime creates a LocalPluginRuntime with a known working path. +// Unlike ConstructPluginRuntime, it does not compute a checksum-based path, +// which avoids walking the full directory tree (including .venv / .uv-cache). +func NewLocalPluginRuntime( + appConfig *app.Config, + pluginDecoder decoder.PluginDecoder, + manifest plugin_entities.PluginDeclaration, + workingPath string, +) *LocalPluginRuntime { + return &LocalPluginRuntime{ + PluginRuntime: plugin_entities.PluginRuntime{ + Config: manifest, + State: plugin_entities.PluginRuntimeState{ + Status: plugin_entities.PLUGIN_RUNTIME_STATUS_PENDING, + Verified: manifest.Verified, + WorkingPath: workingPath, + }, + }, + BasicChecksum: basic_runtime.BasicChecksum{ + Decoder: pluginDecoder, + }, + scheduleStatus: ScheduleStatusStopped, + defaultPythonInterpreterPath: appConfig.PythonInterpreterPath, + uvPath: appConfig.UvPath, + appConfig: appConfig, + instances: []*PluginInstance{}, + instanceLocker: &sync.RWMutex{}, + notifiers: []PluginRuntimeNotifier{}, + notifierLock: &sync.Mutex{}, + } +} diff --git a/internal/server/constants/constants.go b/internal/server/constants/constants.go index e3c1108d55..0d6c15aaaf 100644 --- a/internal/server/constants/constants.go +++ b/internal/server/constants/constants.go @@ -1,9 +1,10 @@ package constants const ( - X_PLUGIN_ID = "X-Plugin-ID" - X_API_KEY = "X-Api-Key" - X_ADMIN_API_KEY = "X-Admin-Api-Key" + X_PLUGIN_ID = "X-Plugin-ID" + X_API_KEY = "X-Api-Key" + X_ADMIN_API_KEY = "X-Admin-Api-Key" + PluginUniqueIdentifier = "X-Plugin-Unique-Identifier" CONTEXT_KEY_PLUGIN_INSTALLATION = "plugin_installation" CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER = "plugin_unique_identifier" diff --git a/internal/server/http_server.go b/internal/server/http_server.go index f28bbfcc1b..2c1c494a6b 100644 --- a/internal/server/http_server.go +++ b/internal/server/http_server.go @@ -18,7 +18,7 @@ import ( // server starts a http server and returns a function to stop it func (app *App) server(config *app.Config) func() { -engine := gin.New() + engine := gin.New() engine.Use(log.RecoveryMiddleware()) engine.Use(log.TraceMiddleware()) // OpenTelemetry middleware (extracts upstream trace context and starts server spans) @@ -42,6 +42,7 @@ engine := gin.New() serverlessTransactionGroup := engine.Group("/backwards-invocation") pluginGroup := engine.Group("/plugin/:tenant_id") pprofGroup := engine.Group("/debug/pprof") + invokeGroup := engine.Group("/v2/invoke") if config.AdminApiEnabled { if len(config.AdminApiKey) < 10 { @@ -72,6 +73,7 @@ engine := gin.New() app.serverlessTransactionGroup(serverlessTransactionGroup, config) app.pluginGroup(pluginGroup, config) app.pprofGroup(pprofGroup, config) + app.invokeGroup(invokeGroup, config) srv := &http.Server{ Addr: fmt.Sprintf("%s:%d", config.ServerHost, config.ServerPort), @@ -212,3 +214,17 @@ func (app *App) pprofGroup(group *gin.RouterGroup, config *app.Config) { group.GET("/threadcreate", controllers.PprofThreadcreate) } } + +func (app *App) invokeGroup(group *gin.RouterGroup, config *app.Config) { + group.Use(CheckingKey(config.ServerKey)) + dispatchGroup := group.Group("/dispatch") + dispatchGroup.Use(controllers.CollectActiveDispatchRequests()) + dispatchGroup.Use(app.FetchPluginDirect()) + dispatchGroup.Use(app.RedirectPluginInvoke()) + dispatchGroup.Use(app.InitClusterID()) + + dispatchGroup.POST("/agent_strategy/invoke", + controllers.InvokeAgentStrategy(config)) + + app.setupGeneratedRoutes(dispatchGroup, config) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go index a6bee434e1..0947c55b4f 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -182,10 +182,32 @@ func (app *App) InitClusterID() gin.HandlerFunc { func (app *App) AdminAPIKey(key string) gin.HandlerFunc { return func(ctx *gin.Context) { if ctx.GetHeader(constants.X_ADMIN_API_KEY) != key { - ctx.AbortWithStatusJSON(401, gin.H{"message": "unauthorized"}) + ctx.AbortWithStatusJSON( + 401, + exception.UnauthorizedError().ToResponse()) return } + ctx.Next() + } +} +func (app *App) FetchPluginDirect() gin.HandlerFunc { + return func(ctx *gin.Context) { + identifier := ctx.Request.Header.Get(constants.PluginUniqueIdentifier) + if identifier == "" { + ctx.AbortWithStatusJSON( + 400, + exception.BadRequestError(errors.New("X-Plugin-Unique-Identifier header is required")).ToResponse()) + return + } + pluginID, err := plugin_entities.NewPluginUniqueIdentifier(identifier) + if err != nil { + ctx.AbortWithStatusJSON(400, + exception.UniqueIdentifierError(err).ToResponse(), + ) + return + } + ctx.Set(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER, pluginID) ctx.Next() } } diff --git a/pkg/slim/config.go b/pkg/slim/config.go new file mode 100644 index 0000000000..e8cf2bd835 --- /dev/null +++ b/pkg/slim/config.go @@ -0,0 +1,159 @@ +package slim + +import ( + "encoding/json" + "os" + + "github.com/google/uuid" + "github.com/langgenius/dify-plugin-daemon/internal/types/app" +) + +type RequestMeta struct { + TenantID string `json:"tenant_id"` + UserID string `json:"user_id"` + Data json.RawMessage `json:"data"` +} + +type LocalConfig struct { + Folder string `json:"folder"` + PythonPath string `json:"python_path"` + UvPath string `json:"uv_path"` + PythonEnvInitTimeout int `json:"python_env_init_timeout"` + MaxExecutionTimeout int `json:"max_execution_timeout"` + PipMirrorURL string `json:"pip_mirror_url"` + PipExtraArgs string `json:"pip_extra_args"` + MarketplaceURL string `json:"marketplace_url"` +} + +type RemoteConfig struct { + DaemonAddr string `json:"daemon_addr"` + DaemonKey string `json:"daemon_key"` +} + +type InvokeContext struct { + PluginID string + Action string + Request RequestMeta +} + +type SlimConfig struct { + Mode string `json:"mode"` + Local LocalConfig `json:"local"` + Remote RemoteConfig `json:"remote"` +} + +func NewInvokeContext(id, action, argsJSON string) (*InvokeContext, error) { + var req RequestMeta + if err := json.Unmarshal([]byte(argsJSON), &req); err != nil { + return nil, NewError(ErrInvalidArgsJSON, err.Error()) + } + if req.TenantID == "" { + req.TenantID = uuid.Nil.String() + } + return &InvokeContext{ + PluginID: id, + Action: action, + Request: req, + }, nil +} + +func LoadConfigFromFile(path string) (*SlimConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, NewError(ErrConfigLoad, err.Error()) + } + + var cfg SlimConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, NewError(ErrConfigLoad, err.Error()) + } + + if err := fillDefaults(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func LoadConfig() (*SlimConfig, error) { + cfg := &SlimConfig{ + Mode: env("SLIM_MODE", ModeRemote), + } + + switch cfg.Mode { + case ModeLocal: + cfg.Local = LocalConfig{ + Folder: env("SLIM_FOLDER", ""), + PythonPath: env("SLIM_PYTHON_PATH", ""), + UvPath: env("SLIM_UV_PATH", ""), + PythonEnvInitTimeout: envInt("SLIM_PYTHON_ENV_INIT_TIMEOUT", 0), + MaxExecutionTimeout: envInt("SLIM_MAX_EXECUTION_TIMEOUT", 0), + PipMirrorURL: env("SLIM_PIP_MIRROR_URL", ""), + PipExtraArgs: env("SLIM_PIP_EXTRA_ARGS", ""), + MarketplaceURL: env("SLIM_MARKETPLACE_URL", ""), + } + case ModeRemote: + cfg.Remote = RemoteConfig{ + DaemonAddr: env("SLIM_DAEMON_ADDR", ""), + DaemonKey: env("SLIM_DAEMON_KEY", ""), + } + } + + if err := fillDefaults(cfg); err != nil { + return nil, err + } + return cfg, nil +} + +func (lc *LocalConfig) toAppConfig() *app.Config { + return &app.Config{ + PluginWorkingPath: lc.Folder, + PluginInstalledPath: lc.Folder, + PluginPackageCachePath: lc.Folder, + PythonInterpreterPath: lc.PythonPath, + UvPath: lc.UvPath, + PythonEnvInitTimeout: lc.PythonEnvInitTimeout, + PluginMaxExecutionTimeout: lc.MaxExecutionTimeout, + PipMirrorUrl: lc.PipMirrorURL, + PipExtraArgs: lc.PipExtraArgs, + PipPreferBinary: true, + PipVerbose: true, + PluginRuntimeBufferSize: 1024, + PluginRuntimeMaxBufferSize: 5242880, + Platform: app.PLATFORM_LOCAL, + } +} + +func fillDefaults(cfg *SlimConfig) error { + if cfg.Mode == "" { + cfg.Mode = ModeRemote + } + + if cfg.Mode == ModeLocal { + if cfg.Local.Folder == "" { + return NewError(ErrConfigInvalid, "local.folder is required") + } + if cfg.Local.PythonPath == "" { + cfg.Local.PythonPath = "python3" + } + if cfg.Local.PythonEnvInitTimeout == 0 { + cfg.Local.PythonEnvInitTimeout = 120 + } + if cfg.Local.MaxExecutionTimeout == 0 { + cfg.Local.MaxExecutionTimeout = 600 + } + if cfg.Local.MarketplaceURL == "" { + cfg.Local.MarketplaceURL = "https://marketplace.dify.ai" + } + } + + if cfg.Mode == ModeRemote { + if cfg.Remote.DaemonAddr == "" { + return NewError(ErrConfigInvalid, "remote.daemon_addr is required") + } + if cfg.Remote.DaemonKey == "" { + return NewError(ErrConfigInvalid, "remote.daemon_key is required") + } + } + + return nil +} diff --git a/pkg/slim/config_test.go b/pkg/slim/config_test.go new file mode 100644 index 0000000000..25a447f007 --- /dev/null +++ b/pkg/slim/config_test.go @@ -0,0 +1,217 @@ +package slim + +import ( + "os" + "path/filepath" + "testing" +) + +func TestFillDefaults_LocalRequiresFolder(t *testing.T) { + cfg := &SlimConfig{Mode: ModeLocal} + err := fillDefaults(cfg) + if err == nil { + t.Fatal("fillDefaults() should fail when local.folder is empty") + } + se, ok := err.(*SlimError) + if !ok { + t.Fatalf("expected *SlimError, got %T", err) + } + if se.Code != ErrConfigInvalid { + t.Fatalf("Code = %q; want %q", se.Code, ErrConfigInvalid) + } +} + +func TestFillDefaults_LocalDefaults(t *testing.T) { + cfg := &SlimConfig{ + Mode: ModeLocal, + Local: LocalConfig{ + Folder: "/tmp/test-plugins", + }, + } + if err := fillDefaults(cfg); err != nil { + t.Fatalf("fillDefaults() error: %v", err) + } + if cfg.Local.PythonPath != "python3" { + t.Errorf("PythonPath = %q; want %q", cfg.Local.PythonPath, "python3") + } + if cfg.Local.PythonEnvInitTimeout != 120 { + t.Errorf("PythonEnvInitTimeout = %d; want 120", cfg.Local.PythonEnvInitTimeout) + } + if cfg.Local.MaxExecutionTimeout != 600 { + t.Errorf("MaxExecutionTimeout = %d; want 600", cfg.Local.MaxExecutionTimeout) + } + if cfg.Local.MarketplaceURL != "https://marketplace.dify.ai" { + t.Errorf("MarketplaceURL = %q; want %q", cfg.Local.MarketplaceURL, "https://marketplace.dify.ai") + } +} + +func TestFillDefaults_LocalPreservesExplicitValues(t *testing.T) { + cfg := &SlimConfig{ + Mode: ModeLocal, + Local: LocalConfig{ + Folder: "/tmp/test-plugins", + PythonPath: "/usr/bin/python3.12", + PythonEnvInitTimeout: 60, + MaxExecutionTimeout: 300, + MarketplaceURL: "https://custom.marketplace.example.com", + }, + } + if err := fillDefaults(cfg); err != nil { + t.Fatalf("fillDefaults() error: %v", err) + } + if cfg.Local.PythonPath != "/usr/bin/python3.12" { + t.Errorf("PythonPath = %q; want %q", cfg.Local.PythonPath, "/usr/bin/python3.12") + } + if cfg.Local.PythonEnvInitTimeout != 60 { + t.Errorf("PythonEnvInitTimeout = %d; want 60", cfg.Local.PythonEnvInitTimeout) + } + if cfg.Local.MaxExecutionTimeout != 300 { + t.Errorf("MaxExecutionTimeout = %d; want 300", cfg.Local.MaxExecutionTimeout) + } + if cfg.Local.MarketplaceURL != "https://custom.marketplace.example.com" { + t.Errorf("MarketplaceURL = %q; want custom", cfg.Local.MarketplaceURL) + } +} + +func TestFillDefaults_RemoteRequiresDaemonAddr(t *testing.T) { + cfg := &SlimConfig{ + Mode: ModeRemote, + Remote: RemoteConfig{ + DaemonKey: "secret", + }, + } + err := fillDefaults(cfg) + if err == nil { + t.Fatal("fillDefaults() should fail when remote.daemon_addr is empty") + } +} + +func TestFillDefaults_RemoteRequiresDaemonKey(t *testing.T) { + cfg := &SlimConfig{ + Mode: ModeRemote, + Remote: RemoteConfig{ + DaemonAddr: "http://localhost:5003", + }, + } + err := fillDefaults(cfg) + if err == nil { + t.Fatal("fillDefaults() should fail when remote.daemon_key is empty") + } +} + +func TestFillDefaults_EmptyModeDefaultsToRemote(t *testing.T) { + cfg := &SlimConfig{ + Remote: RemoteConfig{ + DaemonAddr: "http://localhost:5003", + DaemonKey: "secret", + }, + } + if err := fillDefaults(cfg); err != nil { + t.Fatalf("fillDefaults() error: %v", err) + } + if cfg.Mode != ModeRemote { + t.Fatalf("Mode = %q; want %q", cfg.Mode, ModeRemote) + } +} + +func TestNewInvokeContext_Valid(t *testing.T) { + argsJSON := `{"tenant_id":"t1","user_id":"u1","data":{"key":"val"}}` + ctx, err := NewInvokeContext("author/plugin:1.0.0", "invoke_tool", argsJSON) + if err != nil { + t.Fatalf("NewInvokeContext() error: %v", err) + } + if ctx.PluginID != "author/plugin:1.0.0" { + t.Errorf("PluginID = %q; want %q", ctx.PluginID, "author/plugin:1.0.0") + } + if ctx.Action != "invoke_tool" { + t.Errorf("Action = %q; want %q", ctx.Action, "invoke_tool") + } + if ctx.Request.TenantID != "t1" { + t.Errorf("TenantID = %q; want %q", ctx.Request.TenantID, "t1") + } + if ctx.Request.UserID != "u1" { + t.Errorf("UserID = %q; want %q", ctx.Request.UserID, "u1") + } +} + +func TestNewInvokeContext_DefaultTenantID(t *testing.T) { + argsJSON := `{"data":{"key":"val"}}` + ctx, err := NewInvokeContext("plugin", "invoke_tool", argsJSON) + if err != nil { + t.Fatalf("NewInvokeContext() error: %v", err) + } + if ctx.Request.TenantID != "00000000-0000-0000-0000-000000000000" { + t.Errorf("TenantID = %q; want nil UUID", ctx.Request.TenantID) + } +} + +func TestNewInvokeContext_InvalidJSON(t *testing.T) { + _, err := NewInvokeContext("plugin", "invoke_tool", "not-json") + if err == nil { + t.Fatal("NewInvokeContext() should fail on invalid JSON") + } + se, ok := err.(*SlimError) + if !ok { + t.Fatalf("expected *SlimError, got %T", err) + } + if se.Code != ErrInvalidArgsJSON { + t.Errorf("Code = %q; want %q", se.Code, ErrInvalidArgsJSON) + } +} + +func TestLoadConfigFromFile_Valid(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + content := `{ + "mode": "remote", + "remote": { + "daemon_addr": "http://localhost:5003", + "daemon_key": "testkey" + } + }` + if err := os.WriteFile(cfgPath, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + cfg, err := LoadConfigFromFile(cfgPath) + if err != nil { + t.Fatalf("LoadConfigFromFile() error: %v", err) + } + if cfg.Mode != ModeRemote { + t.Errorf("Mode = %q; want %q", cfg.Mode, ModeRemote) + } + if cfg.Remote.DaemonAddr != "http://localhost:5003" { + t.Errorf("DaemonAddr = %q; want %q", cfg.Remote.DaemonAddr, "http://localhost:5003") + } +} + +func TestLoadConfigFromFile_NotFound(t *testing.T) { + _, err := LoadConfigFromFile("/nonexistent/config.json") + if err == nil { + t.Fatal("LoadConfigFromFile() should fail on missing file") + } +} + +func TestLoadConfigFromFile_InvalidJSON(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "bad.json") + if err := os.WriteFile(cfgPath, []byte("{invalid"), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + _, err := LoadConfigFromFile(cfgPath) + if err == nil { + t.Fatal("LoadConfigFromFile() should fail on invalid JSON") + } +} + +func TestLoadConfigFromFile_ValidationFails(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + content := `{"mode": "remote", "remote": {}}` + if err := os.WriteFile(cfgPath, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + _, err := LoadConfigFromFile(cfgPath) + if err == nil { + t.Fatal("LoadConfigFromFile() should fail when remote config is incomplete") + } +} diff --git a/pkg/slim/errors.go b/pkg/slim/errors.go new file mode 100644 index 0000000000..0c88bd6ac5 --- /dev/null +++ b/pkg/slim/errors.go @@ -0,0 +1,65 @@ +package slim + +type ErrorCode string + +const ( + ErrInvalidInput ErrorCode = "INVALID_INPUT" + ErrInvalidArgsJSON ErrorCode = "INVALID_ARGS_JSON" + ErrConfigLoad ErrorCode = "CONFIG_LOAD_ERROR" + ErrConfigInvalid ErrorCode = "CONFIG_INVALID" + ErrUnknownMode ErrorCode = "UNKNOWN_MODE" + ErrUnknownAction ErrorCode = "UNKNOWN_ACTION" + ErrNotImplemented ErrorCode = "NOT_IMPLEMENTED" + ErrNetwork ErrorCode = "NETWORK_ERROR" + ErrDaemon ErrorCode = "DAEMON_ERROR" + ErrStreamRead ErrorCode = "STREAM_READ_ERROR" + ErrStreamParse ErrorCode = "STREAM_PARSE_ERROR" + ErrPluginNotFound ErrorCode = "PLUGIN_NOT_FOUND" + ErrPluginInit ErrorCode = "PLUGIN_INIT_ERROR" + ErrPluginExec ErrorCode = "PLUGIN_EXEC_ERROR" + ErrPluginDownload ErrorCode = "PLUGIN_DOWNLOAD_ERROR" + ErrPluginDownloadTimeout ErrorCode = "PLUGIN_DOWNLOAD_TIMEOUT" + ErrPluginPackageInvalid ErrorCode = "PLUGIN_PACKAGE_INVALID" + ErrPluginPackageTooLarge ErrorCode = "PLUGIN_PACKAGE_TOO_LARGE" + ErrPluginExtract ErrorCode = "PLUGIN_EXTRACT_ERROR" +) + +const ( + ExitOK = 0 + ExitPluginError = 1 + ExitInputError = 2 + ExitNetworkError = 3 + ExitDaemonError = 4 +) + +const ( + ModeLocal = "local" + ModeRemote = "remote" +) + +type SlimError struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func (e *SlimError) Error() string { + return string(e.Code) + ": " + e.Message +} + +func (e *SlimError) ExitCode() int { + switch e.Code { + case ErrInvalidInput, ErrInvalidArgsJSON, ErrConfigLoad, + ErrConfigInvalid, ErrUnknownMode, ErrUnknownAction: + return ExitInputError + case ErrNetwork, ErrPluginDownload, ErrPluginDownloadTimeout: + return ExitNetworkError + case ErrDaemon: + return ExitDaemonError + default: + return ExitPluginError + } +} + +func NewError(code ErrorCode, msg string) *SlimError { + return &SlimError{Code: code, Message: msg} +} diff --git a/pkg/slim/errors_test.go b/pkg/slim/errors_test.go new file mode 100644 index 0000000000..9cc3d7c356 --- /dev/null +++ b/pkg/slim/errors_test.go @@ -0,0 +1,59 @@ +package slim + +import ( + "encoding/json" + "testing" +) + +func TestSlimError_Error(t *testing.T) { + err := NewError(ErrInvalidInput, "bad input") + if err.Error() != "INVALID_INPUT: bad input" { + t.Fatalf("Error() = %q; want %q", err.Error(), "INVALID_INPUT: bad input") + } +} + +func TestSlimError_ExitCode(t *testing.T) { + tests := []struct { + code ErrorCode + want int + }{ + {ErrInvalidInput, ExitInputError}, + {ErrInvalidArgsJSON, ExitInputError}, + {ErrConfigLoad, ExitInputError}, + {ErrConfigInvalid, ExitInputError}, + {ErrUnknownMode, ExitInputError}, + {ErrUnknownAction, ExitInputError}, + {ErrNetwork, ExitNetworkError}, + {ErrPluginDownload, ExitNetworkError}, + {ErrPluginDownloadTimeout, ExitNetworkError}, + {ErrDaemon, ExitDaemonError}, + {ErrPluginExec, ExitPluginError}, + {ErrStreamParse, ExitPluginError}, + {ErrPluginInit, ExitPluginError}, + {ErrPluginNotFound, ExitPluginError}, + } + for _, tt := range tests { + err := NewError(tt.code, "msg") + if got := err.ExitCode(); got != tt.want { + t.Errorf("ExitCode() for %s = %d; want %d", tt.code, got, tt.want) + } + } +} + +func TestSlimError_JSON(t *testing.T) { + err := NewError(ErrNetwork, "connection refused") + b, jsonErr := json.Marshal(err) + if jsonErr != nil { + t.Fatalf("json.Marshal() error: %v", jsonErr) + } + var decoded SlimError + if jsonErr := json.Unmarshal(b, &decoded); jsonErr != nil { + t.Fatalf("json.Unmarshal() error: %v", jsonErr) + } + if decoded.Code != ErrNetwork { + t.Errorf("Code = %q; want %q", decoded.Code, ErrNetwork) + } + if decoded.Message != "connection refused" { + t.Errorf("Message = %q; want %q", decoded.Message, "connection refused") + } +} diff --git a/pkg/slim/local.go b/pkg/slim/local.go new file mode 100644 index 0000000000..c1ca5a7e8d --- /dev/null +++ b/pkg/slim/local.go @@ -0,0 +1,309 @@ +package slim + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "github.com/langgenius/dify-plugin-daemon/internal/core/local_runtime" + "github.com/langgenius/dify-plugin-daemon/internal/types/app" + "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/pkg/plugin_packager/decoder" + "github.com/langgenius/dify-plugin-daemon/pkg/utils/parser" + "github.com/langgenius/dify-plugin-daemon/pkg/utils/routine" +) + +const downloadTimeout = 60 * time.Second + +func RunLocal(ctx *InvokeContext, local *LocalConfig, out *OutputWriter) error { + workingPath := pluginWorkingPath(local.Folder, ctx.PluginID) + + dec, err := decoder.NewFSPluginDecoder(workingPath) + if err != nil { + out.Message("download", fmt.Sprintf("downloading plugin %s from marketplace", ctx.PluginID)) + dec, err = downloadAndExtract(local, ctx.PluginID, workingPath) + if err != nil { + return err + } + out.Message("download", "plugin downloaded and extracted") + } + + if !routine.IsInit() { + routine.InitPool(4) + } + + appConfig := local.toAppConfig() + rt, err := buildRuntime(appConfig, dec, workingPath) + if err != nil { + return NewError(ErrPluginInit, fmt.Sprintf("build runtime: %s", err)) + } + + out.Message("init", "initializing python environment") + if err := rt.InitPythonEnvironment(); err != nil { + return NewError(ErrPluginInit, fmt.Sprintf("init python env: %s", err)) + } + out.Message("init", "python environment ready") + + reqBytes, sessionID, err := TransformRequest(ctx) + if err != nil { + return err + } + + return execPlugin(rt, local, reqBytes, sessionID, appConfig, out) +} + +func buildRuntime( + appConfig *app.Config, + dec *decoder.FSPluginDecoder, + workingPath string, +) (*local_runtime.LocalPluginRuntime, error) { + manifest, err := dec.Manifest() + if err != nil { + return nil, fmt.Errorf("read manifest: %w", err) + } + + return local_runtime.NewLocalPluginRuntime( + appConfig, + dec, + manifest, + workingPath, + ), nil +} + +func downloadAndExtract(local *LocalConfig, pluginID, workingPath string) (*decoder.FSPluginDecoder, error) { + pkgBytes, err := downloadFromMarketplace(local.MarketplaceURL, pluginID) + if err != nil { + return nil, err + } + + zipDec, err := decoder.NewZipPluginDecoder(pkgBytes) + if err != nil { + return nil, NewError(ErrPluginPackageInvalid, fmt.Sprintf("invalid plugin package: %s", err)) + } + + if err := os.MkdirAll(workingPath, 0755); err != nil { + return nil, NewError(ErrPluginExtract, fmt.Sprintf("create working dir: %s", err)) + } + + if err := zipDec.ExtractTo(workingPath); err != nil { + os.RemoveAll(workingPath) + return nil, NewError(ErrPluginExtract, fmt.Sprintf("extract package: %s", err)) + } + + fsDec, err := decoder.NewFSPluginDecoder(workingPath) + if err != nil { + os.RemoveAll(workingPath) + return nil, NewError(ErrPluginExtract, fmt.Sprintf("load extracted plugin: %s", err)) + } + + return fsDec, nil +} + +func downloadFromMarketplace(marketplaceURL, pluginID string) ([]byte, error) { + u, err := url.Parse(strings.TrimRight(marketplaceURL, "/") + "/api/v1/plugins/download") + if err != nil { + return nil, NewError(ErrPluginDownload, fmt.Sprintf("parse marketplace url: %s", err)) + } + q := u.Query() + q.Set("unique_identifier", pluginID) + u.RawQuery = q.Encode() + + client := &http.Client{Timeout: downloadTimeout} + resp, err := client.Get(u.String()) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, NewError(ErrPluginDownloadTimeout, + fmt.Sprintf("marketplace download timed out after %s for %s", downloadTimeout, pluginID)) + } + return nil, NewError(ErrPluginDownload, fmt.Sprintf("marketplace request failed: %s", err)) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, NewError(ErrPluginNotFound, + fmt.Sprintf("plugin %s not found in marketplace (%s)", pluginID, marketplaceURL)) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, NewError(ErrPluginDownload, + fmt.Sprintf("marketplace returned status %d for %s: %s", resp.StatusCode, pluginID, string(body))) + } + + const maxSize = 15 * 1024 * 1024 + data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) + if err != nil { + return nil, NewError(ErrPluginDownload, fmt.Sprintf("read response body: %s", err)) + } + if len(data) > maxSize { + return nil, NewError(ErrPluginPackageTooLarge, + fmt.Sprintf("plugin package %s exceeds 15 MiB size limit", pluginID)) + } + + return data, nil +} + +const maxStderrSize = 64 * 1024 + +func execPlugin( + rt *local_runtime.LocalPluginRuntime, + local *LocalConfig, + reqBytes []byte, + sessionID string, + appConfig *app.Config, + out *OutputWriter, +) error { + pythonPath, err := filepath.Abs(filepath.Join(rt.State.WorkingPath, ".venv", "bin", "python")) + if err != nil { + return NewError(ErrPluginExec, fmt.Sprintf("resolve python path: %s", err)) + } + + cmd := exec.Command(pythonPath, "-m", rt.Config.Meta.Runner.Entrypoint) + cmd.Dir = rt.State.WorkingPath + cmd.Env = append(os.Environ(), "INSTALL_METHOD=local", "PATH="+os.Getenv("PATH")) + + stdin, err := cmd.StdinPipe() + if err != nil { + return NewError(ErrPluginExec, fmt.Sprintf("stdin pipe: %s", err)) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return NewError(ErrPluginExec, fmt.Sprintf("stdout pipe: %s", err)) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return NewError(ErrPluginExec, fmt.Sprintf("stderr pipe: %s", err)) + } + + if err := cmd.Start(); err != nil { + return NewError(ErrPluginExec, fmt.Sprintf("start subprocess: %s", err)) + } + + stderrCh := make(chan string, 1) + go func() { + b, _ := io.ReadAll(io.LimitReader(stderr, maxStderrSize)) + stderrCh <- string(b) + }() + + if _, err := stdin.Write(append(reqBytes, '\n')); err != nil { + cmd.Process.Kill() + cmd.Wait() + <-stderrCh + return NewError(ErrPluginExec, fmt.Sprintf("write stdin: %s", err)) + } + stdin.Close() + + timeout := time.Duration(local.MaxExecutionTimeout) * time.Second + var timedOut atomic.Bool + killTimer := time.AfterFunc(timeout, func() { + timedOut.Store(true) + cmd.Process.Kill() + }) + defer killTimer.Stop() + + scanner := bufio.NewScanner(stdout) + scanner.Buffer( + make([]byte, appConfig.GetLocalRuntimeBufferSize()), + appConfig.GetLocalRuntimeMaxBufferSize(), + ) + + var execErr error + done := false + + for scanner.Scan() { + data := scanner.Bytes() + if len(data) == 0 { + continue + } + + plugin_entities.ParsePluginUniversalEvent( + data, + "", + func(sid string, payload []byte) { + if sid != sessionID { + return + } + msg, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](payload) + if err != nil { + execErr = NewError(ErrStreamParse, err.Error()) + done = true + return + } + switch msg.Type { + case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: + out.Chunk(json.RawMessage(msg.Data)) + case plugin_entities.SESSION_MESSAGE_TYPE_END: + out.Done() + done = true + case plugin_entities.SESSION_MESSAGE_TYPE_ERROR: + errResp, parseErr := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](msg.Data) + if parseErr != nil { + out.Error(ErrPluginExec, string(msg.Data)) + } else { + out.Error(ErrPluginExec, errResp.Error()) + } + done = true + } + }, + func() { + killTimer.Reset(timeout) + }, + func(errMsg string) { + out.Error(ErrPluginExec, errMsg) + done = true + }, + func(logEvent plugin_entities.PluginLogEvent) { + }, + ) + + if done || execErr != nil { + break + } + } + + if scanErr := scanner.Err(); scanErr != nil && execErr == nil { + if timedOut.Load() { + execErr = NewError(ErrPluginExec, "execution timeout") + } else { + execErr = NewError(ErrStreamRead, scanErr.Error()) + } + } + + cmd.Process.Kill() + cmd.Wait() + + stderrMsg := <-stderrCh + + if execErr != nil { + if stderrMsg != "" { + if se, ok := execErr.(*SlimError); ok { + return NewError(se.Code, se.Error()+"; stderr: "+truncate(stderrMsg, 512)) + } + return NewError(ErrPluginExec, "stderr: "+truncate(stderrMsg, 512)) + } + return execErr + } + + return nil +} + +func pluginWorkingPath(folder, pluginID string) string { + normalized := strings.ReplaceAll(pluginID, ":", "-") + return filepath.Join(folder, normalized) +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max] + "..." +} diff --git a/pkg/slim/output.go b/pkg/slim/output.go new file mode 100644 index 0000000000..cb88cd7f41 --- /dev/null +++ b/pkg/slim/output.go @@ -0,0 +1,44 @@ +package slim + +import ( + "encoding/json" + "io" +) + +type outputEvent struct { + Event string `json:"event"` + Data any `json:"data"` +} + +type OutputWriter struct { + enc *json.Encoder +} + +func NewOutputWriter(w io.Writer) *OutputWriter { + return &OutputWriter{enc: json.NewEncoder(w)} +} + +func (o *OutputWriter) Chunk(data json.RawMessage) error { + return o.enc.Encode(outputEvent{Event: "chunk", Data: data}) +} + +func (o *OutputWriter) Done() error { + return o.enc.Encode(outputEvent{Event: "done"}) +} + +func (o *OutputWriter) Error(code ErrorCode, msg string) error { + return o.enc.Encode(outputEvent{ + Event: "error", + Data: SlimError{Code: code, Message: msg}, + }) +} + +func (o *OutputWriter) Message(stage string, msg string) error { + return o.enc.Encode(outputEvent{ + Event: "message", + Data: map[string]string{ + "stage": stage, + "message": msg, + }, + }) +} diff --git a/pkg/slim/output_test.go b/pkg/slim/output_test.go new file mode 100644 index 0000000000..03e782bbaf --- /dev/null +++ b/pkg/slim/output_test.go @@ -0,0 +1,110 @@ +package slim + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestOutputWriter_Chunk(t *testing.T) { + var buf bytes.Buffer + out := NewOutputWriter(&buf) + if err := out.Chunk(json.RawMessage(`{"result":"ok"}`)); err != nil { + t.Fatalf("Chunk() error: %v", err) + } + var evt outputEvent + if err := json.Unmarshal(buf.Bytes(), &evt); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + if evt.Event != "chunk" { + t.Errorf("Event = %q; want %q", evt.Event, "chunk") + } +} + +func TestOutputWriter_Done(t *testing.T) { + var buf bytes.Buffer + out := NewOutputWriter(&buf) + if err := out.Done(); err != nil { + t.Fatalf("Done() error: %v", err) + } + var evt outputEvent + if err := json.Unmarshal(buf.Bytes(), &evt); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + if evt.Event != "done" { + t.Errorf("Event = %q; want %q", evt.Event, "done") + } +} + +func TestOutputWriter_Error(t *testing.T) { + var buf bytes.Buffer + out := NewOutputWriter(&buf) + if err := out.Error(ErrPluginExec, "plugin crashed"); err != nil { + t.Fatalf("Error() error: %v", err) + } + var raw map[string]any + if err := json.Unmarshal(buf.Bytes(), &raw); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + if raw["event"] != "error" { + t.Errorf("event = %v; want %q", raw["event"], "error") + } + data, ok := raw["data"].(map[string]any) + if !ok { + t.Fatalf("data is not a map: %T", raw["data"]) + } + if data["code"] != string(ErrPluginExec) { + t.Errorf("data.code = %v; want %q", data["code"], ErrPluginExec) + } + if data["message"] != "plugin crashed" { + t.Errorf("data.message = %v; want %q", data["message"], "plugin crashed") + } +} + +func TestOutputWriter_Message(t *testing.T) { + var buf bytes.Buffer + out := NewOutputWriter(&buf) + if err := out.Message("init", "starting up"); err != nil { + t.Fatalf("Message() error: %v", err) + } + var raw map[string]any + if err := json.Unmarshal(buf.Bytes(), &raw); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + if raw["event"] != "message" { + t.Errorf("event = %v; want %q", raw["event"], "message") + } + data := raw["data"].(map[string]any) + if data["stage"] != "init" { + t.Errorf("stage = %v; want %q", data["stage"], "init") + } + if data["message"] != "starting up" { + t.Errorf("message = %v; want %q", data["message"], "starting up") + } +} + +func TestOutputWriter_MultipleEvents(t *testing.T) { + var buf bytes.Buffer + out := NewOutputWriter(&buf) + out.Chunk(json.RawMessage(`{"a":1}`)) + out.Chunk(json.RawMessage(`{"a":2}`)) + out.Done() + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + if len(lines) != 3 { + t.Fatalf("got %d lines; want 3", len(lines)) + } + for i, line := range lines { + var evt outputEvent + if err := json.Unmarshal([]byte(line), &evt); err != nil { + t.Fatalf("line %d: json.Unmarshal() error: %v", i, err) + } + if i < 2 && evt.Event != "chunk" { + t.Errorf("line %d: event = %q; want %q", i, evt.Event, "chunk") + } + if i == 2 && evt.Event != "done" { + t.Errorf("line %d: event = %q; want %q", i, evt.Event, "done") + } + } +} diff --git a/pkg/slim/remote.go b/pkg/slim/remote.go new file mode 100644 index 0000000000..be1c098536 --- /dev/null +++ b/pkg/slim/remote.go @@ -0,0 +1,157 @@ +package slim + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const dispatchPrefix = "/v2/invoke/dispatch" + +type DaemonClient struct { + addr string + key string + client *http.Client +} + +func NewDaemonClient(addr, key string) *DaemonClient { + return &DaemonClient{ + addr: strings.TrimRight(addr, "/"), + key: key, + client: &http.Client{Timeout: 600 * time.Second}, + } +} + +type ActionRoute struct { + Type string + Path string +} + +var ActionRoutes = map[string]ActionRoute{ + // tool + "invoke_tool": {Type: "tool", Path: "/tool/invoke"}, + "validate_tool_credentials": {Type: "tool", Path: "/tool/validate_credentials"}, + "get_tool_runtime_parameters": {Type: "tool", Path: "/tool/get_runtime_parameters"}, + + // model + "invoke_llm": {Type: "model", Path: "/llm/invoke"}, + "get_llm_num_tokens": {Type: "model", Path: "/llm/num_tokens"}, + "invoke_text_embedding": {Type: "model", Path: "/text_embedding/invoke"}, + "invoke_multimodal_embedding": {Type: "model", Path: "/multimodal_embedding/invoke"}, + "get_text_embedding_num_tokens": {Type: "model", Path: "/text_embedding/num_tokens"}, + "invoke_rerank": {Type: "model", Path: "/rerank/invoke"}, + "invoke_multimodal_rerank": {Type: "model", Path: "/multimodal_rerank/invoke"}, + "invoke_tts": {Type: "model", Path: "/tts/invoke"}, + "get_tts_model_voices": {Type: "model", Path: "/tts/model/voices"}, + "invoke_speech2text": {Type: "model", Path: "/speech2text/invoke"}, + "invoke_moderation": {Type: "model", Path: "/moderation/invoke"}, + "validate_provider_credentials": {Type: "model", Path: "/model/validate_provider_credentials"}, + "validate_model_credentials": {Type: "model", Path: "/model/validate_model_credentials"}, + "get_ai_model_schemas": {Type: "model", Path: "/model/schema"}, + + // agent strategy + "invoke_agent_strategy": {Type: "agent_strategy", Path: "/agent_strategy/invoke"}, + + // endpoint + "invoke_endpoint": {Type: "endpoint", Path: "/endpoint/invoke"}, + + // oauth + "get_authorization_url": {Type: "oauth", Path: "/oauth/get_authorization_url"}, + "get_credentials": {Type: "oauth", Path: "/oauth/get_credentials"}, + "refresh_credentials": {Type: "oauth", Path: "/oauth/refresh_credentials"}, + + // datasource + "validate_datasource_credentials": {Type: "datasource", Path: "/datasource/validate_credentials"}, + "invoke_website_datasource_get_crawl": {Type: "datasource", Path: "/datasource/get_website_crawl"}, + "invoke_online_document_datasource_get_pages": {Type: "datasource", Path: "/datasource/get_online_document_pages"}, + "invoke_online_document_datasource_get_page_content": {Type: "datasource", Path: "/datasource/get_online_document_page_content"}, + "invoke_online_drive_browse_files": {Type: "datasource", Path: "/datasource/online_drive_browse_files"}, + "invoke_online_drive_download_file": {Type: "datasource", Path: "/datasource/online_drive_download_file"}, + + // dynamic parameter + "fetch_parameter_options": {Type: "dynamic_parameter", Path: "/dynamic_select/fetch_parameter_options"}, + + // trigger + "invoke_trigger_event": {Type: "trigger", Path: "/trigger/invoke_event"}, + "dispatch_trigger_event": {Type: "trigger", Path: "/trigger/dispatch_event"}, + "subscribe_trigger": {Type: "trigger", Path: "/trigger/subscribe"}, + "unsubscribe_trigger": {Type: "trigger", Path: "/trigger/unsubscribe"}, + "refresh_trigger": {Type: "trigger", Path: "/trigger/refresh"}, + "validate_trigger_credentials": {Type: "trigger", Path: "/trigger/validate_credentials"}, +} + +func LookupRoute(action string) (ActionRoute, bool) { + r, ok := ActionRoutes[action] + return r, ok +} + +func (c *DaemonClient) Dispatch(ctx *InvokeContext) (io.ReadCloser, error) { + route, ok := LookupRoute(ctx.Action) + if !ok { + return nil, NewError(ErrUnknownAction, ctx.Action) + } + + body, err := json.Marshal(ctx.Request) + if err != nil { + return nil, NewError(ErrInvalidInput, fmt.Sprintf("failed to marshal request: %s", err)) + } + + url := c.addr + dispatchPrefix + route.Path + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, NewError(ErrNetwork, err.Error()) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Api-Key", c.key) + req.Header.Set("X-Plugin-Unique-Identifier", ctx.PluginID) + + resp, err := c.client.Do(req) + if err != nil { + return nil, NewError(ErrNetwork, err.Error()) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + return nil, NewError(ErrDaemon, fmt.Sprintf("status %d: %s", resp.StatusCode, string(b))) + } + + return resp.Body, nil +} + +func RunRemote(ctx *InvokeContext, remote *RemoteConfig, out *OutputWriter) error { + client := NewDaemonClient(remote.DaemonAddr, remote.DaemonKey) + + body, err := client.Dispatch(ctx) + if err != nil { + return err + } + defer body.Close() + + scanner := bufio.NewScanner(body) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + if !json.Valid([]byte(payload)) { + return NewError(ErrStreamParse, "invalid JSON in SSE frame") + } + if err := out.Chunk(json.RawMessage(payload)); err != nil { + return NewError(ErrStreamRead, err.Error()) + } + } + if err := scanner.Err(); err != nil { + return NewError(ErrStreamRead, err.Error()) + } + + out.Done() + return nil +} diff --git a/pkg/slim/remote_test.go b/pkg/slim/remote_test.go new file mode 100644 index 0000000000..d3c7b149df --- /dev/null +++ b/pkg/slim/remote_test.go @@ -0,0 +1,280 @@ +package slim + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestLookupRoute_KnownActions(t *testing.T) { + tests := []struct { + action string + wantType string + wantPath string + }{ + {"invoke_tool", "tool", "/tool/invoke"}, + {"invoke_llm", "model", "/llm/invoke"}, + {"invoke_agent_strategy", "agent_strategy", "/agent_strategy/invoke"}, + {"invoke_endpoint", "endpoint", "/endpoint/invoke"}, + {"get_authorization_url", "oauth", "/oauth/get_authorization_url"}, + {"validate_datasource_credentials", "datasource", "/datasource/validate_credentials"}, + {"fetch_parameter_options", "dynamic_parameter", "/dynamic_select/fetch_parameter_options"}, + {"invoke_trigger_event", "trigger", "/trigger/invoke_event"}, + } + for _, tt := range tests { + route, ok := LookupRoute(tt.action) + if !ok { + t.Errorf("LookupRoute(%q) not found", tt.action) + continue + } + if route.Type != tt.wantType { + t.Errorf("LookupRoute(%q).Type = %q; want %q", tt.action, route.Type, tt.wantType) + } + if route.Path != tt.wantPath { + t.Errorf("LookupRoute(%q).Path = %q; want %q", tt.action, route.Path, tt.wantPath) + } + } +} + +func TestLookupRoute_UnknownAction(t *testing.T) { + _, ok := LookupRoute("nonexistent") + if ok { + t.Error("LookupRoute(nonexistent) should return false") + } +} + +func TestDaemonClient_Dispatch_UnknownAction(t *testing.T) { + client := NewDaemonClient("http://localhost:9999", "key") + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "bad_action", + Request: RequestMeta{Data: json.RawMessage(`{}`)}, + } + _, err := client.Dispatch(ctx) + if err == nil { + t.Fatal("Dispatch() should fail for unknown action") + } + se := err.(*SlimError) + if se.Code != ErrUnknownAction { + t.Errorf("Code = %q; want %q", se.Code, ErrUnknownAction) + } +} + +func TestDaemonClient_Dispatch_NonOKStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal"}`)) + })) + defer srv.Close() + + client := NewDaemonClient(srv.URL, "testkey") + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + _, err := client.Dispatch(ctx) + if err == nil { + t.Fatal("Dispatch() should fail on non-200 status") + } + se := err.(*SlimError) + if se.Code != ErrDaemon { + t.Errorf("Code = %q; want %q", se.Code, ErrDaemon) + } + if !strings.Contains(se.Message, "500") { + t.Errorf("Message should contain status code 500: %q", se.Message) + } +} + +func TestDaemonClient_Dispatch_SetsHeaders(t *testing.T) { + var gotHeaders http.Header + var gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = r.Header + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := NewDaemonClient(srv.URL, "my-api-key") + ctx := &InvokeContext{ + PluginID: "author/plugin:1.0.0", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + body, err := client.Dispatch(ctx) + if err != nil { + t.Fatalf("Dispatch() error: %v", err) + } + body.Close() + + if gotHeaders.Get("X-Api-Key") != "my-api-key" { + t.Errorf("X-Api-Key = %q; want %q", gotHeaders.Get("X-Api-Key"), "my-api-key") + } + if gotHeaders.Get("X-Plugin-Unique-Identifier") != "author/plugin:1.0.0" { + t.Errorf("X-Plugin-Unique-Identifier = %q; want %q", + gotHeaders.Get("X-Plugin-Unique-Identifier"), "author/plugin:1.0.0") + } + if gotHeaders.Get("Content-Type") != "application/json" { + t.Errorf("Content-Type = %q; want %q", gotHeaders.Get("Content-Type"), "application/json") + } + wantPath := dispatchPrefix + "/tool/invoke" + if gotPath != wantPath { + t.Errorf("path = %q; want %q", gotPath, wantPath) + } +} + +func TestRunRemote_SSEStream(t *testing.T) { + chunks := []map[string]any{ + {"result": "chunk1"}, + {"result": "chunk2"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + for _, chunk := range chunks { + b, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", b) + } + })) + defer srv.Close() + + var buf bytes.Buffer + out := NewOutputWriter(&buf) + remote := &RemoteConfig{DaemonAddr: srv.URL, DaemonKey: "key"} + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + + if err := RunRemote(ctx, remote, out); err != nil { + t.Fatalf("RunRemote() error: %v", err) + } + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + if len(lines) != 3 { + t.Fatalf("got %d output lines; want 3 (2 chunks + done)", len(lines)) + } + + for i := 0; i < 2; i++ { + var evt outputEvent + if err := json.Unmarshal([]byte(lines[i]), &evt); err != nil { + t.Fatalf("line %d: json.Unmarshal() error: %v", i, err) + } + if evt.Event != "chunk" { + t.Errorf("line %d: event = %q; want %q", i, evt.Event, "chunk") + } + } + + var doneEvt outputEvent + if err := json.Unmarshal([]byte(lines[2]), &doneEvt); err != nil { + t.Fatalf("done line: json.Unmarshal() error: %v", err) + } + if doneEvt.Event != "done" { + t.Errorf("last event = %q; want %q", doneEvt.Event, "done") + } +} + +func TestRunRemote_InvalidSSEJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "data: {not-valid-json}\n\n") + })) + defer srv.Close() + + var buf bytes.Buffer + out := NewOutputWriter(&buf) + remote := &RemoteConfig{DaemonAddr: srv.URL, DaemonKey: "key"} + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + + err := RunRemote(ctx, remote, out) + if err == nil { + t.Fatal("RunRemote() should fail on invalid SSE JSON") + } + se := err.(*SlimError) + if se.Code != ErrStreamParse { + t.Errorf("Code = %q; want %q", se.Code, ErrStreamParse) + } +} + +func TestRunRemote_SkipsNonDataLines(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "event: ping\n") + fmt.Fprint(w, ": comment\n") + fmt.Fprint(w, "data: {\"ok\":true}\n\n") + })) + defer srv.Close() + + var buf bytes.Buffer + out := NewOutputWriter(&buf) + remote := &RemoteConfig{DaemonAddr: srv.URL, DaemonKey: "key"} + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + + if err := RunRemote(ctx, remote, out); err != nil { + t.Fatalf("RunRemote() error: %v", err) + } + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + if len(lines) != 2 { + t.Fatalf("got %d output lines; want 2 (1 chunk + done)", len(lines)) + } +} + +func TestRunRemote_DaemonError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("forbidden")) + })) + defer srv.Close() + + var buf bytes.Buffer + out := NewOutputWriter(&buf) + remote := &RemoteConfig{DaemonAddr: srv.URL, DaemonKey: "key"} + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + + err := RunRemote(ctx, remote, out) + if err == nil { + t.Fatal("RunRemote() should fail on daemon error") + } + se := err.(*SlimError) + if se.Code != ErrDaemon { + t.Errorf("Code = %q; want %q", se.Code, ErrDaemon) + } +} diff --git a/pkg/slim/transform.go b/pkg/slim/transform.go new file mode 100644 index 0000000000..ec0511afda --- /dev/null +++ b/pkg/slim/transform.go @@ -0,0 +1,43 @@ +package slim + +import ( + "encoding/json" + + "github.com/google/uuid" +) + +func TransformRequest(ctx *InvokeContext) ([]byte, string, error) { + sessionID := uuid.New().String() + + var args map[string]any + if err := json.Unmarshal(ctx.Request.Data, &args); err != nil { + return nil, "", NewError(ErrInvalidArgsJSON, err.Error()) + } + + route, ok := LookupRoute(ctx.Action) + if !ok { + return nil, "", NewError(ErrUnknownAction, ctx.Action) + } + + args["user_id"] = ctx.Request.UserID + args["type"] = route.Type + args["action"] = ctx.Action + + message := map[string]any{ + "tenant_id": ctx.Request.TenantID, + "session_id": sessionID, + "conversation_id": nil, + "message_id": nil, + "app_id": nil, + "endpoint_id": nil, + "context": map[string]any{}, + "event": "request", + "data": args, + } + + b, err := json.Marshal(message) + if err != nil { + return nil, "", NewError(ErrInvalidInput, "failed to marshal message: "+err.Error()) + } + return b, sessionID, nil +} diff --git a/pkg/slim/transform_test.go b/pkg/slim/transform_test.go new file mode 100644 index 0000000000..aba28a764d --- /dev/null +++ b/pkg/slim/transform_test.go @@ -0,0 +1,168 @@ +package slim + +import ( + "encoding/json" + "testing" +) + +func TestTransformRequest_ValidToolAction(t *testing.T) { + ctx := &InvokeContext{ + PluginID: "author/plugin:1.0.0", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "tenant-1", + UserID: "user-1", + Data: json.RawMessage(`{"tool_name":"search","params":{"q":"test"}}`), + }, + } + b, sessionID, err := TransformRequest(ctx) + if err != nil { + t.Fatalf("TransformRequest() error: %v", err) + } + if sessionID == "" { + t.Fatal("sessionID should not be empty") + } + + var msg map[string]any + if err := json.Unmarshal(b, &msg); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + if msg["tenant_id"] != "tenant-1" { + t.Errorf("tenant_id = %v; want %q", msg["tenant_id"], "tenant-1") + } + if msg["session_id"] != sessionID { + t.Errorf("session_id = %v; want %q", msg["session_id"], sessionID) + } + if msg["event"] != "request" { + t.Errorf("event = %v; want %q", msg["event"], "request") + } + + data, ok := msg["data"].(map[string]any) + if !ok { + t.Fatalf("data is not a map: %T", msg["data"]) + } + if data["type"] != "tool" { + t.Errorf("data.type = %v; want %q", data["type"], "tool") + } + if data["action"] != "invoke_tool" { + t.Errorf("data.action = %v; want %q", data["action"], "invoke_tool") + } + if data["user_id"] != "user-1" { + t.Errorf("data.user_id = %v; want %q", data["user_id"], "user-1") + } +} + +func TestTransformRequest_ModelAction(t *testing.T) { + ctx := &InvokeContext{ + PluginID: "author/plugin:1.0.0", + Action: "invoke_llm", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{"model":"gpt-4"}`), + }, + } + b, _, err := TransformRequest(ctx) + if err != nil { + t.Fatalf("TransformRequest() error: %v", err) + } + var msg map[string]any + if err := json.Unmarshal(b, &msg); err != nil { + t.Fatalf("json.Unmarshal() error: %v", err) + } + data := msg["data"].(map[string]any) + if data["type"] != "model" { + t.Errorf("data.type = %v; want %q", data["type"], "model") + } + if data["action"] != "invoke_llm" { + t.Errorf("data.action = %v; want %q", data["action"], "invoke_llm") + } +} + +func TestTransformRequest_UnknownAction(t *testing.T) { + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "nonexistent_action", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + _, _, err := TransformRequest(ctx) + if err == nil { + t.Fatal("TransformRequest() should fail for unknown action") + } + se, ok := err.(*SlimError) + if !ok { + t.Fatalf("expected *SlimError, got %T", err) + } + if se.Code != ErrUnknownAction { + t.Errorf("Code = %q; want %q", se.Code, ErrUnknownAction) + } +} + +func TestTransformRequest_InvalidDataJSON(t *testing.T) { + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`not-json`), + }, + } + _, _, err := TransformRequest(ctx) + if err == nil { + t.Fatal("TransformRequest() should fail for invalid data JSON") + } + se, ok := err.(*SlimError) + if !ok { + t.Fatalf("expected *SlimError, got %T", err) + } + if se.Code != ErrInvalidArgsJSON { + t.Errorf("Code = %q; want %q", se.Code, ErrInvalidArgsJSON) + } +} + +func TestTransformRequest_MarshalErrorReturnsSlimError(t *testing.T) { + ctx := &InvokeContext{ + PluginID: "plugin", + Action: "invoke_tool", + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{"key":"val"}`), + }, + } + _, _, err := TransformRequest(ctx) + if err != nil { + t.Fatalf("TransformRequest() error: %v", err) + } +} + +func TestTransformRequest_AllActionRoutes(t *testing.T) { + for action, route := range ActionRoutes { + ctx := &InvokeContext{ + PluginID: "plugin", + Action: action, + Request: RequestMeta{ + TenantID: "t1", + Data: json.RawMessage(`{}`), + }, + } + b, _, err := TransformRequest(ctx) + if err != nil { + t.Errorf("TransformRequest(%q) error: %v", action, err) + continue + } + var msg map[string]any + if err := json.Unmarshal(b, &msg); err != nil { + t.Errorf("json.Unmarshal for action %q error: %v", action, err) + continue + } + data := msg["data"].(map[string]any) + if data["type"] != route.Type { + t.Errorf("action %q: data.type = %v; want %q", action, data["type"], route.Type) + } + if data["action"] != action { + t.Errorf("action %q: data.action = %v; want %q", action, data["action"], action) + } + } +} diff --git a/pkg/slim/utils.go b/pkg/slim/utils.go new file mode 100644 index 0000000000..5f68ed7b08 --- /dev/null +++ b/pkg/slim/utils.go @@ -0,0 +1,23 @@ +package slim + +import ( + "os" + "strconv" +) + +func env(name string, d string) string { + v := os.Getenv(name) + if v == "" { + return d + } + return v +} + +func envInt(name string, d int) int { + v := os.Getenv(name) + i, err := strconv.Atoi(v) + if err != nil || v == "" { + return d + } + return i +} diff --git a/pkg/slim/utils_test.go b/pkg/slim/utils_test.go new file mode 100644 index 0000000000..ee76679277 --- /dev/null +++ b/pkg/slim/utils_test.go @@ -0,0 +1,70 @@ +package slim + +import ( + "testing" +) + +func TestTruncate(t *testing.T) { + tests := []struct { + input string + max int + want string + }{ + {"hello", 10, "hello"}, + {"hello", 5, "hello"}, + {"hello world", 5, "hello..."}, + {"", 5, ""}, + {"abc", 0, "..."}, + } + for _, tt := range tests { + got := truncate(tt.input, tt.max) + if got != tt.want { + t.Errorf("truncate(%q, %d) = %q; want %q", tt.input, tt.max, got, tt.want) + } + } +} + +func TestPluginWorkingPath(t *testing.T) { + tests := []struct { + folder string + pluginID string + wantEnd string + }{ + {"/plugins", "author/plugin:1.0.0", "author/plugin-1.0.0"}, + {"/plugins", "simple", "simple"}, + {"/plugins", "a:b:c", "a-b-c"}, + } + for _, tt := range tests { + got := pluginWorkingPath(tt.folder, tt.pluginID) + if got != tt.folder+"/"+tt.wantEnd { + t.Errorf("pluginWorkingPath(%q, %q) = %q; want suffix %q", + tt.folder, tt.pluginID, got, tt.wantEnd) + } + } +} + +func TestEnv(t *testing.T) { + t.Setenv("SLIM_TEST_VAR", "value") + if got := env("SLIM_TEST_VAR", "default"); got != "value" { + t.Errorf("env() = %q; want %q", got, "value") + } + if got := env("SLIM_TEST_UNSET_VAR", "default"); got != "default" { + t.Errorf("env() = %q; want %q", got, "default") + } +} + +func TestEnvInt(t *testing.T) { + t.Setenv("SLIM_TEST_INT", "42") + if got := envInt("SLIM_TEST_INT", 0); got != 42 { + t.Errorf("envInt() = %d; want 42", got) + } + + t.Setenv("SLIM_TEST_INT_BAD", "abc") + if got := envInt("SLIM_TEST_INT_BAD", 99); got != 99 { + t.Errorf("envInt() = %d; want 99 (default on parse error)", got) + } + + if got := envInt("SLIM_TEST_INT_UNSET", 10); got != 10 { + t.Errorf("envInt() = %d; want 10", got) + } +}