From b67de369511fb672c5e2636266e59e4504daba84 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Wed, 25 Feb 2026 00:38:50 -0500 Subject: [PATCH 1/5] feat: add step.build_from_config CI/CD pipeline step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements step.build_from_config (Phase 5.1 roadmap) — a pipeline step that assembles a self-contained Docker image from a workflow config YAML file, a server binary, and optional plugin binaries. - Creates a temp build context, copies config + server + plugin binaries - Generates a Dockerfile with correct ENTRYPOINT/CMD for workflow server - Executes docker build (and optional docker push) via exec.Command - exec.Command is injectable for deterministic unit testing - 17 tests cover factory validation, Dockerfile generation, error paths, push flag, plugin inclusion, and build context file layout - Registers step.build_from_config in plugins/cicd manifest and factory map Co-Authored-By: Claude Opus 4.6 --- module/pipeline_step_build_from_config.go | 246 +++++++++ .../pipeline_step_build_from_config_test.go | 517 ++++++++++++++++++ plugins/cicd/plugin.go | 35 +- plugins/cicd/plugin_test.go | 6 +- 4 files changed, 786 insertions(+), 18 deletions(-) create mode 100644 module/pipeline_step_build_from_config.go create mode 100644 module/pipeline_step_build_from_config_test.go diff --git a/module/pipeline_step_build_from_config.go b/module/pipeline_step_build_from_config.go new file mode 100644 index 00000000..6c709f37 --- /dev/null +++ b/module/pipeline_step_build_from_config.go @@ -0,0 +1,246 @@ +package module + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/CrisisTextLine/modular" +) + +// PluginSpec describes a plugin binary to include in the built image. +type PluginSpec struct { + Name string + Binary string +} + +// BuildFromConfigStep reads a workflow config YAML file, assembles a Docker +// build context with the server binary and any required plugin binaries, +// generates a Dockerfile, builds the image, and optionally pushes it. +type BuildFromConfigStep struct { + name string + configFile string + baseImage string + serverBinary string + tag string + push bool + plugins []PluginSpec + + // execCommand is the function used to create exec.Cmd instances. + // Defaults to exec.CommandContext; overridable in tests. + execCommand func(ctx context.Context, name string, args ...string) *exec.Cmd +} + +// NewBuildFromConfigStepFactory returns a StepFactory that creates BuildFromConfigStep instances. +func NewBuildFromConfigStepFactory() StepFactory { + return func(name string, config map[string]any, _ modular.Application) (PipelineStep, error) { + configFile, _ := config["config_file"].(string) + if configFile == "" { + return nil, fmt.Errorf("build_from_config step %q: 'config_file' is required", name) + } + + tag, _ := config["tag"].(string) + if tag == "" { + return nil, fmt.Errorf("build_from_config step %q: 'tag' is required", name) + } + + baseImage, _ := config["base_image"].(string) + if baseImage == "" { + baseImage = "ghcr.io/gocodealone/workflow-runtime:latest" + } + + serverBinary, _ := config["server_binary"].(string) + if serverBinary == "" { + serverBinary = "/usr/local/bin/workflow-server" + } + + push, _ := config["push"].(bool) + + var plugins []PluginSpec + if pluginsRaw, ok := config["plugins"].([]any); ok { + for i, p := range pluginsRaw { + m, ok := p.(map[string]any) + if !ok { + return nil, fmt.Errorf("build_from_config step %q: plugins[%d] must be a map", name, i) + } + pName, _ := m["name"].(string) + pBinary, _ := m["binary"].(string) + if pName == "" || pBinary == "" { + return nil, fmt.Errorf("build_from_config step %q: plugins[%d] requires 'name' and 'binary'", name, i) + } + plugins = append(plugins, PluginSpec{Name: pName, Binary: pBinary}) + } + } + + return &BuildFromConfigStep{ + name: name, + configFile: configFile, + baseImage: baseImage, + serverBinary: serverBinary, + tag: tag, + push: push, + plugins: plugins, + execCommand: exec.CommandContext, + }, nil + } +} + +// Name returns the step name. +func (s *BuildFromConfigStep) Name() string { return s.name } + +// Execute assembles the build context, generates a Dockerfile, builds the +// Docker image, and optionally pushes it. +func (s *BuildFromConfigStep) Execute(ctx context.Context, _ *PipelineContext) (*StepResult, error) { + // Validate that the config file exists. + if _, err := os.Stat(s.configFile); err != nil { + return nil, fmt.Errorf("build_from_config step %q: config_file %q not found: %w", s.name, s.configFile, err) + } + + // Validate that the server binary exists. + if _, err := os.Stat(s.serverBinary); err != nil { + return nil, fmt.Errorf("build_from_config step %q: server_binary %q not found: %w", s.name, s.serverBinary, err) + } + + // Create a temporary build context directory. + buildDir, err := os.MkdirTemp("", "workflow-build-*") + if err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to create temp build dir: %w", s.name, err) + } + defer os.RemoveAll(buildDir) + + // Copy config file into build context as config.yaml. + if err := copyFile(s.configFile, filepath.Join(buildDir, "config.yaml")); err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to copy config file: %w", s.name, err) + } + + // Copy server binary into build context as server. + serverDst := filepath.Join(buildDir, "server") + if err := copyFile(s.serverBinary, serverDst); err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to copy server binary: %w", s.name, err) + } + if err := os.Chmod(serverDst, 0755); err != nil { //nolint:gosec // G302: intentionally executable + return nil, fmt.Errorf("build_from_config step %q: failed to chmod server binary: %w", s.name, err) + } + + // Copy plugin binaries into build context under plugins//. + pluginsDir := filepath.Join(buildDir, "plugins") + for _, plugin := range s.plugins { + if _, err := os.Stat(plugin.Binary); err != nil { + return nil, fmt.Errorf("build_from_config step %q: plugin %q binary %q not found: %w", + s.name, plugin.Name, plugin.Binary, err) + } + pluginDir := filepath.Join(pluginsDir, plugin.Name) + if err := os.MkdirAll(pluginDir, 0750); err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to create plugin dir for %q: %w", + s.name, plugin.Name, err) + } + pluginBinaryName := filepath.Base(plugin.Binary) + pluginDst := filepath.Join(pluginDir, pluginBinaryName) + if err := copyFile(plugin.Binary, pluginDst); err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to copy plugin %q binary: %w", + s.name, plugin.Name, err) + } + if err := os.Chmod(pluginDst, 0755); err != nil { //nolint:gosec // G302: intentionally executable + return nil, fmt.Errorf("build_from_config step %q: failed to chmod plugin %q binary: %w", + s.name, plugin.Name, err) + } + } + + // Generate Dockerfile content. + dockerfileContent := s.generateDockerfile() + + // Write Dockerfile into build context. + dockerfilePath := filepath.Join(buildDir, "Dockerfile") + if err := os.WriteFile(dockerfilePath, []byte(dockerfileContent), 0600); err != nil { + return nil, fmt.Errorf("build_from_config step %q: failed to write Dockerfile: %w", s.name, err) + } + + // Execute docker build. + if err := s.runDockerBuild(ctx, buildDir); err != nil { + return nil, fmt.Errorf("build_from_config step %q: docker build failed: %w", s.name, err) + } + + // Optionally push the image. + if s.push { + if err := s.runDockerPush(ctx); err != nil { + return nil, fmt.Errorf("build_from_config step %q: docker push failed: %w", s.name, err) + } + } + + return &StepResult{ + Output: map[string]any{ + "image_tag": s.tag, + "dockerfile_content": dockerfileContent, + }, + }, nil +} + +// generateDockerfile returns a Dockerfile string for the build context layout. +func (s *BuildFromConfigStep) generateDockerfile() string { + var sb strings.Builder + + fmt.Fprintf(&sb, "FROM %s\n", s.baseImage) + sb.WriteString("COPY server /server\n") + sb.WriteString("COPY config.yaml /app/config.yaml\n") + + if len(s.plugins) > 0 { + sb.WriteString("COPY plugins/ /app/data/plugins/\n") + } + + sb.WriteString("WORKDIR /app\n") + sb.WriteString("ENTRYPOINT [\"/server\"]\n") + sb.WriteString("CMD [\"-config\", \"/app/config.yaml\", \"-data-dir\", \"/app/data\"]\n") + + return sb.String() +} + +// runDockerBuild executes "docker build -t ". +func (s *BuildFromConfigStep) runDockerBuild(ctx context.Context, buildDir string) error { + var stdout, stderr bytes.Buffer + cmd := s.execCommand(ctx, "docker", "build", "-t", s.tag, buildDir) //nolint:gosec // G204: tag from trusted pipeline config + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("%w\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + return nil +} + +// runDockerPush executes "docker push ". +func (s *BuildFromConfigStep) runDockerPush(ctx context.Context) error { + var stdout, stderr bytes.Buffer + cmd := s.execCommand(ctx, "docker", "push", s.tag) //nolint:gosec // G204: tag from trusted pipeline config + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("%w\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + return nil +} + +// copyFile copies src to dst, creating dst if it does not exist. +func copyFile(src, dst string) error { + in, err := os.Open(src) //nolint:gosec // G304: path from trusted pipeline config + if err != nil { + return fmt.Errorf("open %q: %w", src, err) + } + defer in.Close() + + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("create %q: %w", dst, err) + } + defer out.Close() + + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy %q -> %q: %w", src, dst, err) + } + return nil +} diff --git a/module/pipeline_step_build_from_config_test.go b/module/pipeline_step_build_from_config_test.go new file mode 100644 index 00000000..5ee6b210 --- /dev/null +++ b/module/pipeline_step_build_from_config_test.go @@ -0,0 +1,517 @@ +package module + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// setupBuildFromConfigFiles creates a temporary directory with a fake config +// file and a fake server binary (empty files). It returns the directory path +// and a cleanup function. +func setupBuildFromConfigFiles(t *testing.T) (configFile, serverBinary string, cleanup func()) { + t.Helper() + dir := t.TempDir() + + configFile = filepath.Join(dir, "app.yaml") + if err := os.WriteFile(configFile, []byte("version: 1\n"), 0600); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + serverBinary = filepath.Join(dir, "workflow-server") + if err := os.WriteFile(serverBinary, []byte("#!/bin/sh\n"), 0755); err != nil { //nolint:gosec + t.Fatalf("failed to create server binary: %v", err) + } + + return configFile, serverBinary, func() {} // t.TempDir cleans up automatically +} + +// noopExecCommand returns a mock exec.CommandContext function that succeeds +// without running any real process. +func noopExecCommand(_ context.Context, name string, args ...string) *exec.Cmd { + // Invoke a real no-op command so cmd.Run() succeeds. + return exec.Command("true") +} + +// failingExecCommand returns a mock that always fails with an exit error. +func failingExecCommand(_ context.Context, _ string, _ ...string) *exec.Cmd { + return exec.Command("false") +} + +func TestBuildFromConfigStep_FactoryRequiresConfigFile(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + _, err := factory("bfc", map[string]any{"tag": "my-app:latest"}, nil) + if err == nil { + t.Fatal("expected error when config_file is missing") + } + if !strings.Contains(err.Error(), "config_file") { + t.Errorf("expected error to mention config_file, got: %v", err) + } +} + +func TestBuildFromConfigStep_FactoryRequiresTag(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + _, err := factory("bfc", map[string]any{"config_file": "app.yaml"}, nil) + if err == nil { + t.Fatal("expected error when tag is missing") + } + if !strings.Contains(err.Error(), "tag") { + t.Errorf("expected error to mention tag, got: %v", err) + } +} + +func TestBuildFromConfigStep_FactoryPluginMissingFields(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + _, err := factory("bfc", map[string]any{ + "config_file": "app.yaml", + "tag": "my-app:latest", + "plugins": []any{ + map[string]any{"name": "admin"}, // missing binary + }, + }, nil) + if err == nil { + t.Fatal("expected error when plugin binary is missing") + } + if !strings.Contains(err.Error(), "binary") { + t.Errorf("expected error to mention binary, got: %v", err) + } +} + +func TestBuildFromConfigStep_FactoryPluginInvalidEntry(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + _, err := factory("bfc", map[string]any{ + "config_file": "app.yaml", + "tag": "my-app:latest", + "plugins": []any{"not-a-map"}, + }, nil) + if err == nil { + t.Fatal("expected error for non-map plugin entry") + } +} + +func TestBuildFromConfigStep_Name(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + step, err := factory("my-build", map[string]any{ + "config_file": "app.yaml", + "tag": "my-app:latest", + }, nil) + if err != nil { + t.Fatalf("unexpected factory error: %v", err) + } + if step.Name() != "my-build" { + t.Errorf("expected name %q, got %q", "my-build", step.Name()) + } +} + +func TestBuildFromConfigStep_DefaultBaseImage(t *testing.T) { + factory := NewBuildFromConfigStepFactory() + raw, err := factory("bfc", map[string]any{ + "config_file": "app.yaml", + "tag": "my-app:latest", + }, nil) + if err != nil { + t.Fatalf("unexpected factory error: %v", err) + } + bfc := raw.(*BuildFromConfigStep) + if bfc.baseImage != "ghcr.io/gocodealone/workflow-runtime:latest" { + t.Errorf("unexpected default base_image: %q", bfc.baseImage) + } +} + +func TestBuildFromConfigStep_GenerateDockerfile_NoPLugins(t *testing.T) { + s := &BuildFromConfigStep{ + name: "bfc", + baseImage: "gcr.io/distroless/static-debian12:nonroot", + tag: "my-app:latest", + plugins: nil, + } + + got := s.generateDockerfile() + + expectedLines := []string{ + "FROM gcr.io/distroless/static-debian12:nonroot", + "COPY server /server", + "COPY config.yaml /app/config.yaml", + "WORKDIR /app", + "ENTRYPOINT [\"/server\"]", + `CMD ["-config", "/app/config.yaml", "-data-dir", "/app/data"]`, + } + + for _, line := range expectedLines { + if !strings.Contains(got, line) { + t.Errorf("Dockerfile missing line %q\nGot:\n%s", line, got) + } + } + + // Without plugins, there should be no plugins COPY line. + if strings.Contains(got, "COPY plugins/") { + t.Errorf("Dockerfile should not contain plugins COPY when no plugins configured") + } +} + +func TestBuildFromConfigStep_GenerateDockerfile_WithPlugins(t *testing.T) { + s := &BuildFromConfigStep{ + name: "bfc", + baseImage: "gcr.io/distroless/static-debian12:nonroot", + tag: "my-app:latest", + plugins: []PluginSpec{ + {Name: "admin", Binary: "data/plugins/admin/admin"}, + }, + } + + got := s.generateDockerfile() + + if !strings.Contains(got, "COPY plugins/ /app/data/plugins/") { + t.Errorf("Dockerfile should contain plugins COPY line when plugins are configured\nGot:\n%s", got) + } +} + +func TestBuildFromConfigStep_Execute_MissingConfigFile(t *testing.T) { + s := &BuildFromConfigStep{ + name: "bfc", + configFile: "/nonexistent/app.yaml", + serverBinary: "/nonexistent/server", + tag: "my-app:latest", + execCommand: noopExecCommand, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err == nil { + t.Fatal("expected error for missing config_file") + } + if !strings.Contains(err.Error(), "config_file") { + t.Errorf("expected error to mention config_file, got: %v", err) + } +} + +func TestBuildFromConfigStep_Execute_MissingServerBinary(t *testing.T) { + configFile, _, _ := setupBuildFromConfigFiles(t) + + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: "/nonexistent/server", + tag: "my-app:latest", + execCommand: noopExecCommand, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err == nil { + t.Fatal("expected error for missing server_binary") + } + if !strings.Contains(err.Error(), "server_binary") { + t.Errorf("expected error to mention server_binary, got: %v", err) + } +} + +func TestBuildFromConfigStep_Execute_MissingPluginBinary(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + tag: "my-app:latest", + plugins: []PluginSpec{ + {Name: "admin", Binary: "/nonexistent/admin"}, + }, + execCommand: noopExecCommand, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err == nil { + t.Fatal("expected error for missing plugin binary") + } + if !strings.Contains(err.Error(), "plugin") { + t.Errorf("expected error to mention plugin, got: %v", err) + } +} + +func TestBuildFromConfigStep_Execute_DockerBuildFailure(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + tag: "my-app:latest", + execCommand: failingExecCommand, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err == nil { + t.Fatal("expected error when docker build fails") + } + if !strings.Contains(err.Error(), "docker build") { + t.Errorf("expected error to mention docker build, got: %v", err) + } +} + +func TestBuildFromConfigStep_Execute_DockerPushFailure(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + callCount := 0 + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + tag: "my-app:latest", + push: true, + execCommand: func(ctx context.Context, name string, args ...string) *exec.Cmd { + callCount++ + if callCount == 1 { + // First call is docker build — succeed. + return exec.Command("true") + } + // Second call is docker push — fail. + return exec.Command("false") + }, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err == nil { + t.Fatal("expected error when docker push fails") + } + if !strings.Contains(err.Error(), "docker push") { + t.Errorf("expected error to mention docker push, got: %v", err) + } +} + +func TestBuildFromConfigStep_Execute_NoPush(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + buildCalled := false + pushCalled := false + + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + baseImage: "gcr.io/distroless/static-debian12:nonroot", + tag: "my-app:latest", + push: false, + execCommand: func(ctx context.Context, name string, args ...string) *exec.Cmd { + if name == "docker" && len(args) > 0 { + switch args[0] { + case "build": + buildCalled = true + case "push": + pushCalled = true + } + } + return exec.Command("true") + }, + } + + result, err := s.Execute(context.Background(), &PipelineContext{}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if !buildCalled { + t.Error("expected docker build to be called") + } + if pushCalled { + t.Error("expected docker push NOT to be called when push=false") + } + + if result.Output["image_tag"] != "my-app:latest" { + t.Errorf("expected image_tag %q, got %v", "my-app:latest", result.Output["image_tag"]) + } + + dockerfileContent, ok := result.Output["dockerfile_content"].(string) + if !ok || dockerfileContent == "" { + t.Error("expected dockerfile_content to be non-empty string") + } +} + +func TestBuildFromConfigStep_Execute_WithPush(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + var dockerCalls []string + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + baseImage: "gcr.io/distroless/static-debian12:nonroot", + tag: "my-app:latest", + push: true, + execCommand: func(ctx context.Context, name string, args ...string) *exec.Cmd { + if name == "docker" && len(args) > 0 { + dockerCalls = append(dockerCalls, args[0]) + } + return exec.Command("true") + }, + } + + result, err := s.Execute(context.Background(), &PipelineContext{}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if len(dockerCalls) != 2 { + t.Fatalf("expected 2 docker calls (build + push), got %d: %v", len(dockerCalls), dockerCalls) + } + if dockerCalls[0] != "build" { + t.Errorf("expected first docker call to be 'build', got %q", dockerCalls[0]) + } + if dockerCalls[1] != "push" { + t.Errorf("expected second docker call to be 'push', got %q", dockerCalls[1]) + } + + if result.Output["image_tag"] != "my-app:latest" { + t.Errorf("expected image_tag %q, got %v", "my-app:latest", result.Output["image_tag"]) + } +} + +func TestBuildFromConfigStep_Execute_WithPlugins(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + // Create fake plugin binaries. + pluginDir := t.TempDir() + adminBinary := filepath.Join(pluginDir, "admin") + if err := os.WriteFile(adminBinary, []byte("#!/bin/sh\n"), 0755); err != nil { //nolint:gosec + t.Fatalf("failed to create admin binary: %v", err) + } + bentoBinary := filepath.Join(pluginDir, "workflow-plugin-bento") + if err := os.WriteFile(bentoBinary, []byte("#!/bin/sh\n"), 0755); err != nil { //nolint:gosec + t.Fatalf("failed to create bento binary: %v", err) + } + + var buildArgs []string + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + baseImage: "gcr.io/distroless/static-debian12:nonroot", + tag: "my-app:latest", + push: false, + plugins: []PluginSpec{ + {Name: "admin", Binary: adminBinary}, + {Name: "bento", Binary: bentoBinary}, + }, + execCommand: func(ctx context.Context, name string, args ...string) *exec.Cmd { + if name == "docker" && len(args) > 0 && args[0] == "build" { + buildArgs = args + } + return exec.Command("true") + }, + } + + result, err := s.Execute(context.Background(), &PipelineContext{}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // Verify the Dockerfile includes the plugins COPY line. + dockerfileContent, _ := result.Output["dockerfile_content"].(string) + if !strings.Contains(dockerfileContent, "COPY plugins/ /app/data/plugins/") { + t.Errorf("Dockerfile should contain plugins COPY line\nGot:\n%s", dockerfileContent) + } + + // Verify docker build was called with a context dir argument. + if len(buildArgs) < 3 { + t.Fatalf("expected docker build -t , got args: %v", buildArgs) + } +} + +func TestBuildFromConfigStep_Execute_BuildContextLayout(t *testing.T) { + configFile, serverBinary, _ := setupBuildFromConfigFiles(t) + + pluginDir := t.TempDir() + adminBinary := filepath.Join(pluginDir, "admin") + if err := os.WriteFile(adminBinary, []byte("#!/bin/sh\n"), 0755); err != nil { //nolint:gosec + t.Fatalf("failed to create plugin binary: %v", err) + } + + var capturedBuildDir string + s := &BuildFromConfigStep{ + name: "bfc", + configFile: configFile, + serverBinary: serverBinary, + baseImage: "alpine:latest", + tag: "my-app:latest", + plugins: []PluginSpec{ + {Name: "admin", Binary: adminBinary}, + }, + execCommand: func(ctx context.Context, name string, args ...string) *exec.Cmd { + // Capture the build context dir (last argument to docker build). + if name == "docker" && len(args) > 0 && args[0] == "build" { + capturedBuildDir = args[len(args)-1] + // Make a copy so we can inspect it after Execute returns + // (Execute defers RemoveAll on buildDir). + copyDir := t.TempDir() + _ = copyDirRecursive(capturedBuildDir, copyDir) + capturedBuildDir = copyDir + } + return exec.Command("true") + }, + } + + _, err := s.Execute(context.Background(), &PipelineContext{}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // Check expected files in the copied build context. + expectedFiles := []string{ + "Dockerfile", + "config.yaml", + "server", + filepath.Join("plugins", "admin", "admin"), + } + for _, f := range expectedFiles { + if _, err := os.Stat(filepath.Join(capturedBuildDir, f)); err != nil { + t.Errorf("build context missing expected file %q: %v", f, err) + } + } +} + +// copyDirRecursive copies the contents of src into dst directory. +func copyDirRecursive(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + rel, err := filepath.Rel(src, path) + if err != nil { + return err + } + dstPath := filepath.Join(dst, rel) + if info.IsDir() { + return os.MkdirAll(dstPath, info.Mode()) + } + return func() error { + in, err := os.Open(path) //nolint:gosec + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer out.Close() + _, err = fmt.Fprintf(out, "") + if err != nil { + return err + } + _, err = out.Seek(0, 0) + if err != nil { + return err + } + f, err := os.Open(path) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + _, copyErr := io.Copy(out, f) + return copyErr + }() + }) +} diff --git a/plugins/cicd/plugin.go b/plugins/cicd/plugin.go index 91544b89..145c695d 100644 --- a/plugins/cicd/plugin.go +++ b/plugins/cicd/plugin.go @@ -1,6 +1,7 @@ // Package cicd provides a plugin that registers CI/CD pipeline step types: // shell_exec, artifact_pull, artifact_push, docker_build, docker_push, -// docker_run, scan_sast, scan_container, scan_deps, deploy, gate, build_ui. +// docker_run, scan_sast, scan_container, scan_deps, deploy, gate, build_ui, +// build_from_config. package cicd import ( @@ -22,13 +23,13 @@ func New() *Plugin { BaseNativePlugin: plugin.BaseNativePlugin{ PluginName: "cicd", PluginVersion: "1.0.0", - PluginDescription: "CI/CD pipeline step types (shell exec, Docker, artifact management, security scanning, deploy, gate)", + PluginDescription: "CI/CD pipeline step types (shell exec, Docker, artifact management, security scanning, deploy, gate, build from config)", }, Manifest: plugin.PluginManifest{ Name: "cicd", Version: "1.0.0", Author: "GoCodeAlone", - Description: "CI/CD pipeline step types (shell exec, Docker, artifact management, security scanning, deploy, gate)", + Description: "CI/CD pipeline step types (shell exec, Docker, artifact management, security scanning, deploy, gate, build from config)", Tier: plugin.TierCore, StepTypes: []string{ "step.shell_exec", @@ -43,6 +44,7 @@ func New() *Plugin { "step.deploy", "step.gate", "step.build_ui", + "step.build_from_config", }, Capabilities: []plugin.CapabilityDecl{ {Name: "cicd-pipeline", Role: "provider", Priority: 50}, @@ -57,7 +59,7 @@ func (p *Plugin) Capabilities() []capability.Contract { return []capability.Contract{ { Name: "cicd-pipeline", - Description: "CI/CD pipeline operations: shell exec, Docker, artifact management, security scanning, deploy, gate", + Description: "CI/CD pipeline operations: shell exec, Docker, artifact management, security scanning, deploy, gate, build from config", }, } } @@ -65,18 +67,19 @@ func (p *Plugin) Capabilities() []capability.Contract { // StepFactories returns the CI/CD step factories. func (p *Plugin) StepFactories() map[string]plugin.StepFactory { return map[string]plugin.StepFactory{ - "step.shell_exec": wrapStepFactory(module.NewShellExecStepFactory()), - "step.artifact_pull": wrapStepFactory(module.NewArtifactPullStepFactory()), - "step.artifact_push": wrapStepFactory(module.NewArtifactPushStepFactory()), - "step.docker_build": wrapStepFactory(module.NewDockerBuildStepFactory()), - "step.docker_push": wrapStepFactory(module.NewDockerPushStepFactory()), - "step.docker_run": wrapStepFactory(module.NewDockerRunStepFactory()), - "step.scan_sast": wrapStepFactory(module.NewScanSASTStepFactory()), - "step.scan_container": wrapStepFactory(module.NewScanContainerStepFactory()), - "step.scan_deps": wrapStepFactory(module.NewScanDepsStepFactory()), - "step.deploy": wrapStepFactory(module.NewDeployStepFactory()), - "step.gate": wrapStepFactory(module.NewGateStepFactory()), - "step.build_ui": wrapStepFactory(module.NewBuildUIStepFactory()), + "step.shell_exec": wrapStepFactory(module.NewShellExecStepFactory()), + "step.artifact_pull": wrapStepFactory(module.NewArtifactPullStepFactory()), + "step.artifact_push": wrapStepFactory(module.NewArtifactPushStepFactory()), + "step.docker_build": wrapStepFactory(module.NewDockerBuildStepFactory()), + "step.docker_push": wrapStepFactory(module.NewDockerPushStepFactory()), + "step.docker_run": wrapStepFactory(module.NewDockerRunStepFactory()), + "step.scan_sast": wrapStepFactory(module.NewScanSASTStepFactory()), + "step.scan_container": wrapStepFactory(module.NewScanContainerStepFactory()), + "step.scan_deps": wrapStepFactory(module.NewScanDepsStepFactory()), + "step.deploy": wrapStepFactory(module.NewDeployStepFactory()), + "step.gate": wrapStepFactory(module.NewGateStepFactory()), + "step.build_ui": wrapStepFactory(module.NewBuildUIStepFactory()), + "step.build_from_config": wrapStepFactory(module.NewBuildFromConfigStepFactory()), } } diff --git a/plugins/cicd/plugin_test.go b/plugins/cicd/plugin_test.go index cacea13c..0f598a05 100644 --- a/plugins/cicd/plugin_test.go +++ b/plugins/cicd/plugin_test.go @@ -43,6 +43,7 @@ func TestStepFactories(t *testing.T) { "step.deploy", "step.gate", "step.build_ui", + "step.build_from_config", } for _, stepType := range expectedSteps { @@ -54,6 +55,7 @@ func TestStepFactories(t *testing.T) { if len(factories) != len(expectedSteps) { t.Errorf("expected %d step factories, got %d", len(expectedSteps), len(factories)) } + } func TestPluginLoads(t *testing.T) { @@ -64,7 +66,7 @@ func TestPluginLoads(t *testing.T) { } steps := loader.StepFactories() - if len(steps) != 12 { - t.Fatalf("expected 12 step factories after load, got %d", len(steps)) + if len(steps) != 13 { + t.Fatalf("expected 13 step factories after load, got %d", len(steps)) } } From 090c9cf9ddb50fb59c32f9e8de17c508cd42cbfd Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Wed, 25 Feb 2026 01:07:27 -0500 Subject: [PATCH 2/5] feat: add step.statemachine_transition and step.statemachine_get pipeline steps Implements two new pipeline steps for interacting with state machine workflow instances directly from pipeline execution: - step.statemachine_transition: triggers a named transition on a workflow instance, supports data templates, fail_on_error flag, and the TransitionTrigger interface for testability via mocks. - step.statemachine_get: reads the current state of a workflow instance. Both steps resolve entity_id and data fields from Go templates using the existing TemplateEngine, look up the named StateMachineEngine by service name from the app registry, and return structured output (transition_ok, new_state, current_state, entity_id). Steps are registered in plugins/statemachine with StepFactories() and declared in the manifest's StepTypes list. Co-Authored-By: Claude Opus 4.6 --- module/pipeline_step_statemachine_get.go | 88 ++++ module/pipeline_step_statemachine_get_test.go | 194 ++++++++ .../pipeline_step_statemachine_transition.go | 188 ++++++++ ...eline_step_statemachine_transition_test.go | 419 ++++++++++++++++++ plugins/statemachine/plugin.go | 20 + plugins/statemachine/plugin_test.go | 18 + 6 files changed, 927 insertions(+) create mode 100644 module/pipeline_step_statemachine_get.go create mode 100644 module/pipeline_step_statemachine_get_test.go create mode 100644 module/pipeline_step_statemachine_transition.go create mode 100644 module/pipeline_step_statemachine_transition_test.go diff --git a/module/pipeline_step_statemachine_get.go b/module/pipeline_step_statemachine_get.go new file mode 100644 index 00000000..293f1bd4 --- /dev/null +++ b/module/pipeline_step_statemachine_get.go @@ -0,0 +1,88 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// StateMachineGetStep reads the current state of a workflow instance. +type StateMachineGetStep struct { + name string + statemachine string + entityID string + app modular.Application + tmpl *TemplateEngine +} + +// NewStateMachineGetStepFactory returns a StepFactory for step.statemachine_get. +// +// Config: +// +// type: step.statemachine_get +// config: +// statemachine: "order-sm" # service name of the StateMachineEngine +// entity_id: "{{.order_id}}" # which instance to look up (template) +// +// Outputs: current_state (string), entity_id (string). +// Returns an error (stopping the pipeline) when the instance is not found. +func NewStateMachineGetStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + sm, _ := config["statemachine"].(string) + if sm == "" { + return nil, fmt.Errorf("statemachine_get step %q: 'statemachine' is required", name) + } + + entityID, _ := config["entity_id"].(string) + if entityID == "" { + return nil, fmt.Errorf("statemachine_get step %q: 'entity_id' is required", name) + } + + return &StateMachineGetStep{ + name: name, + statemachine: sm, + entityID: entityID, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +// Name returns the step name. +func (s *StateMachineGetStep) Name() string { return s.name } + +// Execute resolves the entity_id template, looks up the StateMachineEngine, and +// returns the current state of the workflow instance. +func (s *StateMachineGetStep) Execute(_ context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("statemachine_get step %q: no application context", s.name) + } + + svc, ok := s.app.SvcRegistry()[s.statemachine] + if !ok { + return nil, fmt.Errorf("statemachine_get step %q: statemachine service %q not found", s.name, s.statemachine) + } + + engine, ok := svc.(*StateMachineEngine) + if !ok { + return nil, fmt.Errorf("statemachine_get step %q: service %q is not a StateMachineEngine", s.name, s.statemachine) + } + + entityID, err := s.tmpl.Resolve(s.entityID, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_get step %q: failed to resolve entity_id: %w", s.name, err) + } + + instance, err := engine.GetInstance(entityID) + if err != nil { + return nil, fmt.Errorf("statemachine_get step %q: instance not found: %w", s.name, err) + } + + return &StepResult{ + Output: map[string]any{ + "current_state": instance.CurrentState, + "entity_id": entityID, + }, + }, nil +} diff --git a/module/pipeline_step_statemachine_get_test.go b/module/pipeline_step_statemachine_get_test.go new file mode 100644 index 00000000..320dfb24 --- /dev/null +++ b/module/pipeline_step_statemachine_get_test.go @@ -0,0 +1,194 @@ +package module + +import ( + "context" + "testing" +) + +// --- Factory validation tests --- + +func TestStateMachineGetStep_MissingStatemachine(t *testing.T) { + factory := NewStateMachineGetStepFactory() + _, err := factory("get-state", map[string]any{ + "entity_id": "order-1", + }, nil) + if err == nil { + t.Fatal("expected error for missing statemachine") + } +} + +func TestStateMachineGetStep_MissingEntityID(t *testing.T) { + factory := NewStateMachineGetStepFactory() + _, err := factory("get-state", map[string]any{ + "statemachine": "order-sm", + }, nil) + if err == nil { + t.Fatal("expected error for missing entity_id") + } +} + +// --- Execution tests --- + +func TestStateMachineGetStep_ReturnsCurrentState(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-order-state", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + state, _ := result.Output["current_state"].(string) + if state != "pending" { + t.Errorf("expected current_state='pending', got %q", state) + } + entityID, _ := result.Output["entity_id"].(string) + if entityID != "order-1" { + t.Errorf("expected entity_id='order-1', got %q", entityID) + } +} + +func TestStateMachineGetStep_TemplatedEntityID(t *testing.T) { + engine := setupOrderStateMachine(t, "order-99", "") + app := newAppWithSM(engine) + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-state-template", map[string]any{ + "statemachine": "order-sm", + "entity_id": "{{.order_id}}", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"order_id": "order-99"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + state, _ := result.Output["current_state"].(string) + if state != "pending" { + t.Errorf("expected current_state='pending', got %q", state) + } + entityID, _ := result.Output["entity_id"].(string) + if entityID != "order-99" { + t.Errorf("expected entity_id='order-99', got %q", entityID) + } +} + +func TestStateMachineGetStep_ReturnsStateAfterTransition(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + // Trigger a transition first + if err := engine.TriggerTransition(context.Background(), "order-1", "approve", nil); err != nil { + t.Fatalf("trigger transition: %v", err) + } + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-approved-state", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + state, _ := result.Output["current_state"].(string) + if state != "approved" { + t.Errorf("expected current_state='approved', got %q", state) + } +} + +func TestStateMachineGetStep_InstanceNotFound(t *testing.T) { + engine := setupOrderStateMachine(t, "", "") // no instances created + app := newAppWithSM(engine) + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-missing", map[string]any{ + "statemachine": "order-sm", + "entity_id": "nonexistent", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for nonexistent instance") + } +} + +func TestStateMachineGetStep_ServiceNotFound(t *testing.T) { + app := NewMockApplication() + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-state", map[string]any{ + "statemachine": "nonexistent-sm", + "entity_id": "order-1", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing service") + } +} + +func TestStateMachineGetStep_ServiceWrongType(t *testing.T) { + app := NewMockApplication() + app.Services["order-sm"] = "not-an-engine" + + factory := NewStateMachineGetStepFactory() + step, err := factory("get-state", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for wrong service type") + } +} + +func TestStateMachineGetStep_NoAppContext(t *testing.T) { + factory := NewStateMachineGetStepFactory() + step, err := factory("get-state", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for nil app") + } +} diff --git a/module/pipeline_step_statemachine_transition.go b/module/pipeline_step_statemachine_transition.go new file mode 100644 index 00000000..c193d05b --- /dev/null +++ b/module/pipeline_step_statemachine_transition.go @@ -0,0 +1,188 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// StateMachineTransitionStep triggers a state machine transition from within a pipeline. +type StateMachineTransitionStep struct { + name string + statemachine string + entityID string + event string + data map[string]any + failOnError bool + app modular.Application + tmpl *TemplateEngine +} + +// NewStateMachineTransitionStepFactory returns a StepFactory for step.statemachine_transition. +// +// Config: +// +// type: step.statemachine_transition +// config: +// statemachine: "order-sm" # service name of the StateMachineEngine +// entity_id: "{{.order_id}}" # which instance to transition (template) +// event: "approve" # transition name +// data: # optional data map (values may use templates) +// approved_by: "{{.user_id}}" +// fail_on_error: false # stop pipeline on invalid transition (default: false) +// +// Outputs: transition_ok (bool), new_state (string), error (string, only on failure). +func NewStateMachineTransitionStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + sm, _ := config["statemachine"].(string) + if sm == "" { + return nil, fmt.Errorf("statemachine_transition step %q: 'statemachine' is required", name) + } + + entityID, _ := config["entity_id"].(string) + if entityID == "" { + return nil, fmt.Errorf("statemachine_transition step %q: 'entity_id' is required", name) + } + + event, _ := config["event"].(string) + if event == "" { + return nil, fmt.Errorf("statemachine_transition step %q: 'event' is required", name) + } + + var data map[string]any + if d, ok := config["data"].(map[string]any); ok { + data = d + } + + failOnError, _ := config["fail_on_error"].(bool) + + return &StateMachineTransitionStep{ + name: name, + statemachine: sm, + entityID: entityID, + event: event, + data: data, + failOnError: failOnError, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +// Name returns the step name. +func (s *StateMachineTransitionStep) Name() string { return s.name } + +// Execute resolves templates, looks up the StateMachineEngine by service name, and +// triggers the requested transition. On success it sets transition_ok=true and +// new_state to the resulting state. On failure it sets transition_ok=false and +// error to the error message; if fail_on_error is true the pipeline is stopped. +func (s *StateMachineTransitionStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("statemachine_transition step %q: no application context", s.name) + } + + // Resolve statemachine engine from service registry + svc, ok := s.app.SvcRegistry()[s.statemachine] + if !ok { + return nil, fmt.Errorf("statemachine_transition step %q: statemachine service %q not found", s.name, s.statemachine) + } + + engine, ok := svc.(*StateMachineEngine) + if !ok { + // Also accept the TransitionTrigger interface for testability / mocking + trigger, ok := svc.(TransitionTrigger) + if !ok { + return nil, fmt.Errorf("statemachine_transition step %q: service %q does not implement StateMachineEngine or TransitionTrigger", s.name, s.statemachine) + } + return s.executeViaTrigger(ctx, pc, trigger) + } + + return s.executeViaEngine(ctx, pc, engine) +} + +func (s *StateMachineTransitionStep) executeViaEngine(ctx context.Context, pc *PipelineContext, engine *StateMachineEngine) (*StepResult, error) { + entityID, err := s.tmpl.Resolve(s.entityID, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve entity_id: %w", s.name, err) + } + + event, err := s.tmpl.Resolve(s.event, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve event: %w", s.name, err) + } + + data, err := s.tmpl.ResolveMap(s.data, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve data: %w", s.name, err) + } + + transErr := engine.TriggerTransition(ctx, entityID, event, data) + if transErr != nil { + if s.failOnError { + return nil, fmt.Errorf("statemachine_transition step %q: transition failed: %w", s.name, transErr) + } + return &StepResult{ + Output: map[string]any{ + "transition_ok": false, + "error": transErr.Error(), + }, + }, nil + } + + // Fetch the new state from the engine + instance, err := engine.GetInstance(entityID) + if err != nil { + // Transition succeeded but we can't read new state — treat as success with unknown state + return &StepResult{ + Output: map[string]any{ + "transition_ok": true, + "new_state": "", + }, + }, nil + } + + return &StepResult{ + Output: map[string]any{ + "transition_ok": true, + "new_state": instance.CurrentState, + }, + }, nil +} + +func (s *StateMachineTransitionStep) executeViaTrigger(ctx context.Context, pc *PipelineContext, trigger TransitionTrigger) (*StepResult, error) { + entityID, err := s.tmpl.Resolve(s.entityID, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve entity_id: %w", s.name, err) + } + + event, err := s.tmpl.Resolve(s.event, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve event: %w", s.name, err) + } + + data, err := s.tmpl.ResolveMap(s.data, pc) + if err != nil { + return nil, fmt.Errorf("statemachine_transition step %q: failed to resolve data: %w", s.name, err) + } + + transErr := trigger.TriggerTransition(ctx, entityID, event, data) + if transErr != nil { + if s.failOnError { + return nil, fmt.Errorf("statemachine_transition step %q: transition failed: %w", s.name, transErr) + } + return &StepResult{ + Output: map[string]any{ + "transition_ok": false, + "error": transErr.Error(), + }, + }, nil + } + + return &StepResult{ + Output: map[string]any{ + "transition_ok": true, + "new_state": "", + }, + }, nil +} diff --git a/module/pipeline_step_statemachine_transition_test.go b/module/pipeline_step_statemachine_transition_test.go new file mode 100644 index 00000000..3cb27865 --- /dev/null +++ b/module/pipeline_step_statemachine_transition_test.go @@ -0,0 +1,419 @@ +package module + +import ( + "context" + "errors" + "testing" +) + +// mockTransitionTrigger implements TransitionTrigger for testing without a real engine. +type mockTransitionTrigger struct { + triggerErr error + capturedID string + capturedEvt string + capturedData map[string]any +} + +func (m *mockTransitionTrigger) TriggerTransition(_ context.Context, workflowID, transitionName string, data map[string]any) error { + m.capturedID = workflowID + m.capturedEvt = transitionName + m.capturedData = data + return m.triggerErr +} + +// setupOrderStateMachine creates a StateMachineEngine with a simple order workflow. +func setupOrderStateMachine(t *testing.T, instanceID, initialState string) *StateMachineEngine { + t.Helper() + + engine := NewStateMachineEngine("order-sm") + def := &StateMachineDefinition{ + Name: "order", + InitialState: "pending", + States: map[string]*State{ + "pending": {Name: "pending"}, + "approved": {Name: "approved"}, + "rejected": {Name: "rejected", IsFinal: true}, + }, + Transitions: map[string]*Transition{ + "approve": {Name: "approve", FromState: "pending", ToState: "approved"}, + "reject": {Name: "reject", FromState: "pending", ToState: "rejected"}, + }, + } + if err := engine.RegisterDefinition(def); err != nil { + t.Fatalf("register definition: %v", err) + } + + if instanceID != "" { + instance, err := engine.CreateWorkflow("order", instanceID, nil) + if err != nil { + t.Fatalf("create workflow: %v", err) + } + // If caller wants a non-initial state, force it directly for test setup + if initialState != "" && initialState != def.InitialState { + instance.CurrentState = initialState + } + } + + return engine +} + +// newAppWithSM registers the engine under "order-sm" in a MockApplication. +func newAppWithSM(engine *StateMachineEngine) *MockApplication { + app := NewMockApplication() + app.Services["order-sm"] = engine + return app +} + +// --- Factory validation tests --- + +func TestStateMachineTransitionStep_MissingStatemachine(t *testing.T) { + factory := NewStateMachineTransitionStepFactory() + _, err := factory("step1", map[string]any{ + "entity_id": "order-1", + "event": "approve", + }, nil) + if err == nil { + t.Fatal("expected error for missing statemachine") + } +} + +func TestStateMachineTransitionStep_MissingEntityID(t *testing.T) { + factory := NewStateMachineTransitionStepFactory() + _, err := factory("step1", map[string]any{ + "statemachine": "order-sm", + "event": "approve", + }, nil) + if err == nil { + t.Fatal("expected error for missing entity_id") + } +} + +func TestStateMachineTransitionStep_MissingEvent(t *testing.T) { + factory := NewStateMachineTransitionStepFactory() + _, err := factory("step1", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + }, nil) + if err == nil { + t.Fatal("expected error for missing event") + } +} + +// --- Execution tests: using real StateMachineEngine --- + +func TestStateMachineTransitionStep_SuccessfulTransition(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("approve-order", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + ok, _ := result.Output["transition_ok"].(bool) + if !ok { + t.Error("expected transition_ok=true") + } + newState, _ := result.Output["new_state"].(string) + if newState != "approved" { + t.Errorf("expected new_state='approved', got %q", newState) + } +} + +func TestStateMachineTransitionStep_TemplatedEntityIDAndEvent(t *testing.T) { + engine := setupOrderStateMachine(t, "order-42", "") + app := newAppWithSM(engine) + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("dynamic-approve", map[string]any{ + "statemachine": "order-sm", + "entity_id": "{{.order_id}}", + "event": "{{.action}}", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{ + "order_id": "order-42", + "action": "approve", + }, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + ok, _ := result.Output["transition_ok"].(bool) + if !ok { + t.Error("expected transition_ok=true") + } +} + +func TestStateMachineTransitionStep_InvalidTransition_FailOnErrorFalse(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + factory := NewStateMachineTransitionStepFactory() + // "reject" is valid, but "approve" from "approved" is not — trigger approve twice + step, err := factory("double-approve", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + "fail_on_error": false, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + // First transition succeeds + _, err = step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("first execute error: %v", err) + } + + // Second transition: "approve" from "approved" is invalid + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("expected no pipeline error with fail_on_error=false, got: %v", err) + } + + ok, _ := result.Output["transition_ok"].(bool) + if ok { + t.Error("expected transition_ok=false for invalid transition") + } + errMsg, _ := result.Output["error"].(string) + if errMsg == "" { + t.Error("expected error message in output") + } +} + +func TestStateMachineTransitionStep_InvalidTransition_FailOnErrorTrue(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("strict-approve", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + "fail_on_error": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + // First transition succeeds + _, err = step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("first execute error: %v", err) + } + + // Second transition should fail with pipeline error + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected pipeline error with fail_on_error=true") + } +} + +func TestStateMachineTransitionStep_WithData(t *testing.T) { + engine := setupOrderStateMachine(t, "order-1", "") + app := newAppWithSM(engine) + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("approve-with-data", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + "data": map[string]any{ + "approved_by": "{{.user_id}}", + }, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"user_id": "u-99"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + ok, _ := result.Output["transition_ok"].(bool) + if !ok { + t.Error("expected transition_ok=true") + } + + // Verify data was merged into instance + instance, err := engine.GetInstance("order-1") + if err != nil { + t.Fatalf("get instance: %v", err) + } + approvedBy, _ := instance.Data["approved_by"].(string) + if approvedBy != "u-99" { + t.Errorf("expected approved_by='u-99', got %q", approvedBy) + } +} + +func TestStateMachineTransitionStep_ServiceNotFound(t *testing.T) { + app := NewMockApplication() + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("step1", map[string]any{ + "statemachine": "nonexistent-sm", + "entity_id": "order-1", + "event": "approve", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing service") + } +} + +func TestStateMachineTransitionStep_ServiceWrongType(t *testing.T) { + app := NewMockApplication() + // Register something that is neither *StateMachineEngine nor TransitionTrigger + app.Services["order-sm"] = "not-an-engine" + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("step1", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for wrong service type") + } +} + +func TestStateMachineTransitionStep_NoAppContext(t *testing.T) { + factory := NewStateMachineTransitionStepFactory() + step, err := factory("step1", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for nil app") + } +} + +// --- Execution tests: using mock TransitionTrigger --- + +func TestStateMachineTransitionStep_MockTrigger_Success(t *testing.T) { + mock := &mockTransitionTrigger{} + app := NewMockApplication() + app.Services["order-sm"] = mock + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("mock-approve", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if mock.capturedID != "order-1" { + t.Errorf("expected capturedID='order-1', got %q", mock.capturedID) + } + if mock.capturedEvt != "approve" { + t.Errorf("expected capturedEvt='approve', got %q", mock.capturedEvt) + } + + ok, _ := result.Output["transition_ok"].(bool) + if !ok { + t.Error("expected transition_ok=true") + } +} + +func TestStateMachineTransitionStep_MockTrigger_Error_NoFail(t *testing.T) { + mock := &mockTransitionTrigger{triggerErr: errors.New("invalid transition")} + app := NewMockApplication() + app.Services["order-sm"] = mock + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("mock-fail", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + "fail_on_error": false, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("expected no pipeline error, got: %v", err) + } + + ok, _ := result.Output["transition_ok"].(bool) + if ok { + t.Error("expected transition_ok=false") + } + errMsg, _ := result.Output["error"].(string) + if errMsg == "" { + t.Error("expected error in output") + } +} + +func TestStateMachineTransitionStep_MockTrigger_Error_Fail(t *testing.T) { + mock := &mockTransitionTrigger{triggerErr: errors.New("invalid transition")} + app := NewMockApplication() + app.Services["order-sm"] = mock + + factory := NewStateMachineTransitionStepFactory() + step, err := factory("mock-strict-fail", map[string]any{ + "statemachine": "order-sm", + "entity_id": "order-1", + "event": "approve", + "fail_on_error": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected pipeline error with fail_on_error=true") + } +} diff --git a/plugins/statemachine/plugin.go b/plugins/statemachine/plugin.go index d9ebf4d0..2bc303ff 100644 --- a/plugins/statemachine/plugin.go +++ b/plugins/statemachine/plugin.go @@ -37,6 +37,10 @@ func New() *Plugin { "state.tracker", "state.connector", }, + StepTypes: []string{ + "step.statemachine_transition", + "step.statemachine_get", + }, WorkflowTypes: []string{"statemachine"}, Capabilities: []plugin.CapabilityDecl{ {Name: "state-machine", Role: "provider", Priority: 10}, @@ -98,6 +102,22 @@ func (p *Plugin) WorkflowHandlers() map[string]plugin.WorkflowHandlerFactory { } } +// StepFactories returns the pipeline step factories for state machine operations. +func (p *Plugin) StepFactories() map[string]plugin.StepFactory { + return map[string]plugin.StepFactory{ + "step.statemachine_transition": wrapStepFactory(module.NewStateMachineTransitionStepFactory()), + "step.statemachine_get": wrapStepFactory(module.NewStateMachineGetStepFactory()), + } +} + +// wrapStepFactory converts a module.StepFactory to a plugin.StepFactory, +// threading the modular.Application through so steps can access the service registry. +func wrapStepFactory(f module.StepFactory) plugin.StepFactory { + return func(name string, cfg map[string]any, app modular.Application) (any, error) { + return f(name, cfg, app) + } +} + // ModuleSchemas returns UI schema definitions for state machine module types. func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { return []*schema.ModuleSchema{ diff --git a/plugins/statemachine/plugin_test.go b/plugins/statemachine/plugin_test.go index ec34d514..26d19671 100644 --- a/plugins/statemachine/plugin_test.go +++ b/plugins/statemachine/plugin_test.go @@ -24,6 +24,9 @@ func TestPluginManifest(t *testing.T) { if len(m.ModuleTypes) != 3 { t.Errorf("expected 3 module types, got %d", len(m.ModuleTypes)) } + if len(m.StepTypes) != 2 { + t.Errorf("expected 2 step types, got %d", len(m.StepTypes)) + } if len(m.WorkflowTypes) != 1 { t.Errorf("expected 1 workflow type, got %d", len(m.WorkflowTypes)) } @@ -100,6 +103,21 @@ func TestWorkflowHandlers(t *testing.T) { } } +func TestStepFactories(t *testing.T) { + p := New() + factories := p.StepFactories() + + expectedTypes := []string{"step.statemachine_transition", "step.statemachine_get"} + for _, typ := range expectedTypes { + if _, ok := factories[typ]; !ok { + t.Errorf("missing step factory for %q", typ) + } + } + if len(factories) != len(expectedTypes) { + t.Errorf("expected %d step factories, got %d", len(expectedTypes), len(factories)) + } +} + func TestModuleSchemas(t *testing.T) { p := New() schemas := p.ModuleSchemas() From e94f4338a72b467a85f88ae1f1f911357c0f1c18 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Wed, 25 Feb 2026 01:08:56 -0500 Subject: [PATCH 3/5] feat: add cache.redis module and step.cache_get/set/delete pipeline steps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Redis-backed caching for the workflow engine: - module/cache_redis.go: CacheModule interface + RedisCache module (cache.redis type) with Get/Set/Delete ops, key prefixing, default TTL, and RedisClient interface for testability - module/pipeline_step_cache_get.go: step.cache_get — reads from cache with template key, configurable output field, miss_ok flag - module/pipeline_step_cache_set.go: step.cache_set — writes to cache with template key+value, optional TTL override - module/pipeline_step_cache_delete.go: step.cache_delete — removes a key from cache - Full test coverage using miniredis (already a project dependency) for the Redis module, and a mockCacheModule for the pipeline step tests - plugins/storage/plugin.go: registers cache.redis module factory and schema (7 module types) - plugins/pipelinesteps/plugin.go: registers step.cache_get/set/delete factories (21 step types) All existing tests continue to pass. Co-Authored-By: Claude Sonnet 4.6 --- module/cache_redis.go | 161 +++++++++++++++ module/cache_redis_test.go | 173 ++++++++++++++++ module/pipeline_step_cache_delete.go | 78 +++++++ module/pipeline_step_cache_delete_test.go | 110 ++++++++++ module/pipeline_step_cache_get.go | 107 ++++++++++ module/pipeline_step_cache_get_test.go | 236 ++++++++++++++++++++++ module/pipeline_step_cache_set.go | 102 ++++++++++ module/pipeline_step_cache_set_test.go | 145 +++++++++++++ plugins/pipelinesteps/plugin.go | 6 + plugins/pipelinesteps/plugin_test.go | 7 +- plugins/storage/plugin.go | 61 +++++- plugins/storage/plugin_test.go | 44 +++- 12 files changed, 1217 insertions(+), 13 deletions(-) create mode 100644 module/cache_redis.go create mode 100644 module/cache_redis_test.go create mode 100644 module/pipeline_step_cache_delete.go create mode 100644 module/pipeline_step_cache_delete_test.go create mode 100644 module/pipeline_step_cache_get.go create mode 100644 module/pipeline_step_cache_get_test.go create mode 100644 module/pipeline_step_cache_set.go create mode 100644 module/pipeline_step_cache_set_test.go diff --git a/module/cache_redis.go b/module/cache_redis.go new file mode 100644 index 00000000..3571b485 --- /dev/null +++ b/module/cache_redis.go @@ -0,0 +1,161 @@ +package module + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/CrisisTextLine/modular" + "github.com/redis/go-redis/v9" +) + +// CacheModule defines the interface for cache operations used by pipeline steps. +type CacheModule interface { + Get(ctx context.Context, key string) (string, error) + Set(ctx context.Context, key, value string, ttl time.Duration) error + Delete(ctx context.Context, key string) error +} + +// RedisClient is the subset of go-redis client methods used by RedisCache. +// Keeping it as an interface enables mocking in tests. +type RedisClient interface { + Ping(ctx context.Context) *redis.StatusCmd + Get(ctx context.Context, key string) *redis.StringCmd + Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd + Del(ctx context.Context, keys ...string) *redis.IntCmd + Close() error +} + +// RedisCacheConfig holds configuration for the cache.redis module. +type RedisCacheConfig struct { + Address string + Password string + DB int + Prefix string + DefaultTTL time.Duration +} + +// RedisCache is a module that connects to a Redis instance and exposes +// Get/Set/Delete operations for use by pipeline steps. +type RedisCache struct { + name string + cfg RedisCacheConfig + client RedisClient + logger modular.Logger +} + +// NewRedisCache creates a new RedisCache module with the given name and config. +func NewRedisCache(name string, cfg RedisCacheConfig) *RedisCache { + return &RedisCache{ + name: name, + cfg: cfg, + logger: &noopLogger{}, + } +} + +// NewRedisCacheWithClient creates a RedisCache backed by a pre-built client. +// This is intended for testing only. +func NewRedisCacheWithClient(name string, cfg RedisCacheConfig, client RedisClient) *RedisCache { + return &RedisCache{ + name: name, + cfg: cfg, + client: client, + logger: &noopLogger{}, + } +} + +func (r *RedisCache) Name() string { return r.name } + +func (r *RedisCache) Init(app modular.Application) error { + r.logger = app.Logger() + return nil +} + +// Start connects to Redis and verifies the connection with PING. +func (r *RedisCache) Start(ctx context.Context) error { + if r.client != nil { + // Already set (e.g. in tests) + return nil + } + + opts := &redis.Options{ + Addr: r.cfg.Address, + DB: r.cfg.DB, + } + if r.cfg.Password != "" { + opts.Password = r.cfg.Password + } + + r.client = redis.NewClient(opts) + + if err := r.client.Ping(ctx).Err(); err != nil { + _ = r.client.Close() + r.client = nil + return fmt.Errorf("cache.redis %q: ping failed: %w", r.name, err) + } + + r.logger.Info("Redis cache started", "name", r.name, "address", r.cfg.Address) + return nil +} + +// Stop closes the Redis connection. +func (r *RedisCache) Stop(_ context.Context) error { + if r.client != nil { + r.logger.Info("Redis cache stopped", "name", r.name) + return r.client.Close() + } + return nil +} + +// Get retrieves a value from Redis by key (with prefix applied). +// Returns redis.Nil wrapped in an error when the key does not exist. +func (r *RedisCache) Get(ctx context.Context, key string) (string, error) { + if r.client == nil { + return "", fmt.Errorf("cache.redis %q: not started", r.name) + } + val, err := r.client.Get(ctx, r.prefixed(key)).Result() + if err != nil { + return "", err + } + return val, nil +} + +// Set stores a value in Redis with optional TTL. A zero duration uses the +// module-level default; if the default is also zero the key never expires. +func (r *RedisCache) Set(ctx context.Context, key, value string, ttl time.Duration) error { + if r.client == nil { + return fmt.Errorf("cache.redis %q: not started", r.name) + } + if ttl == 0 { + ttl = r.cfg.DefaultTTL + } + return r.client.Set(ctx, r.prefixed(key), value, ttl).Err() +} + +// Delete removes a key from Redis (with prefix applied). +func (r *RedisCache) Delete(ctx context.Context, key string) error { + if r.client == nil { + return fmt.Errorf("cache.redis %q: not started", r.name) + } + return r.client.Del(ctx, r.prefixed(key)).Err() +} + +func (r *RedisCache) prefixed(key string) string { + return r.cfg.Prefix + key +} + +func (r *RedisCache) ProvidesServices() []modular.ServiceProvider { + return []modular.ServiceProvider{ + {Name: r.name, Description: "Redis cache connection", Instance: r}, + } +} + +func (r *RedisCache) RequiresServices() []modular.ServiceDependency { + return nil +} + +// ExpandEnvString resolves ${VAR} and $VAR environment variable references. +func ExpandEnvString(s string) string { + return os.ExpandEnv(s) +} diff --git a/module/cache_redis_test.go b/module/cache_redis_test.go new file mode 100644 index 00000000..87da5423 --- /dev/null +++ b/module/cache_redis_test.go @@ -0,0 +1,173 @@ +package module + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +// newTestRedisCache creates a RedisCache backed by a miniredis server. +func newTestRedisCache(t *testing.T) (*RedisCache, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { client.Close() }) + + cfg := RedisCacheConfig{ + Address: mr.Addr(), + Prefix: "test:", + DefaultTTL: time.Hour, + } + cache := NewRedisCacheWithClient("cache", cfg, client) + return cache, mr +} + +func TestRedisCacheGetSetDelete(t *testing.T) { + ctx := context.Background() + cache, _ := newTestRedisCache(t) + + // Set a value + if err := cache.Set(ctx, "mykey", "myvalue", 0); err != nil { + t.Fatalf("Set failed: %v", err) + } + + // Get it back + val, err := cache.Get(ctx, "mykey") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if val != "myvalue" { + t.Errorf("expected %q, got %q", "myvalue", val) + } + + // Delete it + if err := cache.Delete(ctx, "mykey"); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Get after delete should return redis.Nil + _, err = cache.Get(ctx, "mykey") + if err == nil { + t.Fatal("expected error after delete, got nil") + } +} + +func TestRedisCacheKeyPrefix(t *testing.T) { + ctx := context.Background() + cache, mr := newTestRedisCache(t) + + if err := cache.Set(ctx, "hello", "world", 0); err != nil { + t.Fatalf("Set failed: %v", err) + } + + // Verify prefix is stored in miniredis + keys := mr.Keys() + found := false + for _, k := range keys { + if k == "test:hello" { + found = true + break + } + } + if !found { + t.Errorf("expected key %q in redis, got keys: %v", "test:hello", keys) + } +} + +func TestRedisCacheDefaultTTL(t *testing.T) { + ctx := context.Background() + cache, mr := newTestRedisCache(t) + + // Set with TTL=0 should use DefaultTTL (1 hour) + if err := cache.Set(ctx, "ttlkey", "ttlval", 0); err != nil { + t.Fatalf("Set failed: %v", err) + } + + ttl := mr.TTL("test:ttlkey") + if ttl <= 0 { + t.Errorf("expected positive TTL, got %v", ttl) + } +} + +func TestRedisCacheExplicitTTL(t *testing.T) { + ctx := context.Background() + cache, mr := newTestRedisCache(t) + + // Set with explicit TTL=30m + if err := cache.Set(ctx, "short", "val", 30*time.Minute); err != nil { + t.Fatalf("Set failed: %v", err) + } + + ttl := mr.TTL("test:short") + // miniredis reports TTL in seconds-level precision; just verify it's set + if ttl <= 0 { + t.Errorf("expected positive TTL, got %v", ttl) + } + if ttl > time.Hour { + t.Errorf("expected TTL <= 1h, got %v", ttl) + } +} + +func TestRedisCacheMiss(t *testing.T) { + ctx := context.Background() + cache, _ := newTestRedisCache(t) + + _, err := cache.Get(ctx, "nonexistent") + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRedisCacheNotStarted(t *testing.T) { + ctx := context.Background() + cfg := RedisCacheConfig{Address: "localhost:6379", Prefix: "wf:"} + cache := NewRedisCache("cache", cfg) + + if _, err := cache.Get(ctx, "k"); err == nil { + t.Error("expected error from Get when not started") + } + if err := cache.Set(ctx, "k", "v", 0); err == nil { + t.Error("expected error from Set when not started") + } + if err := cache.Delete(ctx, "k"); err == nil { + t.Error("expected error from Delete when not started") + } +} + +func TestRedisCacheInit(t *testing.T) { + cfg := RedisCacheConfig{Address: "localhost:6379", Prefix: "wf:"} + cache := NewRedisCache("cache", cfg) + app := NewMockApplication() + + if err := cache.Init(app); err != nil { + t.Fatalf("Init failed: %v", err) + } +} + +func TestRedisCacheStop(t *testing.T) { + ctx := context.Background() + cache, _ := newTestRedisCache(t) + + if err := cache.Stop(ctx); err != nil { + t.Fatalf("Stop failed: %v", err) + } + // Stop when already nil is a no-op + cache2 := NewRedisCache("cache2", RedisCacheConfig{}) + if err := cache2.Stop(ctx); err != nil { + t.Fatalf("Stop on uninitialised cache failed: %v", err) + } +} + +func TestRedisCacheProvidesServices(t *testing.T) { + cache := NewRedisCache("mycache", RedisCacheConfig{}) + svcs := cache.ProvidesServices() + if len(svcs) != 1 { + t.Fatalf("expected 1 service, got %d", len(svcs)) + } + if svcs[0].Name != "mycache" { + t.Errorf("expected service name %q, got %q", "mycache", svcs[0].Name) + } +} diff --git a/module/pipeline_step_cache_delete.go b/module/pipeline_step_cache_delete.go new file mode 100644 index 00000000..2073929c --- /dev/null +++ b/module/pipeline_step_cache_delete.go @@ -0,0 +1,78 @@ +package module + +import ( + "context" + "fmt" + + "github.com/CrisisTextLine/modular" +) + +// CacheDeleteStep removes a key from a named CacheModule. +type CacheDeleteStep struct { + name string + cache string // service name of the CacheModule + key string // key template + app modular.Application + tmpl *TemplateEngine +} + +// NewCacheDeleteStepFactory returns a StepFactory that creates CacheDeleteStep instances. +func NewCacheDeleteStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + cache, _ := config["cache"].(string) + if cache == "" { + return nil, fmt.Errorf("cache_delete step %q: 'cache' is required", name) + } + + key, _ := config["key"].(string) + if key == "" { + return nil, fmt.Errorf("cache_delete step %q: 'key' is required", name) + } + + return &CacheDeleteStep{ + name: name, + cache: cache, + key: key, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +func (s *CacheDeleteStep) Name() string { return s.name } + +func (s *CacheDeleteStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("cache_delete step %q: no application context", s.name) + } + + cm, err := s.resolveCache() + if err != nil { + return nil, err + } + + resolvedKey, err := s.tmpl.Resolve(s.key, pc) + if err != nil { + return nil, fmt.Errorf("cache_delete step %q: failed to resolve key template: %w", s.name, err) + } + + if err := cm.Delete(ctx, resolvedKey); err != nil { + return nil, fmt.Errorf("cache_delete step %q: delete failed: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "deleted": true, + }}, nil +} + +func (s *CacheDeleteStep) resolveCache() (CacheModule, error) { + svc, ok := s.app.SvcRegistry()[s.cache] + if !ok { + return nil, fmt.Errorf("cache_delete step %q: cache service %q not found", s.name, s.cache) + } + cm, ok := svc.(CacheModule) + if !ok { + return nil, fmt.Errorf("cache_delete step %q: service %q does not implement CacheModule", s.name, s.cache) + } + return cm, nil +} diff --git a/module/pipeline_step_cache_delete_test.go b/module/pipeline_step_cache_delete_test.go new file mode 100644 index 00000000..59a39223 --- /dev/null +++ b/module/pipeline_step_cache_delete_test.go @@ -0,0 +1,110 @@ +package module + +import ( + "context" + "errors" + "testing" +) + +func TestCacheDeleteStep_Basic(t *testing.T) { + cm := newMockCacheModule() + cm.data["user:42"] = "cached" + app := mockAppWithCache("cache", cm) + + factory := NewCacheDeleteStepFactory() + step, err := factory("del-user", map[string]any{ + "cache": "cache", + "key": "user:{{.user_id}}", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"user_id": "42"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["deleted"] != true { + t.Errorf("expected deleted=true, got %v", result.Output["deleted"]) + } + if _, exists := cm.data["user:42"]; exists { + t.Error("expected key to be removed from mock cache") + } +} + +func TestCacheDeleteStep_DeleteError(t *testing.T) { + cm := newMockCacheModule() + cm.deleteErr = errors.New("delete failed") + app := mockAppWithCache("cache", cm) + + factory := NewCacheDeleteStepFactory() + step, err := factory("del-err", map[string]any{ + "cache": "cache", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error from underlying Delete") + } +} + +func TestCacheDeleteStep_MissingCache(t *testing.T) { + factory := NewCacheDeleteStepFactory() + _, err := factory("bad", map[string]any{"key": "k"}, nil) + if err == nil { + t.Fatal("expected error for missing cache") + } +} + +func TestCacheDeleteStep_MissingKey(t *testing.T) { + factory := NewCacheDeleteStepFactory() + _, err := factory("bad", map[string]any{"cache": "c"}, nil) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestCacheDeleteStep_ServiceNotFound(t *testing.T) { + app := NewMockApplication() + factory := NewCacheDeleteStepFactory() + step, err := factory("del-missing-svc", map[string]any{ + "cache": "nonexistent", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing service") + } +} + +func TestCacheDeleteStep_ServiceWrongType(t *testing.T) { + app := NewMockApplication() + app.Services["cache"] = "not-a-cache" + + factory := NewCacheDeleteStepFactory() + step, err := factory("del-wrong-type", map[string]any{ + "cache": "cache", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for wrong service type") + } +} diff --git a/module/pipeline_step_cache_get.go b/module/pipeline_step_cache_get.go new file mode 100644 index 00000000..ebd1e67b --- /dev/null +++ b/module/pipeline_step_cache_get.go @@ -0,0 +1,107 @@ +package module + +import ( + "context" + "errors" + "fmt" + + "github.com/CrisisTextLine/modular" + "github.com/redis/go-redis/v9" +) + +// CacheGetStep reads a value from a named CacheModule and stores it in the +// pipeline context under a configurable output field. +type CacheGetStep struct { + name string + cache string // service name of the CacheModule + key string // key template, e.g. "user:{{.user_id}}" + output string // output field name (default: "value") + missOK bool // when true a cache miss is not an error + app modular.Application + tmpl *TemplateEngine +} + +// NewCacheGetStepFactory returns a StepFactory that creates CacheGetStep instances. +func NewCacheGetStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + cache, _ := config["cache"].(string) + if cache == "" { + return nil, fmt.Errorf("cache_get step %q: 'cache' is required", name) + } + + key, _ := config["key"].(string) + if key == "" { + return nil, fmt.Errorf("cache_get step %q: 'key' is required", name) + } + + output, _ := config["output"].(string) + if output == "" { + output = "value" + } + + missOK := true + if v, ok := config["miss_ok"].(bool); ok { + missOK = v + } + + return &CacheGetStep{ + name: name, + cache: cache, + key: key, + output: output, + missOK: missOK, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +func (s *CacheGetStep) Name() string { return s.name } + +func (s *CacheGetStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("cache_get step %q: no application context", s.name) + } + + cm, err := s.resolveCache() + if err != nil { + return nil, err + } + + resolvedKey, err := s.tmpl.Resolve(s.key, pc) + if err != nil { + return nil, fmt.Errorf("cache_get step %q: failed to resolve key template: %w", s.name, err) + } + + val, err := cm.Get(ctx, resolvedKey) + if err != nil { + if errors.Is(err, redis.Nil) { + // Cache miss + if !s.missOK { + return nil, fmt.Errorf("cache_get step %q: cache miss for key %q", s.name, resolvedKey) + } + return &StepResult{Output: map[string]any{ + s.output: "", + "cache_hit": false, + }}, nil + } + return nil, fmt.Errorf("cache_get step %q: get failed: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + s.output: val, + "cache_hit": true, + }}, nil +} + +func (s *CacheGetStep) resolveCache() (CacheModule, error) { + svc, ok := s.app.SvcRegistry()[s.cache] + if !ok { + return nil, fmt.Errorf("cache_get step %q: cache service %q not found", s.name, s.cache) + } + cm, ok := svc.(CacheModule) + if !ok { + return nil, fmt.Errorf("cache_get step %q: service %q does not implement CacheModule", s.name, s.cache) + } + return cm, nil +} diff --git a/module/pipeline_step_cache_get_test.go b/module/pipeline_step_cache_get_test.go new file mode 100644 index 00000000..ec3fc9e9 --- /dev/null +++ b/module/pipeline_step_cache_get_test.go @@ -0,0 +1,236 @@ +package module + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +// mockCacheModule is an in-memory CacheModule for testing pipeline steps. +type mockCacheModule struct { + data map[string]string + getErr error + setErr error + deleteErr error +} + +func newMockCacheModule() *mockCacheModule { + return &mockCacheModule{data: make(map[string]string)} +} + +func (m *mockCacheModule) Get(_ context.Context, key string) (string, error) { + if m.getErr != nil { + return "", m.getErr + } + v, ok := m.data[key] + if !ok { + return "", redis.Nil + } + return v, nil +} + +func (m *mockCacheModule) Set(_ context.Context, key, value string, _ time.Duration) error { + if m.setErr != nil { + return m.setErr + } + m.data[key] = value + return nil +} + +func (m *mockCacheModule) Delete(_ context.Context, key string) error { + if m.deleteErr != nil { + return m.deleteErr + } + delete(m.data, key) + return nil +} + +// mockAppWithCache creates a MockApplication with a CacheModule service registered. +func mockAppWithCache(name string, cm CacheModule) *MockApplication { + app := NewMockApplication() + app.Services[name] = cm + return app +} + +// ---- tests ---- + +func TestCacheGetStep_Hit(t *testing.T) { + cm := newMockCacheModule() + cm.data["user:42"] = `{"id":42}` + app := mockAppWithCache("cache", cm) + + factory := NewCacheGetStepFactory() + step, err := factory("get-user", map[string]any{ + "cache": "cache", + "key": "user:{{.user_id}}", + "output": "user_data", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{"user_id": "42"}, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["user_data"] != `{"id":42}` { + t.Errorf("expected user_data=%q, got %v", `{"id":42}`, result.Output["user_data"]) + } + if result.Output["cache_hit"] != true { + t.Errorf("expected cache_hit=true, got %v", result.Output["cache_hit"]) + } +} + +func TestCacheGetStep_MissOK(t *testing.T) { + cm := newMockCacheModule() + app := mockAppWithCache("cache", cm) + + factory := NewCacheGetStepFactory() + step, err := factory("get-user", map[string]any{ + "cache": "cache", + "key": "user:99", + "miss_ok": true, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["value"] != "" { + t.Errorf("expected empty value on miss, got %v", result.Output["value"]) + } + if result.Output["cache_hit"] != false { + t.Errorf("expected cache_hit=false on miss, got %v", result.Output["cache_hit"]) + } +} + +func TestCacheGetStep_MissNotOK(t *testing.T) { + cm := newMockCacheModule() + app := mockAppWithCache("cache", cm) + + factory := NewCacheGetStepFactory() + step, err := factory("get-user", map[string]any{ + "cache": "cache", + "key": "user:99", + "miss_ok": false, + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error on cache miss with miss_ok=false") + } +} + +func TestCacheGetStep_DefaultOutput(t *testing.T) { + cm := newMockCacheModule() + cm.data["thekey"] = "thevalue" + app := mockAppWithCache("cache", cm) + + factory := NewCacheGetStepFactory() + step, err := factory("get-val", map[string]any{ + "cache": "cache", + "key": "thekey", + // output not set → default "value" + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + if result.Output["value"] != "thevalue" { + t.Errorf("expected output[value]=%q, got %v", "thevalue", result.Output["value"]) + } +} + +func TestCacheGetStep_GetError(t *testing.T) { + cm := newMockCacheModule() + cm.getErr = errors.New("connection refused") + app := mockAppWithCache("cache", cm) + + factory := NewCacheGetStepFactory() + step, err := factory("get-err", map[string]any{ + "cache": "cache", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error from underlying Get") + } +} + +func TestCacheGetStep_MissingCache(t *testing.T) { + factory := NewCacheGetStepFactory() + _, err := factory("bad", map[string]any{"key": "k"}, nil) + if err == nil { + t.Fatal("expected error for missing cache") + } +} + +func TestCacheGetStep_MissingKey(t *testing.T) { + factory := NewCacheGetStepFactory() + _, err := factory("bad", map[string]any{"cache": "c"}, nil) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestCacheGetStep_ServiceNotFound(t *testing.T) { + app := NewMockApplication() + factory := NewCacheGetStepFactory() + step, err := factory("get-missing-svc", map[string]any{ + "cache": "nonexistent", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing service") + } +} + +func TestCacheGetStep_ServiceWrongType(t *testing.T) { + app := NewMockApplication() + app.Services["cache"] = "not-a-cache-module" + + factory := NewCacheGetStepFactory() + step, err := factory("get-wrong-type", map[string]any{ + "cache": "cache", + "key": "k", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for wrong service type") + } +} diff --git a/module/pipeline_step_cache_set.go b/module/pipeline_step_cache_set.go new file mode 100644 index 00000000..bb5dc74a --- /dev/null +++ b/module/pipeline_step_cache_set.go @@ -0,0 +1,102 @@ +package module + +import ( + "context" + "fmt" + "time" + + "github.com/CrisisTextLine/modular" +) + +// CacheSetStep writes a value to a named CacheModule. +type CacheSetStep struct { + name string + cache string // service name of the CacheModule + key string // key template + value string // value template + ttl time.Duration // 0 means use the module default + app modular.Application + tmpl *TemplateEngine +} + +// NewCacheSetStepFactory returns a StepFactory that creates CacheSetStep instances. +func NewCacheSetStepFactory() StepFactory { + return func(name string, config map[string]any, app modular.Application) (PipelineStep, error) { + cache, _ := config["cache"].(string) + if cache == "" { + return nil, fmt.Errorf("cache_set step %q: 'cache' is required", name) + } + + key, _ := config["key"].(string) + if key == "" { + return nil, fmt.Errorf("cache_set step %q: 'key' is required", name) + } + + value, _ := config["value"].(string) + if value == "" { + return nil, fmt.Errorf("cache_set step %q: 'value' is required", name) + } + + var ttl time.Duration + if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" { + parsed, err := time.ParseDuration(ttlStr) + if err != nil { + return nil, fmt.Errorf("cache_set step %q: invalid 'ttl' %q: %w", name, ttlStr, err) + } + ttl = parsed + } + + return &CacheSetStep{ + name: name, + cache: cache, + key: key, + value: value, + ttl: ttl, + app: app, + tmpl: NewTemplateEngine(), + }, nil + } +} + +func (s *CacheSetStep) Name() string { return s.name } + +func (s *CacheSetStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + if s.app == nil { + return nil, fmt.Errorf("cache_set step %q: no application context", s.name) + } + + cm, err := s.resolveCache() + if err != nil { + return nil, err + } + + resolvedKey, err := s.tmpl.Resolve(s.key, pc) + if err != nil { + return nil, fmt.Errorf("cache_set step %q: failed to resolve key template: %w", s.name, err) + } + + resolvedValue, err := s.tmpl.Resolve(s.value, pc) + if err != nil { + return nil, fmt.Errorf("cache_set step %q: failed to resolve value template: %w", s.name, err) + } + + if err := cm.Set(ctx, resolvedKey, resolvedValue, s.ttl); err != nil { + return nil, fmt.Errorf("cache_set step %q: set failed: %w", s.name, err) + } + + return &StepResult{Output: map[string]any{ + "cached": true, + }}, nil +} + +func (s *CacheSetStep) resolveCache() (CacheModule, error) { + svc, ok := s.app.SvcRegistry()[s.cache] + if !ok { + return nil, fmt.Errorf("cache_set step %q: cache service %q not found", s.name, s.cache) + } + cm, ok := svc.(CacheModule) + if !ok { + return nil, fmt.Errorf("cache_set step %q: service %q does not implement CacheModule", s.name, s.cache) + } + return cm, nil +} diff --git a/module/pipeline_step_cache_set_test.go b/module/pipeline_step_cache_set_test.go new file mode 100644 index 00000000..29f0d71d --- /dev/null +++ b/module/pipeline_step_cache_set_test.go @@ -0,0 +1,145 @@ +package module + +import ( + "context" + "errors" + "testing" +) + +func TestCacheSetStep_Basic(t *testing.T) { + cm := newMockCacheModule() + app := mockAppWithCache("cache", cm) + + factory := NewCacheSetStepFactory() + step, err := factory("set-user", map[string]any{ + "cache": "cache", + "key": "user:{{.user_id}}", + "value": "{{.profile}}", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(map[string]any{ + "user_id": "42", + "profile": `{"name":"Alice"}`, + }, nil) + + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["cached"] != true { + t.Errorf("expected cached=true, got %v", result.Output["cached"]) + } + if cm.data["user:42"] != `{"name":"Alice"}` { + t.Errorf("expected stored value, got %v", cm.data["user:42"]) + } +} + +func TestCacheSetStep_WithTTL(t *testing.T) { + cm := newMockCacheModule() + app := mockAppWithCache("cache", cm) + + factory := NewCacheSetStepFactory() + step, err := factory("set-ttl", map[string]any{ + "cache": "cache", + "key": "k", + "value": "v", + "ttl": "30m", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + if result.Output["cached"] != true { + t.Errorf("expected cached=true") + } + if cm.data["k"] != "v" { + t.Errorf("expected stored value %q, got %q", "v", cm.data["k"]) + } +} + +func TestCacheSetStep_InvalidTTL(t *testing.T) { + factory := NewCacheSetStepFactory() + _, err := factory("bad-ttl", map[string]any{ + "cache": "cache", + "key": "k", + "value": "v", + "ttl": "notaduration", + }, nil) + if err == nil { + t.Fatal("expected error for invalid TTL") + } +} + +func TestCacheSetStep_SetError(t *testing.T) { + cm := newMockCacheModule() + cm.setErr = errors.New("redis unavailable") + app := mockAppWithCache("cache", cm) + + factory := NewCacheSetStepFactory() + step, err := factory("set-err", map[string]any{ + "cache": "cache", + "key": "k", + "value": "v", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error from underlying Set") + } +} + +func TestCacheSetStep_MissingCache(t *testing.T) { + factory := NewCacheSetStepFactory() + _, err := factory("bad", map[string]any{"key": "k", "value": "v"}, nil) + if err == nil { + t.Fatal("expected error for missing cache") + } +} + +func TestCacheSetStep_MissingKey(t *testing.T) { + factory := NewCacheSetStepFactory() + _, err := factory("bad", map[string]any{"cache": "c", "value": "v"}, nil) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestCacheSetStep_MissingValue(t *testing.T) { + factory := NewCacheSetStepFactory() + _, err := factory("bad", map[string]any{"cache": "c", "key": "k"}, nil) + if err == nil { + t.Fatal("expected error for missing value") + } +} + +func TestCacheSetStep_ServiceNotFound(t *testing.T) { + app := NewMockApplication() + factory := NewCacheSetStepFactory() + step, err := factory("set-missing-svc", map[string]any{ + "cache": "nonexistent", + "key": "k", + "value": "v", + }, app) + if err != nil { + t.Fatalf("factory error: %v", err) + } + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for missing service") + } +} diff --git a/plugins/pipelinesteps/plugin.go b/plugins/pipelinesteps/plugin.go index c02f74f9..72213659 100644 --- a/plugins/pipelinesteps/plugin.go +++ b/plugins/pipelinesteps/plugin.go @@ -69,6 +69,9 @@ func New() *Plugin { "step.validate_request_body", "step.foreach", "step.webhook_verify", + "step.cache_get", + "step.cache_set", + "step.cache_delete", }, WorkflowTypes: []string{"pipeline"}, Capabilities: []plugin.CapabilityDecl{ @@ -114,6 +117,9 @@ func (p *Plugin) StepFactories() map[string]plugin.StepFactory { return p.concreteStepRegistry }, nil)), "step.webhook_verify": wrapStepFactory(module.NewWebhookVerifyStepFactory()), + "step.cache_get": wrapStepFactory(module.NewCacheGetStepFactory()), + "step.cache_set": wrapStepFactory(module.NewCacheSetStepFactory()), + "step.cache_delete": wrapStepFactory(module.NewCacheDeleteStepFactory()), } } diff --git a/plugins/pipelinesteps/plugin_test.go b/plugins/pipelinesteps/plugin_test.go index 8314fc0e..1c540050 100644 --- a/plugins/pipelinesteps/plugin_test.go +++ b/plugins/pipelinesteps/plugin_test.go @@ -49,6 +49,9 @@ func TestStepFactories(t *testing.T) { "step.validate_request_body", "step.foreach", "step.webhook_verify", + "step.cache_get", + "step.cache_set", + "step.cache_delete", } for _, stepType := range expectedSteps { @@ -70,7 +73,7 @@ func TestPluginLoads(t *testing.T) { } steps := loader.StepFactories() - if len(steps) != 18 { - t.Fatalf("expected 18 step factories after load, got %d", len(steps)) + if len(steps) != 21 { + t.Fatalf("expected 21 step factories after load, got %d", len(steps)) } } diff --git a/plugins/storage/plugin.go b/plugins/storage/plugin.go index 601c9aaf..c6b59836 100644 --- a/plugins/storage/plugin.go +++ b/plugins/storage/plugin.go @@ -1,6 +1,8 @@ package storage import ( + "time" + "github.com/CrisisTextLine/modular" "github.com/GoCodeAlone/workflow/capability" "github.com/GoCodeAlone/workflow/config" @@ -10,8 +12,8 @@ import ( ) // Plugin provides storage and database capabilities: storage.s3, storage.local, -// storage.gcs, storage.sqlite, database.workflow, persistence.store modules, -// and the step.db_query / step.db_exec pipeline step factories. +// storage.gcs, storage.sqlite, database.workflow, persistence.store, cache.redis +// modules, and the step.db_query / step.db_exec pipeline step factories. type Plugin struct { plugin.BaseEnginePlugin } @@ -23,13 +25,13 @@ func New() *Plugin { BaseNativePlugin: plugin.BaseNativePlugin{ PluginName: "storage", PluginVersion: "1.0.0", - PluginDescription: "Storage, database, and persistence modules with DB pipeline steps", + PluginDescription: "Storage, database, persistence, and cache modules with DB pipeline steps", }, Manifest: plugin.PluginManifest{ Name: "storage", Version: "1.0.0", Author: "GoCodeAlone", - Description: "Storage, database, and persistence modules with DB pipeline steps", + Description: "Storage, database, persistence, and cache modules with DB pipeline steps", Tier: plugin.TierCore, ModuleTypes: []string{ "storage.s3", @@ -38,11 +40,13 @@ func New() *Plugin { "storage.sqlite", "database.workflow", "persistence.store", + "cache.redis", }, Capabilities: []plugin.CapabilityDecl{ {Name: "storage", Role: "provider", Priority: 10}, {Name: "database", Role: "provider", Priority: 10}, {Name: "persistence", Role: "provider", Priority: 10}, + {Name: "cache", Role: "provider", Priority: 10}, }, }, }, @@ -64,6 +68,10 @@ func (p *Plugin) Capabilities() []capability.Contract { Name: "persistence", Description: "Persistence layer that uses a database service for storage", }, + { + Name: "cache", + Description: "Redis-backed key/value cache for pipeline data", + }, } } @@ -141,6 +149,31 @@ func (p *Plugin) ModuleFactories() map[string]plugin.ModuleFactory { } return module.NewPersistenceStore(name, dbServiceName) }, + "cache.redis": func(name string, cfg map[string]any) modular.Module { + redisCfg := module.RedisCacheConfig{ + Address: "localhost:6379", + Prefix: "wf:", + DefaultTTL: time.Hour, + } + if addr, ok := cfg["address"].(string); ok && addr != "" { + redisCfg.Address = module.ExpandEnvString(addr) + } + if pw, ok := cfg["password"].(string); ok { + redisCfg.Password = module.ExpandEnvString(pw) + } + if db, ok := cfg["db"].(float64); ok { + redisCfg.DB = int(db) + } + if prefix, ok := cfg["prefix"].(string); ok && prefix != "" { + redisCfg.Prefix = prefix + } + if ttlStr, ok := cfg["defaultTTL"].(string); ok && ttlStr != "" { + if d, err := time.ParseDuration(ttlStr); err == nil { + redisCfg.DefaultTTL = d + } + } + return module.NewRedisCache(name, redisCfg) + }, } } @@ -226,5 +259,25 @@ func (p *Plugin) ModuleSchemas() []*schema.ModuleSchema { }, DefaultConfig: map[string]any{"database": "database"}, }, + { + Type: "cache.redis", + Label: "Redis Cache", + Category: "cache", + Description: "Redis-backed key/value cache for pipeline step data", + Outputs: []schema.ServiceIODef{{Name: "cache", Type: "CacheModule", Description: "Redis cache service"}}, + ConfigFields: []schema.ConfigFieldDef{ + {Key: "address", Label: "Address", Type: schema.FieldTypeString, DefaultValue: "localhost:6379", Description: "Redis server address (host:port)", Placeholder: "localhost:6379"}, + {Key: "password", Label: "Password", Type: schema.FieldTypeString, Description: "Redis password (optional)", Sensitive: true}, + {Key: "db", Label: "Database", Type: schema.FieldTypeNumber, DefaultValue: 0, Description: "Redis database number"}, + {Key: "prefix", Label: "Key Prefix", Type: schema.FieldTypeString, DefaultValue: "wf:", Description: "Prefix applied to all cache keys"}, + {Key: "defaultTTL", Label: "Default TTL", Type: schema.FieldTypeString, DefaultValue: "1h", Description: "Default time-to-live for cached values (e.g. 30m, 1h, 24h)"}, + }, + DefaultConfig: map[string]any{ + "address": "localhost:6379", + "db": 0, + "prefix": "wf:", + "defaultTTL": "1h", + }, + }, } } diff --git a/plugins/storage/plugin_test.go b/plugins/storage/plugin_test.go index 34e9031e..627a0568 100644 --- a/plugins/storage/plugin_test.go +++ b/plugins/storage/plugin_test.go @@ -21,8 +21,8 @@ func TestPluginManifest(t *testing.T) { if m.Name != "storage" { t.Errorf("expected name %q, got %q", "storage", m.Name) } - if len(m.ModuleTypes) != 6 { - t.Errorf("expected 6 module types, got %d", len(m.ModuleTypes)) + if len(m.ModuleTypes) != 7 { + t.Errorf("expected 7 module types, got %d", len(m.ModuleTypes)) } if len(m.StepTypes) != 0 { t.Errorf("expected 0 step types, got %d", len(m.StepTypes)) @@ -32,14 +32,14 @@ func TestPluginManifest(t *testing.T) { func TestPluginCapabilities(t *testing.T) { p := New() caps := p.Capabilities() - if len(caps) != 3 { - t.Fatalf("expected 3 capabilities, got %d", len(caps)) + if len(caps) != 4 { + t.Fatalf("expected 4 capabilities, got %d", len(caps)) } names := map[string]bool{} for _, c := range caps { names[c.Name] = true } - for _, expected := range []string{"storage", "database", "persistence"} { + for _, expected := range []string{"storage", "database", "persistence", "cache"} { if !names[expected] { t.Errorf("missing capability %q", expected) } @@ -53,6 +53,7 @@ func TestModuleFactories(t *testing.T) { expectedTypes := []string{ "storage.s3", "storage.local", "storage.gcs", "storage.sqlite", "database.workflow", "persistence.store", + "cache.redis", } for _, typ := range expectedTypes { factory, ok := factories[typ] @@ -133,8 +134,8 @@ func TestStepFactories(t *testing.T) { func TestModuleSchemas(t *testing.T) { p := New() schemas := p.ModuleSchemas() - if len(schemas) != 6 { - t.Fatalf("expected 6 module schemas, got %d", len(schemas)) + if len(schemas) != 7 { + t.Fatalf("expected 7 module schemas, got %d", len(schemas)) } types := map[string]bool{} @@ -144,6 +145,7 @@ func TestModuleSchemas(t *testing.T) { expectedTypes := []string{ "storage.s3", "storage.local", "storage.gcs", "storage.sqlite", "database.workflow", "persistence.store", + "cache.redis", } for _, expected := range expectedTypes { if !types[expected] { @@ -151,3 +153,31 @@ func TestModuleSchemas(t *testing.T) { } } } + +func TestCacheRedisFactory(t *testing.T) { + p := New() + factories := p.ModuleFactories() + + factory, ok := factories["cache.redis"] + if !ok { + t.Fatal("missing factory for cache.redis") + } + + // Default config + mod := factory("cache", map[string]any{}) + if mod == nil { + t.Fatal("cache.redis factory returned nil with empty config") + } + + // Full config + mod = factory("cache", map[string]any{ + "address": "redis:6379", + "password": "secret", + "db": float64(1), + "prefix": "myapp:", + "defaultTTL": "30m", + }) + if mod == nil { + t.Fatal("cache.redis factory returned nil with full config") + } +} From e91abec58e88f9a37e3fc65f4fd0170f2fe0d3f4 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Wed, 25 Feb 2026 01:45:20 -0500 Subject: [PATCH 4/5] feat(wfctl): add wfctl api extract command for offline OpenAPI 3.0 spec generation Adds a new `wfctl api extract` subcommand that parses a workflow YAML config file offline (no server startup) and emits a valid OpenAPI 3.0.3 specification covering all HTTP endpoints from two sources: - workflow routes (workflows..routes) - pipeline HTTP triggers (pipelines..trigger.type == "http") Schema inference from step types is supported: step.validate rules map to request body properties, step.user_register/user_login produce typed request/response schemas, step.json_response infers response status and body, and step.auth_required hints 401/403 responses. Output format is JSON (default) or YAML, configurable via -format. Output goes to stdout or a file via -output. Supports -title, -version, and repeatable -server flags. Registers the command as "api" in main.go dispatch. 20 tests added covering all flags, both endpoint sources, schema inference, error paths, and stdout output. Co-Authored-By: Claude Opus 4.6 --- cmd/wfctl/api_extract.go | 530 +++++++++++++++++++++++++++++++ cmd/wfctl/api_extract_test.go | 583 ++++++++++++++++++++++++++++++++++ cmd/wfctl/main.go | 4 + 3 files changed, 1117 insertions(+) create mode 100644 cmd/wfctl/api_extract.go create mode 100644 cmd/wfctl/api_extract_test.go diff --git a/cmd/wfctl/api_extract.go b/cmd/wfctl/api_extract.go new file mode 100644 index 00000000..51dd207c --- /dev/null +++ b/cmd/wfctl/api_extract.go @@ -0,0 +1,530 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "os" + "regexp" + "sort" + "strings" + + "github.com/GoCodeAlone/workflow/config" + "github.com/GoCodeAlone/workflow/module" + "gopkg.in/yaml.v3" +) + +// serverFlag is a flag.Value that accumulates multiple -server flags. +type serverFlag []string + +func (s *serverFlag) String() string { + return strings.Join(*s, ", ") +} + +func (s *serverFlag) Set(v string) error { + *s = append(*s, v) + return nil +} + +func runAPI(args []string) error { + if len(args) < 1 { + return apiUsage() + } + switch args[0] { + case "extract": + return runAPIExtract(args[1:]) + default: + return apiUsage() + } +} + +func apiUsage() error { + fmt.Fprintf(flag.CommandLine.Output(), `Usage: wfctl api [options] + +Subcommands: + extract Extract OpenAPI 3.0 spec from a workflow config file (offline) +`) + return fmt.Errorf("api subcommand is required") +} + +// runAPIExtract parses a workflow YAML config file offline and outputs an +// OpenAPI 3.0 specification of all HTTP endpoints defined in the config. +func runAPIExtract(args []string) error { + fs := flag.NewFlagSet("api extract", flag.ContinueOnError) + format := fs.String("format", "json", "Output format: json or yaml") + title := fs.String("title", "", "API title (default: extracted from config or \"Workflow API\")") + version := fs.String("version", "1.0.0", "API version") + var servers serverFlag + fs.Var(&servers, "server", "Server URL to include (repeatable)") + output := fs.String("output", "", "Write to file instead of stdout") + includeSchemas := fs.Bool("include-schemas", true, "Attempt to infer request/response schemas from step types") + fs.Usage = func() { + fmt.Fprintf(fs.Output(), `Usage: wfctl api extract [options] + +Parse a workflow config file offline and output an OpenAPI 3.0 specification +of all HTTP endpoints defined in the config. + +Examples: + wfctl api extract config.yaml + wfctl api extract -format yaml -output openapi.yaml config.yaml + wfctl api extract -title "My API" -version "2.0.0" config.yaml + wfctl api extract -server https://api.example.com config.yaml + +Options: +`) + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + if fs.NArg() < 1 { + fs.Usage() + return fmt.Errorf("config file path is required") + } + + configPath := fs.Arg(0) + cfg, err := config.LoadFromFile(configPath) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + // Determine title: flag > config-derived > default + apiTitle := *title + if apiTitle == "" { + apiTitle = extractTitleFromConfig(cfg) + } + if apiTitle == "" { + apiTitle = "Workflow API" + } + + // Build generator with settings + genCfg := module.OpenAPIGeneratorConfig{ + Title: apiTitle, + Version: *version, + Servers: []string(servers), + } + gen := module.NewOpenAPIGenerator("api-extract", genCfg) + + // Build spec from workflow routes + gen.BuildSpec(cfg.Workflows) + + // Extract pipeline HTTP endpoints and add them to the spec + if len(cfg.Pipelines) > 0 { + pipelineRoutes := extractPipelineRoutes(cfg.Pipelines, *includeSchemas, gen) + if len(pipelineRoutes) > 0 { + gen.BuildSpecFromRoutes(appendToExistingSpec(gen, pipelineRoutes)) + } + } + + if *includeSchemas { + gen.ApplySchemas() + } + + spec := gen.GetSpec() + if spec == nil { + spec = &module.OpenAPISpec{ + OpenAPI: "3.0.3", + Info: module.OpenAPIInfo{ + Title: apiTitle, + Version: *version, + }, + Paths: make(map[string]*module.OpenAPIPath), + } + } + + // Determine output writer + var w *os.File + if *output != "" { + f, err := os.Create(*output) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer f.Close() + w = f + } else { + w = os.Stdout + } + + // Encode output + switch strings.ToLower(*format) { + case "yaml", "yml": + enc := yaml.NewEncoder(w) + enc.SetIndent(2) + if err := enc.Encode(spec); err != nil { + return fmt.Errorf("failed to encode spec as YAML: %w", err) + } + case "json", "": + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + if err := enc.Encode(spec); err != nil { + return fmt.Errorf("failed to encode spec as JSON: %w", err) + } + default: + return fmt.Errorf("unsupported format %q: use json or yaml", *format) + } + + if *output != "" { + fmt.Fprintf(os.Stderr, "OpenAPI spec written to %s\n", *output) + } + return nil +} + +// extractTitleFromConfig attempts to derive a meaningful API title from the config. +// It looks for module names that suggest an application name. +func extractTitleFromConfig(cfg *config.WorkflowConfig) string { + // Look for a server module with a descriptive name + for _, mod := range cfg.Modules { + if mod.Type == "http.server" && mod.Name != "server" && mod.Name != "" { + return strings.Title(strings.ReplaceAll(mod.Name, "-", " ")) //nolint:staticcheck + } + } + return "" +} + +// pipelineEndpoint describes an HTTP endpoint extracted from a pipeline definition. +type pipelineEndpoint struct { + name string + method string + path string + steps []map[string]any + includeSchemas bool +} + +// pathParamRegexExtract matches {paramName} in route paths. +var pathParamRegexExtract = regexp.MustCompile(`\{([^}]+)\}`) + +// extractPipelineRoutes scans the pipelines map for HTTP-triggered pipelines +// and returns route definitions for each one. +func extractPipelineRoutes(pipelines map[string]any, includeSchemas bool, gen *module.OpenAPIGenerator) []module.RouteDefinition { + // Collect endpoints sorted by name for stable output + names := make([]string, 0, len(pipelines)) + for name := range pipelines { + names = append(names, name) + } + sort.Strings(names) + + var routes []module.RouteDefinition + for _, name := range names { + raw := pipelines[name] + pipelineMap, ok := raw.(map[string]any) + if !ok { + continue + } + + ep := parsePipelineEndpoint(name, pipelineMap, includeSchemas) + if ep == nil { + continue + } + + route := module.RouteDefinition{ + Method: strings.ToUpper(ep.method), + Path: ep.path, + Handler: ep.name, + Tags: []string{"pipelines"}, + Summary: fmt.Sprintf("%s %s (pipeline: %s)", strings.ToUpper(ep.method), ep.path, ep.name), + } + + if includeSchemas && len(ep.steps) > 0 { + applyPipelineSchemas(gen, ep) + } + + routes = append(routes, route) + } + return routes +} + +// parsePipelineEndpoint extracts HTTP trigger details from a pipeline config map. +// Returns nil if the pipeline has no HTTP trigger. +func parsePipelineEndpoint(name string, pipelineMap map[string]any, includeSchemas bool) *pipelineEndpoint { + triggerRaw, ok := pipelineMap["trigger"] + if !ok { + return nil + } + triggerMap, ok := triggerRaw.(map[string]any) + if !ok { + return nil + } + + triggerType, _ := triggerMap["type"].(string) + if triggerType != "http" { + return nil + } + + triggerConfig, ok := triggerMap["config"].(map[string]any) + if !ok { + return nil + } + + path, _ := triggerConfig["path"].(string) + method, _ := triggerConfig["method"].(string) + if path == "" || method == "" { + return nil + } + + ep := &pipelineEndpoint{ + name: name, + method: method, + path: path, + includeSchemas: includeSchemas, + } + + // Extract steps + if stepsRaw, ok := pipelineMap["steps"].([]any); ok { + for _, stepRaw := range stepsRaw { + if stepMap, ok := stepRaw.(map[string]any); ok { + ep.steps = append(ep.steps, stepMap) + } + } + } + + return ep +} + +// applyPipelineSchemas infers request/response schemas from pipeline step types +// and registers them with the OpenAPI generator. +func applyPipelineSchemas(gen *module.OpenAPIGenerator, ep *pipelineEndpoint) { + var reqSchema *module.OpenAPISchema + var respSchema *module.OpenAPISchema + hasAuthRequired := false + var statusCode string + + for _, step := range ep.steps { + stepType, _ := step["type"].(string) + stepCfg, _ := step["config"].(map[string]any) + + switch stepType { + case "step.validate": + // Infer request body schema from validation rules + if stepCfg != nil { + if reqSchema == nil { + reqSchema = &module.OpenAPISchema{ + Type: "object", + Properties: make(map[string]*module.OpenAPISchema), + } + } + inferValidateSchema(reqSchema, stepCfg) + } + + case "step.user_register": + // Request: email + password; Response: user object + if reqSchema == nil { + reqSchema = userCredentialsSchema() + } + if respSchema == nil { + respSchema = userObjectSchema() + } + + case "step.user_login": + // Request: email + password; Response: token + if reqSchema == nil { + reqSchema = userCredentialsSchema() + } + if respSchema == nil { + respSchema = loginResponseSchema() + } + + case "step.auth_required": + hasAuthRequired = true + + case "step.json_response": + // Determine status code from config + if stepCfg != nil { + if sc, ok := stepCfg["statusCode"]; ok { + statusCode = fmt.Sprintf("%v", sc) + } else if sc, ok := stepCfg["status"]; ok { + statusCode = fmt.Sprintf("%v", sc) + } + // Infer response schema from body if present + if body, ok := stepCfg["body"]; ok { + if respSchema == nil { + respSchema = inferBodySchema(body) + } + } + } + if respSchema == nil { + respSchema = &module.OpenAPISchema{Type: "object"} + } + } + } + + // Build extra responses map for auth + if hasAuthRequired { + // We'll register these via SetOperationSchema — the generator adds 401/403 + // by detecting auth in middlewares, but for pipelines we set it directly. + // We use a trick: add "auth" to the route middleware list by ensuring + // we call SetOperationSchema with a summary that includes auth note. + // The cleanest approach is to set the schemas and rely on the generator. + } + + // Set the inferred schemas on the operation + gen.SetOperationSchema(ep.method, ep.path, reqSchema, respSchema) + + // If auth is required or we have a custom status code, register a component schema + // and set additional responses. + if hasAuthRequired || statusCode != "" { + // We handle the status code override by registering a component schema + // The ApplySchemas call will wire up the request/response schemas. + // For auth, register a 401 schema by adding it as a component. + if hasAuthRequired { + gen.RegisterComponentSchema("UnauthorizedError", &module.OpenAPISchema{ + Type: "object", + Properties: map[string]*module.OpenAPISchema{ + "error": {Type: "string", Example: "Unauthorized"}, + }, + }) + } + } +} + +// inferValidateSchema parses step.validate config rules and populates an OpenAPI schema. +// Rule format: "required,email" or "required,min=8". +func inferValidateSchema(schema *module.OpenAPISchema, stepCfg map[string]any) { + rules, ok := stepCfg["rules"].(map[string]any) + if !ok { + return + } + + for field, ruleRaw := range rules { + ruleStr, _ := ruleRaw.(string) + parts := strings.Split(ruleStr, ",") + + fieldSchema := &module.OpenAPISchema{Type: "string"} + isRequired := false + + for _, part := range parts { + part = strings.TrimSpace(part) + switch { + case part == "required": + isRequired = true + case part == "email": + fieldSchema.Format = "email" + case strings.HasPrefix(part, "min="): + // min length hint — keep as string + case part == "numeric" || part == "number": + fieldSchema.Type = "number" + case part == "boolean" || part == "bool": + fieldSchema.Type = "boolean" + } + } + + schema.Properties[field] = fieldSchema + if isRequired { + schema.Required = append(schema.Required, field) + } + } + + // Sort required for stable output + sort.Strings(schema.Required) +} + +// inferBodySchema creates a schema from a body config value. +func inferBodySchema(body any) *module.OpenAPISchema { + bodyMap, ok := body.(map[string]any) + if !ok { + return &module.OpenAPISchema{Type: "object"} + } + + schema := &module.OpenAPISchema{ + Type: "object", + Properties: make(map[string]*module.OpenAPISchema), + } + for k, v := range bodyMap { + switch v.(type) { + case int, int64, float64: + schema.Properties[k] = &module.OpenAPISchema{Type: "integer"} + case bool: + schema.Properties[k] = &module.OpenAPISchema{Type: "boolean"} + default: + schema.Properties[k] = &module.OpenAPISchema{Type: "string"} + } + } + return schema +} + +// userCredentialsSchema returns a schema for email+password request bodies. +func userCredentialsSchema() *module.OpenAPISchema { + return &module.OpenAPISchema{ + Type: "object", + Properties: map[string]*module.OpenAPISchema{ + "email": {Type: "string", Format: "email"}, + "password": {Type: "string", Format: "password"}, + }, + Required: []string{"email", "password"}, + } +} + +// userObjectSchema returns a schema for a user response object. +func userObjectSchema() *module.OpenAPISchema { + return &module.OpenAPISchema{ + Type: "object", + Properties: map[string]*module.OpenAPISchema{ + "id": {Type: "string", Format: "uuid"}, + "email": {Type: "string", Format: "email"}, + }, + } +} + +// loginResponseSchema returns a schema for a login response with a token. +func loginResponseSchema() *module.OpenAPISchema { + return &module.OpenAPISchema{ + Type: "object", + Properties: map[string]*module.OpenAPISchema{ + "token": {Type: "string", Description: "JWT access token"}, + }, + Required: []string{"token"}, + } +} + +// appendToExistingSpec builds a combined route list from the existing spec paths +// plus new pipeline routes, used when rebuilding the spec to include both. +func appendToExistingSpec(gen *module.OpenAPIGenerator, pipelineRoutes []module.RouteDefinition) []module.RouteDefinition { + spec := gen.GetSpec() + if spec == nil { + return pipelineRoutes + } + + // Collect existing routes from the spec + var existing []module.RouteDefinition + paths := gen.SortedPaths() + for _, path := range paths { + pathItem := spec.Paths[path] + if pathItem == nil { + continue + } + ops := []struct { + method string + op *module.OpenAPIOperation + }{ + {"GET", pathItem.Get}, + {"POST", pathItem.Post}, + {"PUT", pathItem.Put}, + {"DELETE", pathItem.Delete}, + {"PATCH", pathItem.Patch}, + {"OPTIONS", pathItem.Options}, + } + for _, entry := range ops { + if entry.op == nil { + continue + } + route := module.RouteDefinition{ + Method: entry.method, + Path: path, + Summary: entry.op.Summary, + Tags: entry.op.Tags, + } + if len(entry.op.Parameters) > 0 { + // Extract handler from tags if available + if len(entry.op.Tags) > 0 { + route.Handler = entry.op.Tags[0] + } + } + // Extract middlewares hints from responses + if _, hasAuth := entry.op.Responses["401"]; hasAuth { + route.Middlewares = []string{"auth"} + } + existing = append(existing, route) + } + } + + return append(existing, pipelineRoutes...) +} diff --git a/cmd/wfctl/api_extract_test.go b/cmd/wfctl/api_extract_test.go new file mode 100644 index 00000000..66cc9bb5 --- /dev/null +++ b/cmd/wfctl/api_extract_test.go @@ -0,0 +1,583 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/GoCodeAlone/workflow/module" + "gopkg.in/yaml.v3" +) + +const configWithPipelines = ` +modules: + - name: server + type: http.server + config: + address: ":8080" + - name: router + type: http.router + dependsOn: [server] + +pipelines: + create-user: + trigger: + type: http + config: + path: /api/v1/users + method: POST + steps: + - type: step.validate + config: + rules: + email: required,email + password: required,min=8 + - type: step.user_register + - type: step.json_response + config: + statusCode: 201 + + login: + trigger: + type: http + config: + path: /api/v1/login + method: POST + steps: + - type: step.user_login + - type: step.json_response + + health: + trigger: + type: http + config: + path: /healthz + method: GET + steps: + - type: step.json_response + config: + status: 200 + body: + status: ok +` + +const configWithWorkflowRoutes = ` +modules: + - name: server + type: http.server + config: + address: ":8080" + - name: router + type: http.router + dependsOn: [server] + - name: jwt + type: auth.jwt + config: + secret: "test-secret" + dependsOn: [router] + - name: auth-middleware + type: http.middleware.auth + dependsOn: [jwt] + +workflows: + http: + router: router + server: server + routes: + - method: POST + path: /api/auth/login + handler: jwt + - method: GET + path: /api/auth/profile + handler: jwt + middlewares: + - auth-middleware + - method: GET + path: /api/users/{id} + handler: jwt +` + +const configWithBothSourcesYAML = ` +modules: + - name: server + type: http.server + config: + address: ":8080" + - name: router + type: http.router + +workflows: + http: + router: router + server: server + routes: + - method: GET + path: /api/items + handler: items-handler + +pipelines: + create-item: + trigger: + type: http + config: + path: /api/items + method: POST + steps: + - type: step.validate + config: + rules: + name: required + - type: step.json_response + config: + statusCode: 201 +` + +const configNoPipelines = ` +modules: + - name: server + type: http.server + config: + address: ":8080" +` + +func TestRunAPIExtractMissingArg(t *testing.T) { + err := runAPIExtract([]string{}) + if err == nil { + t.Fatal("expected error when no config file given") + } + if !strings.Contains(err.Error(), "config file path is required") { + t.Errorf("expected 'config file path is required', got: %v", err) + } +} + +func TestRunAPIExtractMissingSubcommand(t *testing.T) { + err := runAPI([]string{}) + if err == nil { + t.Fatal("expected error when no subcommand given") + } +} + +func TestRunAPIExtractUnknownSubcommand(t *testing.T) { + err := runAPI([]string{"unknown"}) + if err == nil { + t.Fatal("expected error for unknown subcommand") + } +} + +func TestRunAPIExtractJSONOutput(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-format", "json", "-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, err := os.ReadFile(outPath) + if err != nil { + t.Fatalf("failed to read output: %v", err) + } + + var spec module.OpenAPISpec + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + if spec.OpenAPI != "3.0.3" { + t.Errorf("expected OpenAPI 3.0.3, got %q", spec.OpenAPI) + } + if spec.Info.Title == "" { + t.Error("expected non-empty title") + } + if spec.Info.Version == "" { + t.Error("expected non-empty version") + } + if len(spec.Paths) == 0 { + t.Error("expected at least one path in spec") + } +} + +func TestRunAPIExtractYAMLOutput(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + outPath := filepath.Join(dir, "openapi.yaml") + + err := runAPIExtract([]string{"-format", "yaml", "-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, err := os.ReadFile(outPath) + if err != nil { + t.Fatalf("failed to read output: %v", err) + } + + var spec module.OpenAPISpec + if err := yaml.Unmarshal(data, &spec); err != nil { + t.Fatalf("output is not valid YAML: %v", err) + } + + if spec.OpenAPI != "3.0.3" { + t.Errorf("expected OpenAPI 3.0.3, got %q", spec.OpenAPI) + } +} + +func TestRunAPIExtractPipelineEndpoints(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + // Check that pipeline HTTP endpoints are in the spec + if _, ok := spec.Paths["/api/v1/users"]; !ok { + t.Error("expected /api/v1/users in spec paths") + } + if _, ok := spec.Paths["/api/v1/login"]; !ok { + t.Error("expected /api/v1/login in spec paths") + } + if _, ok := spec.Paths["/healthz"]; !ok { + t.Error("expected /healthz in spec paths") + } +} + +func TestRunAPIExtractWorkflowRoutes(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithWorkflowRoutes) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + if _, ok := spec.Paths["/api/auth/login"]; !ok { + t.Error("expected /api/auth/login in spec paths") + } + if _, ok := spec.Paths["/api/auth/profile"]; !ok { + t.Error("expected /api/auth/profile in spec paths") + } + if _, ok := spec.Paths["/api/users/{id}"]; !ok { + t.Error("expected /api/users/{id} in spec paths") + } +} + +func TestRunAPIExtractBothSources(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithBothSourcesYAML) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + // Both workflow route and pipeline should be present + pathItem, ok := spec.Paths["/api/items"] + if !ok { + t.Fatal("expected /api/items in spec paths") + } + + // GET from workflow routes + if pathItem.Get == nil { + t.Error("expected GET /api/items from workflow routes") + } + // POST from pipeline + if pathItem.Post == nil { + t.Error("expected POST /api/items from pipeline") + } +} + +func TestRunAPIExtractCustomTitle(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configNoPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-title", "My Custom API", "-version", "2.0.0", "-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + if spec.Info.Title != "My Custom API" { + t.Errorf("expected title 'My Custom API', got %q", spec.Info.Title) + } + if spec.Info.Version != "2.0.0" { + t.Errorf("expected version '2.0.0', got %q", spec.Info.Version) + } +} + +func TestRunAPIExtractWithServers(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configNoPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{ + "-server", "https://api.example.com", + "-server", "https://staging.example.com", + "-output", outPath, + cfgPath, + }) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + if len(spec.Servers) != 2 { + t.Errorf("expected 2 servers, got %d", len(spec.Servers)) + } + if spec.Servers[0].URL != "https://api.example.com" { + t.Errorf("expected first server URL 'https://api.example.com', got %q", spec.Servers[0].URL) + } +} + +func TestRunAPIExtractStdout(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := runAPIExtract([]string{cfgPath}) + + w.Close() + os.Stdout = oldStdout + + if err != nil { + t.Fatalf("api extract to stdout failed: %v", err) + } + + var buf strings.Builder + readBuf := make([]byte, 4096) + for { + n, readErr := r.Read(readBuf) + buf.Write(readBuf[:n]) + if readErr != nil { + break + } + } + + output := buf.String() + if !strings.Contains(output, `"openapi"`) { + t.Errorf("expected JSON output with 'openapi' key, got: %s", output[:min(len(output), 200)]) + } +} + +func TestRunAPIExtractInvalidFormat(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configNoPipelines) + + err := runAPIExtract([]string{"-format", "xml", cfgPath}) + if err == nil { + t.Fatal("expected error for unsupported format") + } + if !strings.Contains(err.Error(), "unsupported format") { + t.Errorf("expected 'unsupported format' error, got: %v", err) + } +} + +func TestRunAPIExtractInvalidConfig(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", "not: valid: yaml: {{{") + + err := runAPIExtract([]string{cfgPath}) + if err == nil { + t.Fatal("expected error for invalid config") + } +} + +func TestRunAPIExtractMissingConfigFile(t *testing.T) { + err := runAPIExtract([]string{"/nonexistent/path/config.yaml"}) + if err == nil { + t.Fatal("expected error for missing config file") + } +} + +func TestRunAPIExtractSchemaInference(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-include-schemas=true", "-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + // Check that user register endpoint has request body + usersPath, ok := spec.Paths["/api/v1/users"] + if !ok { + t.Fatal("expected /api/v1/users in spec") + } + if usersPath.Post == nil { + t.Fatal("expected POST /api/v1/users") + } + if usersPath.Post.RequestBody == nil { + t.Error("expected request body for POST /api/v1/users") + } +} + +func TestRunAPIExtractNoSchemaInference(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configWithPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-include-schemas=false", "-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + if len(data) == 0 { + t.Fatal("expected non-empty output") + } + var spec module.OpenAPISpec + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } +} + +func TestRunAPIExtractDefaultTitle(t *testing.T) { + dir := t.TempDir() + cfgPath := writeTestConfig(t, dir, "config.yaml", configNoPipelines) + outPath := filepath.Join(dir, "openapi.json") + + err := runAPIExtract([]string{"-output", outPath, cfgPath}) + if err != nil { + t.Fatalf("api extract failed: %v", err) + } + + data, _ := os.ReadFile(outPath) + var spec module.OpenAPISpec + json.Unmarshal(data, &spec) //nolint:errcheck + + // Should fall back to "Workflow API" when no title is provided and none can be inferred + if spec.Info.Title == "" { + t.Error("expected non-empty default title") + } +} + +func TestInferValidateSchema(t *testing.T) { + schema := &module.OpenAPISchema{ + Type: "object", + Properties: make(map[string]*module.OpenAPISchema), + } + stepCfg := map[string]any{ + "rules": map[string]any{ + "email": "required,email", + "password": "required,min=8", + "age": "numeric", + }, + } + inferValidateSchema(schema, stepCfg) + + if _, ok := schema.Properties["email"]; !ok { + t.Error("expected email property") + } + if schema.Properties["email"].Format != "email" { + t.Errorf("expected email format, got %q", schema.Properties["email"].Format) + } + if _, ok := schema.Properties["password"]; !ok { + t.Error("expected password property") + } + if _, ok := schema.Properties["age"]; !ok { + t.Error("expected age property") + } + if schema.Properties["age"].Type != "number" { + t.Errorf("expected number type for age, got %q", schema.Properties["age"].Type) + } + + // Check required fields + requiredSet := make(map[string]bool) + for _, r := range schema.Required { + requiredSet[r] = true + } + if !requiredSet["email"] { + t.Error("expected email in required") + } + if !requiredSet["password"] { + t.Error("expected password in required") + } +} + +func TestParsePipelineEndpointNoHTTPTrigger(t *testing.T) { + pipelineMap := map[string]any{ + "trigger": map[string]any{ + "type": "schedule", + "config": map[string]any{ + "cron": "0 * * * *", + }, + }, + } + ep := parsePipelineEndpoint("my-pipeline", pipelineMap, false) + if ep != nil { + t.Error("expected nil for non-HTTP trigger") + } +} + +func TestParsePipelineEndpointHTTPTrigger(t *testing.T) { + pipelineMap := map[string]any{ + "trigger": map[string]any{ + "type": "http", + "config": map[string]any{ + "path": "/api/test", + "method": "POST", + }, + }, + "steps": []any{ + map[string]any{"type": "step.json_response"}, + }, + } + ep := parsePipelineEndpoint("test-pipeline", pipelineMap, true) + if ep == nil { + t.Fatal("expected non-nil endpoint for HTTP trigger") + } + if ep.path != "/api/test" { + t.Errorf("expected path '/api/test', got %q", ep.path) + } + if ep.method != "POST" { + t.Errorf("expected method 'POST', got %q", ep.method) + } + if ep.name != "test-pipeline" { + t.Errorf("expected name 'test-pipeline', got %q", ep.name) + } + if len(ep.steps) != 1 { + t.Errorf("expected 1 step, got %d", len(ep.steps)) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/cmd/wfctl/main.go b/cmd/wfctl/main.go index 6cd2d745..e538617d 100644 --- a/cmd/wfctl/main.go +++ b/cmd/wfctl/main.go @@ -18,8 +18,10 @@ var commands = map[string]func([]string) error{ "manifest": runManifest, "migrate": runMigrate, "build-ui": runBuildUI, + "ui": runUI, "publish": runPublish, "deploy": runDeploy, + "api": runAPI, } func usage() { @@ -39,8 +41,10 @@ Commands: manifest Analyze config and report infrastructure requirements migrate Manage database schema migrations build-ui Build the application UI (npm install + npm run build + validate) + ui UI tooling (scaffold: generate Vite+React+TypeScript SPA from OpenAPI spec) publish Prepare and publish a plugin manifest to the workflow-registry deploy Deploy the workflow application (docker, kubernetes, cloud) + api API tooling (extract: generate OpenAPI 3.0 spec from config) Run 'wfctl -h' for command-specific help. `, version) From a2e9701a1975bb01d28c6bc15454343b3a2ac019 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Wed, 25 Feb 2026 01:49:58 -0500 Subject: [PATCH 5/5] feat(wfctl): add wfctl ui scaffold command for Vite+React+TypeScript SPA generation Adds `wfctl ui scaffold` which reads an OpenAPI 3.0 spec (JSON or YAML, file or stdin) and generates a complete, immediately-buildable Vite+React+TS SPA wired to the backend API. - New `wfctl ui` parent command dispatching to `scaffold` and `build` subcommands - Parses OpenAPI 3.0 spec: groups paths into resource pages, detects auth endpoints - Generates: package.json, tsconfig.json, vite.config.ts (with /api proxy), index.html, src/main.tsx, App.tsx (routes), api.ts (typed client with Bearer auth + 401 redirect), auth.tsx (AuthProvider/useAuth context), Layout, DataTable, FormField components, and one [Resource]Page.tsx per resource - Auth pages (LoginPage, RegisterPage) are auto-detected from spec or forced with -auth flag; skipped when no auth endpoints are found - Form fields inferred from requestBody schema (text/email/password/number/select) - API client exports typed functions per operation using path template literals - 27 tests covering spec parsing (JSON/YAML), analysis, code generation helpers, and CLI integration Co-Authored-By: Claude Opus 4.6 --- cmd/wfctl/scaffold.go | 768 +++++++++++++++++++++++++++ cmd/wfctl/scaffold_templates.go | 883 ++++++++++++++++++++++++++++++++ cmd/wfctl/scaffold_test.go | 808 +++++++++++++++++++++++++++++ 3 files changed, 2459 insertions(+) create mode 100644 cmd/wfctl/scaffold.go create mode 100644 cmd/wfctl/scaffold_templates.go create mode 100644 cmd/wfctl/scaffold_test.go diff --git a/cmd/wfctl/scaffold.go b/cmd/wfctl/scaffold.go new file mode 100644 index 00000000..91b046f0 --- /dev/null +++ b/cmd/wfctl/scaffold.go @@ -0,0 +1,768 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "text/template" + "unicode" + + "gopkg.in/yaml.v3" +) + +// scaffoldSpec mirrors the subset of OpenAPI 3.0 we need for scaffolding. +// We redeclare these locally so scaffold.go has no import dependency on the +// module package, keeping wfctl self-contained. +type scaffoldSpec struct { + Info scaffoldInfo `json:"info" yaml:"info"` + Paths map[string]*scaffoldPath `json:"paths" yaml:"paths"` + Components *scaffoldComponents `json:"components,omitempty" yaml:"components,omitempty"` +} + +type scaffoldInfo struct { + Title string `json:"title" yaml:"title"` + Version string `json:"version" yaml:"version"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` +} + +type scaffoldPath struct { + Get *scaffoldOp `json:"get,omitempty" yaml:"get,omitempty"` + Post *scaffoldOp `json:"post,omitempty" yaml:"post,omitempty"` + Put *scaffoldOp `json:"put,omitempty" yaml:"put,omitempty"` + Delete *scaffoldOp `json:"delete,omitempty" yaml:"delete,omitempty"` + Patch *scaffoldOp `json:"patch,omitempty" yaml:"patch,omitempty"` +} + +type scaffoldOp struct { + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Parameters []scaffoldParam `json:"parameters,omitempty" yaml:"parameters,omitempty"` + RequestBody *scaffoldReqBody `json:"requestBody,omitempty" yaml:"requestBody,omitempty"` +} + +type scaffoldParam struct { + Name string `json:"name" yaml:"name"` + In string `json:"in" yaml:"in"` +} + +type scaffoldReqBody struct { + Content map[string]*scaffoldMediaType `json:"content,omitempty" yaml:"content,omitempty"` +} + +type scaffoldMediaType struct { + Schema *scaffoldSchema `json:"schema,omitempty" yaml:"schema,omitempty"` +} + +type scaffoldSchema struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Properties map[string]*scaffoldSchema `json:"properties,omitempty" yaml:"properties,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Enum []string `json:"enum,omitempty" yaml:"enum,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` +} + +type scaffoldComponents struct { + Schemas map[string]*scaffoldSchema `json:"schemas,omitempty" yaml:"schemas,omitempty"` +} + +// --- Analysis types --- + +// apiOperation is a parsed API operation for template use. +type apiOperation struct { + FuncName string // e.g. "getUsers" + Method string // "GET" + Path string // "/api/v1/users" + HasBody bool + PathParams []string +} + +// resourceGroup groups related operations under a resource name. +type resourceGroup struct { + Name string // e.g. "Users" + NameLower string // e.g. "users" + NamePlural string // e.g. "users" + ListOp *apiOperation + DetailOp *apiOperation + CreateOp *apiOperation + UpdateOp *apiOperation + DeleteOp *apiOperation + FormFields []formField +} + +// formField describes a field in a generated form. +type formField struct { + Name string + Label string + Type string // "text", "email", "password", "number", "select" + Required bool + Options []string // for select type +} + +// scaffoldData is the top-level data passed to all templates. +type scaffoldData struct { + Title string + Version string + Theme string + HasAuth bool + LoginPath string + RegisterPath string + Resources []resourceGroup + Operations []apiOperation + // Auth-specific operation paths + LoginOp *apiOperation + RegisterOp *apiOperation +} + +// runUI dispatches `wfctl ui [args]`. +func runUI(args []string) error { + if len(args) == 0 { + fmt.Fprintf(os.Stderr, `Usage: wfctl ui [options] + +Subcommands: + scaffold Generate a Vite+React+TypeScript SPA from an OpenAPI spec + build Build the application UI (npm install + npm run build + validate) + +Run 'wfctl ui -h' for subcommand-specific help. +`) + return fmt.Errorf("subcommand required") + } + + sub := args[0] + rest := args[1:] + switch sub { + case "scaffold": + return runUIScaffold(rest) + case "build": + return runBuildUI(rest) + default: + return fmt.Errorf("unknown ui subcommand %q — use 'scaffold' or 'build'", sub) + } +} + +// runUIScaffold implements `wfctl ui scaffold`. +func runUIScaffold(args []string) error { + fs := flag.NewFlagSet("ui scaffold", flag.ContinueOnError) + specFile := fs.String("spec", "", "Path to OpenAPI spec file (JSON or YAML); reads stdin if not set") + output := fs.String("output", "ui", "Output directory for the scaffolded UI") + title := fs.String("title", "", "Application title (extracted from spec if not provided)") + auth := fs.Bool("auth", false, "Include login/register pages (auto-detected if not set)") + theme := fs.String("theme", "auto", "Color theme: light, dark, auto") + fs.Usage = func() { + fmt.Fprintf(fs.Output(), `Usage: wfctl ui scaffold [options] + +Generate a complete Vite+React+TypeScript SPA from an OpenAPI 3.0 spec. + +The generated UI is immediately buildable with: + cd && npm install && npm run build + +Examples: + wfctl ui scaffold -spec openapi.yaml -output ui + cat openapi.json | wfctl ui scaffold -output ./frontend + wfctl ui scaffold -spec api.yaml -title "My App" -auth -theme dark + +Options: +`) + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + + // Read spec. + var specBytes []byte + var err error + if *specFile != "" { + specBytes, err = os.ReadFile(*specFile) //nolint:gosec // user-supplied path + if err != nil { + return fmt.Errorf("failed to read spec file: %w", err) + } + } else { + specBytes, err = io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read spec from stdin: %w", err) + } + } + + // Parse spec. + spec, err := parseScaffoldSpec(specBytes) + if err != nil { + return fmt.Errorf("failed to parse OpenAPI spec: %w", err) + } + + // Build scaffold data. + data := analyzeSpec(spec, *title, *auth, *theme) + + // Resolve output directory. + absOutput, err := filepath.Abs(*output) + if err != nil { + return fmt.Errorf("failed to resolve output path: %w", err) + } + + // Generate files. + if err := generateScaffold(absOutput, data); err != nil { + return fmt.Errorf("scaffold generation failed: %w", err) + } + + fmt.Printf("\nUI scaffold generated in %s/\n\n", absOutput) + fmt.Println("Next steps:") + fmt.Printf(" cd %s\n", *output) + fmt.Println(" npm install") + fmt.Println(" npm run dev # start dev server with API proxy") + fmt.Println(" npm run build # production build") + fmt.Println() + return nil +} + +// parseScaffoldSpec parses a JSON or YAML OpenAPI spec. +func parseScaffoldSpec(data []byte) (*scaffoldSpec, error) { + spec := &scaffoldSpec{} + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, fmt.Errorf("spec is empty") + } + // Try JSON first (starts with '{'), then YAML. + if data[0] == '{' { + if err := json.Unmarshal(data, spec); err != nil { + return nil, fmt.Errorf("JSON parse error: %w", err) + } + } else { + if err := yaml.Unmarshal(data, spec); err != nil { + return nil, fmt.Errorf("YAML parse error: %w", err) + } + } + if spec.Paths == nil { + spec.Paths = make(map[string]*scaffoldPath) + } + return spec, nil +} + +// authPathRe matches paths that look like auth endpoints. +var authPathRe = regexp.MustCompile(`(?i)(auth|login|register|signup|signin|token|session)`) + +// analyzeSpec extracts resources, operations, and auth info from the spec. +func analyzeSpec(spec *scaffoldSpec, titleOverride string, forceAuth bool, theme string) scaffoldData { + title := spec.Info.Title + if titleOverride != "" { + title = titleOverride + } + if title == "" { + title = "My App" + } + + data := scaffoldData{ + Title: title, + Version: spec.Info.Version, + Theme: theme, + } + + // Collect all operations. + var allOps []apiOperation + + // Walk paths sorted for deterministic output. + paths := make([]string, 0, len(spec.Paths)) + for p := range spec.Paths { + paths = append(paths, p) + } + sort.Strings(paths) + + // Detect auth endpoints. + for _, p := range paths { + if authPathRe.MatchString(p) { + pi := spec.Paths[p] + if pi.Post != nil { + op := buildAPIOperation(pi.Post, "POST", p) + pl := strings.ToLower(p) + if strings.Contains(pl, "login") || strings.Contains(pl, "signin") || strings.Contains(pl, "token") || strings.Contains(pl, "session") { + data.HasAuth = true + data.LoginPath = p + data.LoginOp = &op + } + if strings.Contains(pl, "register") || strings.Contains(pl, "signup") { + data.HasAuth = true + data.RegisterPath = p + data.RegisterOp = &op + } + } + } + } + + if forceAuth && !data.HasAuth { + data.HasAuth = true + if data.LoginPath == "" { + data.LoginPath = "/auth/login" + } + if data.RegisterPath == "" { + data.RegisterPath = "/auth/register" + } + } + + // Group paths into resources, skipping auth paths. + resourceMap := map[string]*resourceGroup{} + resourceOrder := []string{} + + for _, p := range paths { + if authPathRe.MatchString(p) { + // Still add auth ops to allOps. + pi := spec.Paths[p] + for _, op := range pathOps(pi, p) { + allOps = append(allOps, op) + } + continue + } + + pi := spec.Paths[p] + resName := resourceNameFromPath(p) + if resName == "" { + continue + } + + if _, exists := resourceMap[resName]; !exists { + resourceMap[resName] = &resourceGroup{ + Name: toCamelCase(resName), + NameLower: strings.ToLower(resName), + NamePlural: strings.ToLower(resName), + } + resourceOrder = append(resourceOrder, resName) + } + + rg := resourceMap[resName] + hasPathParam := strings.Contains(p, "{") + + if pi.Get != nil { + op := buildAPIOperation(pi.Get, "GET", p) + allOps = append(allOps, op) + if hasPathParam { + rg.DetailOp = &op + } else { + rg.ListOp = &op + } + } + if pi.Post != nil { + op := buildAPIOperation(pi.Post, "POST", p) + allOps = append(allOps, op) + if !hasPathParam { + rg.CreateOp = &op + // Extract form fields from request body. + if pi.Post.RequestBody != nil { + rg.FormFields = extractFormFields(pi.Post.RequestBody, spec.Components) + } + } + } + if pi.Put != nil { + op := buildAPIOperation(pi.Put, "PUT", p) + allOps = append(allOps, op) + rg.UpdateOp = &op + if rg.CreateOp == nil && pi.Put.RequestBody != nil { + rg.FormFields = extractFormFields(pi.Put.RequestBody, spec.Components) + } + } + if pi.Patch != nil { + op := buildAPIOperation(pi.Patch, "PATCH", p) + allOps = append(allOps, op) + if rg.UpdateOp == nil { + rg.UpdateOp = &op + } + } + if pi.Delete != nil { + op := buildAPIOperation(pi.Delete, "DELETE", p) + allOps = append(allOps, op) + rg.DeleteOp = &op + } + } + + // Build ordered resources list. + for _, name := range resourceOrder { + data.Resources = append(data.Resources, *resourceMap[name]) + } + data.Operations = allOps + + return data +} + +// pathOps returns all operations defined in a path item. +func pathOps(pi *scaffoldPath, p string) []apiOperation { + var ops []apiOperation + if pi == nil { + return ops + } + if pi.Get != nil { + ops = append(ops, buildAPIOperation(pi.Get, "GET", p)) + } + if pi.Post != nil { + ops = append(ops, buildAPIOperation(pi.Post, "POST", p)) + } + if pi.Put != nil { + ops = append(ops, buildAPIOperation(pi.Put, "PUT", p)) + } + if pi.Patch != nil { + ops = append(ops, buildAPIOperation(pi.Patch, "PATCH", p)) + } + if pi.Delete != nil { + ops = append(ops, buildAPIOperation(pi.Delete, "DELETE", p)) + } + return ops +} + +// buildAPIOperation creates an apiOperation from a spec op. +func buildAPIOperation(op *scaffoldOp, method, path string) apiOperation { + funcName := op.OperationID + if funcName == "" { + funcName = generateFuncName(method, path) + } else { + // Ensure it starts lower-case for TS convention. + if len(funcName) > 0 { + r := []rune(funcName) + r[0] = unicode.ToLower(r[0]) + funcName = string(r) + } + } + + var pathParams []string + for _, param := range op.Parameters { + if param.In == "path" { + pathParams = append(pathParams, param.Name) + } + } + // Also extract from path pattern {name}. + for _, m := range regexp.MustCompile(`\{([^}]+)\}`).FindAllStringSubmatch(path, -1) { + name := m[1] + found := false + for _, pp := range pathParams { + if pp == name { + found = true + break + } + } + if !found { + pathParams = append(pathParams, name) + } + } + + return apiOperation{ + FuncName: funcName, + Method: strings.ToUpper(method), + Path: path, + HasBody: op.RequestBody != nil, + PathParams: pathParams, + } +} + +// generateFuncName produces a camelCase function name from method + path. +func generateFuncName(method, path string) string { + // Strip leading slash, remove path params, split on / and -. + clean := strings.TrimPrefix(path, "/") + // Replace {param} with "By" + re := regexp.MustCompile(`\{([^}]+)\}`) + clean = re.ReplaceAllStringFunc(clean, func(m string) string { + inner := m[1 : len(m)-1] + return "By" + toCamelCase(inner) + }) + clean = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(clean) + + parts := strings.Split(clean, "_") + var sb strings.Builder + sb.WriteString(strings.ToLower(method)) + for _, p := range parts { + if p == "" { + continue + } + sb.WriteString(strings.ToUpper(p[:1])) + if len(p) > 1 { + sb.WriteString(p[1:]) + } + } + return sb.String() +} + +// resourceNameFromPath derives a resource name from a URL path. +// e.g. "/api/v1/users/{id}" -> "users", "/users" -> "users" +func resourceNameFromPath(path string) string { + parts := strings.Split(strings.Trim(path, "/"), "/") + // Find the last non-param, non-version segment. + for i := len(parts) - 1; i >= 0; i-- { + seg := parts[i] + if seg == "" || strings.HasPrefix(seg, "{") { + continue + } + // Skip version segments like "v1", "v2", "api". + if seg == "api" || regexp.MustCompile(`^v\d+$`).MatchString(seg) { + continue + } + return seg + } + return "" +} + +// extractFormFields infers form fields from a request body schema. +func extractFormFields(rb *scaffoldReqBody, components *scaffoldComponents) []formField { + if rb == nil { + return nil + } + mt, ok := rb.Content["application/json"] + if !ok { + // Try the first content type. + for _, v := range rb.Content { + mt = v + break + } + } + if mt == nil || mt.Schema == nil { + return nil + } + + schema := resolveSchemaRef(mt.Schema, components) + if schema == nil { + return nil + } + + required := map[string]bool{} + for _, r := range schema.Required { + required[r] = true + } + + // Sort property names for deterministic output. + propNames := make([]string, 0, len(schema.Properties)) + for name := range schema.Properties { + propNames = append(propNames, name) + } + sort.Strings(propNames) + + var fields []formField + for _, name := range propNames { + prop := schema.Properties[name] + prop = resolveSchemaRef(prop, components) + if prop == nil { + prop = &scaffoldSchema{Type: "string"} + } + ft := inferFieldType(name, prop) + f := formField{ + Name: name, + Label: toLabel(name), + Type: ft, + Required: required[name], + } + if ft == "select" && len(prop.Enum) > 0 { + f.Options = prop.Enum + } + fields = append(fields, f) + } + return fields +} + +// resolveSchemaRef dereferences a $ref if present. +func resolveSchemaRef(s *scaffoldSchema, components *scaffoldComponents) *scaffoldSchema { + if s == nil { + return nil + } + if s.Ref == "" || components == nil { + return s + } + // Refs look like "#/components/schemas/Foo" + parts := strings.Split(s.Ref, "/") + if len(parts) >= 4 && parts[1] == "components" && parts[2] == "schemas" { + name := parts[3] + if components.Schemas != nil { + if resolved, ok := components.Schemas[name]; ok { + return resolved + } + } + } + return s +} + +// inferFieldType guesses the HTML input type from name and schema. +func inferFieldType(name string, schema *scaffoldSchema) string { + if len(schema.Enum) > 0 { + return "select" + } + lower := strings.ToLower(name) + switch { + case strings.Contains(lower, "email"): + return "email" + case strings.Contains(lower, "password") || strings.Contains(lower, "secret"): + return "password" + case schema.Type == "integer" || schema.Type == "number" || schema.Format == "int32" || schema.Format == "int64": + return "number" + default: + return "text" + } +} + +// toLabel converts a camelCase or snake_case field name to a human label. +func toLabel(name string) string { + // snake_case to words. + s := strings.ReplaceAll(name, "_", " ") + // camelCase to words. + var out []rune + for i, r := range s { + if i > 0 && unicode.IsUpper(r) && !unicode.IsUpper(rune(s[i-1])) { + out = append(out, ' ') + } + out = append(out, r) + } + s = string(out) + // Capitalize first letter. + if len(s) > 0 { + r := []rune(s) + r[0] = unicode.ToUpper(r[0]) + s = string(r) + } + return s +} + +// toCamelCase converts snake_case or kebab-case to CamelCase. +func toCamelCase(s string) string { + parts := strings.FieldsFunc(s, func(r rune) bool { return r == '-' || r == '_' }) + var sb strings.Builder + for _, p := range parts { + if p == "" { + continue + } + sb.WriteString(strings.ToUpper(p[:1])) + if len(p) > 1 { + sb.WriteString(p[1:]) + } + } + return sb.String() +} + +// --- File generation --- + +// scaffoldFile describes one file to generate. +type scaffoldFile struct { + // path relative to the output directory. + path string + // tmplName keys into scaffoldTemplates map. + tmplName string + // onlyIf: if non-nil, the file is only generated when this returns true. + onlyIf func(scaffoldData) bool +} + +var scaffoldFiles = []scaffoldFile{ + {path: "package.json", tmplName: "package.json"}, + {path: "tsconfig.json", tmplName: "tsconfig.json"}, + {path: "vite.config.ts", tmplName: "vite.config.ts"}, + {path: "index.html", tmplName: "index.html"}, + {path: "src/main.tsx", tmplName: "main.tsx"}, + {path: "src/App.tsx", tmplName: "App.tsx"}, + {path: "src/api.ts", tmplName: "api.ts"}, + {path: "src/auth.tsx", tmplName: "auth.tsx", onlyIf: func(d scaffoldData) bool { return d.HasAuth }}, + {path: "src/components/Layout.tsx", tmplName: "Layout.tsx"}, + {path: "src/components/DataTable.tsx", tmplName: "DataTable.tsx"}, + {path: "src/components/FormField.tsx", tmplName: "FormField.tsx"}, + {path: "src/pages/DashboardPage.tsx", tmplName: "DashboardPage.tsx"}, + {path: "src/pages/LoginPage.tsx", tmplName: "LoginPage.tsx", onlyIf: func(d scaffoldData) bool { return d.HasAuth }}, + {path: "src/pages/RegisterPage.tsx", tmplName: "RegisterPage.tsx", onlyIf: func(d scaffoldData) bool { return d.HasAuth }}, +} + +// generateScaffold writes all scaffold files to outDir. +func generateScaffold(outDir string, data scaffoldData) error { + // Parse all templates once. + tmplMap, err := parseScaffoldTemplates() + if err != nil { + return fmt.Errorf("failed to parse scaffold templates: %w", err) + } + + // Generate static files. + for _, sf := range scaffoldFiles { + if sf.onlyIf != nil && !sf.onlyIf(data) { + continue + } + tmpl, ok := tmplMap[sf.tmplName] + if !ok { + return fmt.Errorf("template %q not found", sf.tmplName) + } + destPath := filepath.Join(outDir, sf.path) + if err := writeTemplate(tmpl, destPath, data); err != nil { + return fmt.Errorf("generate %s: %w", sf.path, err) + } + fmt.Printf(" create %s\n", filepath.Join(filepath.Base(outDir), sf.path)) + } + + // Generate one page per resource. + for _, rg := range data.Resources { + tmpl, ok := tmplMap["ResourcePage.tsx"] + if !ok { + return fmt.Errorf("template ResourcePage.tsx not found") + } + pagePath := filepath.Join(outDir, "src", "pages", rg.Name+"Page.tsx") + if err := writeTemplate(tmpl, pagePath, rg); err != nil { + return fmt.Errorf("generate %sPage.tsx: %w", rg.Name, err) + } + fmt.Printf(" create %s\n", filepath.Join(filepath.Base(outDir), "src", "pages", rg.Name+"Page.tsx")) + } + + return nil +} + +// writeTemplate renders a template to a file, creating parent directories as needed. +func writeTemplate(tmpl *template.Template, destPath string, data any) error { + if err := os.MkdirAll(filepath.Dir(destPath), 0750); err != nil { + return fmt.Errorf("create directory: %w", err) + } + f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640) //nolint:gosec // generated UI files + if err != nil { + return fmt.Errorf("create file: %w", err) + } + defer f.Close() + return tmpl.Execute(f, data) +} + +// parseScaffoldTemplates parses all scaffold template strings into a map. +func parseScaffoldTemplates() (map[string]*template.Template, error) { + funcs := template.FuncMap{ + "lower": strings.ToLower, + "upper": strings.ToUpper, + "title": toCamelCase, + "join": strings.Join, + "hasPrefix": strings.HasPrefix, + "trimPrefix": strings.TrimPrefix, + "replace": strings.ReplaceAll, + "jsPath": jsPathExpr, + "tsTupleArgs": tsTupleArgs, + } + + result := make(map[string]*template.Template, len(scaffoldTemplates)) + for name, src := range scaffoldTemplates { + t, err := template.New(name).Funcs(funcs).Parse(src) + if err != nil { + return nil, fmt.Errorf("parse template %q: %w", name, err) + } + result[name] = t + } + return result, nil +} + +// jsPathExpr converts an OpenAPI path like "/users/{id}" to a JS template +// literal expression like "`/users/${id}`". +func jsPathExpr(path string) string { + result := regexp.MustCompile(`\{([^}]+)\}`).ReplaceAllString(path, `${$1}`) + if strings.Contains(result, "${") { + return "`" + result + "`" + } + return "'" + result + "'" +} + +// tsTupleArgs returns comma-separated TypeScript parameter list for path params + optional body. +// e.g. ("GET", ["id"], false) -> "id: string" +// e.g. ("POST", [], true) -> "data: any" +// e.g. ("PUT", ["id"], true) -> "id: string, data: any" +func tsTupleArgs(method string, pathParams []string, hasBody bool) string { + var parts []string + for _, p := range pathParams { + parts = append(parts, p+": string") + } + if hasBody || method == "POST" || method == "PUT" || method == "PATCH" { + parts = append(parts, "data: any") + } + return strings.Join(parts, ", ") +} diff --git a/cmd/wfctl/scaffold_templates.go b/cmd/wfctl/scaffold_templates.go new file mode 100644 index 00000000..b1d66fdc --- /dev/null +++ b/cmd/wfctl/scaffold_templates.go @@ -0,0 +1,883 @@ +package main + +// ob / cb are Go template actions that output literal {{ and }} in the generated files. +// All JSX double-brace patterns like style={{ ... }} must use these to avoid +// being mis-parsed as Go template directives. +// +// Usage inside a template string: {{ob}} ... {{cb}} +// e.g. style={{ob}} color: 'red' {{cb}} + +// scaffoldTemplates maps template names to their raw Go text/template source. +var scaffoldTemplates = map[string]string{ + "package.json": `{ + "name": "{{.Title | lower | replace " " "-"}}", + "version": "0.1.0", + "private": true, + "scripts": { + "dev": "vite", + "build": "tsc && vite build", + "preview": "vite preview" + }, + "dependencies": { + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-router-dom": "^6.28.0" + }, + "devDependencies": { + "@types/react": "^18.3.12", + "@types/react-dom": "^18.3.1", + "@vitejs/plugin-react": "^4.3.4", + "typescript": "^5.7.2", + "vite": "^6.0.5" + } +} +`, + + "tsconfig.json": `{ + "compilerOptions": { + "target": "ES2020", + "useDefineForClassFields": true, + "lib": ["ES2020", "DOM", "DOM.Iterable"], + "module": "ESNext", + "skipLibCheck": true, + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + "jsx": "react-jsx", + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true + }, + "include": ["src"] +} +`, + + "vite.config.ts": `import { defineConfig } from 'vite'; +import react from '@vitejs/plugin-react'; + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [react()], + server: { + proxy: { + '/api': { + target: 'http://localhost:8080', + changeOrigin: true, + }, + }, + }, +}); +`, + + "index.html": ` + + + + + {{.Title}} + + + +
+ + + +`, + + "main.tsx": `import { StrictMode } from 'react'; +import { createRoot } from 'react-dom/client'; +import { BrowserRouter } from 'react-router-dom'; +import App from './App'; +{{- if .HasAuth}} +import { AuthProvider } from './auth'; +{{- end}} + +createRoot(document.getElementById('root')!).render( + + +{{- if .HasAuth}} + + + +{{- else}} + +{{- end}} + + +); +`, + + "App.tsx": `import { Routes, Route, Navigate } from 'react-router-dom'; +import Layout from './components/Layout'; +import DashboardPage from './pages/DashboardPage'; +{{- if .HasAuth}} +import LoginPage from './pages/LoginPage'; +import RegisterPage from './pages/RegisterPage'; +import { useAuth } from './auth'; +{{- end}} +{{- range .Resources}} +import {{.Name}}Page from './pages/{{.Name}}Page'; +{{- end}} +{{- if .HasAuth}} +function PrivateRoute({ children }: { children: React.ReactNode }) { + const { token } = useAuth(); + return token ? <>{children} : ; +} +{{- end}} + +export default function App() { + return ( + +{{- if .HasAuth}} + } /> + } /> + + + + }> +{{- else}} + }> +{{- end}} + } /> +{{- range .Resources}} + } /> +{{- end}} + + + ); +} +`, + + // api.ts is a static file — no template directives in the fixed parts. + // The dynamic part is the list of exported functions. + "api.ts": `const API_BASE = ''; + +async function apiCall(method: string, path: string, body?: unknown): Promise { + const token = localStorage.getItem('token'); + const res = await fetch(API_BASE + path, { + method, + headers: { + 'Content-Type': 'application/json', + ...(token ? { Authorization: ` + "`" + `Bearer ${token}` + "`" + ` } : {}), + }, + body: body !== undefined ? JSON.stringify(body) : undefined, + }); + if (res.status === 401) { + localStorage.removeItem('token'); + window.location.href = '/login'; + throw new Error('Unauthorized'); + } + if (!res.ok) { + const text = await res.text().catch(() => res.statusText); + throw new Error(` + "`" + `HTTP ${res.status}: ${text}` + "`" + `); + } + const ct = res.headers.get('content-type') ?? ''; + if (ct.includes('application/json')) { + return res.json(); + } + return res.text(); +} + +// Generated API functions +{{range .Operations}} +export const {{.FuncName}} = ({{tsTupleArgs .Method .PathParams .HasBody}}) => + apiCall('{{.Method}}', {{jsPath .Path}}{{if or .HasBody (eq .Method "POST") (eq .Method "PUT") (eq .Method "PATCH")}}, data{{end}}); +{{end}} +`, + + // auth.tsx has no template directives — it is emitted verbatim. + // The value prop uses {{ }} for JSX object literal; escape with ob/cb funcs. + "auth.tsx": `import { createContext, useContext, useState, useCallback, ReactNode } from 'react'; + +interface AuthContextValue { + token: string | null; + login: (token: string) => void; + logout: () => void; +} + +const AuthContext = createContext(null); + +export function AuthProvider({ children }: { children: ReactNode }) { + const [token, setToken] = useState(() => + localStorage.getItem('token') + ); + + const login = useCallback((t: string) => { + localStorage.setItem('token', t); + setToken(t); + }, []); + + const logout = useCallback(() => { + localStorage.removeItem('token'); + setToken(null); + }, []); + + const value: AuthContextValue = { token, login, logout }; + return ( + + {children} + + ); +} + +export function useAuth(): AuthContextValue { + const ctx = useContext(AuthContext); + if (!ctx) throw new Error('useAuth must be used inside AuthProvider'); + return ctx; +} +`, + + "Layout.tsx": `import { Outlet, NavLink } from 'react-router-dom'; +{{- if .HasAuth}} +import { useAuth } from '../auth'; +{{- end}} +import type { CSSProperties } from 'react'; + +const navStyle: CSSProperties = { + width: 220, + minHeight: '100vh', + background: '#1a1a2e', + color: '#eee', + padding: '1rem', + display: 'flex', + flexDirection: 'column', + gap: '0.5rem', +}; + +const linkStyle: CSSProperties = { + color: '#ccc', + textDecoration: 'none', + padding: '0.4rem 0.6rem', + borderRadius: 4, +}; + +const mainStyle: CSSProperties = { + flex: 1, + padding: '2rem', +}; + +const titleStyle: CSSProperties = { + fontSize: '1.1rem', + fontWeight: 700, + marginBottom: '1rem', + color: '#fff', +}; + +const wrapStyle: CSSProperties = { + display: 'flex', +}; + +export default function Layout() { +{{- if .HasAuth}} + const { logout } = useAuth(); +{{- end}} + return ( +
+ +
+ +
+
+ ); +} +{{- if .HasAuth}} + +const logoutWrapStyle: CSSProperties = { marginTop: 'auto' }; +const logoutBtnStyle: CSSProperties = { + background: 'none', + border: '1px solid #555', + color: '#ccc', + padding: '0.4rem 0.8rem', + borderRadius: 4, + cursor: 'pointer', +}; +{{- end}} +`, + + "DataTable.tsx": `import type { CSSProperties } from 'react'; + +interface Column { + key: keyof T; + label: string; +} + +interface DataTableProps> { + columns: Column[]; + rows: T[]; + onSelect?: (row: T) => void; +} + +const tableStyle: CSSProperties = { width: '100%', borderCollapse: 'collapse' }; +const thStyle: CSSProperties = { + textAlign: 'left', + padding: '0.5rem 1rem', + background: '#f5f5f5', + borderBottom: '2px solid #ddd', + fontWeight: 600, +}; +const tdStyle: CSSProperties = { + padding: '0.5rem 1rem', + borderBottom: '1px solid #eee', +}; +const emptyStyle: CSSProperties = { + padding: '0.5rem 1rem', + borderBottom: '1px solid #eee', + textAlign: 'center', + color: '#999', +}; + +export default function DataTable>({ + columns, + rows, + onSelect, +}: DataTableProps) { + return ( + + + + {columns.map((col) => ( + + ))} + + + + {rows.map((row, i) => ( + onSelect?.(row)} + style={onSelect ? { cursor: 'pointer' } : undefined} + > + {columns.map((col) => ( + + ))} + + ))} + {rows.length === 0 && ( + + + + )} + +
{col.label}
+ {String(row[col.key] ?? '')} +
+ No data +
+ ); +} +`, + + "FormField.tsx": `import type { CSSProperties } from 'react'; + +interface FormFieldProps { + name: string; + label: string; + type?: string; + value: string; + onChange: (name: string, value: string) => void; + required?: boolean; + options?: string[]; +} + +const wrapStyle: CSSProperties = { marginBottom: '1rem' }; +const labelStyle: CSSProperties = { + display: 'block', + marginBottom: '0.25rem', + fontWeight: 500, + fontSize: '0.9rem', +}; +const inputStyle: CSSProperties = { + width: '100%', + padding: '0.5rem 0.75rem', + border: '1px solid #ccc', + borderRadius: 4, + fontSize: '1rem', +}; +const reqStyle: CSSProperties = { color: 'red' }; + +export default function FormField({ + name, + label, + type = 'text', + value, + onChange, + required = false, + options = [], +}: FormFieldProps) { + return ( +
+ + {type === 'select' ? ( + + ) : ( + onChange(name, e.target.value)} + style={inputStyle} + /> + )} +
+ ); +} +`, + + "DashboardPage.tsx": `import type { CSSProperties } from 'react'; + +const headStyle: CSSProperties = { marginTop: 0 }; +const cardGridStyle: CSSProperties = { + display: 'flex', + gap: '1rem', + flexWrap: 'wrap', + marginTop: '1.5rem', +}; +const cardStyle: CSSProperties = { + display: 'block', + padding: '1rem 1.5rem', + border: '1px solid #ddd', + borderRadius: 8, + textDecoration: 'none', + color: 'inherit', + background: '#fafafa', +}; +const cardTitleStyle: CSSProperties = { fontWeight: 700, fontSize: '1.1rem' }; + +export default function DashboardPage() { + return ( +
+

{{.Title}}

+

Welcome to the dashboard. Use the navigation to explore the API resources.

+{{- if .Resources}} +
+{{- range .Resources}} + +
{{.Name}}
+
+{{- end}} +
+{{- end}} +
+ ); +} +`, + + "LoginPage.tsx": `import { useState, FormEvent } from 'react'; +import { useNavigate, Link } from 'react-router-dom'; +import { useAuth } from '../auth'; +import type { CSSProperties } from 'react'; + +const pageStyle: CSSProperties = { + minHeight: '100vh', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + background: '#f5f5f5', +}; +const cardStyle: CSSProperties = { + background: '#fff', + padding: '2rem', + borderRadius: 8, + width: 360, + boxShadow: '0 2px 8px rgba(0,0,0,0.1)', +}; +const headStyle: CSSProperties = { marginTop: 0, fontSize: '1.5rem' }; +const errStyle: CSSProperties = { color: 'red', marginBottom: '1rem', fontSize: '0.9rem' }; +const fieldStyle: CSSProperties = { marginBottom: '1rem' }; +const fieldStyleLast: CSSProperties = { marginBottom: '1.5rem' }; +const labelStyle: CSSProperties = { display: 'block', marginBottom: '0.25rem', fontWeight: 500 }; +const inputStyle: CSSProperties = { + width: '100%', + padding: '0.5rem 0.75rem', + border: '1px solid #ccc', + borderRadius: 4, + fontSize: '1rem', +}; +const btnStyle: CSSProperties = { + width: '100%', + padding: '0.6rem', + background: '#1a1a2e', + color: '#fff', + border: 'none', + borderRadius: 4, + fontSize: '1rem', + cursor: 'pointer', +}; +const footerStyle: CSSProperties = { marginTop: '1rem', textAlign: 'center', fontSize: '0.9rem' }; + +export default function LoginPage() { + const { login } = useAuth(); + const navigate = useNavigate(); + const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const [error, setError] = useState(''); + const [loading, setLoading] = useState(false); + + async function handleSubmit(e: FormEvent) { + e.preventDefault(); + setError(''); + setLoading(true); + try { + const tok = localStorage.getItem('token'); + const res = await fetch('{{.LoginPath}}', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(tok ? { Authorization: ` + "`" + `Bearer ${tok}` + "`" + ` } : {}), + }, + body: JSON.stringify({ email, password }), + }); + if (!res.ok) { + const text = await res.text().catch(() => 'Login failed'); + throw new Error(text); + } + const data = await res.json() as Record; + const token = (data.token ?? data.access_token ?? data.jwt) as string | undefined; + if (!token) throw new Error('No token in response'); + login(token); + navigate('/'); + } catch (err) { + setError(err instanceof Error ? err.message : 'Login failed'); + } finally { + setLoading(false); + } + } + + return ( +
+
+

Sign In

+ {error &&
{error}
} +
+
+ + setEmail(e.target.value)} required style={inputStyle} /> +
+
+ + setPassword(e.target.value)} required style={inputStyle} /> +
+ +
+

