diff --git a/e2e/internal/contrasttest/contrasttest.go b/e2e/internal/contrasttest/contrasttest.go index 4f9cef472e..28c44514fe 100644 --- a/e2e/internal/contrasttest/contrasttest.go +++ b/e2e/internal/contrasttest/contrasttest.go @@ -41,6 +41,7 @@ import ( "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" ) // Flags contains the parsed Flags for the test. @@ -85,6 +86,10 @@ type ContrastTest struct { GHCRToken string Kubeclient *kubeclient.Kubeclient + // SkipKDSProxy disables injection of kds-proxy env vars + CA bundle into. + // Used by kds-pcs-downtime test. + SkipKDSProxy bool + // outputs of contrast subcommands meshCACertPEM []byte rootCACertPEM []byte @@ -187,6 +192,12 @@ func (ct *ContrastTest) Init(t *testing.T, resources []any) { resources = kuberesource.AddLogging(resources, "debug", "*") resources = kuberesource.PatchNodeSelector(resources) resources = ct.OverrideStorageClass(t, resources) + if !ct.SkipKDSProxy { + ct.copyKDSProxyCA(t) + resources = kuberesource.AddKDSProxy(resources, + kuberesource.KDSProxyDefaultService, + kuberesource.KDSProxyCAConfigMap) + } unstructuredResources, err := kuberesource.ResourcesToUnstructured(resources) require.NoError(err) @@ -198,6 +209,27 @@ func (ct *ContrastTest) Init(t *testing.T, resources []any) { ct.installRuntime(t, resources) } +// copyKDSProxyCA copies the cluster-wide kds-proxy CA ConfigMap from "default" +// into the test's namespace, where the AddKDSProxy mutator mounts it. +func (ct *ContrastTest) copyKDSProxyCA(t *testing.T) { + require := require.New(t) + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) + defer cancel() + + src, err := ct.Kubeclient.Client.CoreV1(). + ConfigMaps("default"). + Get(ctx, kuberesource.KDSProxyCAConfigMap, metav1.GetOptions{}) + require.NoError(err, + "kds-proxy CA ConfigMap %q missing in default namespace — run the kds-proxy bootstrap step", + kuberesource.KDSProxyCAConfigMap) + + cm := corev1ac.ConfigMap(kuberesource.KDSProxyCAConfigMap, ct.Namespace). + WithData(src.Data) + unstr, err := kuberesource.ResourcesToUnstructured([]any{cm}) + require.NoError(err) + require.NoError(ct.Kubeclient.Apply(ctx, unstr...)) +} + // OverrideStorageClass looks for a StorageClass with a well-known label and modifies the resources to use that class. func (ct *ContrastTest) OverrideStorageClass(t *testing.T, resources []any) []any { ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) diff --git a/e2e/kds-pcs-downtime/kds-pcs-downtime_test.go b/e2e/kds-pcs-downtime/kds-pcs-downtime_test.go index 1bb9d7ce2f..89d644a1c2 100644 --- a/e2e/kds-pcs-downtime/kds-pcs-downtime_test.go +++ b/e2e/kds-pcs-downtime/kds-pcs-downtime_test.go @@ -34,6 +34,8 @@ func TestKDSPCSDowntime(t *testing.T) { platform, err := platforms.FromString(contrasttest.Flags.PlatformStr) require.NoError(t, err) ct := contrasttest.New(t) + // This test drives https_proxy itself via goproxy. + ct.SkipKDSProxy = true runtimeHandler, err := manifest.RuntimeHandler(platform) require.NoError(t, err) diff --git a/go.work b/go.work index 044c7cb157..b5a33f13d5 100644 --- a/go.work +++ b/go.work @@ -6,6 +6,7 @@ use ( ./imagepuller ./imagestore ./initdata-processor + ./kds-proxy ./service-mesh ./tools/debugshell ./tools/fifo diff --git a/internal/kuberesource/mutators.go b/internal/kuberesource/mutators.go index 854145c6e8..60f8feddfe 100644 --- a/internal/kuberesource/mutators.go +++ b/internal/kuberesource/mutators.go @@ -32,6 +32,21 @@ const ( securePVAnnotationKey = "contrast.edgeless.systems/secure-pv" workloadSecretIDAnnotationKey = "contrast.edgeless.systems/workload-secret-id" imageStoreSizeAnnotationKey = "contrast.edgeless.systems/image-store-size" + skipKDSProxyAnnotationKey = "contrast.edgeless.systems/skip-kds-proxy" +) + +// Defaults for routing Contrast pod attestation traffic through the in-cluster kds-proxy. +const ( + KDSProxyDefaultService = "http://kds-proxy.default.svc:3128" + KDSProxyCAConfigMap = "kds-proxy-ca" + KDSProxyCAKey = "ca.crt" + kdsProxyCAVolumeName = "kds-proxy-ca" + kdsProxyMountDir = "/etc/ssl/kds-proxy" + kdsProxyCAPath = kdsProxyMountDir + "/" + KDSProxyCAKey + + // Keep in-cluster traffic out of the forward proxy. + kdsProxyNoProxy = "localhost,127.0.0.1,.svc,.svc.cluster.local,.cluster.local," + + "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16" ) // contrastRuntimeClassPrefixes lists runtime class prefixes that identify Contrast pods. @@ -300,6 +315,70 @@ func ensureVolumeExists(spec *applycorev1.PodSpecApplyConfiguration, volumeName return nil } +// AddKDSProxy mounts the proxy CA from configMapName and uses the SSL_CERT_FILE in every Contrast container. +func AddKDSProxy(resources []any, proxyURL, configMapName string) []any { + out := make([]any, 0, len(resources)) + for _, resource := range resources { + out = append(out, MapPodSpecWithMeta(resource, func(meta *applymetav1.ObjectMetaApplyConfiguration, spec *applycorev1.PodSpecApplyConfiguration) (*applymetav1.ObjectMetaApplyConfiguration, *applycorev1.PodSpecApplyConfiguration) { + if !IsContrastPod(spec) { + return meta, spec + } + if meta != nil && meta.Annotations[skipKDSProxyAnnotationKey] == "true" { + return meta, spec + } + injectKDSProxy(spec, proxyURL, configMapName) + return meta, spec + })) + } + return out +} + +func injectKDSProxy(spec *applycorev1.PodSpecApplyConfiguration, proxyURL, configMapName string) { + if !hasVolumeNamed(spec, kdsProxyCAVolumeName) { + spec.Volumes = append(spec.Volumes, *applycorev1.Volume(). + WithName(kdsProxyCAVolumeName). + WithConfigMap(applycorev1.ConfigMapVolumeSource().WithName(configMapName))) + } + + envs := []struct{ name, value string }{ + {"https_proxy", proxyURL}, + {"HTTPS_PROXY", proxyURL}, + {"no_proxy", kdsProxyNoProxy}, + {"NO_PROXY", kdsProxyNoProxy}, + {"SSL_CERT_FILE", kdsProxyCAPath}, + } + for i := range spec.Containers { + c := &spec.Containers[i] + for _, e := range envs { + if !hasEnvNamed(c, e.name) { + c.Env = append(c.Env, *applycorev1.EnvVar().WithName(e.name).WithValue(e.value)) + } + } + addOrReplaceVolumeMount(c, *applycorev1.VolumeMount(). + WithName(kdsProxyCAVolumeName). + WithMountPath(kdsProxyMountDir). + WithReadOnly(true)) + } +} + +func hasVolumeNamed(spec *applycorev1.PodSpecApplyConfiguration, name string) bool { + for _, v := range spec.Volumes { + if v.Name != nil && *v.Name == name { + return true + } + } + return false +} + +func hasEnvNamed(c *applycorev1.ContainerApplyConfiguration, name string) bool { + for _, e := range c.Env { + if e.Name != nil && *e.Name == name { + return true + } + } + return false +} + // AddPortForwarders adds a port-forwarder for each Service. func AddPortForwarders(resources []any) []any { var out []any diff --git a/internal/kuberesource/mutators_test.go b/internal/kuberesource/mutators_test.go index 4970577bfb..9b9cfbce40 100644 --- a/internal/kuberesource/mutators_test.go +++ b/internal/kuberesource/mutators_test.go @@ -783,6 +783,62 @@ spec: } } +func TestAddKDSProxy(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + contrastPod := applycorev1.Pod("worker", "default"). + WithSpec(applycorev1.PodSpec(). + WithRuntimeClassName("contrast-cc-foo"). + WithContainers(applycorev1.Container().WithName("app").WithImage("nginx"))) + skipPod := applycorev1.Pod("skipper", "default"). + WithAnnotations(map[string]string{skipKDSProxyAnnotationKey: "true"}). + WithSpec(applycorev1.PodSpec(). + WithRuntimeClassName("contrast-cc-foo"). + WithContainers(applycorev1.Container().WithName("app").WithImage("nginx"))) + nonContrastPod := applycorev1.Pod("plain", "default"). + WithSpec(applycorev1.PodSpec(). + WithContainers(applycorev1.Container().WithName("app").WithImage("nginx"))) + + out := AddKDSProxy([]any{contrastPod, skipPod, nonContrastPod}, + "http://kds-proxy:3128", "kds-proxy-ca") + require.Len(out, 3) + + mutated, ok := out[0].(*applycorev1.PodApplyConfiguration) + require.True(ok) + spec := mutated.Spec + assert.Empty(spec.InitContainers, "should not add an init container") + + require.Len(spec.Containers, 1) + envs := map[string]string{} + for _, e := range spec.Containers[0].Env { + envs[*e.Name] = *e.Value + } + assert.Equal("http://kds-proxy:3128", envs["https_proxy"]) + assert.Equal("http://kds-proxy:3128", envs["HTTPS_PROXY"]) + assert.Equal(kdsProxyNoProxy, envs["no_proxy"]) + assert.Equal(kdsProxyNoProxy, envs["NO_PROXY"]) + assert.Equal(kdsProxyCAPath, envs["SSL_CERT_FILE"]) + require.Len(spec.Volumes, 1) + assert.Equal(kdsProxyCAVolumeName, *spec.Volumes[0].Name) + + skipped, ok := out[1].(*applycorev1.PodApplyConfiguration) + require.True(ok) + assert.Empty(skipped.Spec.Volumes) + assert.Empty(skipped.Spec.Containers[0].Env) + + plain, ok := out[2].(*applycorev1.PodApplyConfiguration) + require.True(ok) + assert.Empty(plain.Spec.Volumes) + assert.Empty(plain.Spec.Containers[0].Env) + + out2 := AddKDSProxy(out, "http://kds-proxy:3128", "kds-proxy-ca") + mutated2, ok := out2[0].(*applycorev1.PodApplyConfiguration) + require.True(ok) + assert.Len(mutated2.Spec.Volumes, 1) + assert.Len(mutated2.Spec.Containers[0].Env, 5) +} + func TestMapPodSpecWithErrors(t *testing.T) { require := require.New(t) diff --git a/internal/kuberesource/parts.go b/internal/kuberesource/parts.go index 322c18c999..0befc48669 100644 --- a/internal/kuberesource/parts.go +++ b/internal/kuberesource/parts.go @@ -573,3 +573,58 @@ func GetPodCPUCount(spec *applycorev1.PodSpecApplyConfiguration) uint64 { totalCPUs := (totalMilliCPUs+999)/1000 + 1 return uint64(totalCPUs) } + +// KDSProxy returns the resources for an in-cluster HTTPS forward proxy caching responses from AMD KDS, Intel PCS, and NVIDIA RIM endpoints. +func KDSProxy(namespace, storageClassName string) []any { + const ( + name = "kds-proxy" + port = int32(3128) + stateDir = "/var/lib/kds-proxy" + stateVol = "state" + ) + labels := map[string]string{"app.kubernetes.io/name": name} + + pvcSpec := applycorev1.PersistentVolumeClaimSpec(). + WithAccessModes(corev1.ReadWriteOnce). + WithResources(applycorev1.VolumeResourceRequirements(). + WithRequests(corev1.ResourceList{corev1.ResourceStorage: resource.MustParse("1Gi")})) + if storageClassName != "" { + pvcSpec = pvcSpec.WithStorageClassName(storageClassName) + } + pvc := applycorev1.PersistentVolumeClaim(name+"-state", namespace).WithSpec(pvcSpec) + + mem := corev1.ResourceList{corev1.ResourceMemory: resource.MustParse("256Mi")} + deployment := Deployment(name, namespace). + WithSpec(DeploymentSpec(). + WithReplicas(1). + WithSelector(LabelSelector().WithMatchLabels(labels)). + WithStrategy(applyappsv1.DeploymentStrategy(). + WithType(appsv1.RecreateDeploymentStrategyType)). + WithTemplate(PodTemplateSpec(). + WithLabels(labels). + WithSpec(PodSpec(). + WithVolumes(applycorev1.Volume(). + WithName(stateVol). + WithPersistentVolumeClaim(applycorev1.PersistentVolumeClaimVolumeSource(). + WithClaimName(name + "-state"))). + WithContainers(applycorev1.Container(). + WithName(name). + WithImage("ghcr.io/edgelesssys/contrast/kds-proxy:latest"). + WithArgs(fmt.Sprintf("-addr=:%d", port), "-state-dir="+stateDir). + WithPorts(applycorev1.ContainerPort(). + WithName("proxy"). + WithContainerPort(port)). + WithVolumeMounts(applycorev1.VolumeMount(). + WithName(stateVol). + WithMountPath(stateDir)). + WithReadinessProbe(applycorev1.Probe(). + WithHTTPGet(applycorev1.HTTPGetAction(). + WithPath("/healthz"). + WithPort(intstr.FromInt32(port))). + WithPeriodSeconds(5)). + WithResources(applycorev1.ResourceRequirements(). + WithRequests(mem). + WithLimits(mem)))))) + + return []any{deployment, ServiceForDeployment(deployment), pvc} +} diff --git a/internal/kuberesource/parts_test.go b/internal/kuberesource/parts_test.go index 71e23f7a9c..0e11af23d6 100644 --- a/internal/kuberesource/parts_test.go +++ b/internal/kuberesource/parts_test.go @@ -29,6 +29,17 @@ func TestCoordinator(t *testing.T) { t.Log("\n" + string(b)) } +func TestKDSProxy(t *testing.T) { + require := require.New(t) + + resources := KDSProxy("default", "") + require.Len(resources, 3) + + b, err := EncodeResources(resources...) + require.NoError(err) + t.Log("\n" + string(b)) +} + func TestNoNamespaces(t *testing.T) { coordinator := CoordinatorBundle() openssl := OpenSSL() diff --git a/internal/kuberesource/resourcegen/main.go b/internal/kuberesource/resourcegen/main.go index 608809ad43..5346323d21 100644 --- a/internal/kuberesource/resourcegen/main.go +++ b/internal/kuberesource/resourcegen/main.go @@ -42,6 +42,8 @@ func main() { switch set { case "coordinator": subResources = kuberesource.PatchRuntimeHandlers(kuberesource.CoordinatorBundle(), "contrast-cc") + case "kds-proxy": + subResources = kuberesource.KDSProxy(*namespace, *storageClass) case "runtime": platformCollection := kuberesource.PlatformCollection{} if err := platformCollection.AddFromCommaSeparated(*rawPlatform); err != nil { diff --git a/justfile b/justfile index 1b06802d07..4d6073048b 100644 --- a/justfile +++ b/justfile @@ -23,6 +23,8 @@ port-forwarder: (push "port-forwarder") service-mesh-proxy: (push "service-mesh-proxy") +kds-proxy: (push "kds-proxy") + initializer: (push "initializer") memdump: (push "memdump") @@ -429,6 +431,38 @@ wait-for-workload target=default_deploy_target set=default_set: ;; esac +# Provision the kds-proxy singleton in "default" and its kds-proxy-ca ConfigMap. Re-run after pushing a new image. +kds-proxy-bootstrap set=default_set storage_class="": kds-proxy + #!/usr/bin/env bash + set -euo pipefail + sc="{{ storage_class }}" + if [[ -z "$sc" ]]; then + sc=$(kubectl get sc -l ci.contrast.edgeless.systems/is-default-class=true -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || true) + fi + args=(--image-replacements ./{{ workspace_dir }}/just.containerlookup --namespace default) + if [[ -n "$sc" ]]; then + args+=(--storage-class "$sc") + fi + nix shell .#{{ set }}.contrast.resourcegen --command resourcegen "${args[@]}" kds-proxy \ + | kubectl apply --server-side --force-conflicts -f - + kubectl -n default rollout status deploy/kds-proxy --timeout=120s + ca=$(kubectl -n default exec deploy/kds-proxy -- cat /var/lib/kds-proxy/ca/ca.crt) + kubectl -n default create configmap kds-proxy-ca \ + --from-literal=ca.crt="$ca" \ + --dry-run=client -o yaml | kubectl apply -f - + +# Wipe the existing kds-proxy (pod, service, PVC, CA configmap) and bootstrap a fresh one. +kds-proxy-redeploy set=default_set storage_class="": + #!/usr/bin/env bash + set -euo pipefail + kubectl -n default delete --ignore-not-found --wait=true \ + deploy/kds-proxy svc/kds-proxy pvc/kds-proxy-state configmap/kds-proxy-ca + just kds-proxy-bootstrap "{{ set }}" "{{ storage_class }}" + +# Print the kds-proxy /metrics endpoint. +kds-proxy-stats: + kubectl -n default exec deploy/kds-proxy -- wget -qO- http://127.0.0.1:3128/metrics + request-fifo-ticket timeout="": #!/usr/bin/env bash set -euo pipefail diff --git a/kds-proxy/go.mod b/kds-proxy/go.mod new file mode 100644 index 0000000000..7da139d5a1 --- /dev/null +++ b/kds-proxy/go.mod @@ -0,0 +1,5 @@ +module github.com/edgelesssys/contrast/kds-proxy + +go 1.25.6 + +require golang.org/x/sync v0.20.0 diff --git a/kds-proxy/go.sum b/kds-proxy/go.sum new file mode 100644 index 0000000000..733d7160e1 --- /dev/null +++ b/kds-proxy/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= diff --git a/kds-proxy/internal/allowlist/allowlist.go b/kds-proxy/internal/allowlist/allowlist.go new file mode 100644 index 0000000000..d6aa8edfb9 --- /dev/null +++ b/kds-proxy/internal/allowlist/allowlist.go @@ -0,0 +1,17 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package allowlist + +// Default lists the hosts that kds-proxy will proxy to. Any other host is rejected. +var Default = map[string]struct{}{ + "kdsintf.amd.com": {}, + "api.trustedservices.intel.com": {}, + "rim.attestation.nvidia.com": {}, +} + +// Allows reports whether host is in the allowlist. +func Allows(host string) bool { + _, ok := Default[host] + return ok +} diff --git a/kds-proxy/internal/ca/ca.go b/kds-proxy/internal/ca/ca.go new file mode 100644 index 0000000000..b6960ea61e --- /dev/null +++ b/kds-proxy/internal/ca/ca.go @@ -0,0 +1,204 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package ca + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math/big" + "os" + "path/filepath" + "sync" + "time" +) + +const ( + caCertFile = "ca.crt" + caKeyFile = "ca.key" + + caValidity = 10 * 365 * 24 * time.Hour + leafValidity = 30 * 24 * time.Hour +) + +// CA is a persistent signing authority used to mint leaf certs on demand. +type CA struct { + cert *x509.Certificate + key *ecdsa.PrivateKey + + mu sync.Mutex + leafs map[string]leafEntry +} + +type leafEntry struct { + certPEM []byte + keyPEM []byte + expires time.Time +} + +// LoadOrGenerate loads a CA from dir, generating and persisting one if absent. +func LoadOrGenerate(dir string) (*CA, error) { + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("creating CA dir: %w", err) + } + certPath := filepath.Join(dir, caCertFile) + keyPath := filepath.Join(dir, caKeyFile) + + certPEM, certErr := os.ReadFile(certPath) + keyPEM, keyErr := os.ReadFile(keyPath) + if certErr == nil && keyErr == nil { + cert, key, err := parseCAPEM(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("parsing existing CA material: %w", err) + } + return &CA{cert: cert, key: key, leafs: map[string]leafEntry{}}, nil + } + if certErr != nil && !errors.Is(certErr, os.ErrNotExist) { + return nil, certErr + } + if keyErr != nil && !errors.Is(keyErr, os.ErrNotExist) { + return nil, keyErr + } + + cert, key, err := generateCA() + if err != nil { + return nil, err + } + if err := writeCAPEM(dir, cert, key); err != nil { + return nil, err + } + return &CA{cert: cert, key: key, leafs: map[string]leafEntry{}}, nil +} + +// CertPEM returns the CA certificate in PEM form for distribution to clients. +func (c *CA) CertPEM() []byte { + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.cert.Raw}) +} + +// LeafPEM returns a leaf cert + key pair (PEM-encoded) for host, minting and caching it on first use. +func (c *CA) LeafPEM(host string) (certPEM, keyPEM []byte, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if e, ok := c.leafs[host]; ok && time.Now().Before(e.expires) { + return e.certPEM, e.keyPEM, nil + } + cert, key, err := c.mintLeaf(host) + if err != nil { + return nil, nil, err + } + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, nil, err + } + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + c.leafs[host] = leafEntry{ + certPEM: certPEM, + keyPEM: keyPEM, + expires: cert.NotAfter.Add(-time.Hour), + } + return certPEM, keyPEM, nil +} + +func (c *CA) mintLeaf(host string) (*x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + serial, err := randomSerial() + if err != nil { + return nil, nil, err + } + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: host}, + DNSNames: []string{host}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(leafValidity), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, c.cert, &key.PublicKey, c.key) + if err != nil { + return nil, nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, nil, err + } + return cert, key, nil +} + +func generateCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + serial, err := randomSerial() + if err != nil { + return nil, nil, err + } + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "contrast kds-proxy CA"}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(caValidity), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + return nil, nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, nil, err + } + return cert, key, nil +} + +func randomSerial() (*big.Int, error) { + return rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) +} + +func parseCAPEM(certPEM, keyPEM []byte) (*x509.Certificate, *ecdsa.PrivateKey, error) { + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil || certBlock.Type != "CERTIFICATE" { + return nil, nil, errors.New("invalid CA cert PEM") + } + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, nil, err + } + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil || keyBlock.Type != "EC PRIVATE KEY" { + return nil, nil, errors.New("invalid CA key PEM") + } + key, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, nil, err + } + return cert, key, nil +} + +func writeCAPEM(dir string, cert *x509.Certificate, key *ecdsa.PrivateKey) error { + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + if err := os.WriteFile(filepath.Join(dir, caCertFile), certPEM, 0o600); err != nil { + return err + } + return os.WriteFile(filepath.Join(dir, caKeyFile), keyPEM, 0o600) +} diff --git a/kds-proxy/internal/cache/cache.go b/kds-proxy/internal/cache/cache.go new file mode 100644 index 0000000000..56131e40eb --- /dev/null +++ b/kds-proxy/internal/cache/cache.go @@ -0,0 +1,166 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/json" + "io/fs" + "log/slog" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +const minTTL = time.Hour + +// Entry is a cached HTTP response. +type Entry struct { + URL string `json:"url"` + Status int `json:"status"` + Header http.Header `json:"header"` + Body []byte `json:"body"` + FreshUntil time.Time `json:"freshUntil"` +} + +// Cache is an in-memory + on-disk cache for upstream HTTP responses. +type Cache struct { + dir string + mu sync.RWMutex + entries map[string]*Entry +} + +// New opens or creates a cache at dir, loading existing entries into memory. +func New(dir string) (*Cache, error) { + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, err + } + c := &Cache{dir: dir, entries: map[string]*Entry{}} + if err := c.loadAll(); err != nil { + return nil, err + } + return c, nil +} + +// Get returns the entry for url and whether or not it is fresh. +func (c *Cache) Get(url string) (*Entry, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + e, ok := c.entries[url] + if !ok { + return nil, false + } + return e, time.Now().Before(e.FreshUntil) +} + +// Put stores a response under url, computing freshness from headers / body. +func (c *Cache) Put(url string, status int, header http.Header, body []byte) (*Entry, error) { + e := &Entry{ + URL: url, + Status: status, + Header: header.Clone(), + Body: body, + FreshUntil: time.Now().Add(freshness(header, body)), + } + if err := c.writeDisk(e); err != nil { + return nil, err + } + c.mu.Lock() + c.entries[url] = e + c.mu.Unlock() + return e, nil +} + +func (c *Cache) loadAll() error { + return filepath.WalkDir(c.dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() || !strings.HasSuffix(path, ".json") { + return nil + } + raw, err := os.ReadFile(path) + if err != nil { + return err + } + var e Entry + if err := json.Unmarshal(raw, &e); err != nil { + slog.Warn("skipping corrupt cache entry", "path", path, "err", err) + return nil + } + c.entries[e.URL] = &e + return nil + }) +} + +func (c *Cache) writeDisk(e *Entry) error { + raw, err := json.Marshal(e) + if err != nil { + return err + } + path := filepath.Join(c.dir, keyHash(e.URL)+".json") + tmp := path + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func keyHash(url string) string { + sum := sha256.Sum256([]byte(url)) + return hex.EncodeToString(sum[:]) +} + +func freshness(header http.Header, body []byte) time.Duration { + if d, ok := crlFreshness(body); ok { + return d + } + if d, ok := cacheControlMaxAge(header); ok { + return d + } + return minTTL +} + +func cacheControlMaxAge(header http.Header) (time.Duration, bool) { + cc := header.Get("Cache-Control") + if cc == "" { + return 0, false + } + if strings.Contains(cc, "no-store") || strings.Contains(cc, "no-cache") { + return 0, false + } + for _, part := range strings.Split(cc, ",") { + part = strings.TrimSpace(part) + if !strings.HasPrefix(part, "max-age=") { + continue + } + secs, err := strconv.Atoi(strings.TrimPrefix(part, "max-age=")) + if err != nil || secs <= 0 { + return 0, false + } + return time.Duration(secs) * time.Second, true + } + return 0, false +} + +func crlFreshness(body []byte) (time.Duration, bool) { + if len(body) == 0 { + return 0, false + } + crl, err := x509.ParseRevocationList(body) + if err != nil || crl.NextUpdate.IsZero() { + return 0, false + } + d := time.Until(crl.NextUpdate) + if d < 0 { + return 0, false + } + return d, true +} diff --git a/kds-proxy/internal/cache/cache_test.go b/kds-proxy/internal/cache/cache_test.go new file mode 100644 index 0000000000..a8ba5f9f2b --- /dev/null +++ b/kds-proxy/internal/cache/cache_test.go @@ -0,0 +1,94 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "net/http" + "testing" + "time" +) + +func TestCacheControlMaxAge(t *testing.T) { + cases := []struct { + name string + header string + want time.Duration + ok bool + }{ + {"empty", "", 0, false}, + {"max-age 60", "max-age=60", 60 * time.Second, true}, + {"public max-age", "public, max-age=3600", 3600 * time.Second, true}, + {"no-store", "no-store, max-age=60", 0, false}, + {"no-cache", "no-cache", 0, false}, + {"zero", "max-age=0", 0, false}, + {"garbage", "max-age=NaN", 0, false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + h := http.Header{} + if c.header != "" { + h.Set("Cache-Control", c.header) + } + got, ok := cacheControlMaxAge(h) + if ok != c.ok { + t.Fatalf("ok=%v want %v", ok, c.ok) + } + if got != c.want { + t.Fatalf("dur=%v want %v", got, c.want) + } + }) + } +} + +func TestPutGetRoundTrip(t *testing.T) { + dir := t.TempDir() + c, err := New(dir) + if err != nil { + t.Fatal(err) + } + h := http.Header{} + h.Set("Cache-Control", "max-age=120") + if _, err := c.Put("https://example/x", 200, h, []byte("hello")); err != nil { + t.Fatal(err) + } + e, fresh := c.Get("https://example/x") + if e == nil || !fresh { + t.Fatalf("expected fresh entry, got %+v fresh=%v", e, fresh) + } + if string(e.Body) != "hello" { + t.Fatalf("body=%q", e.Body) + } + + // Reopen from disk. + c2, err := New(dir) + if err != nil { + t.Fatal(err) + } + e2, fresh2 := c2.Get("https://example/x") + if e2 == nil || !fresh2 || string(e2.Body) != "hello" { + t.Fatalf("disk reload failed: %+v fresh=%v", e2, fresh2) + } +} + +func TestStaleEntryStillReturned(t *testing.T) { + dir := t.TempDir() + c, err := New(dir) + if err != nil { + t.Fatal(err) + } + h := http.Header{} + h.Set("Cache-Control", "max-age=1") + e, err := c.Put("https://example/x", 200, h, []byte("hi")) + if err != nil { + t.Fatal(err) + } + e.FreshUntil = time.Now().Add(-time.Hour) + got, fresh := c.Get("https://example/x") + if got == nil { + t.Fatal("entry vanished") + } + if fresh { + t.Fatal("expected stale, got fresh") + } +} diff --git a/kds-proxy/internal/proxy/proxy.go b/kds-proxy/internal/proxy/proxy.go new file mode 100644 index 0000000000..24a9c27250 --- /dev/null +++ b/kds-proxy/internal/proxy/proxy.go @@ -0,0 +1,208 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package proxy + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "sync/atomic" + + "github.com/edgelesssys/contrast/kds-proxy/internal/allowlist" + "github.com/edgelesssys/contrast/kds-proxy/internal/ca" + "github.com/edgelesssys/contrast/kds-proxy/internal/cache" + "github.com/edgelesssys/contrast/kds-proxy/internal/upstream" +) + +// AllowFunc decides whether the proxy will tunnel to host. +type AllowFunc func(host string) bool + +// Server is the HTTP-level proxy server. +type Server struct { + log *slog.Logger + ca *ca.CA + cache *cache.Cache + upstream *upstream.Fetcher + allows AllowFunc + + hits atomic.Uint64 + misses atomic.Uint64 + stale atomic.Uint64 + rejected atomic.Uint64 + errors atomic.Uint64 +} + +// New constructs a Server. If allows is nil, the default allowlist is used. +func New(log *slog.Logger, ca *ca.CA, c *cache.Cache, u *upstream.Fetcher, allows AllowFunc) *Server { + if allows == nil { + allows = allowlist.Allows + } + return &Server{log: log, ca: ca, cache: c, upstream: u, allows: allows} +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodConnect: + s.handleConnect(w, r) + case http.MethodGet: + switch r.URL.Path { + case "/healthz": + _, _ = w.Write([]byte("ok")) + case "/metrics": + s.writeMetrics(w) + case "/ca.crt": + w.Header().Set("Content-Type", "application/x-pem-file") + _, _ = w.Write(s.ca.CertPEM()) + default: + http.Error(w, "kds-proxy: direct requests not supported", http.StatusBadRequest) + } + default: + s.rejected.Add(1) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleConnect(w http.ResponseWriter, r *http.Request) { + host := r.URL.Host + if host == "" { + host = r.Host + } + sniHost := stripPort(host) + if !s.allows(sniHost) { + s.rejected.Add(1) + s.log.Warn("rejecting CONNECT to disallowed host", "host", host) + http.Error(w, "host not allowed", http.StatusForbidden) + return + } + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijack unsupported", http.StatusInternalServerError) + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + s.log.Error("hijack failed", "err", err) + return + } + defer clientConn.Close() + + if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { + return + } + + certPEM, keyPEM, err := s.ca.LeafPEM(sniHost) + if err != nil { + s.log.Error("minting leaf", "host", sniHost, "err", err) + return + } + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + s.log.Error("loading leaf keypair", "err", err) + return + } + tlsConn := tls.Server(clientConn, &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + MinVersion: tls.VersionTLS12, + }) + if err := tlsConn.HandshakeContext(r.Context()); err != nil { + s.log.Debug("TLS handshake with client failed", "host", sniHost, "err", err) + return + } + defer tlsConn.Close() + + s.serveTunneled(r.Context(), tlsConn, sniHost) +} + +func (s *Server) serveTunneled(ctx context.Context, conn net.Conn, host string) { + br := bufio.NewReader(conn) + for { + req, err := http.ReadRequest(br) + if err != nil { + if !errors.Is(err, io.EOF) { + s.log.Debug("reading tunneled request", "host", host, "err", err) + } + return + } + if req.Method != http.MethodGet { + s.rejected.Add(1) + writeStatus(conn, http.StatusMethodNotAllowed, "only GET allowed") + return + } + req.URL.Scheme = "https" + req.URL.Host = host + fullURL := req.URL.String() + + if entry, fresh := s.cache.Get(fullURL); fresh { + s.hits.Add(1) + s.log.Info("cache hit", "url", fullURL) + writeRaw(conn, entry.Status, entry.Header, entry.Body) + continue + } + + res, err := s.upstream.Get(ctx, fullURL) + if err != nil { + if entry, _ := s.cache.Get(fullURL); entry != nil { + s.stale.Add(1) + s.log.Warn("serving stale on upstream error", "url", fullURL, "err", err) + writeRaw(conn, entry.Status, entry.Header, entry.Body) + continue + } + s.errors.Add(1) + s.log.Error("upstream fetch failed and no cache entry", "url", fullURL, "err", err) + writeStatus(conn, http.StatusBadGateway, "upstream unavailable") + return + } + + s.misses.Add(1) + s.log.Info("cache miss, fetched upstream", "url", fullURL, "status", res.Status) + entry, err := s.cache.Put(fullURL, res.Status, res.Header, res.Body) + if err != nil { + s.log.Error("cache write failed", "url", fullURL, "err", err) + writeRaw(conn, res.Status, res.Header, res.Body) + continue + } + writeRaw(conn, entry.Status, entry.Header, entry.Body) + } +} + +func (s *Server) writeMetrics(w http.ResponseWriter) { + fmt.Fprintf(w, "kds_proxy_cache_hits %d\n", s.hits.Load()) + fmt.Fprintf(w, "kds_proxy_cache_misses %d\n", s.misses.Load()) + fmt.Fprintf(w, "kds_proxy_cache_stale %d\n", s.stale.Load()) + fmt.Fprintf(w, "kds_proxy_rejected %d\n", s.rejected.Load()) + fmt.Fprintf(w, "kds_proxy_upstream_errors %d\n", s.errors.Load()) +} + +func writeRaw(conn net.Conn, status int, header http.Header, body []byte) { + resp := &http.Response{ + Status: http.StatusText(status), + StatusCode: status, + ProtoMajor: 1, + ProtoMinor: 1, + Header: header, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } + _ = resp.Write(conn) +} + +func writeStatus(conn net.Conn, status int, msg string) { + h := http.Header{} + h.Set("Content-Type", "text/plain; charset=utf-8") + writeRaw(conn, status, h, []byte(msg)) +} + +func stripPort(host string) string { + if h, _, err := net.SplitHostPort(host); err == nil { + return h + } + return host +} diff --git a/kds-proxy/internal/proxy/proxy_test.go b/kds-proxy/internal/proxy/proxy_test.go new file mode 100644 index 0000000000..b8a4f53c2b --- /dev/null +++ b/kds-proxy/internal/proxy/proxy_test.go @@ -0,0 +1,149 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package proxy + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/edgelesssys/contrast/kds-proxy/internal/ca" + "github.com/edgelesssys/contrast/kds-proxy/internal/cache" + "github.com/edgelesssys/contrast/kds-proxy/internal/upstream" +) + +func TestForwardProxyEndToEnd(t *testing.T) { + var upstreamHits atomic.Int64 + upstreamSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits.Add(1) + w.Header().Set("Cache-Control", "max-age=3600") + _, _ = fmt.Fprintf(w, "vcek bytes for %s", r.URL.Path) + })) + defer upstreamSrv.Close() + upstreamURL, err := url.Parse(upstreamSrv.URL) + if err != nil { + t.Fatal(err) + } + + dialer := &net.Dialer{} + fetchClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if strings.HasPrefix(addr, "kdsintf.amd.com:") { + return dialer.DialContext(ctx, network, upstreamURL.Host) + } + return dialer.DialContext(ctx, network, addr) + }, + }, + Timeout: 5 * time.Second, + } + + authority, err := ca.LoadOrGenerate(t.TempDir()) + if err != nil { + t.Fatal(err) + } + cch, err := cache.New(t.TempDir()) + if err != nil { + t.Fatal(err) + } + srv := New(slog.New(slog.DiscardHandler), + authority, cch, upstream.New(fetchClient), nil) + + proxySrv := httptest.NewServer(srv) + defer proxySrv.Close() + proxyURL, _ := url.Parse(proxySrv.URL) + + clientPool := x509.NewCertPool() + if ok := clientPool.AppendCertsFromPEM(authority.CertPEM()); !ok { + t.Fatal("failed to append CA to pool") + } + clientCC := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{RootCAs: clientPool}, + }, + Timeout: 5 * time.Second, + } + + body1 := mustGet(t, clientCC, "https://kdsintf.amd.com/vcek/v1/Milan/abc") + if want := "vcek bytes for /vcek/v1/Milan/abc"; body1 != want { + t.Fatalf("body=%q want %q", body1, want) + } + if got := upstreamHits.Load(); got != 1 { + t.Fatalf("upstream hits=%d want 1", got) + } + + body2 := mustGet(t, clientCC, "https://kdsintf.amd.com/vcek/v1/Milan/abc") + if body2 != body1 { + t.Fatalf("body mismatch") + } + if got := upstreamHits.Load(); got != 1 { + t.Fatalf("upstream hits=%d want still 1", got) + } +} + +func TestRejectsDisallowedHost(t *testing.T) { + authority, err := ca.LoadOrGenerate(t.TempDir()) + if err != nil { + t.Fatal(err) + } + cch, err := cache.New(t.TempDir()) + if err != nil { + t.Fatal(err) + } + srv := New(slog.New(slog.DiscardHandler), + authority, cch, upstream.New(&http.Client{Timeout: time.Second}), nil) + proxySrv := httptest.NewServer(srv) + defer proxySrv.Close() + proxyURL, _ := url.Parse(proxySrv.URL) + + c := &http.Client{ + Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}, + Timeout: 2 * time.Second, + } + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://evil.example.com/", nil) + if err != nil { + t.Fatal(err) + } + resp, err := c.Do(req) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("expected error connecting to disallowed host") + } +} + +func mustGet(t *testing.T, c *http.Client, url string) string { + t.Helper() + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, url, nil) + if err != nil { + t.Fatalf("build request %s: %v", url, err) + } + resp, err := c.Do(req) + if err != nil { + t.Fatalf("GET %s: %v", url, err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d", resp.StatusCode) + } + return string(b) +} diff --git a/kds-proxy/internal/upstream/upstream.go b/kds-proxy/internal/upstream/upstream.go new file mode 100644 index 0000000000..f08b2bba7c --- /dev/null +++ b/kds-proxy/internal/upstream/upstream.go @@ -0,0 +1,63 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package upstream + +import ( + "context" + "fmt" + "io" + "net/http" + + "golang.org/x/sync/singleflight" +) + +// Result is a captured upstream response. +type Result struct { + Status int + Header http.Header + Body []byte +} + +// Fetcher pulls upstream responses. +type Fetcher struct { + client *http.Client + group singleflight.Group +} + +// New returns a Fetcher backed by client. The client must not be pre-configured with a proxy. +func New(client *http.Client) *Fetcher { + return &Fetcher{client: client} +} + +// Get fetches the given url. +func (f *Fetcher) Get(ctx context.Context, url string) (*Result, error) { + v, err, _ := f.group.Do(url, func() (any, error) { + return f.doGet(ctx, url) + }) + if err != nil { + return nil, err + } + res, ok := v.(*Result) + if !ok { + return nil, fmt.Errorf("unexpected singleflight result type %T", v) + } + return res, nil +} + +func (f *Fetcher) doGet(ctx context.Context, url string) (*Result, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + resp, err := f.client.Do(req) + if err != nil { + return nil, fmt.Errorf("upstream fetch: %w", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading upstream body: %w", err) + } + return &Result{Status: resp.StatusCode, Header: resp.Header, Body: body}, nil +} diff --git a/kds-proxy/main.go b/kds-proxy/main.go new file mode 100644 index 0000000000..48b4cfec9c --- /dev/null +++ b/kds-proxy/main.go @@ -0,0 +1,81 @@ +// Copyright 2026 Edgeless Systems GmbH +// SPDX-License-Identifier: BUSL-1.1 + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log/slog" + "net/http" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/edgelesssys/contrast/kds-proxy/internal/ca" + "github.com/edgelesssys/contrast/kds-proxy/internal/cache" + "github.com/edgelesssys/contrast/kds-proxy/internal/proxy" + "github.com/edgelesssys/contrast/kds-proxy/internal/upstream" +) + +var version = "0.0.0-dev" + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} + +func run() error { + var ( + addr = flag.String("addr", ":3128", "listen address") + stateDir = flag.String("state-dir", "/var/lib/kds-proxy", "directory for CA and cache state") + upstreamTimeout = flag.Duration("upstream-timeout", 10*time.Second, "per-request upstream timeout") + ) + flag.Parse() + + log := slog.New(slog.NewTextHandler(os.Stderr, nil)) + log.Info("kds-proxy starting", "version", version, "addr", *addr, "stateDir", *stateDir) + + authority, err := ca.LoadOrGenerate(filepath.Join(*stateDir, "ca")) + if err != nil { + return fmt.Errorf("CA init: %w", err) + } + c, err := cache.New(filepath.Join(*stateDir, "cache")) + if err != nil { + return fmt.Errorf("cache init: %w", err) + } + fetcher := upstream.New(&http.Client{Timeout: *upstreamTimeout}) + + httpSrv := &http.Server{ + Addr: *addr, + Handler: proxy.New(log, authority, c, fetcher, nil), + ReadHeaderTimeout: 10 * time.Second, + } + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { errCh <- httpSrv.ListenAndServe() }() + + select { + case err := <-errCh: + if !errors.Is(err, http.ErrServerClosed) { + return err + } + case <-ctx.Done(): + log.Info("shutdown signal received") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + log.Error("graceful shutdown failed", "err", err) + } + } + return nil +} diff --git a/packages/by-name/kds-proxy/package.nix b/packages/by-name/kds-proxy/package.nix new file mode 100644 index 0000000000..3fda082b5b --- /dev/null +++ b/packages/by-name/kds-proxy/package.nix @@ -0,0 +1,44 @@ +# Copyright 2026 Edgeless Systems GmbH +# SPDX-License-Identifier: BUSL-1.1 + +{ lib, buildGoModule }: + +buildGoModule (finalAttrs: { + pname = "kds-proxy"; + version = builtins.readFile ../../../version.txt; + + src = + let + inherit (lib) fileset path hasSuffix; + root = ../../../.; + in + fileset.toSource { + inherit root; + fileset = fileset.unions [ + (path.append root "kds-proxy/go.mod") + (path.append root "kds-proxy/go.sum") + (fileset.fileFilter (file: hasSuffix ".go" file.name) (path.append root "kds-proxy")) + ]; + }; + + proxyVendor = true; + vendorHash = "sha256-UpUfi+SkMdjdz5xzIqfGQuPsdLC+D6mD9ObCgFeuuoQ="; + + sourceRoot = "${finalAttrs.src.name}/kds-proxy"; + subPackages = [ "." ]; + + env.CGO_ENABLED = 0; + + ldflags = [ + "-s" + "-X main.version=v${finalAttrs.version}" + ]; + + checkPhase = '' + runHook preCheck + go test ./... + runHook postCheck + ''; + + meta.mainProgram = "kds-proxy"; +}) diff --git a/packages/containers.nix b/packages/containers.nix index 7e11a6cb8d..2bb66bccc7 100644 --- a/packages/containers.nix +++ b/packages/containers.nix @@ -73,6 +73,16 @@ ]; }; + kds-proxy = contrastPkgs.buildOciImage { + name = "kds-proxy"; + tag = "v${contrastPkgs.kds-proxy.version}"; + copyToRoot = (with pkgs; [ busybox ]) ++ (with dockerTools; [ caCertificates ]); + config = { + Entrypoint = [ "${contrastPkgs.kds-proxy}/bin/kds-proxy" ]; + Env = [ "PATH=/bin" ]; + }; + }; + service-mesh-proxy = contrastPkgs.buildOciImage { name = "service-mesh-proxy"; tag = "v${contrastPkgs.service-mesh.version}";