From 1d1dafd80b75542f70c67a1fe4cb58a13f161bd7 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 12:20:27 +0000 Subject: [PATCH 1/8] feat: provide v2 plugin interface for pre eval init actions --- .gitignore | 1 + Makefile | 10 +- cmd/agent.go | 183 +++++++++++-- cmd/agent_test.go | 240 ++++++++++++++++++ .../0001-support-versioned-runner-plugins.md | 38 +++ internal/oci.go | 47 +++- internal/oci_manifest_test.go | 67 +++++ runner/grpc.go | 45 +++- runner/grpc_test.go | 33 +++ runner/plugin.go | 8 +- runner/proto/results.pb.go | 2 +- runner/proto/runner.pb.go | 137 ++++++++-- runner/proto/runner.proto | 9 + runner/proto/runner_grpc.pb.go | 37 +++ runner/proto/types.pb.go | 2 +- 15 files changed, 802 insertions(+), 57 deletions(-) create mode 100644 docs/adr/0001-support-versioned-runner-plugins.md create mode 100644 internal/oci_manifest_test.go create mode 100644 runner/grpc_test.go diff --git a/.gitignore b/.gitignore index db2bc2a..aea7037 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ main cover.out config.yml concom +.config diff --git a/Makefile b/Makefile index f5f1e7f..341d986 100644 --- a/Makefile +++ b/Makefile @@ -42,4 +42,12 @@ test: ## Run tests $(WARN) "Tests failed"; \ exit 1; \ fi ; \ - $(OK) Tests passed \ No newline at end of file + $(OK) Tests passed + + +build: ## Build the project + @go build -o dist/concom main.go + +run: ## Run the project + @go run main.go agent --config ./.config/config.yaml + diff --git a/cmd/agent.go b/cmd/agent.go index 6280a0a..c2a9daa 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -6,9 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/compliance-framework/agent/runner/proto" - "github.com/google/uuid" - "github.com/robfig/cron/v3" "math/rand" "net/http" "os" @@ -16,11 +13,16 @@ import ( "os/signal" "path" "runtime" + "strconv" "strings" "sync" "syscall" "time" + "github.com/compliance-framework/agent/runner/proto" + "github.com/google/uuid" + "github.com/robfig/cron/v3" + "github.com/compliance-framework/agent/internal" "github.com/compliance-framework/agent/runner" "github.com/compliance-framework/api/sdk" @@ -46,11 +48,13 @@ type agentPolicy string type agentPluginConfig map[string]string type agentPlugin struct { - Schedule *string `mapstructure:"schedule,omitempty"` - Source string `mapstructure:"source"` - Policies []agentPolicy `mapstructure:"policies"` - Config agentPluginConfig `mapstructure:"config"` - Labels map[string]string `mapstructure:"labels"` + ProtocolVersion int32 `mapstructure:"protocol_version"` + Schedule *string `mapstructure:"schedule,omitempty"` + Source string `mapstructure:"source"` + Policies []agentPolicy `mapstructure:"policies"` + Config agentPluginConfig `mapstructure:"config"` + Labels map[string]string `mapstructure:"labels"` + protocolSet bool } type agentConfig struct { @@ -82,6 +86,9 @@ func (ac *agentConfig) validate() error { const AgentPluginDir = ".compliance-framework/plugins" const AgentPolicyDir = ".compliance-framework/policies" +const DefaultProtocolVersion int32 = 1 +const RunnerV2ProtocolVersion int32 = 2 +const AnnotationProtocolVersionKey = "org.ccf.plugin.protocol.version" func AgentCmd() *cobra.Command { var agentCmd = &cobra.Command{ @@ -138,13 +145,75 @@ func mergeConfig(cmd *cobra.Command, fileConfig *viper.Viper) (*agentConfig, err config := &agentConfig{} err := fileConfig.Unmarshal(config) + if err != nil { return nil, err } + markExplicitPluginProtocols(fileConfig, config) + updateAllPluginProtocols(config) + return config, nil } +func markExplicitPluginProtocols(fileConfig *viper.Viper, config *agentConfig) { + rawPlugins := fileConfig.GetStringMap("plugins") + for name, rawPlugin := range rawPlugins { + pluginConfig, ok := config.Plugins[name] + if !ok || pluginConfig == nil { + continue + } + + pluginMap, ok := rawPlugin.(map[string]interface{}) + if !ok { + continue + } + + _, pluginConfig.protocolSet = pluginMap["protocol_version"] + } +} + +func updateAllPluginProtocols(agentConfig *agentConfig) { + for _, pluginConfig := range agentConfig.Plugins { + if pluginConfig.ProtocolVersion == 0 { + pluginConfig.ProtocolVersion = DefaultProtocolVersion + } + } +} + +func protocolVersionFromAnnotations(annotations map[string]string) (int32, bool) { + value, ok := annotations[AnnotationProtocolVersionKey] + if !ok { + return 0, false + } + + parsed, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return 0, false + } + + if parsed < 1 { + return 0, false + } + + if parsed != int64(DefaultProtocolVersion) && parsed != int64(RunnerV2ProtocolVersion) { + return 0, false + } + + return int32(parsed), true +} + +func runnerDispenseName(protocolVersion int32) (string, error) { + switch protocolVersion { + case DefaultProtocolVersion: + return "runner", nil + case RunnerV2ProtocolVersion: + return "runner-v2", nil + default: + return "", fmt.Errorf("unsupported plugin protocol_version=%d", protocolVersion) + } +} + func loadConfig(cmd *cobra.Command, v *viper.Viper) (*agentConfig, error) { err := v.ReadInConfig() if err != nil { @@ -234,16 +303,18 @@ type AgentRunner struct { mu sync.Mutex config *agentConfig - pluginLocations map[string]string - policyLocations map[string]string + pluginLocations map[string]string + policyLocations map[string]string + fetchAnnotations func(source string, option ...remote.Option) (map[string]string, error) queryBundles []*rego.Rego } func NewAgentRunner() *AgentRunner { return &AgentRunner{ - pluginLocations: map[string]string{}, - policyLocations: map[string]string{}, + pluginLocations: map[string]string{}, + policyLocations: map[string]string{}, + fetchAnnotations: internal.GetAnnotations, } } @@ -266,6 +337,8 @@ func (ar *AgentRunner) Run(ctx context.Context) error { return err } + ar.resolvePluginProtocols() + err = ar.DownloadPolicies(ctx) if err != nil { ar.logger.Error("Error downloading policies", "error", err) @@ -281,6 +354,33 @@ func (ar *AgentRunner) Run(ctx context.Context) error { return ar.runAllPlugins(ctx) } +func (ar *AgentRunner) resolvePluginProtocols() { + for pluginName, pluginConfig := range ar.config.Plugins { + if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCI(pluginConfig.Source) { + continue + } + + annotations, err := ar.fetchAnnotations(pluginConfig.Source) + if err != nil { + ar.logger.Warn("Failed to fetch plugin annotations, using configured/default protocol version", "plugin", pluginName, "source", pluginConfig.Source, "protocol_version", pluginConfig.ProtocolVersion, "error", err) + continue + } + + value, ok := annotations[AnnotationProtocolVersionKey] + if !ok { + continue + } + + protocolVersion, ok := protocolVersionFromAnnotations(annotations) + if !ok { + ar.logger.Warn("Ignoring unsupported plugin protocol version annotation", "plugin", pluginName, "source", pluginConfig.Source, "value", value, "protocol_version", pluginConfig.ProtocolVersion) + continue + } + + pluginConfig.ProtocolVersion = protocolVersion + } +} + // Should never return, either handles any error or panics. func (ar *AgentRunner) runDaemon(ctx context.Context) { sigs := make(chan os.Signal, 1) @@ -363,7 +463,7 @@ func (ar *AgentRunner) setupCron(ctx context.Context) (*cron.Cron, error) { err := ar.runPlugin(ctx, pluginName, pluginConfig) if err != nil { // TODO how will we handle these errors ? - ar.logger.Error("Error running plugin", "error", err) + ar.logger.Error("Error running plugin", "error", err, "protocol_version", pluginConfig.ProtocolVersion) } }) @@ -405,13 +505,13 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error { source := ar.pluginLocations[pluginConfig.Source] - logger.Debug("Running plugin", "source", source) + logger.Debug("Running plugin", "source", source, "protocol_version", pluginConfig.ProtocolVersion) if _, err := os.ReadFile(source); err != nil { return err } - runnerInstance, err := ar.getRunnerInstance(logger, source) + runnerInstance, err := ar.getRunnerInstance(logger, source, pluginConfig.ProtocolVersion) if err != nil { return err @@ -444,6 +544,21 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error { // Create a new results helper for the plugin to send results back to resultsHelper := runner.NewApiHelper(logger, client, labels) + if pluginConfig.ProtocolVersion > 1 { + runnerV2, ok := runnerInstance.(runner.RunnerV2) + if !ok { + return fmt.Errorf("plugin %s configured as protocol_version=%d but does not support RunnerV2", pluginName, pluginConfig.ProtocolVersion) + } + + _, err := runnerV2.Init(&proto.InitRequest{ + PolicyPaths: policyPaths, + }, resultsHelper) + + if err != nil { + return err + } + } + // TODO: Send failed results to the database? _, err = runnerInstance.Eval(&proto.EvalRequest{ PolicyPaths: policyPaths, @@ -496,8 +611,8 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent OS: runtime.GOOS, })) - fmt.Println("Running plugin", "source", plugin.Source) - fmt.Println("Running plugin", "source", pluginExecutable) + ar.logger.Info("Running plugin", "source", plugin.Source, "protocol_version", plugin.ProtocolVersion) + ar.logger.Info("Running plugin", "source", pluginExecutable, "protocol_version", plugin.ProtocolVersion) if err != nil { return err @@ -517,13 +632,13 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent labels[k] = v } - logger.Debug("Running plugin", "source", pluginExecutable) + logger.Debug("Running plugin", "source", pluginExecutable, "protocol_version", plugin.ProtocolVersion) if _, err := os.ReadFile(pluginExecutable); err != nil { return err } - runnerInstance, err := ar.getRunnerInstance(logger, pluginExecutable) + runnerInstance, err := ar.getRunnerInstance(logger, pluginExecutable, plugin.ProtocolVersion) if err != nil { return err @@ -539,6 +654,21 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent // Create a new results helper for the plugin to send results back to resultsHelper := runner.NewApiHelper(logger, client, labels) + if plugin.ProtocolVersion > 1 { + runnerV2, ok := runnerInstance.(runner.RunnerV2) + if !ok { + return fmt.Errorf("plugin %s configured as protocol_version=%d but does not support RunnerV2", name, plugin.ProtocolVersion) + } + + _, err := runnerV2.Init(&proto.InitRequest{ + PolicyPaths: policyPaths, + }, resultsHelper) + + if err != nil { + return err + } + } + // TODO: Send failed results to the database? _, err = runnerInstance.Eval(&proto.EvalRequest{ PolicyPaths: policyPaths, @@ -581,7 +711,7 @@ func (ar *AgentRunner) SendHeartbeat(ctx context.Context, staticAgentUUID uuid.U return nil } -func (ar *AgentRunner) getRunnerInstance(logger hclog.Logger, path string) (runner.Runner, error) { +func (ar *AgentRunner) getRunnerInstance(logger hclog.Logger, path string, protocolVersion int32) (runner.Runner, error) { // We're a host! Start by launching the plugin process. client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: runner.HandshakeConfig, @@ -598,15 +728,24 @@ func (ar *AgentRunner) getRunnerInstance(logger hclog.Logger, path string) (runn return nil, err } + dispenseName, err := runnerDispenseName(protocolVersion) + if err != nil { + return nil, err + } + // Request the plugin - raw, err := rpcClient.Dispense("runner") + logger.Debug("Dispensing plugin", "runner", dispenseName) + raw, err := rpcClient.Dispense(dispenseName) if err != nil { return nil, err } // We should have a Greeter now! This feels like a normal interface // implementation but is in fact over an RPC connection. - runnerInstance := raw.(runner.Runner) + runnerInstance, ok := raw.(runner.Runner) + if !ok { + return nil, fmt.Errorf("dispensed plugin %q does not implement runner.Runner", dispenseName) + } return runnerInstance, nil } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index 3555c46..15456f4 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -2,9 +2,11 @@ package cmd import ( "bytes" + "errors" "fmt" "testing" + "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/spf13/viper" ) @@ -118,3 +120,241 @@ func TestAgentCmd_ConfigurationMerging(t *testing.T) { } }) } + +func TestMergeConfig_DefaultsPluginProtocolVersion(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBufferString("api:\n url: http://localhost:8080\n\nplugins:\n plugin-with-default:\n source: ghcr.io/some-plugin:v1\n plugin-with-explicit:\n source: ghcr.io/some-plugin:v2\n protocol_version: 2\n")) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + config, err := mergeConfig(AgentCmd(), v) + if err != nil { + t.Fatalf("Error merging config: %v", err) + } + + if got := config.Plugins["plugin-with-default"].ProtocolVersion; got != 1 { + t.Fatalf("Expected plugin-with-default protocol version to be 1, got %d", got) + } + + if got := config.Plugins["plugin-with-explicit"].ProtocolVersion; got != 2 { + t.Fatalf("Expected plugin-with-explicit protocol version to be 2, got %d", got) + } +} + +func TestUpdateAllPluginProtocols_DefaultsOnlyUnset(t *testing.T) { + config := &agentConfig{ + Plugins: map[string]*agentPlugin{ + "defaulted": { + Source: "ghcr.io/defaulted:v1", + }, + "explicit": { + Source: "ghcr.io/explicit:v2", + ProtocolVersion: 2, + }, + }, + } + + updateAllPluginProtocols(config) + + if got := config.Plugins["defaulted"].ProtocolVersion; got != 1 { + t.Fatalf("Expected defaulted plugin protocol version to be 1, got %d", got) + } + + if got := config.Plugins["explicit"].ProtocolVersion; got != 2 { + t.Fatalf("Expected explicit plugin protocol version to remain 2, got %d", got) + } +} + +func TestMergeConfig_DoesNotFetchAnnotations(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBufferString("api:\n url: http://localhost:8080\n\nplugins:\n plugin-with-default:\n source: ghcr.io/some-plugin:v1\n")) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + config, err := mergeConfig(AgentCmd(), v) + if err != nil { + t.Fatalf("Error merging config: %v", err) + } + + if got := config.Plugins["plugin-with-default"].ProtocolVersion; got != DefaultProtocolVersion { + t.Fatalf("Expected plugin-with-default protocol version to be %d, got %d", DefaultProtocolVersion, got) + } +} + +func TestResolvePluginProtocols_UsesAnnotationsOnlyForImplicitOCIPlugins(t *testing.T) { + lookupCount := 0 + fetchAnnotations := func(source string, option ...remote.Option) (map[string]string, error) { + lookupCount++ + return map[string]string{ + AnnotationProtocolVersionKey: "2", + }, nil + } + + config := &agentConfig{ + Plugins: map[string]*agentPlugin{ + "implicit-oci": { + Source: "ghcr.io/implicit:v1", + ProtocolVersion: DefaultProtocolVersion, + protocolSet: false, + }, + "explicit-v1": { + Source: "ghcr.io/explicit:v1", + ProtocolVersion: DefaultProtocolVersion, + protocolSet: true, + }, + "non-oci": { + Source: "/tmp/plugin", + ProtocolVersion: DefaultProtocolVersion, + protocolSet: false, + }, + }, + } + + runner := NewAgentRunner() + runner.fetchAnnotations = fetchAnnotations + runner.UpdateConfig(config) + runner.resolvePluginProtocols() + + if lookupCount != 1 { + t.Fatalf("Expected one annotation lookup, got %d", lookupCount) + } + + if got := config.Plugins["implicit-oci"].ProtocolVersion; got != RunnerV2ProtocolVersion { + t.Fatalf("Expected implicit-oci protocol version to be %d, got %d", RunnerV2ProtocolVersion, got) + } + + if got := config.Plugins["explicit-v1"].ProtocolVersion; got != DefaultProtocolVersion { + t.Fatalf("Expected explicit-v1 protocol version to remain %d, got %d", DefaultProtocolVersion, got) + } + + if got := config.Plugins["non-oci"].ProtocolVersion; got != DefaultProtocolVersion { + t.Fatalf("Expected non-oci protocol version to remain %d, got %d", DefaultProtocolVersion, got) + } +} + +func TestResolvePluginProtocols_KeepsDefaultWhenLookupFails(t *testing.T) { + fetchAnnotations := func(source string, option ...remote.Option) (map[string]string, error) { + return nil, errors.New("lookup failed") + } + + config := &agentConfig{ + Plugins: map[string]*agentPlugin{ + "implicit-oci": { + Source: "ghcr.io/implicit:v1", + ProtocolVersion: DefaultProtocolVersion, + protocolSet: false, + }, + }, + } + + runner := NewAgentRunner() + runner.fetchAnnotations = fetchAnnotations + runner.UpdateConfig(config) + runner.resolvePluginProtocols() + + if got := config.Plugins["implicit-oci"].ProtocolVersion; got != DefaultProtocolVersion { + t.Fatalf("Expected implicit-oci protocol version to remain %d, got %d", DefaultProtocolVersion, got) + } +} + +func TestProtocolVersionFromAnnotations(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expected int32 + ok bool + }{ + { + name: "Uses OCI annotation key", + annotations: map[string]string{ + AnnotationProtocolVersionKey: "2", + }, + expected: 2, + ok: true, + }, + { + name: "Rejects unsupported values", + annotations: map[string]string{ + AnnotationProtocolVersionKey: "100", + }, + expected: 0, + ok: false, + }, + { + name: "Rejects invalid values", + annotations: map[string]string{ + AnnotationProtocolVersionKey: "abc", + }, + expected: 0, + ok: false, + }, + { + name: "Rejects non-positive values", + annotations: map[string]string{ + AnnotationProtocolVersionKey: "0", + }, + expected: 0, + ok: false, + }, + { + name: "Missing keys", + annotations: map[string]string{"other": "1"}, + expected: 0, + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := protocolVersionFromAnnotations(tt.annotations) + if got != tt.expected || ok != tt.ok { + t.Fatalf("protocolVersionFromAnnotations() = (%d, %t), expected (%d, %t)", got, ok, tt.expected, tt.ok) + } + }) + } +} + +func TestRunnerDispenseName(t *testing.T) { + tests := []struct { + name string + protocolVersion int32 + expected string + wantErr bool + }{ + { + name: "Uses runner for v1", + protocolVersion: DefaultProtocolVersion, + expected: "runner", + wantErr: false, + }, + { + name: "Uses runner-v2 for v2", + protocolVersion: RunnerV2ProtocolVersion, + expected: "runner-v2", + wantErr: false, + }, + { + name: "Rejects unsupported protocol version", + protocolVersion: 3, + expected: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := runnerDispenseName(tt.protocolVersion) + if (err != nil) != tt.wantErr { + t.Fatalf("runnerDispenseName() error = %v, wantErr %t", err, tt.wantErr) + } + + if got != tt.expected { + t.Fatalf("runnerDispenseName() = %q, expected %q", got, tt.expected) + } + }) + } +} diff --git a/docs/adr/0001-support-versioned-runner-plugins.md b/docs/adr/0001-support-versioned-runner-plugins.md new file mode 100644 index 0000000..9db59eb --- /dev/null +++ b/docs/adr/0001-support-versioned-runner-plugins.md @@ -0,0 +1,38 @@ +# ADR 0001: Support versioned runner plugins + +- Date: 2026-03-10 + +## Context + +The agent currently assumes every plugin speaks a single runner protocol and can be started through the `runner` dispense name and evaluated immediately. + +The agent now supports a second plugin contract that requires an `Init` step before `Eval`. At the same time, it remains compatible with existing plugins and avoids forcing every deployment to update configuration when an OCI-published plugin can already advertise its protocol version. + +## Decision + +The agent supports explicit runner protocol versions per plugin and retains backward compatibility by defaulting to protocol version 1. + +This is implemented by: + +- adding `protocol_version` to plugin configuration +- defaulting unspecified plugins to protocol version 1 +- reading `org.ccf.plugin.protocol.version` from OCI annotations for OCI plugin sources without an explicit `protocol_version` +- supporting only protocol versions 1 and 2 +- mapping protocol version 1 to the `runner` dispense name and protocol version 2 to `runner-v2` +- calling `Init` before `Eval` for protocol version 2 plugins +- treating explicit configuration as authoritative over OCI metadata + +## Consequences + +### Positive + +- Existing plugins continue to work without configuration changes. +- New plugins can adopt protocol version 2 and perform setup during `Init`. +- OCI-published plugins can self-describe their protocol version, reducing configuration drift. +- Unsupported or invalid annotations do not break execution; the agent logs and falls back to the configured or default version. + +### Negative + +- OCI-backed plugins may require an extra registry metadata lookup before execution. +- The agent now maintains two supported runner contracts instead of one. +- Plugin authors adopting protocol version 2 must implement `Init`. diff --git a/internal/oci.go b/internal/oci.go index ad3f674..0eeab33 100644 --- a/internal/oci.go +++ b/internal/oci.go @@ -2,13 +2,16 @@ package internal import ( "context" + "encoding/json" "errors" + "os" + "path" + "github.com/compliance-framework/gooci/pkg/oci" + "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/hashicorp/go-hclog" - "os" - "path" ) func IsOCI(source string) bool { @@ -17,6 +20,46 @@ func IsOCI(source string) bool { return err == nil } +func GetAnnotations(source string, option ...remote.Option) (map[string]string, error) { + ref, err := name.ParseReference(source) + if err != nil { + return nil, err + } + + opts := append([]remote.Option{ + remote.WithAuthFromKeychain(authn.DefaultKeychain), + }, option...) + + desc, err := remote.Get(ref, opts...) + if err != nil { + return nil, err + } + + return annotationsFromDescriptor(desc), nil +} + +func annotationsFromDescriptor(desc *remote.Descriptor) map[string]string { + if desc == nil { + return map[string]string{} + } + + if len(desc.Manifest) > 0 { + var payload struct { + Annotations map[string]string `json:"annotations"` + } + + if err := json.Unmarshal(desc.Manifest, &payload); err == nil && len(payload.Annotations) > 0 { + return payload.Annotations + } + } + + if len(desc.Annotations) > 0 { + return desc.Annotations + } + + return map[string]string{} +} + func Download(ctx context.Context, source string, outputDir string, binaryPath string, logger hclog.Logger, option ...remote.Option) (string, error) { // Add a task to indicate we've downloaded the items logger.Trace("Checking for source", "source", source) diff --git a/internal/oci_manifest_test.go b/internal/oci_manifest_test.go new file mode 100644 index 0000000..cbc98b3 --- /dev/null +++ b/internal/oci_manifest_test.go @@ -0,0 +1,67 @@ +package internal + +import ( + "reflect" + "testing" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" +) + +func TestAnnotationsFromDescriptor(t *testing.T) { + tests := []struct { + name string + desc *remote.Descriptor + expected map[string]string + }{ + { + name: "Nil descriptor", + desc: nil, + expected: map[string]string{}, + }, + { + name: "Invalid JSON falls back to descriptor annotations", + desc: &remote.Descriptor{ + Manifest: []byte("not-json"), + Descriptor: v1.Descriptor{ + Annotations: map[string]string{"from": "descriptor"}, + }, + }, + expected: map[string]string{"from": "descriptor"}, + }, + { + name: "Manifest without annotations falls back to descriptor annotations", + desc: &remote.Descriptor{ + Manifest: []byte(`{"schemaVersion":2}`), + Descriptor: v1.Descriptor{ + Annotations: map[string]string{"from": "descriptor"}, + }, + }, + expected: map[string]string{"from": "descriptor"}, + }, + { + name: "Uses manifest annotations when present", + desc: &remote.Descriptor{ + Manifest: []byte(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.index.v1+json","manifests":[],"annotations":{"org.opencontainers.image.created":"2026-02-27T10:57:27Z","org.opencontainers.image.title":"plugin-test","org.opencontainers.image.version":"v0.1.0","org.ccf.plugin.protocol.version":"2"}}`), + Descriptor: v1.Descriptor{ + Annotations: map[string]string{"from": "descriptor"}, + }, + }, + expected: map[string]string{ + "org.opencontainers.image.created": "2026-02-27T10:57:27Z", + "org.opencontainers.image.title": "plugin-test", + "org.opencontainers.image.version": "v0.1.0", + "org.ccf.plugin.protocol.version": "2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := annotationsFromDescriptor(tt.desc) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("annotationsFromDescriptor() = %v, expected %v", got, tt.expected) + } + }) + } +} diff --git a/runner/grpc.go b/runner/grpc.go index 649999e..55f23f9 100644 --- a/runner/grpc.go +++ b/runner/grpc.go @@ -2,10 +2,13 @@ package runner import ( "context" + "github.com/compliance-framework/agent/runner/proto" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type ApiHelper interface { @@ -43,25 +46,33 @@ type GRPCClient struct { broker *plugin.GRPCBroker } -func (m *GRPCClient) Configure(request *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { - return m.client.Configure(context.Background(), request) -} - -func (m *GRPCClient) Eval(request *proto.EvalRequest, a ApiHelper) (*proto.EvalResponse, error) { +func (m *GRPCClient) startAPIServer(a ApiHelper) uint32 { apiHelperServer := &GRPCApiHelperServer{Impl: a} - var s *grpc.Server serverFunc := func(opts []grpc.ServerOption) *grpc.Server { - s = grpc.NewServer(opts...) + s := grpc.NewServer(opts...) proto.RegisterApiHelperServer(s, apiHelperServer) - return s } brokerID := m.broker.NextId() go m.broker.AcceptAndServe(brokerID, serverFunc) - request.ApiServer = brokerID + return brokerID +} + +func (m *GRPCClient) Configure(request *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { + return m.client.Configure(context.Background(), request) +} + +func (m *GRPCClient) Init(request *proto.InitRequest, a ApiHelper) (*proto.InitResponse, error) { + request.ApiServer = m.startAPIServer(a) + resp, err := m.client.Init(context.Background(), request) + return resp, err +} + +func (m *GRPCClient) Eval(request *proto.EvalRequest, a ApiHelper) (*proto.EvalResponse, error) { + request.ApiServer = m.startAPIServer(a) resp, err := m.client.Eval(context.Background(), request) return resp, err } @@ -75,6 +86,22 @@ func (m *GRPCServer) Configure(ctx context.Context, req *proto.ConfigureRequest) return m.Impl.Configure(req) } +func (m *GRPCServer) Init(ctx context.Context, req *proto.InitRequest) (*proto.InitResponse, error) { + runnerV2, ok := m.Impl.(RunnerV2) + if !ok { + return nil, status.Error(codes.Unimplemented, "Init is only supported for protocol v2 plugins") + } + + conn, err := m.broker.Dial(req.ApiServer) + if err != nil { + return nil, err + } + defer conn.Close() + + a := &GRPCApiHelperClient{proto.NewApiHelperClient(conn)} + return runnerV2.Init(req, a) +} + func (m *GRPCServer) Eval(ctx context.Context, req *proto.EvalRequest) (*proto.EvalResponse, error) { conn, err := m.broker.Dial(req.ApiServer) if err != nil { diff --git a/runner/grpc_test.go b/runner/grpc_test.go new file mode 100644 index 0000000..851bb6f --- /dev/null +++ b/runner/grpc_test.go @@ -0,0 +1,33 @@ +package runner + +import ( + "context" + "testing" + + "github.com/compliance-framework/agent/runner/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type testRunnerV1 struct{} + +func (t *testRunnerV1) Configure(request *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { + return &proto.ConfigureResponse{}, nil +} + +func (t *testRunnerV1) Eval(request *proto.EvalRequest, a ApiHelper) (*proto.EvalResponse, error) { + return &proto.EvalResponse{}, nil +} + +func TestGRPCServerInitReturnsUnimplementedForRunnerV1(t *testing.T) { + server := &GRPCServer{Impl: &testRunnerV1{}} + + _, err := server.Init(context.Background(), &proto.InitRequest{}) + if err == nil { + t.Fatalf("expected error, got nil") + } + + if status.Code(err) != codes.Unimplemented { + t.Fatalf("expected code %v, got %v", codes.Unimplemented, status.Code(err)) + } +} diff --git a/runner/plugin.go b/runner/plugin.go index 102ddc5..3de745f 100644 --- a/runner/plugin.go +++ b/runner/plugin.go @@ -13,6 +13,11 @@ type Runner interface { Eval(request *proto.EvalRequest, a ApiHelper) (*proto.EvalResponse, error) } +type RunnerV2 interface { + Runner + Init(request *proto.InitRequest, a ApiHelper) (*proto.InitResponse, error) +} + type RunnerGRPCPlugin struct { plugin.Plugin @@ -42,5 +47,6 @@ var HandshakeConfig = plugin.HandshakeConfig{ } var PluginMap = map[string]plugin.Plugin{ - "runner": &RunnerGRPCPlugin{}, + "runner": &RunnerGRPCPlugin{}, + "runner-v2": &RunnerGRPCPlugin{}, } diff --git a/runner/proto/results.pb.go b/runner/proto/results.pb.go index d151dd5..502d6b0 100644 --- a/runner/proto/results.pb.go +++ b/runner/proto/results.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: runner/proto/results.proto diff --git a/runner/proto/runner.pb.go b/runner/proto/runner.pb.go index 6b2a93f..bb34207 100644 --- a/runner/proto/runner.pb.go +++ b/runner/proto/runner.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: runner/proto/runner.proto @@ -155,6 +155,94 @@ func (x *ConfigureResponse) GetValue() []byte { return nil } +type InitRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` + ApiServer uint32 `protobuf:"varint,2,opt,name=apiServer,proto3" json:"apiServer,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitRequest) Reset() { + *x = InitRequest{} + mi := &file_runner_proto_runner_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitRequest) ProtoMessage() {} + +func (x *InitRequest) ProtoReflect() protoreflect.Message { + mi := &file_runner_proto_runner_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitRequest.ProtoReflect.Descriptor instead. +func (*InitRequest) Descriptor() ([]byte, []int) { + return file_runner_proto_runner_proto_rawDescGZIP(), []int{2} +} + +func (x *InitRequest) GetPolicyPaths() []string { + if x != nil { + return x.PolicyPaths + } + return nil +} + +func (x *InitRequest) GetApiServer() uint32 { + if x != nil { + return x.ApiServer + } + return 0 +} + +type InitResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitResponse) Reset() { + *x = InitResponse{} + mi := &file_runner_proto_runner_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitResponse) ProtoMessage() {} + +func (x *InitResponse) ProtoReflect() protoreflect.Message { + mi := &file_runner_proto_runner_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitResponse.ProtoReflect.Descriptor instead. +func (*InitResponse) Descriptor() ([]byte, []int) { + return file_runner_proto_runner_proto_rawDescGZIP(), []int{3} +} + type EvalRequest struct { state protoimpl.MessageState `protogen:"open.v1"` PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` @@ -165,7 +253,7 @@ type EvalRequest struct { func (x *EvalRequest) Reset() { *x = EvalRequest{} - mi := &file_runner_proto_runner_proto_msgTypes[2] + mi := &file_runner_proto_runner_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -177,7 +265,7 @@ func (x *EvalRequest) String() string { func (*EvalRequest) ProtoMessage() {} func (x *EvalRequest) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[2] + mi := &file_runner_proto_runner_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -190,7 +278,7 @@ func (x *EvalRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use EvalRequest.ProtoReflect.Descriptor instead. func (*EvalRequest) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{2} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{4} } func (x *EvalRequest) GetPolicyPaths() []string { @@ -220,7 +308,7 @@ type EvalResponse struct { func (x *EvalResponse) Reset() { *x = EvalResponse{} - mi := &file_runner_proto_runner_proto_msgTypes[3] + mi := &file_runner_proto_runner_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -232,7 +320,7 @@ func (x *EvalResponse) String() string { func (*EvalResponse) ProtoMessage() {} func (x *EvalResponse) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[3] + mi := &file_runner_proto_runner_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -245,7 +333,7 @@ func (x *EvalResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use EvalResponse.ProtoReflect.Descriptor instead. func (*EvalResponse) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{3} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{5} } func (x *EvalResponse) GetStatus() ExecutionStatus { @@ -267,6 +355,10 @@ const file_runner_proto_runner_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\")\n" + "\x11ConfigureResponse\x12\x14\n" + "\x05value\x18\x01 \x01(\fR\x05value\"M\n" + + "\vInitRequest\x12 \n" + + "\vpolicyPaths\x18\x01 \x03(\tR\vpolicyPaths\x12\x1c\n" + + "\tapiServer\x18\x02 \x01(\rR\tapiServer\"\x0e\n" + + "\fInitResponse\"M\n" + "\vEvalRequest\x12 \n" + "\vpolicyPaths\x18\x01 \x03(\tR\vpolicyPaths\x12\x1c\n" + "\tapiServer\x18\x02 \x01(\rR\tapiServer\">\n" + @@ -274,10 +366,11 @@ const file_runner_proto_runner_proto_rawDesc = "" + "\x06status\x18\x01 \x01(\x0e2\x16.proto.ExecutionStatusR\x06status*+\n" + "\x0fExecutionStatus\x12\v\n" + "\aSUCCESS\x10\x00\x12\v\n" + - "\aFAILURE\x10\x012y\n" + + "\aFAILURE\x10\x012\xaa\x01\n" + "\x06Runner\x12>\n" + "\tConfigure\x12\x17.proto.ConfigureRequest\x1a\x18.proto.ConfigureResponse\x12/\n" + - "\x04Eval\x12\x12.proto.EvalRequest\x1a\x13.proto.EvalResponseB\tZ\a./protob\x06proto3" + "\x04Eval\x12\x12.proto.EvalRequest\x1a\x13.proto.EvalResponse\x12/\n" + + "\x04Init\x12\x12.proto.InitRequest\x1a\x13.proto.InitResponseB\tZ\a./protob\x06proto3" var ( file_runner_proto_runner_proto_rawDescOnce sync.Once @@ -292,24 +385,28 @@ func file_runner_proto_runner_proto_rawDescGZIP() []byte { } var file_runner_proto_runner_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_runner_proto_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_runner_proto_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_runner_proto_runner_proto_goTypes = []any{ (ExecutionStatus)(0), // 0: proto.ExecutionStatus (*ConfigureRequest)(nil), // 1: proto.ConfigureRequest (*ConfigureResponse)(nil), // 2: proto.ConfigureResponse - (*EvalRequest)(nil), // 3: proto.EvalRequest - (*EvalResponse)(nil), // 4: proto.EvalResponse - nil, // 5: proto.ConfigureRequest.ConfigEntry + (*InitRequest)(nil), // 3: proto.InitRequest + (*InitResponse)(nil), // 4: proto.InitResponse + (*EvalRequest)(nil), // 5: proto.EvalRequest + (*EvalResponse)(nil), // 6: proto.EvalResponse + nil, // 7: proto.ConfigureRequest.ConfigEntry } var file_runner_proto_runner_proto_depIdxs = []int32{ - 5, // 0: proto.ConfigureRequest.config:type_name -> proto.ConfigureRequest.ConfigEntry + 7, // 0: proto.ConfigureRequest.config:type_name -> proto.ConfigureRequest.ConfigEntry 0, // 1: proto.EvalResponse.status:type_name -> proto.ExecutionStatus 1, // 2: proto.Runner.Configure:input_type -> proto.ConfigureRequest - 3, // 3: proto.Runner.Eval:input_type -> proto.EvalRequest - 2, // 4: proto.Runner.Configure:output_type -> proto.ConfigureResponse - 4, // 5: proto.Runner.Eval:output_type -> proto.EvalResponse - 4, // [4:6] is the sub-list for method output_type - 2, // [2:4] is the sub-list for method input_type + 5, // 3: proto.Runner.Eval:input_type -> proto.EvalRequest + 3, // 4: proto.Runner.Init:input_type -> proto.InitRequest + 2, // 5: proto.Runner.Configure:output_type -> proto.ConfigureResponse + 6, // 6: proto.Runner.Eval:output_type -> proto.EvalResponse + 4, // 7: proto.Runner.Init:output_type -> proto.InitResponse + 5, // [5:8] is the sub-list for method output_type + 2, // [2:5] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name 2, // [2:2] is the sub-list for extension extendee 0, // [0:2] is the sub-list for field type_name @@ -326,7 +423,7 @@ func file_runner_proto_runner_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_runner_proto_runner_proto_rawDesc), len(file_runner_proto_runner_proto_rawDesc)), NumEnums: 1, - NumMessages: 5, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, diff --git a/runner/proto/runner.proto b/runner/proto/runner.proto index 9def54f..9d25118 100644 --- a/runner/proto/runner.proto +++ b/runner/proto/runner.proto @@ -16,6 +16,14 @@ message ConfigureResponse { bytes value = 1; } +message InitRequest { + repeated string policyPaths = 1; + uint32 apiServer = 2; +} + +message InitResponse { +} + message EvalRequest { repeated string policyPaths = 1; uint32 apiServer = 2; @@ -33,4 +41,5 @@ message EvalResponse { service Runner { rpc Configure(ConfigureRequest) returns (ConfigureResponse); rpc Eval(EvalRequest) returns (EvalResponse); + rpc Init(InitRequest) returns (InitResponse); } diff --git a/runner/proto/runner_grpc.pb.go b/runner/proto/runner_grpc.pb.go index efc9a51..116783c 100644 --- a/runner/proto/runner_grpc.pb.go +++ b/runner/proto/runner_grpc.pb.go @@ -21,6 +21,7 @@ const _ = grpc.SupportPackageIsVersion7 const ( Runner_Configure_FullMethodName = "/proto.Runner/Configure" Runner_Eval_FullMethodName = "/proto.Runner/Eval" + Runner_Init_FullMethodName = "/proto.Runner/Init" ) // RunnerClient is the client API for Runner service. @@ -29,6 +30,7 @@ const ( type RunnerClient interface { Configure(ctx context.Context, in *ConfigureRequest, opts ...grpc.CallOption) (*ConfigureResponse, error) Eval(ctx context.Context, in *EvalRequest, opts ...grpc.CallOption) (*EvalResponse, error) + Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) } type runnerClient struct { @@ -57,12 +59,22 @@ func (c *runnerClient) Eval(ctx context.Context, in *EvalRequest, opts ...grpc.C return out, nil } +func (c *runnerClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) { + out := new(InitResponse) + err := c.cc.Invoke(ctx, Runner_Init_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // RunnerServer is the server API for Runner service. // All implementations should embed UnimplementedRunnerServer // for forward compatibility type RunnerServer interface { Configure(context.Context, *ConfigureRequest) (*ConfigureResponse, error) Eval(context.Context, *EvalRequest) (*EvalResponse, error) + Init(context.Context, *InitRequest) (*InitResponse, error) } // UnimplementedRunnerServer should be embedded to have forward compatible implementations. @@ -75,6 +87,9 @@ func (UnimplementedRunnerServer) Configure(context.Context, *ConfigureRequest) ( func (UnimplementedRunnerServer) Eval(context.Context, *EvalRequest) (*EvalResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Eval not implemented") } +func (UnimplementedRunnerServer) Init(context.Context, *InitRequest) (*InitResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Init not implemented") +} // UnsafeRunnerServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to RunnerServer will @@ -123,6 +138,24 @@ func _Runner_Eval_Handler(srv interface{}, ctx context.Context, dec func(interfa return interceptor(ctx, in, info, handler) } +func _Runner_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(RunnerServer).Init(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Runner_Init_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(RunnerServer).Init(ctx, req.(*InitRequest)) + } + return interceptor(ctx, in, info, handler) +} + // Runner_ServiceDesc is the grpc.ServiceDesc for Runner service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -138,6 +171,10 @@ var Runner_ServiceDesc = grpc.ServiceDesc{ MethodName: "Eval", Handler: _Runner_Eval_Handler, }, + { + MethodName: "Init", + Handler: _Runner_Init_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "runner/proto/runner.proto", diff --git a/runner/proto/types.pb.go b/runner/proto/types.pb.go index 3f34340..0d2a045 100644 --- a/runner/proto/types.pb.go +++ b/runner/proto/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: runner/proto/types.proto From efd713e5d7ad720ff0e97d0575a4e63b41fb8894 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 12:47:45 +0000 Subject: [PATCH 2/8] fix: copilot issues --- Makefile | 1 + cmd/agent.go | 20 +++++++++++++++++++- cmd/agent_test.go | 37 +++++++++++++++++++++++++++++++++++++ runner/grpc.go | 6 +++++- runner/grpc_test.go | 12 ++++++++++++ runner/plugin.go | 26 +++++++++++++++++++++++++- 6 files changed, 99 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 341d986..0459bbb 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,7 @@ test: ## Run tests build: ## Build the project + @mkdir -p dist @go build -o dist/concom main.go run: ## Run the project diff --git a/cmd/agent.go b/cmd/agent.go index c2a9daa..08d70ef 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -81,6 +81,20 @@ func (ac *agentConfig) validate() error { return fmt.Errorf("no api config specified in config") } + for name, pluginConfig := range ac.Plugins { + if pluginConfig == nil { + continue + } + + if pluginConfig.ProtocolVersion == 0 { + pluginConfig.ProtocolVersion = DefaultProtocolVersion + } + + if !isSupportedProtocolVersion(pluginConfig.ProtocolVersion) { + return fmt.Errorf("plugin %s has unsupported protocol_version=%d; supported values are %d and %d", name, pluginConfig.ProtocolVersion, DefaultProtocolVersion, RunnerV2ProtocolVersion) + } + } + return nil } @@ -181,6 +195,10 @@ func updateAllPluginProtocols(agentConfig *agentConfig) { } } +func isSupportedProtocolVersion(protocolVersion int32) bool { + return protocolVersion == DefaultProtocolVersion || protocolVersion == RunnerV2ProtocolVersion +} + func protocolVersionFromAnnotations(annotations map[string]string) (int32, bool) { value, ok := annotations[AnnotationProtocolVersionKey] if !ok { @@ -196,7 +214,7 @@ func protocolVersionFromAnnotations(annotations map[string]string) (int32, bool) return 0, false } - if parsed != int64(DefaultProtocolVersion) && parsed != int64(RunnerV2ProtocolVersion) { + if !isSupportedProtocolVersion(int32(parsed)) { return 0, false } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index 15456f4..a0bee1e 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -42,6 +42,19 @@ plugins: configYamlContent: ` api: url: http://localhost:8080 +`, + valid: false, + }, + { + name: "Unsupported Explicit Protocol Version", + configYamlContent: ` +api: + url: http://localhost:8080 + +plugins: + test-plugin: + source: ghcr.io/some-plugin:v1 + protocol_version: 100 `, valid: false, }, @@ -167,6 +180,30 @@ func TestUpdateAllPluginProtocols_DefaultsOnlyUnset(t *testing.T) { } } +func TestMergeConfig_RejectsUnsupportedExplicitProtocolVersion(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBufferString("api:\n url: http://localhost:8080\n\nplugins:\n plugin-with-invalid-version:\n source: ghcr.io/some-plugin:v1\n protocol_version: 100\n")) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + config, err := mergeConfig(AgentCmd(), v) + if err != nil { + t.Fatalf("Error merging config: %v", err) + } + + err = config.validate() + if err == nil { + t.Fatalf("Expected config validation to fail for unsupported protocol version") + } + + expected := "plugin plugin-with-invalid-version has unsupported protocol_version=100; supported values are 1 and 2" + if err.Error() != expected { + t.Fatalf("Expected error %q, got %q", expected, err.Error()) + } +} + func TestMergeConfig_DoesNotFetchAnnotations(t *testing.T) { v := viper.New() v.SetConfigType("yaml") diff --git a/runner/grpc.go b/runner/grpc.go index 55f23f9..09445a2 100644 --- a/runner/grpc.go +++ b/runner/grpc.go @@ -46,6 +46,10 @@ type GRPCClient struct { broker *plugin.GRPCBroker } +type GRPCClientV2 struct { + *GRPCClient +} + func (m *GRPCClient) startAPIServer(a ApiHelper) uint32 { apiHelperServer := &GRPCApiHelperServer{Impl: a} @@ -65,7 +69,7 @@ func (m *GRPCClient) Configure(request *proto.ConfigureRequest) (*proto.Configur return m.client.Configure(context.Background(), request) } -func (m *GRPCClient) Init(request *proto.InitRequest, a ApiHelper) (*proto.InitResponse, error) { +func (m *GRPCClientV2) Init(request *proto.InitRequest, a ApiHelper) (*proto.InitResponse, error) { request.ApiServer = m.startAPIServer(a) resp, err := m.client.Init(context.Background(), request) return resp, err diff --git a/runner/grpc_test.go b/runner/grpc_test.go index 851bb6f..5d5e2ce 100644 --- a/runner/grpc_test.go +++ b/runner/grpc_test.go @@ -31,3 +31,15 @@ func TestGRPCServerInitReturnsUnimplementedForRunnerV1(t *testing.T) { t.Fatalf("expected code %v, got %v", codes.Unimplemented, status.Code(err)) } } + +func TestGRPCClientCapabilitiesMatchProtocolVersion(t *testing.T) { + v1Client := &GRPCClient{} + if _, ok := interface{}(v1Client).(RunnerV2); ok { + t.Fatalf("expected v1 gRPC client to not implement RunnerV2") + } + + v2Client := &GRPCClientV2{GRPCClient: &GRPCClient{}} + if _, ok := interface{}(v2Client).(RunnerV2); !ok { + t.Fatalf("expected v2 gRPC client to implement RunnerV2") + } +} diff --git a/runner/plugin.go b/runner/plugin.go index 3de745f..4448475 100644 --- a/runner/plugin.go +++ b/runner/plugin.go @@ -25,6 +25,13 @@ type RunnerGRPCPlugin struct { Impl Runner } +type RunnerV2GRPCPlugin struct { + plugin.Plugin + + // Impl Injection + Impl Runner +} + func (p *RunnerGRPCPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { proto.RegisterRunnerServer(s, &GRPCServer{ Impl: p.Impl, @@ -40,6 +47,23 @@ func (p *RunnerGRPCPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBr }, nil } +func (p *RunnerV2GRPCPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { + proto.RegisterRunnerServer(s, &GRPCServer{ + Impl: p.Impl, + broker: broker, + }) + return nil +} + +func (p *RunnerV2GRPCPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { + return &GRPCClientV2{ + GRPCClient: &GRPCClient{ + client: proto.NewRunnerClient(c), + broker: broker, + }, + }, nil +} + var HandshakeConfig = plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "RUNNER_PLUGIN", @@ -48,5 +72,5 @@ var HandshakeConfig = plugin.HandshakeConfig{ var PluginMap = map[string]plugin.Plugin{ "runner": &RunnerGRPCPlugin{}, - "runner-v2": &RunnerGRPCPlugin{}, + "runner-v2": &RunnerV2GRPCPlugin{}, } From 0faf4344c4bad45585ff76db1fa9e640bd0b3114 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 13:02:59 +0000 Subject: [PATCH 3/8] fix: copilot issues --- cmd/agent.go | 18 +++++++++++-- cmd/agent_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++ runner/plugin.go | 2 +- 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index 08d70ef..b45c7a6 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -83,10 +83,14 @@ func (ac *agentConfig) validate() error { for name, pluginConfig := range ac.Plugins { if pluginConfig == nil { - continue + return fmt.Errorf("plugin %s has null configuration", name) } if pluginConfig.ProtocolVersion == 0 { + if pluginConfig.protocolSet { + return fmt.Errorf("plugin %s has unsupported protocol_version=%d; supported values are %d and %d", name, pluginConfig.ProtocolVersion, DefaultProtocolVersion, RunnerV2ProtocolVersion) + } + pluginConfig.ProtocolVersion = DefaultProtocolVersion } @@ -174,6 +178,16 @@ func markExplicitPluginProtocols(fileConfig *viper.Viper, config *agentConfig) { rawPlugins := fileConfig.GetStringMap("plugins") for name, rawPlugin := range rawPlugins { pluginConfig, ok := config.Plugins[name] + if rawPlugin == nil { + if config.Plugins == nil { + config.Plugins = map[string]*agentPlugin{} + } + if !ok { + config.Plugins[name] = nil + } + continue + } + if !ok || pluginConfig == nil { continue } @@ -189,7 +203,7 @@ func markExplicitPluginProtocols(fileConfig *viper.Viper, config *agentConfig) { func updateAllPluginProtocols(agentConfig *agentConfig) { for _, pluginConfig := range agentConfig.Plugins { - if pluginConfig.ProtocolVersion == 0 { + if pluginConfig != nil && !pluginConfig.protocolSet && pluginConfig.ProtocolVersion == 0 { pluginConfig.ProtocolVersion = DefaultProtocolVersion } } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index a0bee1e..ed3803a 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -55,6 +55,17 @@ plugins: test-plugin: source: ghcr.io/some-plugin:v1 protocol_version: 100 +`, + valid: false, + }, + { + name: "Null Plugin Configuration", + configYamlContent: ` +api: + url: http://localhost:8080 + +plugins: + test-plugin: null `, valid: false, }, @@ -165,6 +176,12 @@ func TestUpdateAllPluginProtocols_DefaultsOnlyUnset(t *testing.T) { "explicit": { Source: "ghcr.io/explicit:v2", ProtocolVersion: 2, + protocolSet: true, + }, + "explicit-zero": { + Source: "ghcr.io/explicit-zero:v1", + ProtocolVersion: 0, + protocolSet: true, }, }, } @@ -178,6 +195,10 @@ func TestUpdateAllPluginProtocols_DefaultsOnlyUnset(t *testing.T) { if got := config.Plugins["explicit"].ProtocolVersion; got != 2 { t.Fatalf("Expected explicit plugin protocol version to remain 2, got %d", got) } + + if got := config.Plugins["explicit-zero"].ProtocolVersion; got != 0 { + t.Fatalf("Expected explicit-zero plugin protocol version to remain 0, got %d", got) + } } func TestMergeConfig_RejectsUnsupportedExplicitProtocolVersion(t *testing.T) { @@ -204,6 +225,54 @@ func TestMergeConfig_RejectsUnsupportedExplicitProtocolVersion(t *testing.T) { } } +func TestMergeConfig_RejectsExplicitZeroProtocolVersion(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBufferString("api:\n url: http://localhost:8080\n\nplugins:\n plugin-with-zero-version:\n source: ghcr.io/some-plugin:v1\n protocol_version: 0\n")) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + config, err := mergeConfig(AgentCmd(), v) + if err != nil { + t.Fatalf("Error merging config: %v", err) + } + + err = config.validate() + if err == nil { + t.Fatalf("Expected config validation to fail for explicit zero protocol version") + } + + expected := "plugin plugin-with-zero-version has unsupported protocol_version=0; supported values are 1 and 2" + if err.Error() != expected { + t.Fatalf("Expected error %q, got %q", expected, err.Error()) + } +} + +func TestMergeConfig_RejectsNullPluginConfiguration(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBufferString("api:\n url: http://localhost:8080\n\nplugins:\n null-plugin: null\n")) + if err != nil { + t.Fatalf("Error reading config: %v", err) + } + + config, err := mergeConfig(AgentCmd(), v) + if err != nil { + t.Fatalf("Error merging config: %v", err) + } + + err = config.validate() + if err == nil { + t.Fatalf("Expected config validation to fail for null plugin configuration") + } + + expected := "plugin null-plugin has null configuration" + if err.Error() != expected { + t.Fatalf("Expected error %q, got %q", expected, err.Error()) + } +} + func TestMergeConfig_DoesNotFetchAnnotations(t *testing.T) { v := viper.New() v.SetConfigType("yaml") diff --git a/runner/plugin.go b/runner/plugin.go index 4448475..d49ecb9 100644 --- a/runner/plugin.go +++ b/runner/plugin.go @@ -29,7 +29,7 @@ type RunnerV2GRPCPlugin struct { plugin.Plugin // Impl Injection - Impl Runner + Impl RunnerV2 } func (p *RunnerGRPCPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { From 6816a5e6f94677d2cd51fcdb24d6a751398a1dc2 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 13:18:52 +0000 Subject: [PATCH 4/8] fix: copilot issues --- cmd/agent.go | 20 +++++++++++++------- cmd/agent_test.go | 16 ++++++++++++---- internal/oci.go | 5 +++-- internal/utils_test.go | 5 +++++ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index b45c7a6..c6cc592 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -337,7 +337,7 @@ type AgentRunner struct { pluginLocations map[string]string policyLocations map[string]string - fetchAnnotations func(source string, option ...remote.Option) (map[string]string, error) + fetchAnnotations func(ctx context.Context, source string, option ...remote.Option) (map[string]string, error) queryBundles []*rego.Rego } @@ -369,7 +369,7 @@ func (ar *AgentRunner) Run(ctx context.Context) error { return err } - ar.resolvePluginProtocols() + ar.resolvePluginProtocols(ctx) err = ar.DownloadPolicies(ctx) if err != nil { @@ -386,13 +386,19 @@ func (ar *AgentRunner) Run(ctx context.Context) error { return ar.runAllPlugins(ctx) } -func (ar *AgentRunner) resolvePluginProtocols() { +func (ar *AgentRunner) resolvePluginProtocols(ctx context.Context) { + if ctx == nil { + ctx = context.Background() + } + for pluginName, pluginConfig := range ar.config.Plugins { if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCI(pluginConfig.Source) { continue } - annotations, err := ar.fetchAnnotations(pluginConfig.Source) + annotationCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + annotations, err := ar.fetchAnnotations(annotationCtx, pluginConfig.Source) + cancel() if err != nil { ar.logger.Warn("Failed to fetch plugin annotations, using configured/default protocol version", "plugin", pluginName, "source", pluginConfig.Source, "protocol_version", pluginConfig.ProtocolVersion, "error", err) continue @@ -643,13 +649,13 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent OS: runtime.GOOS, })) - ar.logger.Info("Running plugin", "source", plugin.Source, "protocol_version", plugin.ProtocolVersion) - ar.logger.Info("Running plugin", "source", pluginExecutable, "protocol_version", plugin.ProtocolVersion) - if err != nil { return err } + ar.logger.Info("Running plugin", "source", plugin.Source, "protocol_version", plugin.ProtocolVersion) + ar.logger.Info("Running plugin", "source", pluginExecutable, "protocol_version", plugin.ProtocolVersion) + logger := hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("runner.%s", name), Output: os.Stdout, diff --git a/cmd/agent_test.go b/cmd/agent_test.go index ed3803a..e06d84a 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "context" "errors" "fmt" "testing" @@ -293,8 +294,12 @@ func TestMergeConfig_DoesNotFetchAnnotations(t *testing.T) { func TestResolvePluginProtocols_UsesAnnotationsOnlyForImplicitOCIPlugins(t *testing.T) { lookupCount := 0 - fetchAnnotations := func(source string, option ...remote.Option) (map[string]string, error) { + ctx := context.Background() + fetchAnnotations := func(fetchCtx context.Context, source string, option ...remote.Option) (map[string]string, error) { lookupCount++ + if fetchCtx == nil { + t.Fatalf("expected fetchAnnotations context to be set") + } return map[string]string{ AnnotationProtocolVersionKey: "2", }, nil @@ -323,7 +328,7 @@ func TestResolvePluginProtocols_UsesAnnotationsOnlyForImplicitOCIPlugins(t *test runner := NewAgentRunner() runner.fetchAnnotations = fetchAnnotations runner.UpdateConfig(config) - runner.resolvePluginProtocols() + runner.resolvePluginProtocols(ctx) if lookupCount != 1 { t.Fatalf("Expected one annotation lookup, got %d", lookupCount) @@ -343,7 +348,10 @@ func TestResolvePluginProtocols_UsesAnnotationsOnlyForImplicitOCIPlugins(t *test } func TestResolvePluginProtocols_KeepsDefaultWhenLookupFails(t *testing.T) { - fetchAnnotations := func(source string, option ...remote.Option) (map[string]string, error) { + fetchAnnotations := func(fetchCtx context.Context, source string, option ...remote.Option) (map[string]string, error) { + if fetchCtx == nil { + t.Fatalf("expected fetchAnnotations context to be set") + } return nil, errors.New("lookup failed") } @@ -360,7 +368,7 @@ func TestResolvePluginProtocols_KeepsDefaultWhenLookupFails(t *testing.T) { runner := NewAgentRunner() runner.fetchAnnotations = fetchAnnotations runner.UpdateConfig(config) - runner.resolvePluginProtocols() + runner.resolvePluginProtocols(context.Background()) if got := config.Plugins["implicit-oci"].ProtocolVersion; got != DefaultProtocolVersion { t.Fatalf("Expected implicit-oci protocol version to remain %d, got %d", DefaultProtocolVersion, got) diff --git a/internal/oci.go b/internal/oci.go index 0eeab33..bc16924 100644 --- a/internal/oci.go +++ b/internal/oci.go @@ -16,17 +16,18 @@ import ( func IsOCI(source string) bool { // Check whether this can be parsed as an OCI endpoint - _, err := name.NewTag(source, name.StrictValidation) + _, err := name.ParseReference(source, name.StrictValidation) return err == nil } -func GetAnnotations(source string, option ...remote.Option) (map[string]string, error) { +func GetAnnotations(ctx context.Context, source string, option ...remote.Option) (map[string]string, error) { ref, err := name.ParseReference(source) if err != nil { return nil, err } opts := append([]remote.Option{ + remote.WithContext(ctx), remote.WithAuthFromKeychain(authn.DefaultKeychain), }, option...) diff --git a/internal/utils_test.go b/internal/utils_test.go index eaf3f75..a8236a0 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -86,6 +86,11 @@ func TestIsOci(t *testing.T) { source: "docker.io/library/alpine:latest", expected: true, }, + { + name: "Basic OCI url with digest", + source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", + expected: true, + }, { name: "Tar artifact", source: "docker.io/library/alpine.tar.gz", From 3b6ede3fc03d1ab99613a4afe1eecdd652d7f95b Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 14:16:33 +0000 Subject: [PATCH 5/8] fix: copilot issues --- cmd/agent.go | 2 +- cmd/agent_test.go | 33 +++++++++++++++++++++++++++++++++ internal/oci.go | 8 +++++++- internal/utils_test.go | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 74 insertions(+), 3 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index c6cc592..2f03f41 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -392,7 +392,7 @@ func (ar *AgentRunner) resolvePluginProtocols(ctx context.Context) { } for pluginName, pluginConfig := range ar.config.Plugins { - if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCI(pluginConfig.Source) { + if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCIReference(pluginConfig.Source) { continue } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index e06d84a..9c683c9 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -375,6 +375,39 @@ func TestResolvePluginProtocols_KeepsDefaultWhenLookupFails(t *testing.T) { } } +func TestResolvePluginProtocols_UsesAnnotationsForImplicitOCIDigestPlugins(t *testing.T) { + lookupCount := 0 + fetchAnnotations := func(fetchCtx context.Context, source string, option ...remote.Option) (map[string]string, error) { + lookupCount++ + return map[string]string{ + AnnotationProtocolVersionKey: "2", + }, nil + } + + config := &agentConfig{ + Plugins: map[string]*agentPlugin{ + "implicit-oci-digest": { + Source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", + ProtocolVersion: DefaultProtocolVersion, + protocolSet: false, + }, + }, + } + + runner := NewAgentRunner() + runner.fetchAnnotations = fetchAnnotations + runner.UpdateConfig(config) + runner.resolvePluginProtocols(context.Background()) + + if lookupCount != 1 { + t.Fatalf("Expected one annotation lookup, got %d", lookupCount) + } + + if got := config.Plugins["implicit-oci-digest"].ProtocolVersion; got != RunnerV2ProtocolVersion { + t.Fatalf("Expected implicit-oci-digest protocol version to be %d, got %d", RunnerV2ProtocolVersion, got) + } +} + func TestProtocolVersionFromAnnotations(t *testing.T) { tests := []struct { name string diff --git a/internal/oci.go b/internal/oci.go index bc16924..92afdad 100644 --- a/internal/oci.go +++ b/internal/oci.go @@ -15,13 +15,19 @@ import ( ) func IsOCI(source string) bool { + // Check whether this can be parsed as an OCI tag, which is what our downloader supports. + _, err := name.NewTag(source, name.StrictValidation) + return err == nil +} + +func IsOCIReference(source string) bool { // Check whether this can be parsed as an OCI endpoint _, err := name.ParseReference(source, name.StrictValidation) return err == nil } func GetAnnotations(ctx context.Context, source string, option ...remote.Option) (map[string]string, error) { - ref, err := name.ParseReference(source) + ref, err := name.ParseReference(source, name.StrictValidation) if err != nil { return nil, err } diff --git a/internal/utils_test.go b/internal/utils_test.go index a8236a0..5f5e054 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -89,7 +89,7 @@ func TestIsOci(t *testing.T) { { name: "Basic OCI url with digest", source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", - expected: true, + expected: false, }, { name: "Tar artifact", @@ -121,6 +121,38 @@ func TestIsOci(t *testing.T) { } } +func TestIsOCIReference(t *testing.T) { + tests := []struct { + name string + source string + expected bool + }{ + { + name: "Basic OCI url with version", + source: "docker.io/library/alpine:1.0", + expected: true, + }, + { + name: "Basic OCI url with digest", + source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", + expected: true, + }, + { + name: "Tar artifact", + source: "docker.io/library/alpine.tar.gz", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsOCIReference(tt.source); got != tt.expected { + t.Errorf("IsOCIReference() = %v, want %v", got, tt.expected) + } + }) + } +} + func TestSeededUUID(t *testing.T) { t.Run("SeededUUID generates a consistent ID for the same seed", func(t *testing.T) { seedData := []string{ From ecfda80fadb979e791a78efb3fe4dc42870f120d Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 14:45:39 +0000 Subject: [PATCH 6/8] fix: copilot issues --- cmd/agent.go | 2 +- cmd/agent_test.go | 33 ------------------------------- internal/oci.go | 6 ------ internal/utils_test.go | 32 ------------------------------ runner/grpc.go | 44 ++++++++++++++++++++++++++++++++---------- 5 files changed, 35 insertions(+), 82 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index 2f03f41..c6cc592 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -392,7 +392,7 @@ func (ar *AgentRunner) resolvePluginProtocols(ctx context.Context) { } for pluginName, pluginConfig := range ar.config.Plugins { - if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCIReference(pluginConfig.Source) { + if pluginConfig == nil || pluginConfig.protocolSet || !internal.IsOCI(pluginConfig.Source) { continue } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index 9c683c9..e06d84a 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -375,39 +375,6 @@ func TestResolvePluginProtocols_KeepsDefaultWhenLookupFails(t *testing.T) { } } -func TestResolvePluginProtocols_UsesAnnotationsForImplicitOCIDigestPlugins(t *testing.T) { - lookupCount := 0 - fetchAnnotations := func(fetchCtx context.Context, source string, option ...remote.Option) (map[string]string, error) { - lookupCount++ - return map[string]string{ - AnnotationProtocolVersionKey: "2", - }, nil - } - - config := &agentConfig{ - Plugins: map[string]*agentPlugin{ - "implicit-oci-digest": { - Source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", - ProtocolVersion: DefaultProtocolVersion, - protocolSet: false, - }, - }, - } - - runner := NewAgentRunner() - runner.fetchAnnotations = fetchAnnotations - runner.UpdateConfig(config) - runner.resolvePluginProtocols(context.Background()) - - if lookupCount != 1 { - t.Fatalf("Expected one annotation lookup, got %d", lookupCount) - } - - if got := config.Plugins["implicit-oci-digest"].ProtocolVersion; got != RunnerV2ProtocolVersion { - t.Fatalf("Expected implicit-oci-digest protocol version to be %d, got %d", RunnerV2ProtocolVersion, got) - } -} - func TestProtocolVersionFromAnnotations(t *testing.T) { tests := []struct { name string diff --git a/internal/oci.go b/internal/oci.go index 92afdad..560d6ee 100644 --- a/internal/oci.go +++ b/internal/oci.go @@ -20,12 +20,6 @@ func IsOCI(source string) bool { return err == nil } -func IsOCIReference(source string) bool { - // Check whether this can be parsed as an OCI endpoint - _, err := name.ParseReference(source, name.StrictValidation) - return err == nil -} - func GetAnnotations(ctx context.Context, source string, option ...remote.Option) (map[string]string, error) { ref, err := name.ParseReference(source, name.StrictValidation) if err != nil { diff --git a/internal/utils_test.go b/internal/utils_test.go index 5f5e054..dcb8d20 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -121,38 +121,6 @@ func TestIsOci(t *testing.T) { } } -func TestIsOCIReference(t *testing.T) { - tests := []struct { - name string - source string - expected bool - }{ - { - name: "Basic OCI url with version", - source: "docker.io/library/alpine:1.0", - expected: true, - }, - { - name: "Basic OCI url with digest", - source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", - expected: true, - }, - { - name: "Tar artifact", - source: "docker.io/library/alpine.tar.gz", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsOCIReference(tt.source); got != tt.expected { - t.Errorf("IsOCIReference() = %v, want %v", got, tt.expected) - } - }) - } -} - func TestSeededUUID(t *testing.T) { t.Run("SeededUUID generates a consistent ID for the same seed", func(t *testing.T) { seedData := []string{ diff --git a/runner/grpc.go b/runner/grpc.go index 09445a2..a52238f 100644 --- a/runner/grpc.go +++ b/runner/grpc.go @@ -2,6 +2,7 @@ package runner import ( "context" + "sync" "github.com/compliance-framework/agent/runner/proto" "github.com/hashicorp/go-hclog" @@ -28,12 +29,27 @@ func (m *GRPCApiHelperClient) CreateEvidence(ctx context.Context, evidence []*pr } type GRPCApiHelperServer struct { + mu sync.RWMutex + // This is the real implementation Impl ApiHelper } +func (m *GRPCApiHelperServer) SetImpl(impl ApiHelper) { + m.mu.Lock() + defer m.mu.Unlock() + m.Impl = impl +} + func (m *GRPCApiHelperServer) CreateEvidence(ctx context.Context, req *proto.CreateEvidenceRequest) (resp *proto.CreateEvidenceResponse, err error) { - err = m.Impl.CreateEvidence(ctx, req.GetEvidence()) + m.mu.RLock() + impl := m.Impl + m.mu.RUnlock() + if impl == nil { + return nil, status.Error(codes.FailedPrecondition, "API helper server is not configured") + } + + err = impl.CreateEvidence(ctx, req.GetEvidence()) if err != nil { return nil, err } @@ -44,6 +60,10 @@ func (m *GRPCApiHelperServer) CreateEvidence(ctx context.Context, req *proto.Cre type GRPCClient struct { client proto.RunnerClient broker *plugin.GRPCBroker + + apiHelperServer *GRPCApiHelperServer + apiServerID uint32 + apiServerOnce sync.Once } type GRPCClientV2 struct { @@ -51,18 +71,22 @@ type GRPCClientV2 struct { } func (m *GRPCClient) startAPIServer(a ApiHelper) uint32 { - apiHelperServer := &GRPCApiHelperServer{Impl: a} + m.apiServerOnce.Do(func() { + m.apiHelperServer = &GRPCApiHelperServer{} - serverFunc := func(opts []grpc.ServerOption) *grpc.Server { - s := grpc.NewServer(opts...) - proto.RegisterApiHelperServer(s, apiHelperServer) - return s - } + serverFunc := func(opts []grpc.ServerOption) *grpc.Server { + s := grpc.NewServer(opts...) + proto.RegisterApiHelperServer(s, m.apiHelperServer) + return s + } + + m.apiServerID = m.broker.NextId() + go m.broker.AcceptAndServe(m.apiServerID, serverFunc) + }) - brokerID := m.broker.NextId() - go m.broker.AcceptAndServe(brokerID, serverFunc) + m.apiHelperServer.SetImpl(a) - return brokerID + return m.apiServerID } func (m *GRPCClient) Configure(request *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { From c7350a68d9a9458b9ff78c14b79a15b983b968e6 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 15:25:58 +0000 Subject: [PATCH 7/8] fix: copilot issues --- cmd/agent.go | 41 ++++++++++--------- .../0001-support-versioned-runner-plugins.md | 4 ++ internal/oci.go | 13 +++++- internal/oci_manifest_test.go | 31 ++++++++++++++ internal/utils_test.go | 2 +- 5 files changed, 69 insertions(+), 22 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index c6cc592..eb0f624 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -91,7 +91,7 @@ func (ac *agentConfig) validate() error { return fmt.Errorf("plugin %s has unsupported protocol_version=%d; supported values are %d and %d", name, pluginConfig.ProtocolVersion, DefaultProtocolVersion, RunnerV2ProtocolVersion) } - pluginConfig.ProtocolVersion = DefaultProtocolVersion + continue } if !isSupportedProtocolVersion(pluginConfig.ProtocolVersion) { @@ -396,26 +396,29 @@ func (ar *AgentRunner) resolvePluginProtocols(ctx context.Context) { continue } - annotationCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - annotations, err := ar.fetchAnnotations(annotationCtx, pluginConfig.Source) - cancel() - if err != nil { - ar.logger.Warn("Failed to fetch plugin annotations, using configured/default protocol version", "plugin", pluginName, "source", pluginConfig.Source, "protocol_version", pluginConfig.ProtocolVersion, "error", err) - continue - } + func() { + annotationCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() - value, ok := annotations[AnnotationProtocolVersionKey] - if !ok { - continue - } + annotations, err := ar.fetchAnnotations(annotationCtx, pluginConfig.Source) + if err != nil { + ar.logger.Warn("Failed to fetch plugin annotations, using configured/default protocol version", "plugin", pluginName, "source", pluginConfig.Source, "protocol_version", pluginConfig.ProtocolVersion, "error", err) + return + } - protocolVersion, ok := protocolVersionFromAnnotations(annotations) - if !ok { - ar.logger.Warn("Ignoring unsupported plugin protocol version annotation", "plugin", pluginName, "source", pluginConfig.Source, "value", value, "protocol_version", pluginConfig.ProtocolVersion) - continue - } + value, ok := annotations[AnnotationProtocolVersionKey] + if !ok { + return + } + + protocolVersion, ok := protocolVersionFromAnnotations(annotations) + if !ok { + ar.logger.Warn("Ignoring unsupported plugin protocol version annotation", "plugin", pluginName, "source", pluginConfig.Source, "value", value, "protocol_version", pluginConfig.ProtocolVersion) + return + } - pluginConfig.ProtocolVersion = protocolVersion + pluginConfig.ProtocolVersion = protocolVersion + }() } } @@ -772,7 +775,7 @@ func (ar *AgentRunner) getRunnerInstance(logger hclog.Logger, path string, proto } // Request the plugin - logger.Debug("Dispensing plugin", "runner", dispenseName) + logger.Debug("Dispensing plugin", "dispense_name", dispenseName) raw, err := rpcClient.Dispense(dispenseName) if err != nil { return nil, err diff --git a/docs/adr/0001-support-versioned-runner-plugins.md b/docs/adr/0001-support-versioned-runner-plugins.md index 9db59eb..ddcdf4a 100644 --- a/docs/adr/0001-support-versioned-runner-plugins.md +++ b/docs/adr/0001-support-versioned-runner-plugins.md @@ -17,6 +17,8 @@ This is implemented by: - adding `protocol_version` to plugin configuration - defaulting unspecified plugins to protocol version 1 - reading `org.ccf.plugin.protocol.version` from OCI annotations for OCI plugin sources without an explicit `protocol_version` + - this currently applies to tag-form OCI references such as `ghcr.io/example/plugin:v1` + - digest-form references such as `ghcr.io/example/plugin@sha256:...` are not currently treated as supported OCI download sources by the agent - supporting only protocol versions 1 and 2 - mapping protocol version 1 to the `runner` dispense name and protocol version 2 to `runner-v2` - calling `Init` before `Eval` for protocol version 2 plugins @@ -29,10 +31,12 @@ This is implemented by: - Existing plugins continue to work without configuration changes. - New plugins can adopt protocol version 2 and perform setup during `Init`. - OCI-published plugins can self-describe their protocol version, reducing configuration drift. +- The supported OCI source shape is explicit: tag-form references participate in annotation lookup and download. - Unsupported or invalid annotations do not break execution; the agent logs and falls back to the configured or default version. ### Negative - OCI-backed plugins may require an extra registry metadata lookup before execution. +- Digest-form OCI references are not currently supported for plugin download or annotation-based protocol resolution. - The agent now maintains two supported runner contracts instead of one. - Plugin authors adopting protocol version 2 must implement `Init`. diff --git a/internal/oci.go b/internal/oci.go index 560d6ee..d1fcd51 100644 --- a/internal/oci.go +++ b/internal/oci.go @@ -50,17 +50,26 @@ func annotationsFromDescriptor(desc *remote.Descriptor) map[string]string { } if err := json.Unmarshal(desc.Manifest, &payload); err == nil && len(payload.Annotations) > 0 { - return payload.Annotations + return copyAnnotations(payload.Annotations) } } if len(desc.Annotations) > 0 { - return desc.Annotations + return copyAnnotations(desc.Annotations) } return map[string]string{} } +func copyAnnotations(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for key, value := range in { + out[key] = value + } + + return out +} + func Download(ctx context.Context, source string, outputDir string, binaryPath string, logger hclog.Logger, option ...remote.Option) (string, error) { // Add a task to indicate we've downloaded the items logger.Trace("Checking for source", "source", source) diff --git a/internal/oci_manifest_test.go b/internal/oci_manifest_test.go index cbc98b3..7388d0f 100644 --- a/internal/oci_manifest_test.go +++ b/internal/oci_manifest_test.go @@ -65,3 +65,34 @@ func TestAnnotationsFromDescriptor(t *testing.T) { }) } } + +func TestAnnotationsFromDescriptor_ReturnsDefensiveCopy(t *testing.T) { + t.Run("Descriptor annotations are copied", func(t *testing.T) { + desc := &remote.Descriptor{ + Descriptor: v1.Descriptor{ + Annotations: map[string]string{"from": "descriptor"}, + }, + } + + got := annotationsFromDescriptor(desc) + got["from"] = "modified" + + if desc.Annotations["from"] != "descriptor" { + t.Fatalf("expected descriptor annotations to remain unchanged, got %q", desc.Annotations["from"]) + } + }) + + t.Run("Manifest annotations are copied", func(t *testing.T) { + desc := &remote.Descriptor{ + Manifest: []byte(`{"schemaVersion":2,"annotations":{"org.ccf.plugin.protocol.version":"2"}}`), + } + + got := annotationsFromDescriptor(desc) + got["org.ccf.plugin.protocol.version"] = "1" + + again := annotationsFromDescriptor(desc) + if again["org.ccf.plugin.protocol.version"] != "2" { + t.Fatalf("expected manifest annotations to remain unchanged, got %q", again["org.ccf.plugin.protocol.version"]) + } + }) +} diff --git a/internal/utils_test.go b/internal/utils_test.go index dcb8d20..b402ad5 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -87,7 +87,7 @@ func TestIsOci(t *testing.T) { expected: true, }, { - name: "Basic OCI url with digest", + name: "Digest OCI reference is not treated as supported OCI tag", source: "ghcr.io/example/plugin@sha256:88252198a40099248f5cc3272bc879fade8b7001a2bcb36d7b43aa8f54328714", expected: false, }, From 4783fcf251d85014c3013ebcb039f2554614c9c8 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Tue, 10 Mar 2026 15:35:50 +0000 Subject: [PATCH 8/8] fix: minor comment change --- runner/grpc.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runner/grpc.go b/runner/grpc.go index a52238f..14cce6e 100644 --- a/runner/grpc.go +++ b/runner/grpc.go @@ -56,7 +56,7 @@ func (m *GRPCApiHelperServer) CreateEvidence(ctx context.Context, req *proto.Cre return &proto.CreateEvidenceResponse{}, err } -// GRPCClient is an implementation of KV that talks over RPC. +// GRPCClient implements Runner over go-plugin gRPC. type GRPCClient struct { client proto.RunnerClient broker *plugin.GRPCBroker @@ -66,6 +66,7 @@ type GRPCClient struct { apiServerOnce sync.Once } +// GRPCClientV2 extends GRPCClient with RunnerV2 support over go-plugin gRPC. type GRPCClientV2 struct { *GRPCClient }