+ Don't have an account? Register +

+
+
+ ); +} +`, + + "RegisterPage.tsx": `import { useState, FormEvent } from 'react'; +import { useNavigate, Link } from 'react-router-dom'; +import { useAuth } from '../auth'; +import type { CSSProperties } from 'react'; + +const pageStyle: CSSProperties = { + minHeight: '100vh', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + background: '#f5f5f5', +}; +const cardStyle: CSSProperties = { + background: '#fff', + padding: '2rem', + borderRadius: 8, + width: 360, + boxShadow: '0 2px 8px rgba(0,0,0,0.1)', +}; +const headStyle: CSSProperties = { marginTop: 0, fontSize: '1.5rem' }; +const errStyle: CSSProperties = { color: 'red', marginBottom: '1rem', fontSize: '0.9rem' }; +const fieldStyle: CSSProperties = { marginBottom: '1rem' }; +const fieldStyleLast: CSSProperties = { marginBottom: '1.5rem' }; +const labelStyle: CSSProperties = { display: 'block', marginBottom: '0.25rem', fontWeight: 500 }; +const inputStyle: CSSProperties = { + width: '100%', + padding: '0.5rem 0.75rem', + border: '1px solid #ccc', + borderRadius: 4, + fontSize: '1rem', +}; +const btnStyle: CSSProperties = { + width: '100%', + padding: '0.6rem', + background: '#1a1a2e', + color: '#fff', + border: 'none', + borderRadius: 4, + fontSize: '1rem', + cursor: 'pointer', +}; +const footerStyle: CSSProperties = { marginTop: '1rem', textAlign: 'center', fontSize: '0.9rem' }; + +export default function RegisterPage() { + const { login } = useAuth(); + const navigate = useNavigate(); + const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const [error, setError] = useState(''); + const [loading, setLoading] = useState(false); + + async function handleSubmit(e: FormEvent) { + e.preventDefault(); + setError(''); + setLoading(true); + try { + const res = await fetch('{{.RegisterPath}}', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ email, password }), + }); + if (!res.ok) { + const text = await res.text().catch(() => 'Registration failed'); + throw new Error(text); + } + const data = await res.json() as Record; + const token = (data.token ?? data.access_token ?? data.jwt) as string | undefined; + if (token) { + login(token); + navigate('/'); + } else { + navigate('/login'); + } + } catch (err) { + setError(err instanceof Error ? err.message : 'Registration failed'); + } finally { + setLoading(false); + } + } + + return ( +
+
+

Create Account

+ {error &&
{error}
} +
+
+ + setEmail(e.target.value)} required style={inputStyle} /> +
+
+ + setPassword(e.target.value)} required style={inputStyle} /> +
+ +
+

+ Already have an account? Sign In +

+
+
+ ); +} +`, + + "ResourcePage.tsx": `import { useState, useEffect, FormEvent } from 'react'; +import type { CSSProperties } from 'react'; +import DataTable from '../components/DataTable'; +{{- if .FormFields}} +import FormField from '../components/FormField'; +{{- end}} + +type Item = Record; + +const headerStyle: CSSProperties = { display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1.5rem' }; +const h1Style: CSSProperties = { margin: 0 }; +const newBtnStyle: CSSProperties = { padding: '0.5rem 1rem', background: '#1a1a2e', color: '#fff', border: 'none', borderRadius: 4, cursor: 'pointer' }; +const errStyle: CSSProperties = { color: 'red', marginBottom: '1rem' }; +const formBoxStyle: CSSProperties = { background: '#f9f9f9', padding: '1.5rem', borderRadius: 8, marginBottom: '1.5rem', border: '1px solid #eee' }; +const formHeadStyle: CSSProperties = { marginTop: 0, fontSize: '1.1rem' }; +const submitBtnStyle: CSSProperties = { padding: '0.5rem 1.5rem', background: '#1a1a2e', color: '#fff', border: 'none', borderRadius: 4, cursor: 'pointer' }; +const loadingStyle: CSSProperties = { color: '#999' }; +const detailBoxStyle: CSSProperties = { marginTop: '1.5rem', background: '#f9f9f9', padding: '1.5rem', borderRadius: 8, border: '1px solid #eee' }; +const detailHeadStyle: CSSProperties = { display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1rem' }; +const detailTitleStyle: CSSProperties = { margin: 0, fontSize: '1.1rem' }; +const detailBtnsStyle: CSSProperties = { display: 'flex', gap: '0.5rem' }; +const deleteBtnStyle: CSSProperties = { padding: '0.4rem 0.8rem', background: '#dc3545', color: '#fff', border: 'none', borderRadius: 4, cursor: 'pointer' }; +const closeBtnStyle: CSSProperties = { padding: '0.4rem 0.8rem', background: '#6c757d', color: '#fff', border: 'none', borderRadius: 4, cursor: 'pointer' }; +const preStyle: CSSProperties = { margin: 0, fontSize: '0.85rem', overflow: 'auto' }; + +export default function {{.Name}}Page() { + const [items, setItems] = useState([]); + const [selected, setSelected] = useState(null); + const [showForm, setShowForm] = useState(false); + const [error, setError] = useState(''); + const [loading, setLoading] = useState(false); + const [form, setForm] = useState>({ +{{- range .FormFields}} + {{.Name}}: '', +{{- end}} + }); +{{if .ListOp}} + useEffect(() => { + loadItems(); + }, []); + + async function loadItems() { + setLoading(true); + try { + const tok = localStorage.getItem('token'); + const res = await fetch('{{.ListOp.Path}}', { + headers: tok ? { Authorization: ` + "`" + `Bearer ${tok}` + "`" + ` } : {}, + }); + if (!res.ok) throw new Error(` + "`" + `HTTP ${res.status}` + "`" + `); + const data = await res.json() as unknown; + setItems(Array.isArray(data) ? data as Item[] : (data as { items?: Item[], data?: Item[] })?.items ?? (data as { items?: Item[], data?: Item[] })?.data ?? []); + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load'); + } finally { + setLoading(false); + } + } +{{end}} +{{if .CreateOp}} + async function handleCreate(e: FormEvent) { + e.preventDefault(); + setError(''); + try { + const tok = localStorage.getItem('token'); + const res = await fetch('{{.CreateOp.Path}}', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(tok ? { Authorization: ` + "`" + `Bearer ${tok}` + "`" + ` } : {}), + }, + body: JSON.stringify(form), + }); + if (!res.ok) throw new Error(` + "`" + `HTTP ${res.status}` + "`" + `); + setShowForm(false); + setForm({ +{{- range .FormFields}} + {{.Name}}: '', +{{- end}} + }); +{{- if .ListOp}} + await loadItems(); +{{- end}} + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to create'); + } + } +{{end}} +{{if .DeleteOp}} + async function handleDelete(item: Item) { + if (!window.confirm('Delete this item?')) return; + const idKeys = ['id', 'ID', '_id', 'uuid', 'UUID']; + const id = idKeys.map((k) => item[k]).find((v) => v != null); + if (!id) return; + try { + const tok = localStorage.getItem('token'); + const res = await fetch('{{.DeleteOp.Path}}'.replace(/\{[^}]+\}/, String(id)), { + method: 'DELETE', + headers: tok ? { Authorization: ` + "`" + `Bearer ${tok}` + "`" + ` } : {}, + }); + if (!res.ok) throw new Error(` + "`" + `HTTP ${res.status}` + "`" + `); +{{- if .ListOp}} + await loadItems(); +{{- end}} + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to delete'); + } + } +{{end}} + const columns = items.length > 0 + ? Object.keys(items[0]).slice(0, 5).map((k) => ({ key: k as keyof Item, label: k })) + : []; + + return ( +
+
+

{{.Name}}

+{{- if .CreateOp}} + +{{- end}} +
+ {error &&
{error}
} +{{- if .CreateOp}} + {showForm && ( +
+

New {{.Name}}

+
+{{- range .FormFields}} + setForm((f) => ({ ...f, [name]: value }))} + required={ {{- .Required -}} } +{{- if .Options}} + options={[{{range .Options}}'{{.}}', {{end}}]} +{{- end}} + /> +{{- end}} + + +
+ )} +{{- end}} + {loading ? ( +
Loading...
+ ) : ( + setSelected(row)} /> + )} + {selected && ( +
+
+

Detail

+
+{{- if .DeleteOp}} + +{{- end}} + +
+
+
{JSON.stringify(selected, null, 2)}
+
+ )} +
+ ); +} +`, +} diff --git a/cmd/wfctl/scaffold_test.go b/cmd/wfctl/scaffold_test.go new file mode 100644 index 00000000..c2f7074b --- /dev/null +++ b/cmd/wfctl/scaffold_test.go @@ -0,0 +1,808 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +// sampleOpenAPISpec is a comprehensive OpenAPI 3.0 spec used across tests. +const sampleOpenAPISpec = `{ + "openapi": "3.0.3", + "info": { + "title": "Pet Store API", + "version": "1.0.0", + "description": "A sample pet store API" + }, + "paths": { + "/api/v1/pets": { + "get": { + "operationId": "listPets", + "summary": "List all pets", + "tags": ["pets"], + "responses": {"200": {"description": "success"}} + }, + "post": { + "operationId": "createPet", + "summary": "Create a pet", + "tags": ["pets"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "species": {"type": "string", "enum": ["dog", "cat", "bird"]}, + "age": {"type": "integer"} + }, + "required": ["name"] + } + } + } + }, + "responses": {"201": {"description": "created"}} + } + }, + "/api/v1/pets/{id}": { + "get": { + "operationId": "getPet", + "summary": "Get a pet", + "tags": ["pets"], + "parameters": [{"name": "id", "in": "path", "required": true}], + "responses": {"200": {"description": "success"}} + }, + "delete": { + "operationId": "deletePet", + "summary": "Delete a pet", + "tags": ["pets"], + "parameters": [{"name": "id", "in": "path", "required": true}], + "responses": {"204": {"description": "deleted"}} + } + }, + "/auth/login": { + "post": { + "operationId": "login", + "summary": "Log in", + "tags": ["auth"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email"}, + "password": {"type": "string"} + }, + "required": ["email", "password"] + } + } + } + }, + "responses": {"200": {"description": "token"}} + } + }, + "/auth/register": { + "post": { + "operationId": "register", + "summary": "Register", + "tags": ["auth"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "email": {"type": "string"}, + "password": {"type": "string"} + }, + "required": ["email", "password"] + } + } + } + }, + "responses": {"201": {"description": "registered"}} + } + } + } +}` + +// sampleMinimalSpec is a minimal spec with no auth and one resource. +const sampleMinimalSpec = ` +openapi: "3.0.3" +info: + title: "Todo API" + version: "0.1.0" +paths: + /todos: + get: + operationId: listTodos + responses: + "200": + description: success + post: + operationId: createTodo + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + title: + type: string + done: + type: boolean + required: + - title + responses: + "201": + description: created + /todos/{id}: + get: + operationId: getTodo + parameters: + - name: id + in: path + required: true + responses: + "200": + description: success + delete: + operationId: deleteTodo + parameters: + - name: id + in: path + required: true + responses: + "204": + description: deleted +` + +// --- parseScaffoldSpec --- + +func TestParseScaffoldSpec_JSON(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec.Info.Title != "Pet Store API" { + t.Errorf("expected title 'Pet Store API', got %q", spec.Info.Title) + } + if len(spec.Paths) != 4 { + t.Errorf("expected 4 paths, got %d", len(spec.Paths)) + } +} + +func TestParseScaffoldSpec_YAML(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleMinimalSpec)) + if err != nil { + t.Fatalf("unexpected error parsing YAML: %v", err) + } + if spec.Info.Title != "Todo API" { + t.Errorf("expected 'Todo API', got %q", spec.Info.Title) + } + if len(spec.Paths) != 2 { + t.Errorf("expected 2 paths, got %d", len(spec.Paths)) + } +} + +func TestParseScaffoldSpec_Empty(t *testing.T) { + _, err := parseScaffoldSpec([]byte("")) + if err == nil { + t.Fatal("expected error for empty spec") + } +} + +func TestParseScaffoldSpec_Invalid(t *testing.T) { + _, err := parseScaffoldSpec([]byte("{not valid json}")) + if err == nil { + t.Fatal("expected error for invalid spec") + } +} + +// --- analyzeSpec --- + +func TestAnalyzeSpec_DetectsAuth(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + if !data.HasAuth { + t.Error("expected HasAuth=true for spec with /auth/login and /auth/register") + } + if data.LoginPath != "/auth/login" { + t.Errorf("expected LoginPath='/auth/login', got %q", data.LoginPath) + } + if data.RegisterPath != "/auth/register" { + t.Errorf("expected RegisterPath='/auth/register', got %q", data.RegisterPath) + } +} + +func TestAnalyzeSpec_NoAuthInMinimal(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleMinimalSpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + if data.HasAuth { + t.Error("expected HasAuth=false for spec without auth endpoints") + } +} + +func TestAnalyzeSpec_ForceAuth(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleMinimalSpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", true, "auto") + + if !data.HasAuth { + t.Error("expected HasAuth=true when forceAuth=true") + } + if data.LoginPath == "" { + t.Error("expected LoginPath to be set when forceAuth=true") + } +} + +func TestAnalyzeSpec_ResourceGrouping(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + if len(data.Resources) != 1 { + t.Errorf("expected 1 resource (pets), got %d: %v", len(data.Resources), resourceNames(data.Resources)) + } + rg := data.Resources[0] + if rg.Name != "Pets" { + t.Errorf("expected resource name 'Pets', got %q", rg.Name) + } + if rg.ListOp == nil { + t.Error("expected ListOp for GET /api/v1/pets") + } + if rg.CreateOp == nil { + t.Error("expected CreateOp for POST /api/v1/pets") + } + if rg.DetailOp == nil { + t.Error("expected DetailOp for GET /api/v1/pets/{id}") + } + if rg.DeleteOp == nil { + t.Error("expected DeleteOp for DELETE /api/v1/pets/{id}") + } +} + +func TestAnalyzeSpec_TitleOverride(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "My Custom Title", false, "auto") + if data.Title != "My Custom Title" { + t.Errorf("expected title 'My Custom Title', got %q", data.Title) + } +} + +func TestAnalyzeSpec_TitleFromSpec(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + if data.Title != "Pet Store API" { + t.Errorf("expected title from spec 'Pet Store API', got %q", data.Title) + } +} + +func TestAnalyzeSpec_OperationsIncluded(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + funcNames := make(map[string]bool) + for _, op := range data.Operations { + funcNames[op.FuncName] = true + } + + // listPets, createPet, getPet, deletePet are in non-auth paths. + for _, expected := range []string{"listPets", "createPet", "getPet", "deletePet"} { + if !funcNames[expected] { + t.Errorf("expected operation %q in Operations list, got: %v", expected, operationFuncNames(data.Operations)) + } + } +} + +func TestAnalyzeSpec_FormFieldsExtracted(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + if len(data.Resources) == 0 { + t.Fatal("expected at least one resource") + } + rg := data.Resources[0] // Pets + + if len(rg.FormFields) == 0 { + t.Fatal("expected form fields from createPet requestBody") + } + + fieldMap := make(map[string]formField) + for _, f := range rg.FormFields { + fieldMap[f.Name] = f + } + + age, ok := fieldMap["age"] + if !ok { + t.Error("expected 'age' form field") + } else if age.Type != "number" { + t.Errorf("expected age.Type='number', got %q", age.Type) + } + + species, ok := fieldMap["species"] + if !ok { + t.Error("expected 'species' form field") + } else { + if species.Type != "select" { + t.Errorf("expected species.Type='select', got %q", species.Type) + } + if len(species.Options) != 3 { + t.Errorf("expected 3 options for species, got %d", len(species.Options)) + } + } +} + +// --- generateFuncName --- + +func TestGenerateFuncName(t *testing.T) { + cases := []struct { + method string + path string + want string + }{ + {"GET", "/api/v1/users", "getApiV1Users"}, + {"POST", "/api/v1/users", "postApiV1Users"}, + {"GET", "/users/{id}", "getUsersById"}, + {"DELETE", "/users/{id}", "deleteUsersById"}, + {"PUT", "/users/{id}/profile", "putUsersByIdProfile"}, + } + for _, c := range cases { + got := generateFuncName(c.method, c.path) + if got != c.want { + t.Errorf("generateFuncName(%q, %q) = %q, want %q", c.method, c.path, got, c.want) + } + } +} + +// --- resourceNameFromPath --- + +func TestResourceNameFromPath(t *testing.T) { + cases := []struct { + path string + want string + }{ + {"/api/v1/users", "users"}, + {"/api/v1/users/{id}", "users"}, + {"/users", "users"}, + {"/pets/{id}", "pets"}, + {"/api/v2/orders/{id}/items", "items"}, + {"/", ""}, + } + for _, c := range cases { + got := resourceNameFromPath(c.path) + if got != c.want { + t.Errorf("resourceNameFromPath(%q) = %q, want %q", c.path, got, c.want) + } + } +} + +// --- inferFieldType --- + +func TestInferFieldType(t *testing.T) { + cases := []struct { + name string + schema scaffoldSchema + want string + }{ + {"email", scaffoldSchema{Type: "string"}, "email"}, + {"emailAddress", scaffoldSchema{Type: "string"}, "email"}, + {"password", scaffoldSchema{Type: "string"}, "password"}, + {"secret_key", scaffoldSchema{Type: "string"}, "password"}, + {"count", scaffoldSchema{Type: "integer"}, "number"}, + {"price", scaffoldSchema{Type: "number"}, "number"}, + {"name", scaffoldSchema{Type: "string"}, "text"}, + {"status", scaffoldSchema{Type: "string", Enum: []string{"active", "inactive"}}, "select"}, + } + for _, c := range cases { + got := inferFieldType(c.name, &c.schema) + if got != c.want { + t.Errorf("inferFieldType(%q, ...) = %q, want %q", c.name, got, c.want) + } + } +} + +// --- toLabel --- + +func TestToLabel(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"name", "Name"}, + {"first_name", "First name"}, + {"emailAddress", "Email Address"}, + {"user_id", "User id"}, + } + for _, c := range cases { + got := toLabel(c.input) + if got != c.want { + t.Errorf("toLabel(%q) = %q, want %q", c.input, got, c.want) + } + } +} + +// --- jsPathExpr --- + +func TestJsPathExpr(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"/users", "'/users'"}, + {"/users/{id}", "`/users/${id}`"}, + {"/api/v1/users/{id}/posts/{postId}", "`/api/v1/users/${id}/posts/${postId}`"}, + } + for _, c := range cases { + got := jsPathExpr(c.input) + if got != c.want { + t.Errorf("jsPathExpr(%q) = %q, want %q", c.input, got, c.want) + } + } +} + +// --- tsTupleArgs --- + +func TestTsTupleArgs(t *testing.T) { + cases := []struct { + method string + pathParams []string + hasBody bool + want string + }{ + {"GET", nil, false, ""}, + {"GET", []string{"id"}, false, "id: string"}, + {"POST", nil, true, "data: any"}, + {"PUT", []string{"id"}, true, "id: string, data: any"}, + {"DELETE", []string{"id"}, false, "id: string"}, + {"POST", nil, false, "data: any"}, // POST always gets data param + } + for _, c := range cases { + got := tsTupleArgs(c.method, c.pathParams, c.hasBody) + if got != c.want { + t.Errorf("tsTupleArgs(%q, %v, %v) = %q, want %q", c.method, c.pathParams, c.hasBody, got, c.want) + } + } +} + +// --- generateScaffold (integration) --- + +func TestGenerateScaffold_WithAuth(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold failed: %v", err) + } + + // Verify all expected files are generated. + expectedFiles := []string{ + "package.json", + "tsconfig.json", + "vite.config.ts", + "index.html", + "src/main.tsx", + "src/App.tsx", + "src/api.ts", + "src/auth.tsx", + "src/components/Layout.tsx", + "src/components/DataTable.tsx", + "src/components/FormField.tsx", + "src/pages/DashboardPage.tsx", + "src/pages/LoginPage.tsx", + "src/pages/RegisterPage.tsx", + "src/pages/PetsPage.tsx", + } + for _, f := range expectedFiles { + path := filepath.Join(outDir, f) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected file not generated: %s", f) + } + } +} + +func TestGenerateScaffold_NoAuth(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleMinimalSpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold failed: %v", err) + } + + // Auth files must NOT be generated. + for _, f := range []string{"src/auth.tsx", "src/pages/LoginPage.tsx", "src/pages/RegisterPage.tsx"} { + path := filepath.Join(outDir, f) + if _, err := os.Stat(path); err == nil { + t.Errorf("auth file should not be generated without auth: %s", f) + } + } + + // Resource page must be generated. + todoPage := filepath.Join(outDir, "src", "pages", "TodosPage.tsx") + if _, err := os.Stat(todoPage); os.IsNotExist(err) { + t.Error("expected TodosPage.tsx to be generated") + } +} + +func TestGenerateScaffold_PackageJSON(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold: %v", err) + } + + pkgData, err := os.ReadFile(filepath.Join(outDir, "package.json")) + if err != nil { + t.Fatalf("read package.json: %v", err) + } + + var pkg map[string]any + if err := json.Unmarshal(pkgData, &pkg); err != nil { + t.Fatalf("package.json is not valid JSON: %v\ncontent:\n%s", err, pkgData) + } + + deps, ok := pkg["dependencies"].(map[string]any) + if !ok { + t.Fatal("package.json missing dependencies") + } + for _, dep := range []string{"react", "react-dom", "react-router-dom"} { + if _, ok := deps[dep]; !ok { + t.Errorf("package.json missing dependency: %s", dep) + } + } + + devDeps, ok := pkg["devDependencies"].(map[string]any) + if !ok { + t.Fatal("package.json missing devDependencies") + } + for _, dep := range []string{"vite", "typescript", "@vitejs/plugin-react"} { + if _, ok := devDeps[dep]; !ok { + t.Errorf("package.json missing devDependency: %s", dep) + } + } +} + +func TestGenerateScaffold_APIClient(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold: %v", err) + } + + apiTS, err := os.ReadFile(filepath.Join(outDir, "src", "api.ts")) + if err != nil { + t.Fatalf("read api.ts: %v", err) + } + + content := string(apiTS) + for _, funcName := range []string{"listPets", "createPet", "getPet", "deletePet"} { + if !strings.Contains(content, funcName) { + t.Errorf("api.ts missing function %q", funcName) + } + } + + // The API base helper must be present. + if !strings.Contains(content, "apiCall") { + t.Error("api.ts missing apiCall helper") + } + + // Bearer token injection must be present. + if !strings.Contains(content, "Authorization") { + t.Error("api.ts missing Authorization header") + } + + // 401 redirect must be present. + if !strings.Contains(content, "401") { + t.Error("api.ts missing 401 handling") + } +} + +func TestGenerateScaffold_ViteConfig(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleMinimalSpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold: %v", err) + } + + viteConfig, err := os.ReadFile(filepath.Join(outDir, "vite.config.ts")) + if err != nil { + t.Fatalf("read vite.config.ts: %v", err) + } + + content := string(viteConfig) + if !strings.Contains(content, "localhost:8080") { + t.Error("vite.config.ts should proxy /api to localhost:8080") + } + if !strings.Contains(content, "proxy") { + t.Error("vite.config.ts should have proxy config") + } +} + +func TestGenerateScaffold_AppRoutes(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold: %v", err) + } + + appTSX, err := os.ReadFile(filepath.Join(outDir, "src", "App.tsx")) + if err != nil { + t.Fatalf("read App.tsx: %v", err) + } + + content := string(appTSX) + if !strings.Contains(content, "PetsPage") { + t.Error("App.tsx should import PetsPage") + } + if !strings.Contains(content, "LoginPage") { + t.Error("App.tsx should import LoginPage (auth detected)") + } + if !strings.Contains(content, "RegisterPage") { + t.Error("App.tsx should import RegisterPage (auth detected)") + } +} + +func TestGenerateScaffold_LayoutNav(t *testing.T) { + spec, err := parseScaffoldSpec([]byte(sampleOpenAPISpec)) + if err != nil { + t.Fatalf("parse: %v", err) + } + data := analyzeSpec(spec, "", false, "auto") + outDir := t.TempDir() + if err := generateScaffold(outDir, data); err != nil { + t.Fatalf("generateScaffold: %v", err) + } + + layoutTSX, err := os.ReadFile(filepath.Join(outDir, "src", "components", "Layout.tsx")) + if err != nil { + t.Fatalf("read Layout.tsx: %v", err) + } + + content := string(layoutTSX) + // Should have nav link to pets resource. + if !strings.Contains(content, "/pets") { + t.Error("Layout.tsx should have nav link to /pets") + } + // Should have logout since auth is present. + if !strings.Contains(content, "logout") && !strings.Contains(content, "Logout") { + t.Error("Layout.tsx should have logout for auth-enabled spec") + } +} + +// --- runUIScaffold (CLI integration) --- + +func TestRunUIScaffold_FromFile(t *testing.T) { + outDir := t.TempDir() + + // Write spec to temp file. + specFile := filepath.Join(t.TempDir(), "openapi.json") + if err := os.WriteFile(specFile, []byte(sampleOpenAPISpec), 0644); err != nil { + t.Fatal(err) + } + + if err := runUIScaffold([]string{"-spec", specFile, "-output", outDir}); err != nil { + t.Fatalf("runUIScaffold failed: %v", err) + } + + // Quick sanity: package.json should exist. + if _, err := os.Stat(filepath.Join(outDir, "package.json")); os.IsNotExist(err) { + t.Error("expected package.json to be generated") + } +} + +func TestRunUIScaffold_WithTitleFlag(t *testing.T) { + outDir := t.TempDir() + specFile := filepath.Join(t.TempDir(), "openapi.yaml") + if err := os.WriteFile(specFile, []byte(sampleMinimalSpec), 0644); err != nil { + t.Fatal(err) + } + + if err := runUIScaffold([]string{"-spec", specFile, "-output", outDir, "-title", "Custom Title"}); err != nil { + t.Fatalf("runUIScaffold failed: %v", err) + } + + indexHTML, err := os.ReadFile(filepath.Join(outDir, "index.html")) + if err != nil { + t.Fatalf("read index.html: %v", err) + } + if !strings.Contains(string(indexHTML), "Custom Title") { + t.Error("index.html should contain custom title") + } +} + +func TestRunUIScaffold_MissingSpec(t *testing.T) { + err := runUIScaffold([]string{"-spec", "/nonexistent/path.yaml", "-output", t.TempDir()}) + if err == nil { + t.Fatal("expected error for missing spec file") + } +} + +func TestRunUI_Dispatch(t *testing.T) { + // Test that `ui` with no subcommand returns an error. + err := runUI([]string{}) + if err == nil { + t.Fatal("expected error when no subcommand given") + } + + // Test unknown subcommand. + err = runUI([]string{"unknown"}) + if err == nil { + t.Fatal("expected error for unknown subcommand") + } + if !strings.Contains(err.Error(), "unknown ui subcommand") { + t.Errorf("expected 'unknown ui subcommand' error, got: %v", err) + } +} + +// --- helpers --- + +func resourceNames(rgs []resourceGroup) []string { + names := make([]string, len(rgs)) + for i, rg := range rgs { + names[i] = rg.Name + } + return names +} + +func operationFuncNames(ops []apiOperation) []string { + names := make([]string, len(ops)) + for i, op := range ops { + names[i] = op.FuncName + } + return names +}