From dee133c6be6f7966b96fd8fc31baa3ad90a34272 Mon Sep 17 00:00:00 2001 From: Bryan Cox Date: Thu, 23 Apr 2026 10:21:13 -0400 Subject: [PATCH 1/7] refactor(cli): reduce cyclomatic complexity in CLI commands - Extract helper functions from cluster create, install, and infra commands to reduce function complexity below gocyclo threshold - Add empty key validation to parseKeyValuePairs for better CLI error feedback - Tighten Azure infra output file permissions to 0600 - Add VPC cleanup defer in AWS DNS zone creation - Fix public subnet passing for proxy host creation - Improve test assertions: use floor-based counts for monitoring and RBAC tests - Add behavior-driven unit tests for all extracted functions - Enable gocyclo linter with threshold 30 in .golangci.yml Signed-off-by: Bryan Cox Commit-Message-Assisted-by: Claude (via Claude Code) --- .golangci.yml | 1 + cmd/cluster/core/create.go | 454 +++++----- cmd/cluster/core/create_test.go | 840 ++++++++++++++++++ cmd/fix/dr_oidc_iam.go | 147 +-- cmd/infra/aws/create.go | 207 +++-- cmd/infra/azure/create.go | 113 ++- cmd/install/assets/hypershift_operator.go | 612 ++++++------- .../assets/hypershift_operator_test.go | 543 +++++++++++ cmd/install/install.go | 177 ++-- cmd/install/install_test.go | 787 ++++++++++++++++ 10 files changed, 3081 insertions(+), 800 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index f3657cf77a0..1fb15433935 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,6 +3,7 @@ run: allow-parallel-runners: true linters: enable: + - gocyclo - misspell - unparam settings: diff --git a/cmd/cluster/core/create.go b/cmd/cluster/core/create.go index 9dd96f94c1d..4c3ccc771e0 100644 --- a/cmd/cluster/core/create.go +++ b/cmd/cluster/core/create.go @@ -236,51 +236,25 @@ func (r *resources) asObjects() []crclient.Object { func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, error) { prototype := &resources{} - // allow client side defaulting when release image is empty but release stream is set. - if len(opts.ReleaseImage) == 0 && len(opts.ReleaseStream) != 0 { - client, err := util.GetClient() - if err != nil { - return nil, fmt.Errorf("failed to get client: %w", err) - } - defaultVersion, err := supportedversion.LookupDefaultOCPVersion(ctx, opts.ReleaseStream, client) - if err != nil { - return nil, fmt.Errorf("release image is required when unable to lookup default OCP version: %w", err) - } - opts.ReleaseImage = defaultVersion.PullSpec + if err := resolveReleaseImage(ctx, opts); err != nil { + return nil, err } - annotations := map[string]string{} - for _, s := range opts.Annotations { - pair := strings.SplitN(s, "=", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid annotation: %s", s) - } - k, v := pair[0], pair[1] - annotations[k] = v + annotations, err := parseKeyValuePairs(opts.Annotations, "annotation") + if err != nil { + return nil, err } - - labels := map[string]string{} - for _, s := range opts.Labels { - pair := strings.SplitN(s, "=", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid label: %s", s) - } - k, v := pair[0], pair[1] - labels[k] = v + labels, err := parseKeyValuePairs(opts.Labels, "label") + if err != nil { + return nil, err } - if len(opts.ControlPlaneOperatorImage) > 0 { annotations[hyperv1.ControlPlaneOperatorImageAnnotation] = opts.ControlPlaneOperatorImage } - pullSecret := opts.PullSecret - var err error - // overrides if pullSecretFile is set - if len(opts.PullSecretFile) > 0 { - pullSecret, err = os.ReadFile(opts.PullSecretFile) - if err != nil { - return nil, fmt.Errorf("failed to read pull secret file: %w", err) - } + pullSecret, err := resolvePullSecret(opts) + if err != nil { + return nil, err } prototype.Namespace = &corev1.Namespace{ @@ -350,49 +324,124 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e }, } + applyClusterCapabilities(prototype.Cluster, opts) + if err := applyEtcdConfig(prototype.Cluster, opts); err != nil { + return nil, err + } + if err := applySSHKey(prototype, opts); err != nil { + return nil, err + } + if err := applyPausedUntil(prototype.Cluster, opts); err != nil { + return nil, err + } + applyOLMConfig(prototype.Cluster, opts) + if err := applyNetworkConfig(prototype.Cluster, opts); err != nil { + return nil, err + } + applySchedulingConfig(prototype.Cluster, opts) + if err := applyTrustBundleAndImageSources(prototype, opts); err != nil { + return nil, err + } + applyFeatureSet(prototype.Cluster, opts) + + if len(opts.KubeAPIServerDNSName) > 0 { + if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { + return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + } + prototype.Cluster.Spec.KubeAPIServerDNSName = opts.KubeAPIServerDNSName + } + + return prototype, nil +} + +func resolveReleaseImage(ctx context.Context, opts *CreateOptions) error { + if len(opts.ReleaseImage) != 0 || len(opts.ReleaseStream) == 0 { + return nil + } + client, err := util.GetClient() + if err != nil { + return fmt.Errorf("failed to get client: %w", err) + } + defaultVersion, err := supportedversion.LookupDefaultOCPVersion(ctx, opts.ReleaseStream, client) + if err != nil { + return fmt.Errorf("release image is required when unable to lookup default OCP version: %w", err) + } + opts.ReleaseImage = defaultVersion.PullSpec + return nil +} + +func parseKeyValuePairs(items []string, kind string) (map[string]string, error) { + result := map[string]string{} + for _, s := range items { + pair := strings.SplitN(s, "=", 2) + if len(pair) != 2 { + return nil, fmt.Errorf("invalid %s: %s", kind, s) + } + if pair[0] == "" { + return nil, fmt.Errorf("invalid %s: key must not be empty in %q", kind, s) + } + result[pair[0]] = pair[1] + } + return result, nil +} + +func resolvePullSecret(opts *CreateOptions) ([]byte, error) { + if len(opts.PullSecretFile) > 0 { + data, err := os.ReadFile(opts.PullSecretFile) + if err != nil { + return nil, fmt.Errorf("failed to read pull secret file: %w", err) + } + return data, nil + } + return opts.PullSecret, nil +} + +func applyClusterCapabilities(cluster *hyperv1.HostedCluster, opts *CreateOptions) { if len(opts.EnableClusterCapabilities) > 0 { caps := make([]hyperv1.OptionalCapability, len(opts.EnableClusterCapabilities)) for i, c := range opts.EnableClusterCapabilities { caps[i] = hyperv1.OptionalCapability(c) } - prototype.Cluster.Spec.Capabilities.Enabled = caps + cluster.Spec.Capabilities.Enabled = caps } - if len(opts.DisableClusterCapabilities) > 0 { caps := make([]hyperv1.OptionalCapability, len(opts.DisableClusterCapabilities)) for i, c := range opts.DisableClusterCapabilities { caps[i] = hyperv1.OptionalCapability(c) } - prototype.Cluster.Spec.Capabilities.Disabled = caps + cluster.Spec.Capabilities.Disabled = caps } +} +func applyEtcdConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { if opts.EtcdStorageClass != "" { - prototype.Cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName = ptr.To(opts.EtcdStorageClass) + cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName = ptr.To(opts.EtcdStorageClass) } - if opts.EtcdStorageSize != "" { etcdStorageSize, err := resource.ParseQuantity(opts.EtcdStorageSize) if err != nil { - return nil, fmt.Errorf("failed parse ectd storage size: %w", err) + return fmt.Errorf("failed parse ectd storage size: %w", err) } - prototype.Cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size = &etcdStorageSize + cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size = &etcdStorageSize } + return nil +} +func applySSHKey(prototype *resources, opts *CreateOptions) error { sshKey, sshPrivateKey := opts.PublicKey, opts.PrivateKey - // overrides secret if SSHKeyFile is set + var err error if len(opts.SSHKeyFile) > 0 { if opts.GenerateSSH { - return nil, fmt.Errorf("--generate-ssh and --ssh-key cannot be specified together") + return fmt.Errorf("--generate-ssh and --ssh-key cannot be specified together") } - key, err := os.ReadFile(opts.SSHKeyFile) + sshKey, err = os.ReadFile(opts.SSHKeyFile) if err != nil { - return nil, fmt.Errorf("failed to read ssh key file: %w", err) + return fmt.Errorf("failed to read ssh key file: %w", err) } - sshKey = key } else if opts.GenerateSSH { sshKey, sshPrivateKey, err = util.GenerateSSHKeys() if err != nil { - return nil, fmt.Errorf("failed to generate ssh keys: %w", err) + return fmt.Errorf("failed to generate ssh keys: %w", err) } } if len(sshKey) > 0 { @@ -415,105 +464,96 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e } prototype.Cluster.Spec.SSHKey = corev1.LocalObjectReference{Name: prototype.SSHKey.Name} } + return nil +} - // validate pausedUntil value - // valid values are either "true" or RFC3339 format date - if len(opts.PausedUntil) > 0 && opts.PausedUntil != "true" { - _, err := time.Parse(time.RFC3339, opts.PausedUntil) - if err != nil { - return nil, fmt.Errorf("invalid pausedUntil value, should be \"true\" or a valid RFC3339 date format: %w", err) - } - prototype.Cluster.Spec.PausedUntil = &opts.PausedUntil +func applyPausedUntil(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { + if len(opts.PausedUntil) == 0 || opts.PausedUntil == "true" { + return nil + } + _, err := time.Parse(time.RFC3339, opts.PausedUntil) + if err != nil { + return fmt.Errorf("invalid pausedUntil value, should be \"true\" or a valid RFC3339 date format: %w", err) } + cluster.Spec.PausedUntil = &opts.PausedUntil + return nil +} +func applyOLMConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) { if opts.OLMDisableDefaultSources { - prototype.Cluster.Spec.Configuration.OperatorHub = &configv1.OperatorHubSpec{ + cluster.Spec.Configuration.OperatorHub = &configv1.OperatorHubSpec{ DisableAllDefaultSources: true, } } - if len(opts.OLMCatalogPlacement) > 0 { - prototype.Cluster.Spec.OLMCatalogPlacement = opts.OLMCatalogPlacement + cluster.Spec.OLMCatalogPlacement = opts.OLMCatalogPlacement } +} - var clusterNetworkEntries []hyperv1.ClusterNetworkEntry +func applyNetworkConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { for _, cidr := range opts.ClusterCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing ClusterCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing ClusterCIDR (%s): %w", cidr, err) } - clusterNetworkEntries = append(clusterNetworkEntries, hyperv1.ClusterNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.ClusterNetwork = append(cluster.Spec.Networking.ClusterNetwork, hyperv1.ClusterNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.ClusterNetwork = clusterNetworkEntries - - var serviceNetworkEntries []hyperv1.ServiceNetworkEntry for _, cidr := range opts.ServiceCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing ServiceCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing ServiceCIDR (%s): %w", cidr, err) } - serviceNetworkEntries = append(serviceNetworkEntries, hyperv1.ServiceNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.ServiceNetwork = append(cluster.Spec.Networking.ServiceNetwork, hyperv1.ServiceNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.ServiceNetwork = serviceNetworkEntries - - var machineNetworkEntries []hyperv1.MachineNetworkEntry for _, cidr := range opts.MachineCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing MachineCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing MachineCIDR (%s): %w", cidr, err) } - machineNetworkEntries = append(machineNetworkEntries, hyperv1.MachineNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.MachineNetwork = append(cluster.Spec.Networking.MachineNetwork, hyperv1.MachineNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.MachineNetwork = machineNetworkEntries if opts.DisableMultiNetwork { - if prototype.Cluster.Spec.OperatorConfiguration == nil { - prototype.Cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} - } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} - } - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.DisableMultiNetwork = &opts.DisableMultiNetwork + ensureClusterNetworkOperatorSpec(cluster) + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.DisableMultiNetwork = &opts.DisableMultiNetwork } - if opts.OVNKubernetesMTU > 0 { - if prototype.Cluster.Spec.OperatorConfiguration == nil { - prototype.Cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} + ensureClusterNetworkOperatorSpec(cluster) + if cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig == nil { + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig = &hyperv1.OVNKubernetesConfig{} } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} - } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig = &hyperv1.OVNKubernetesConfig{} - } - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig.MTU = opts.OVNKubernetesMTU + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig.MTU = opts.OVNKubernetesMTU } - if opts.AllocateNodeCIDRs { enabled := hyperv1.AllocateNodeCIDRsEnabled - prototype.Cluster.Spec.Networking.AllocateNodeCIDRs = &enabled + cluster.Spec.Networking.AllocateNodeCIDRs = &enabled } + return nil +} - if opts.NodeSelector != nil { - prototype.Cluster.Spec.NodeSelector = opts.NodeSelector +func ensureClusterNetworkOperatorSpec(cluster *hyperv1.HostedCluster) { + if cluster.Spec.OperatorConfiguration == nil { + cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} } - - if opts.PodsLabels != nil { - prototype.Cluster.Spec.Labels = opts.PodsLabels + if cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} } +} - for _, tStr := range opts.Tolerations { - toleration, err := parseTolerationString(tStr) - if err != nil { - return nil, err - } - prototype.Cluster.Spec.Tolerations = append(prototype.Cluster.Spec.Tolerations, *toleration) +func applySchedulingConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) { + if opts.NodeSelector != nil { + cluster.Spec.NodeSelector = opts.NodeSelector } + if opts.PodsLabels != nil { + cluster.Spec.Labels = opts.PodsLabels + } +} +func applyTrustBundleAndImageSources(prototype *resources, opts *CreateOptions) error { if len(opts.AdditionalTrustBundle) > 0 { userCABundle, err := os.ReadFile(opts.AdditionalTrustBundle) if err != nil { - return nil, fmt.Errorf("failed to read additional trust bundle file: %w", err) + return fmt.Errorf("failed to read additional trust bundle file: %w", err) } prototype.AdditionalTrustBundle = &corev1.ConfigMap{ TypeMeta: metav1.TypeMeta{ @@ -534,44 +574,42 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e if len(opts.ImageContentSources) > 0 { icspFileBytes, err := os.ReadFile(opts.ImageContentSources) if err != nil { - return nil, fmt.Errorf("failed to read image content sources file: %w", err) + return fmt.Errorf("failed to read image content sources file: %w", err) } - var imageContentSources []hyperv1.ImageContentSource err = yaml.Unmarshal(icspFileBytes, &imageContentSources) if err != nil { - return nil, fmt.Errorf("unable to deserialize image content sources file: %w", err) + return fmt.Errorf("unable to deserialize image content sources file: %w", err) } prototype.Cluster.Spec.ImageContentSources = imageContentSources } - if opts.FeatureSet != string(configv1.Default) { - switch opts.FeatureSet { - case string(configv1.TechPreviewNoUpgrade): - prototype.Cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ - FeatureGateSelection: configv1.FeatureGateSelection{ - FeatureSet: configv1.TechPreviewNoUpgrade, - }, - } - case string(configv1.DevPreviewNoUpgrade): - prototype.Cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ - FeatureGateSelection: configv1.FeatureGateSelection{ - FeatureSet: configv1.DevPreviewNoUpgrade, - }, - } - default: - return nil, fmt.Errorf("invalid feature set: %s", opts.FeatureSet) + for _, tStr := range opts.Tolerations { + toleration, err := parseTolerationString(tStr) + if err != nil { + return err } + prototype.Cluster.Spec.Tolerations = append(prototype.Cluster.Spec.Tolerations, *toleration) } - if len(opts.KubeAPIServerDNSName) > 0 { - if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { - return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + return nil +} + +func applyFeatureSet(cluster *hyperv1.HostedCluster, opts *CreateOptions) { + switch opts.FeatureSet { + case string(configv1.TechPreviewNoUpgrade): + cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ + FeatureGateSelection: configv1.FeatureGateSelection{ + FeatureSet: configv1.TechPreviewNoUpgrade, + }, + } + case string(configv1.DevPreviewNoUpgrade): + cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ + FeatureGateSelection: configv1.FeatureGateSelection{ + FeatureSet: configv1.DevPreviewNoUpgrade, + }, } - prototype.Cluster.Spec.KubeAPIServerDNSName = opts.KubeAPIServerDNSName } - - return prototype, nil } func apply(ctx context.Context, l logr.Logger, infraID string, objects []crclient.Object, waitForRollout bool, mutate func(crclient.Object)) error { @@ -680,81 +718,103 @@ func (opts *RawCreateOptions) Validate(ctx context.Context) (*ValidatedCreateOpt if opts.Name == "" { return nil, errors.New("--name is required") } - if opts.PullSecretFile == "" { return nil, errors.New("--pull-secret is required") } + if err := opts.validateVersionAndWait(ctx); err != nil { + return nil, err + } + + errs := validation.IsDNS1123Label(opts.Name) + if len(errs) > 0 { + return nil, fmt.Errorf("HostedCluster name failed RFC1123 validation: %s", strings.Join(errs[:], " ")) + } + + if err := opts.validateClusterExistence(ctx); err != nil { + return nil, err + } + if err := opts.validateArchAndFeatureSet(); err != nil { + return nil, err + } + if err := opts.validateCapabilities(); err != nil { + return nil, err + } + if err := opts.validateNetworkOptions(); err != nil { + return nil, err + } + + return &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: opts, + }, + }, nil +} +func (opts *RawCreateOptions) validateVersionAndWait(ctx context.Context) error { if opts.VersionCheck { versionCLI := supportedversion.GetRevision() client, err := util.GetClient() if err != nil { - return nil, fmt.Errorf("failed to get client: %w", err) + return fmt.Errorf("failed to get client: %w", err) } if err := validateVersion(ctx, versionCLI, client); err != nil { - return nil, fmt.Errorf("version validation failed: %w", err) + return fmt.Errorf("version validation failed: %w", err) } } - if opts.Wait && opts.NodePoolReplicas < 1 { - return nil, errors.New("--wait requires --node-pool-replicas > 0") + return errors.New("--wait requires --node-pool-replicas > 0") } + return nil +} - // Validate HostedCluster name follows RFC1123 standard - // https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names - errs := validation.IsDNS1123Label(opts.Name) - if len(errs) > 0 { - return nil, fmt.Errorf("HostedCluster name failed RFC1123 validation: %s", strings.Join(errs[:], " ")) +func (opts *RawCreateOptions) validateClusterExistence(ctx context.Context) error { + if opts.Render { + return nil + } + client, err := util.GetClient() + if err != nil { + return err + } + cluster := &hyperv1.HostedCluster{ObjectMeta: metav1.ObjectMeta{Namespace: opts.Namespace, Name: opts.Name}} + if err := client.Get(ctx, crclient.ObjectKeyFromObject(cluster), cluster); err == nil { + return fmt.Errorf("hostedcluster %s already exists", crclient.ObjectKeyFromObject(cluster)) + } else if !apierrors.IsNotFound(err) { + return fmt.Errorf("hostedcluster doesn't exist validation failed with error: %w", err) } - if !opts.Render { - client, err := util.GetClient() - if err != nil { - return nil, err - } - // Validate HostedCluster with this name doesn't exist in the namespace - cluster := &hyperv1.HostedCluster{ObjectMeta: metav1.ObjectMeta{Namespace: opts.Namespace, Name: opts.Name}} - if err := client.Get(ctx, crclient.ObjectKeyFromObject(cluster), cluster); err == nil { - return nil, fmt.Errorf("hostedcluster %s already exists", crclient.ObjectKeyFromObject(cluster)) - } else if !apierrors.IsNotFound(err) { - return nil, fmt.Errorf("hostedcluster doesn't exist validation failed with error: %w", err) - } - - // Validate multi-arch aspects - kc, err := hyperutil.GetKubeClientSet() - if err != nil { - return nil, fmt.Errorf("could not retrieve kube clientset: %w", err) - } - if err := validateMgmtClusterAndNodePoolCPUArchitectures(ctx, opts, kc, &hyperutil.RegistryClientImageMetadataProvider{}); err != nil { - if strings.Contains(err.Error(), "failed to retrieve manifest") { - opts.Log.Info("WARNING: Unable to access the payload, skipping the Architectures check.", "error", err.Error()) - } else { - return nil, err - } + kc, err := hyperutil.GetKubeClientSet() + if err != nil { + return fmt.Errorf("could not retrieve kube clientset: %w", err) + } + if err := validateMgmtClusterAndNodePoolCPUArchitectures(ctx, opts, kc, &hyperutil.RegistryClientImageMetadataProvider{}); err != nil { + if strings.Contains(err.Error(), "failed to retrieve manifest") { + opts.Log.Info("WARNING: Unable to access the payload, skipping the Architectures check.", "error", err.Error()) + } else { + return err } } + return nil +} +func (opts *RawCreateOptions) validateArchAndFeatureSet() error { arch := strings.ToLower(opts.Arch) switch arch { - case hyperv1.ArchitectureAMD64: - case hyperv1.ArchitectureARM64: - case hyperv1.ArchitecturePPC64LE: - case hyperv1.ArchitectureS390X: + case hyperv1.ArchitectureAMD64, hyperv1.ArchitectureARM64, hyperv1.ArchitecturePPC64LE, hyperv1.ArchitectureS390X: default: - return nil, fmt.Errorf("specified arch %q is not supported", opts.Arch) + return fmt.Errorf("specified arch %q is not supported", opts.Arch) } - // Validate feature set is "", TechPreviewNoUpgrade, or DevPreviewNoUpgrade switch opts.FeatureSet { - case string(configv1.Default): - case string(configv1.TechPreviewNoUpgrade): - case string(configv1.DevPreviewNoUpgrade): + case string(configv1.Default), string(configv1.TechPreviewNoUpgrade), string(configv1.DevPreviewNoUpgrade): case string(configv1.CustomNoUpgrade): - return nil, fmt.Errorf("only a predefined feature set is supported by the feature-set flag") + return fmt.Errorf("only a predefined feature set is supported by the feature-set flag") default: - return nil, fmt.Errorf("specified feature set %q is not supported", opts.FeatureSet) + return fmt.Errorf("specified feature set %q is not supported", opts.FeatureSet) } + return nil +} +func (opts *RawCreateOptions) validateCapabilities() error { acceptedValues := sets.NewString( string(hyperv1.ImageRegistryCapability), string(hyperv1.OpenShiftSamplesCapability), @@ -764,54 +824,44 @@ func (opts *RawCreateOptions) Validate(ctx context.Context) (*ValidatedCreateOpt string(hyperv1.NodeTuningCapability), string(hyperv1.IngressCapability), ) - if len(opts.DisableClusterCapabilities) > 0 { - for _, capability := range opts.DisableClusterCapabilities { - if !acceptedValues.Has(capability) { - return nil, fmt.Errorf("unknown disabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) - } + for _, capability := range opts.DisableClusterCapabilities { + if !acceptedValues.Has(capability) { + return fmt.Errorf("unknown disabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) } } - if len(opts.EnableClusterCapabilities) > 0 { - for _, capability := range opts.EnableClusterCapabilities { - if !acceptedValues.Has(capability) { - return nil, fmt.Errorf("unknown enabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) - } + for _, capability := range opts.EnableClusterCapabilities { + if !acceptedValues.Has(capability) { + return fmt.Errorf("unknown enabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) } } - disabledCaps := sets.NewString(opts.DisableClusterCapabilities...) if disabledCaps.Has(string(hyperv1.IngressCapability)) && !disabledCaps.Has(string(hyperv1.ConsoleCapability)) { - return nil, fmt.Errorf("ingress capability can only be disabled if Console capability is also disabled") + return fmt.Errorf("ingress capability can only be disabled if Console capability is also disabled") } - if len(opts.KubeAPIServerDNSName) > 0 { if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { - return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + return fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) } } + return nil +} +func (opts *RawCreateOptions) validateNetworkOptions() error { if opts.DisableMultiNetwork && opts.NetworkType != "Other" { - return nil, fmt.Errorf("disableMultiNetwork is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) + return fmt.Errorf("disableMultiNetwork is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) } - if opts.OVNKubernetesMTU != 0 { if opts.NetworkType != string(hyperv1.OVNKubernetes) { - return nil, fmt.Errorf("--ovn-kubernetes-mtu is only valid when --network-type is OVNKubernetes (got '%s')", opts.NetworkType) + return fmt.Errorf("--ovn-kubernetes-mtu is only valid when --network-type is OVNKubernetes (got '%s')", opts.NetworkType) } if opts.OVNKubernetesMTU < 576 || opts.OVNKubernetesMTU > 9216 { - return nil, fmt.Errorf("--ovn-kubernetes-mtu must be between 576 and 9216 (got %d)", opts.OVNKubernetesMTU) + return fmt.Errorf("--ovn-kubernetes-mtu must be between 576 and 9216 (got %d)", opts.OVNKubernetesMTU) } } - if opts.AllocateNodeCIDRs && opts.NetworkType != "Other" { - return nil, fmt.Errorf("allocateNodeCIDRs is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) + return fmt.Errorf("allocateNodeCIDRs is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) } - - return &ValidatedCreateOptions{ - validatedCreateOptions: &validatedCreateOptions{ - RawCreateOptions: opts, - }, - }, nil + return nil } // completedCreateOptions is a private wrapper that enforces a call of Complete() before CreateCluster() can be invoked. diff --git a/cmd/cluster/core/create_test.go b/cmd/cluster/core/create_test.go index e2c0e66257c..532f77e8ed6 100644 --- a/cmd/cluster/core/create_test.go +++ b/cmd/cluster/core/create_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "testing" + "time" . "github.com/onsi/gomega" @@ -13,6 +14,8 @@ import ( "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/dockerv1client" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" + configv1 "github.com/openshift/api/config/v1" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -887,3 +890,840 @@ func TestGetServicePublishingStrategyMapping(t *testing.T) { }) } } + +func TestParseKeyValuePairs(t *testing.T) { + tests := []struct { + name string + items []string + kind string + expected map[string]string + expectError string + }{ + { + name: "When valid key=value pairs are provided, it should return a map", + items: []string{"key1=value1", "key2=value2"}, + kind: "annotation", + expected: map[string]string{"key1": "value1", "key2": "value2"}, + }, + { + name: "When an empty list is provided, it should return an empty map", + items: []string{}, + kind: "label", + expected: map[string]string{}, + }, + { + name: "When a malformed pair without equals sign is provided, it should return an error", + items: []string{"badentry"}, + kind: "annotation", + expectError: "invalid annotation: badentry", + }, + { + name: "When a value contains an equals sign, it should split only on the first equals", + items: []string{"key=val=extra"}, + kind: "label", + expected: map[string]string{"key": "val=extra"}, + }, + { + name: "When a value is empty after equals sign, it should accept it", + items: []string{"key="}, + kind: "annotation", + expected: map[string]string{"key": ""}, + }, + { + name: "When the key is empty, it should return an error", + items: []string{"=value"}, + kind: "label", + expectError: "key must not be empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result, err := parseKeyValuePairs(tc.items, tc.kind) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(result).To(Equal(tc.expected)) + } + }) + } +} + +func TestApplyEtcdConfig(t *testing.T) { + tests := []struct { + name string + etcdStorageClass string + etcdStorageSize string + expectError string + expectStorageClass *string + expectSizeString string + }{ + { + name: "When no etcd storage options are provided, it should leave defaults unchanged", + etcdStorageClass: "", + etcdStorageSize: "", + }, + { + name: "When etcd storage class is provided, it should set the storage class", + etcdStorageClass: "gp3-csi", + expectStorageClass: ptr.To("gp3-csi"), + }, + { + name: "When a valid etcd storage size is provided, it should set the storage size", + etcdStorageSize: "8Gi", + expectSizeString: "8Gi", + }, + { + name: "When an invalid etcd storage size is provided, it should return an error", + etcdStorageSize: "notasize", + expectError: "failed parse ectd storage size", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Etcd: hyperv1.EtcdSpec{ + Managed: &hyperv1.ManagedEtcdSpec{ + Storage: hyperv1.ManagedEtcdStorageSpec{ + PersistentVolume: &hyperv1.PersistentVolumeEtcdStorageSpec{ + Size: &hyperv1.DefaultPersistentVolumeEtcdStorageSize, + }, + }, + }, + }, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + EtcdStorageClass: tc.etcdStorageClass, + EtcdStorageSize: tc.etcdStorageSize, + }, + }, + }, + }, + } + + err := applyEtcdConfig(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + if tc.expectStorageClass != nil { + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName).To(Equal(tc.expectStorageClass)) + } + if tc.expectSizeString != "" { + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size.String()).To(Equal(tc.expectSizeString)) + } + // When no storage options are provided, verify defaults remain unchanged + if tc.etcdStorageClass == "" && tc.etcdStorageSize == "" { + // StorageClassName should remain nil (default) + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName).To(BeNil(), "StorageClassName should remain nil when not specified") + // Size should remain at the default value + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size).To(Equal(&hyperv1.DefaultPersistentVolumeEtcdStorageSize), "Size should remain at default when not specified") + } + } + }) + } +} + +func TestApplyPausedUntil(t *testing.T) { + tests := []struct { + name string + pausedUntil string + expectError string + expectSet bool + }{ + { + name: "When pausedUntil is empty, it should not set PausedUntil on the cluster", + pausedUntil: "", + expectSet: false, + }, + { + name: "When pausedUntil is 'true', it should not set PausedUntil on the cluster", + pausedUntil: "true", + expectSet: false, + }, + { + name: "When pausedUntil is a valid RFC3339 date, it should set PausedUntil on the cluster", + pausedUntil: "2026-12-31T23:59:59Z", + expectSet: true, + }, + { + name: "When pausedUntil is an invalid date format, it should return an error", + pausedUntil: "not-a-date", + expectError: "invalid pausedUntil value", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + PausedUntil: tc.pausedUntil, + }, + }, + }, + }, + } + + err := applyPausedUntil(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + if tc.expectSet { + g.Expect(cluster.Spec.PausedUntil).NotTo(BeNil()) + g.Expect(*cluster.Spec.PausedUntil).To(Equal(tc.pausedUntil)) + } else { + g.Expect(cluster.Spec.PausedUntil).To(BeNil()) + } + } + }) + } +} + +func TestApplyOLMConfig(t *testing.T) { + tests := []struct { + name string + olmDisableDefaultSources bool + olmCatalogPlacement hyperv1.OLMCatalogPlacement + expectOperatorHub bool + expectCatalogPlacement hyperv1.OLMCatalogPlacement + }{ + { + name: "When OLM default sources are disabled, it should set DisableAllDefaultSources", + olmDisableDefaultSources: true, + olmCatalogPlacement: "", + expectOperatorHub: true, + }, + { + name: "When OLM catalog placement is set to Guest, it should set OLMCatalogPlacement", + olmCatalogPlacement: hyperv1.GuestOLMCatalogPlacement, + expectCatalogPlacement: hyperv1.GuestOLMCatalogPlacement, + }, + { + name: "When no OLM options are set, it should not modify the cluster", + expectOperatorHub: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Configuration: &hyperv1.ClusterConfiguration{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + OLMDisableDefaultSources: tc.olmDisableDefaultSources, + OLMCatalogPlacement: tc.olmCatalogPlacement, + }, + }, + }, + }, + } + + applyOLMConfig(cluster, opts) + if tc.expectOperatorHub { + g.Expect(cluster.Spec.Configuration.OperatorHub).NotTo(BeNil()) + g.Expect(cluster.Spec.Configuration.OperatorHub.DisableAllDefaultSources).To(BeTrue()) + } + if tc.expectCatalogPlacement != "" { + g.Expect(cluster.Spec.OLMCatalogPlacement).To(Equal(tc.expectCatalogPlacement)) + } + // When no OLM options are set, verify cluster remains unmodified + if !tc.olmDisableDefaultSources && tc.olmCatalogPlacement == "" { + g.Expect(cluster.Spec.Configuration.OperatorHub).To(BeNil(), "OperatorHub should remain nil when OLM options are not set") + g.Expect(cluster.Spec.OLMCatalogPlacement).To(BeEmpty(), "OLMCatalogPlacement should remain empty when not specified") + } + }) + } +} + +func TestApplyFeatureSet(t *testing.T) { + tests := []struct { + name string + featureSet string + expectFeatureGate bool + expectedFeatureSet configv1.FeatureSet + }{ + { + name: "When feature set is Default, it should not set FeatureGate", + featureSet: string(configv1.Default), + expectFeatureGate: false, + }, + { + name: "When feature set is TechPreviewNoUpgrade, it should set the TechPreviewNoUpgrade feature gate", + featureSet: string(configv1.TechPreviewNoUpgrade), + expectFeatureGate: true, + expectedFeatureSet: configv1.TechPreviewNoUpgrade, + }, + { + name: "When feature set is DevPreviewNoUpgrade, it should set the DevPreviewNoUpgrade feature gate", + featureSet: string(configv1.DevPreviewNoUpgrade), + expectFeatureGate: true, + expectedFeatureSet: configv1.DevPreviewNoUpgrade, + }, + { + name: "When feature set is an unrecognized value, it should not set FeatureGate", + featureSet: "SomethingElse", + expectFeatureGate: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Configuration: &hyperv1.ClusterConfiguration{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + FeatureSet: tc.featureSet, + }, + }, + }, + }, + } + + applyFeatureSet(cluster, opts) + if tc.expectFeatureGate { + g.Expect(cluster.Spec.Configuration.FeatureGate).NotTo(BeNil()) + g.Expect(cluster.Spec.Configuration.FeatureGate.FeatureSet).To(Equal(tc.expectedFeatureSet)) + } else { + g.Expect(cluster.Spec.Configuration.FeatureGate).To(BeNil()) + } + }) + } +} + +func TestApplyNetworkConfig(t *testing.T) { + tests := []struct { + name string + clusterCIDR []string + serviceCIDR []string + machineCIDR []string + expectError string + expectClusters int + expectServices int + expectMachines int + }{ + { + name: "When valid CIDRs are provided, it should parse and append them", + clusterCIDR: []string{"10.128.0.0/14"}, + serviceCIDR: []string{"172.30.0.0/16"}, + machineCIDR: []string{"10.0.0.0/16"}, + expectClusters: 1, + expectServices: 1, + expectMachines: 1, + }, + { + name: "When dual-stack CIDRs are provided, it should parse both", + clusterCIDR: []string{"10.128.0.0/14", "fd01::/48"}, + serviceCIDR: []string{"172.30.0.0/16", "fd02::/112"}, + expectClusters: 2, + expectServices: 2, + expectMachines: 0, + }, + { + name: "When an invalid cluster CIDR is provided, it should return an error", + clusterCIDR: []string{"not-a-cidr"}, + expectError: "parsing ClusterCIDR", + }, + { + name: "When an invalid service CIDR is provided, it should return an error", + serviceCIDR: []string{"not-a-cidr"}, + expectError: "parsing ServiceCIDR", + }, + { + name: "When an invalid machine CIDR is provided, it should return an error", + machineCIDR: []string{"not-a-cidr"}, + expectError: "parsing MachineCIDR", + }, + { + name: "When no CIDRs are provided, it should produce empty network lists", + expectClusters: 0, + expectServices: 0, + expectMachines: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + ClusterCIDR: tc.clusterCIDR, + ServiceCIDR: tc.serviceCIDR, + MachineCIDR: tc.machineCIDR, + }, + }, + }, + }, + } + + err := applyNetworkConfig(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(cluster.Spec.Networking.ClusterNetwork).To(HaveLen(tc.expectClusters)) + g.Expect(cluster.Spec.Networking.ServiceNetwork).To(HaveLen(tc.expectServices)) + g.Expect(cluster.Spec.Networking.MachineNetwork).To(HaveLen(tc.expectMachines)) + } + }) + } +} + +func TestApplySchedulingConfig(t *testing.T) { + tests := []struct { + name string + nodeSelector map[string]string + podsLabels map[string]string + expectNodeSelector map[string]string + expectLabels map[string]string + }{ + { + name: "When node selector is provided, it should set NodeSelector on the cluster", + nodeSelector: map[string]string{"role": "cp", "disk": "fast"}, + expectNodeSelector: map[string]string{"role": "cp", "disk": "fast"}, + }, + { + name: "When pods labels are provided, it should set Labels on the cluster", + podsLabels: map[string]string{"team": "hypershift"}, + expectLabels: map[string]string{"team": "hypershift"}, + }, + { + name: "When neither is provided, it should not modify the cluster", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + NodeSelector: tc.nodeSelector, + PodsLabels: tc.podsLabels, + }, + }, + }, + }, + } + + applySchedulingConfig(cluster, opts) + if tc.expectNodeSelector != nil { + g.Expect(cluster.Spec.NodeSelector).To(Equal(tc.expectNodeSelector)) + } + if tc.expectLabels != nil { + g.Expect(cluster.Spec.Labels).To(Equal(tc.expectLabels)) + } + // When neither node selector nor labels are provided, verify cluster remains unmodified + if tc.nodeSelector == nil && tc.podsLabels == nil { + g.Expect(cluster.Spec.NodeSelector).To(BeNil(), "NodeSelector should remain nil when not specified") + g.Expect(cluster.Spec.Labels).To(BeNil(), "Labels should remain nil when not specified") + } + }) + } +} + +func TestParseTolerationString(t *testing.T) { + tests := []struct { + name string + input string + expectError string + expected *corev1.Toleration + }{ + { + name: "When a full toleration is specified, it should parse all fields", + input: "key=node-role.kubernetes.io/master,operator=Exists,effect=NoSchedule", + expected: &corev1.Toleration{ + Key: "node-role.kubernetes.io/master", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + { + name: "When operator is Equal and a value is specified, it should parse correctly", + input: "key=mykey,value=myvalue,operator=Equal,effect=NoExecute", + expected: &corev1.Toleration{ + Key: "mykey", + Value: "myvalue", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoExecute, + }, + }, + { + name: "When tolerationSeconds is specified, it should parse the integer value", + input: "key=mykey,operator=Exists,effect=NoExecute,tolerationSeconds=300", + expected: &corev1.Toleration{ + Key: "mykey", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoExecute, + TolerationSeconds: ptr.To(int64(300)), + }, + }, + { + name: "When effect is PreferNoSchedule, it should normalize the case", + input: "key=mykey,effect=preferNoSchedule", + expected: &corev1.Toleration{ + Key: "mykey", + Effect: corev1.TaintEffectPreferNoSchedule, + }, + }, + { + name: "When an unknown operator type is provided, it should return an error", + input: "key=mykey,operator=Unknown", + expectError: "unknown operator type", + }, + { + name: "When an unknown effect type is provided, it should return an error", + input: "key=mykey,effect=Unknown", + expectError: "unknown effect type", + }, + { + name: "When a malformed key-value is provided, it should return an error", + input: "badformat", + expectError: "invalid toleration cli argument", + }, + { + name: "When an unknown field is provided, it should return an error", + input: "unknownfield=value", + expectError: "unknown field", + }, + { + name: "When tolerationSeconds is not an integer, it should return an error", + input: "key=mykey,tolerationSeconds=abc", + expectError: "failed to parse tolerationSeconds", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result, err := parseTolerationString(tc.input) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(result).To(Equal(tc.expected)) + } + }) + } +} + +func TestPostProcess(t *testing.T) { + t.Run("When secret encryption is nil, it should default to AESCBC", func(t *testing.T) { + g := NewWithT(t) + r := &resources{ + Cluster: &hyperv1.HostedCluster{}, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: "test-cluster", + Namespace: "clusters", + }, + }, + }, + }, + } + + postProcess(r, opts) + + g.Expect(r.Cluster.Spec.SecretEncryption).NotTo(BeNil()) + g.Expect(r.Cluster.Spec.SecretEncryption.Type).To(Equal(hyperv1.AESCBC)) + g.Expect(r.Cluster.Spec.SecretEncryption.AESCBC).NotTo(BeNil()) + g.Expect(r.Cluster.Spec.SecretEncryption.AESCBC.ActiveKey.Name).To(Equal("test-cluster-etcd-encryption-key")) + g.Expect(r.Resources).To(HaveLen(1)) + }) + + t.Run("When secret encryption is already set, it should not override it", func(t *testing.T) { + g := NewWithT(t) + r := &resources{ + Cluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + SecretEncryption: &hyperv1.SecretEncryptionSpec{ + Type: hyperv1.KMS, + }, + }, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: "test-cluster", + Namespace: "clusters", + }, + }, + }, + }, + } + + postProcess(r, opts) + + g.Expect(r.Cluster.Spec.SecretEncryption.Type).To(Equal(hyperv1.KMS)) + g.Expect(r.Resources).To(BeEmpty()) + }) +} + +func TestDefaultNodePool(t *testing.T) { + tests := []struct { + name string + clusterName string + suffix string + expectedName string + }{ + { + name: "When no suffix is provided, it should use the cluster name", + clusterName: "my-cluster", + suffix: "", + expectedName: "my-cluster", + }, + { + name: "When a suffix is provided, it should append it to the cluster name", + clusterName: "my-cluster", + suffix: "us-east-1a", + expectedName: "my-cluster-us-east-1a", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + replicas := int32(3) + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: tc.clusterName, + Namespace: "clusters", + NodePoolReplicas: replicas, + ReleaseImage: "quay.io/openshift/ocp:4.16", + AutoRepair: true, + Arch: "amd64", + NodeDrainTimeout: 5 * time.Minute, + }, + }, + }, + }, + } + + constructor := defaultNodePool(opts) + np := constructor(hyperv1.AWSPlatform, tc.suffix) + + g.Expect(np.Name).To(Equal(tc.expectedName)) + g.Expect(np.Namespace).To(Equal("clusters")) + g.Expect(np.Spec.ClusterName).To(Equal(tc.clusterName)) + g.Expect(np.Spec.Platform.Type).To(Equal(hyperv1.AWSPlatform)) + g.Expect(*np.Spec.Replicas).To(Equal(replicas)) + g.Expect(np.Spec.Management.AutoRepair).To(BeTrue()) + g.Expect(np.Spec.Arch).To(Equal("amd64")) + g.Expect(np.Spec.Release.Image).To(Equal("quay.io/openshift/ocp:4.16")) + }) + } +} + +func TestValidateArchAndFeatureSet(t *testing.T) { + tests := []struct { + name string + arch string + featureSet string + expectError string + }{ + { + name: "When amd64 arch and Default feature set are provided, it should pass", + arch: "amd64", + featureSet: string(configv1.Default), + }, + { + name: "When arm64 arch is provided, it should pass", + arch: "arm64", + featureSet: string(configv1.Default), + }, + { + name: "When ppc64le arch is provided, it should pass", + arch: "ppc64le", + featureSet: string(configv1.Default), + }, + { + name: "When s390x arch is provided, it should pass", + arch: "s390x", + featureSet: string(configv1.Default), + }, + { + name: "When an unsupported arch is provided, it should return an error", + arch: "mips", + featureSet: string(configv1.Default), + expectError: "specified arch", + }, + { + name: "When TechPreviewNoUpgrade feature set is provided, it should pass", + arch: "amd64", + featureSet: string(configv1.TechPreviewNoUpgrade), + }, + { + name: "When DevPreviewNoUpgrade feature set is provided, it should pass", + arch: "amd64", + featureSet: string(configv1.DevPreviewNoUpgrade), + }, + { + name: "When CustomNoUpgrade feature set is provided, it should return an error", + arch: "amd64", + featureSet: string(configv1.CustomNoUpgrade), + expectError: "only a predefined feature set is supported", + }, + { + name: "When an unknown feature set is provided, it should return an error", + arch: "amd64", + featureSet: "SomethingRandom", + expectError: "specified feature set", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + opts := &RawCreateOptions{ + Arch: tc.arch, + FeatureSet: tc.featureSet, + } + + err := opts.validateArchAndFeatureSet() + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestApplyClusterCapabilities(t *testing.T) { + tests := []struct { + name string + enableCaps []string + disableCaps []string + expectEnabled []hyperv1.OptionalCapability + expectDisabled []hyperv1.OptionalCapability + }{ + { + name: "When both enable and disable capabilities are provided, it should set both", + enableCaps: []string{"baremetal", "Console"}, + disableCaps: []string{"ImageRegistry"}, + expectEnabled: []hyperv1.OptionalCapability{"baremetal", "Console"}, + expectDisabled: []hyperv1.OptionalCapability{"ImageRegistry"}, + }, + { + name: "When only disable capabilities are provided, it should set disabled only", + disableCaps: []string{"Insights", "NodeTuning"}, + expectDisabled: []hyperv1.OptionalCapability{"Insights", "NodeTuning"}, + }, + { + name: "When neither enable nor disable are provided, it should not set capabilities", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Capabilities: &hyperv1.Capabilities{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + EnableClusterCapabilities: tc.enableCaps, + DisableClusterCapabilities: tc.disableCaps, + }, + }, + }, + }, + } + + applyClusterCapabilities(cluster, opts) + + if tc.expectEnabled != nil { + g.Expect(cluster.Spec.Capabilities.Enabled).To(Equal(tc.expectEnabled)) + } else { + g.Expect(cluster.Spec.Capabilities.Enabled).To(BeNil()) + } + if tc.expectDisabled != nil { + g.Expect(cluster.Spec.Capabilities.Disabled).To(Equal(tc.expectDisabled)) + } else { + g.Expect(cluster.Spec.Capabilities.Disabled).To(BeNil()) + } + }) + } +} + +func TestEnsureClusterNetworkOperatorSpec(t *testing.T) { + t.Run("When OperatorConfiguration is nil, it should initialize both OperatorConfiguration and ClusterNetworkOperator", func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + + ensureClusterNetworkOperatorSpec(cluster) + + g.Expect(cluster.Spec.OperatorConfiguration).NotTo(BeNil()) + g.Expect(cluster.Spec.OperatorConfiguration.ClusterNetworkOperator).NotTo(BeNil()) + }) + + t.Run("When OperatorConfiguration exists but ClusterNetworkOperator is nil, it should initialize ClusterNetworkOperator", func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + OperatorConfiguration: &hyperv1.OperatorConfiguration{}, + }, + } + + ensureClusterNetworkOperatorSpec(cluster) + + g.Expect(cluster.Spec.OperatorConfiguration.ClusterNetworkOperator).NotTo(BeNil()) + }) +} diff --git a/cmd/fix/dr_oidc_iam.go b/cmd/fix/dr_oidc_iam.go index 131f216a750..381a3b8905e 100644 --- a/cmd/fix/dr_oidc_iam.go +++ b/cmd/fix/dr_oidc_iam.go @@ -449,38 +449,13 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { o.Retryer = retryerFn() }) - // Step 1: Check if OIDC documents exist in S3 - fmt.Println("Step 1: Checking OIDC documents in S3") - oidcDocsExist := o.checkOIDCDocumentsExist(ctx, s3Client) - if oidcDocsExist && !o.ForceRecreate { - fmt.Printf("- (%s) OIDC documents found in S3\n", greenCheck()) - } else { - if o.ForceRecreate { - fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC documents\n", yellowForce()) - } else { - fmt.Printf("- (%s) OIDC documents missing in S3 - will create them\n", redX()) - } - } - - // Step 2: Check if OIDC identity provider exists in IAM - fmt.Println("Step 2: Checking OIDC identity provider in IAM") - providerARN, exists, err := o.checkOIDCProvider(ctx, iamClient) + oidcDocsExist, providerARN, exists, err := o.checkOIDCState(ctx, s3Client, iamClient) if err != nil { - return fmt.Errorf("failed to check OIDC provider: %w", err) + return err } - - if exists && !o.ForceRecreate { - fmt.Printf("- (%s) OIDC identity provider exists\n", greenCheck()) - if oidcDocsExist { - fmt.Println("\nAll OIDC components are in place - no action needed!") - return nil - } - } else { - if o.ForceRecreate { - fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC provider\n", yellowForce()) - } else { - fmt.Printf("- (%s) OIDC identity provider missing - needs to be recreated\n", redX()) - } + if oidcDocsExist && exists && !o.ForceRecreate { + fmt.Println("\nAll OIDC components are in place - no action needed!") + return nil } // Step 3: Create/ensure S3 bucket exists @@ -490,20 +465,8 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { } fmt.Printf("- (%s) S3 bucket ready with public access enabled\n", greenCheck()) - // Step 4: Generate and upload OIDC documents using the EXISTING cluster key - if !oidcDocsExist || o.ForceRecreate { - fmt.Println("Step 4: Generating and uploading OIDC documents") - if o.DryRun { - fmt.Printf("- (%s) DRY RUN: Would generate and upload OIDC documents using existing cluster signing key\n", yellowQuestion()) - } else { - if err := o.generateAndUploadOIDCDocuments(ctx, k8sClient, s3Client); err != nil { - return fmt.Errorf("failed to generate OIDC documents: %w", err) - } - fmt.Printf("- (%s) OIDC documents generated and uploaded\n", greenCheck()) - } - } else { - fmt.Println("Step 4: OIDC documents") - fmt.Printf("- (%s) OIDC documents already exist\n", greenCheck()) + if err := o.ensureOIDCDocuments(ctx, k8sClient, s3Client, oidcDocsExist); err != nil { + return err } // Step 5: Get SSL certificate thumbprint @@ -514,26 +477,8 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { } fmt.Printf("- (%s) SSL certificate thumbprint retrieved\n", greenCheck()) - // Step 6: Create/recreate OIDC provider - if !exists || o.ForceRecreate { - fmt.Println("Step 6: Creating/recreating OIDC identity provider") - if o.DryRun { - fmt.Printf("- (%s) DRY RUN: Would create OIDC provider\n", yellowQuestion()) - fmt.Printf(" Issuer: %s\n", o.Issuer) - fmt.Printf(" Thumbprint: %s\n", thumbprint) - fmt.Printf(" Allowed clients: openshift, sts.amazonaws.com\n") - } else { - if err := o.deleteOIDCProviderIfExists(ctx, iamClient, providerARN); err != nil { - return fmt.Errorf("failed to delete existing OIDC provider: %w", err) - } - if _, err := o.createOIDCProvider(ctx, iamClient, thumbprint); err != nil { - return fmt.Errorf("failed to create OIDC provider: %w", err) - } - fmt.Printf("- (%s) OIDC identity provider successfully created\n", greenCheck()) - } - } else { - fmt.Println("Step 6: OIDC identity provider") - fmt.Printf("- (%s) OIDC identity provider already exists\n", greenCheck()) + if err := o.ensureOIDCProvider(ctx, iamClient, providerARN, thumbprint, exists); err != nil { + return err } // Step 7: Verify configuration and update HostedCluster @@ -548,6 +493,80 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { return nil } +func (o *DrOidcIamOptions) checkOIDCState(ctx context.Context, s3Client *s3.Client, iamClient *iam.Client) (bool, string, bool, error) { + fmt.Println("Step 1: Checking OIDC documents in S3") + oidcDocsExist := o.checkOIDCDocumentsExist(ctx, s3Client) + if oidcDocsExist && !o.ForceRecreate { + fmt.Printf("- (%s) OIDC documents found in S3\n", greenCheck()) + } else if o.ForceRecreate { + fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC documents\n", yellowForce()) + } else { + fmt.Printf("- (%s) OIDC documents missing in S3 - will create them\n", redX()) + } + + fmt.Println("Step 2: Checking OIDC identity provider in IAM") + providerARN, exists, err := o.checkOIDCProvider(ctx, iamClient) + if err != nil { + return false, "", false, fmt.Errorf("failed to check OIDC provider: %w", err) + } + + if exists && !o.ForceRecreate { + fmt.Printf("- (%s) OIDC identity provider exists\n", greenCheck()) + } else if o.ForceRecreate { + fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC provider\n", yellowForce()) + } else { + fmt.Printf("- (%s) OIDC identity provider missing - needs to be recreated\n", redX()) + } + + return oidcDocsExist, providerARN, exists, nil +} + +func (o *DrOidcIamOptions) ensureOIDCDocuments(ctx context.Context, k8sClient client.Client, s3Client *s3.Client, oidcDocsExist bool) error { + if oidcDocsExist && !o.ForceRecreate { + fmt.Println("Step 4: OIDC documents") + fmt.Printf("- (%s) OIDC documents already exist\n", greenCheck()) + return nil + } + + fmt.Println("Step 4: Generating and uploading OIDC documents") + if o.DryRun { + fmt.Printf("- (%s) DRY RUN: Would generate and upload OIDC documents using existing cluster signing key\n", yellowQuestion()) + return nil + } + + if err := o.generateAndUploadOIDCDocuments(ctx, k8sClient, s3Client); err != nil { + return fmt.Errorf("failed to generate OIDC documents: %w", err) + } + fmt.Printf("- (%s) OIDC documents generated and uploaded\n", greenCheck()) + return nil +} + +func (o *DrOidcIamOptions) ensureOIDCProvider(ctx context.Context, iamClient *iam.Client, providerARN, thumbprint string, exists bool) error { + if exists && !o.ForceRecreate { + fmt.Println("Step 6: OIDC identity provider") + fmt.Printf("- (%s) OIDC identity provider already exists\n", greenCheck()) + return nil + } + + fmt.Println("Step 6: Creating/recreating OIDC identity provider") + if o.DryRun { + fmt.Printf("- (%s) DRY RUN: Would create OIDC provider\n", yellowQuestion()) + fmt.Printf(" Issuer: %s\n", o.Issuer) + fmt.Printf(" Thumbprint: %s\n", thumbprint) + fmt.Printf(" Allowed clients: openshift, sts.amazonaws.com\n") + return nil + } + + if err := o.deleteOIDCProviderIfExists(ctx, iamClient, providerARN); err != nil { + return fmt.Errorf("failed to delete existing OIDC provider: %w", err) + } + if _, err := o.createOIDCProvider(ctx, iamClient, thumbprint); err != nil { + return fmt.Errorf("failed to create OIDC provider: %w", err) + } + fmt.Printf("- (%s) OIDC identity provider successfully created\n", greenCheck()) + return nil +} + func (o *DrOidcIamOptions) verifyAndUpdateHostedCluster(ctx context.Context, k8sClient client.Client, s3Client *s3.Client, iamClient *iam.Client) error { if o.DryRun { fmt.Printf(" - (%s) DRY RUN: Would verify OIDC configuration\n", yellowQuestion()) diff --git a/cmd/infra/aws/create.go b/cmd/infra/aws/create.go index 474736c2cb7..07d3715738b 100644 --- a/cmd/infra/aws/create.go +++ b/cmd/infra/aws/create.go @@ -199,24 +199,88 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C return nil, err } - awsSession, err := o.AWSCredentialsOpts.GetSession(ctx, "cli-create-infra", o.CredentialsSecretData, o.Region) + awsSession, vpcOwnerSession, err := o.initAWSSessions(ctx) + if err != nil { + return nil, err + } + + clusterCreatorEC2Client, ec2Client, route53Client, vpcOwnerRoute53Client := o.initAWSClients(awsSession, vpcOwnerSession) + + if err := o.parseAdditionalTags(); err != nil { + return nil, err + } + + result := &CreateInfraOutput{ + InfraID: o.InfraID, + MachineCIDR: o.VPCCIDR, + Region: o.Region, + Name: o.Name, + BaseDomain: o.BaseDomain, + BaseDomainPrefix: o.BaseDomainPrefix, + PublicOnly: o.PublicOnly, + } + if len(o.Zones) == 0 { + zone, err := o.firstZone(ctx, l, ec2Client) + if err != nil { + return nil, err + } + o.Zones = append(o.Zones, zone) + } + + igwID, err := o.createVPCResources(ctx, l, ec2Client, result) + if err != nil { + return nil, err + } + + publicSubnetIDs, endpointRouteTableIds, err := o.createPerZoneResources(ctx, l, ec2Client, result, igwID) if err != nil { return nil, err } + + result.PublicZoneID, err = o.LookupPublicZone(ctx, l, route53Client) + if err != nil { + return nil, err + } + + if vpcOwnerSession != nil { + if err := o.shareSubnets(ctx, l, vpcOwnerSession, awsSession, publicSubnetIDs, result); err != nil { + return nil, err + } + } + + if err := o.createDNSZones(ctx, l, clusterCreatorEC2Client, route53Client, vpcOwnerRoute53Client, result); err != nil { + return nil, err + } + + if err := o.createProxyResources(ctx, l, ec2Client, result, publicSubnetIDs); err != nil { + return nil, err + } + + _ = endpointRouteTableIds + return result, nil +} + +func (o *CreateInfraOptions) initAWSSessions(ctx context.Context) (*aws.Config, *aws.Config, error) { + awsSession, err := o.AWSCredentialsOpts.GetSession(ctx, "cli-create-infra", o.CredentialsSecretData, o.Region) + if err != nil { + return nil, nil, err + } var vpcOwnerSession *aws.Config if o.VPCOwnerCredentialOpts.AWSCredentialsFile != "" { vpcOwnerSession, err = o.VPCOwnerCredentialOpts.GetSession(ctx, "cli-create-infra", nil, o.Region) if err != nil { - return nil, err + return nil, nil, err } } + return awsSession, vpcOwnerSession, nil +} - var clusterCreatorEC2Client, ec2Client awsapi.EC2API - var vpcOwnerRoute53Client, route53Client awsapi.ROUTE53API +func (o *CreateInfraOptions) initAWSClients(awsSession, vpcOwnerSession *aws.Config) (awsapi.EC2API, awsapi.EC2API, awsapi.ROUTE53API, awsapi.ROUTE53API) { awsConfig := awsutil.NewConfig() - clusterCreatorEC2Client = ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { + clusterCreatorEC2Client := ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { o.Retryer = awsConfig() }) + var ec2Client awsapi.EC2API if vpcOwnerSession != nil { ec2Client = ec2.NewFromConfig(*vpcOwnerSession, func(o *ec2.Options) { o.Retryer = awsConfig() @@ -224,9 +288,10 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C } else { ec2Client = clusterCreatorEC2Client } - route53Client = route53.NewFromConfig(*awsSession, func(o *route53.Options) { + route53Client := route53.NewFromConfig(*awsSession, func(o *route53.Options) { o.Retryer = awsutil.NewRoute53Config()() }) + var vpcOwnerRoute53Client awsapi.ROUTE53API if vpcOwnerSession != nil { vpcOwnerRoute53Client = route53.NewFromConfig(*vpcOwnerSession, func(o *route53.Options) { o.Retryer = awsutil.NewRoute53Config()() @@ -234,45 +299,29 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C } else { vpcOwnerRoute53Client = route53Client } + return clusterCreatorEC2Client, ec2Client, route53Client, vpcOwnerRoute53Client +} - if err := o.parseAdditionalTags(); err != nil { - return nil, err - } - - result := &CreateInfraOutput{ - InfraID: o.InfraID, - MachineCIDR: o.VPCCIDR, - Region: o.Region, - Name: o.Name, - BaseDomain: o.BaseDomain, - BaseDomainPrefix: o.BaseDomainPrefix, - PublicOnly: o.PublicOnly, - } - if len(o.Zones) == 0 { - zone, err := o.firstZone(ctx, l, ec2Client) - if err != nil { - return nil, err - } - o.Zones = append(o.Zones, zone) - } - - // VPC resources +func (o *CreateInfraOptions) createVPCResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput) (string, error) { + var err error result.VPCID, err = o.createVPC(ctx, l, ec2Client) if err != nil { - return nil, err + return "", err } if err = o.CreateDHCPOptions(ctx, l, ec2Client, result.VPCID); err != nil { - return nil, err + return "", err } igwID, err := o.CreateInternetGateway(ctx, l, ec2Client, result.VPCID) if err != nil { - return nil, err + return "", err } + return igwID, nil +} - // Per zone resources +func (o *CreateInfraOptions) createPerZoneResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput, igwID string) ([]string, []string, error) { _, cidrNetwork, err := net.ParseCIDR(o.VPCCIDR) if err != nil { - return nil, err + return nil, nil, err } publicNetwork := copyIPNet(cidrNetwork) publicNetwork.Mask = net.CIDRMask(20, 32) @@ -285,31 +334,28 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C var publicSubnetIDs []string var natGatewayID string for _, zone := range o.Zones { - var ( - privateSubnetID string - err error - ) + var privateSubnetID string if !o.PublicOnly { privateSubnetID, err = o.CreatePrivateSubnet(ctx, l, ec2Client, result.VPCID, zone, privateNetwork.String()) if err != nil { - return nil, err + return nil, nil, err } } publicSubnetID, err := o.CreatePublicSubnet(ctx, l, ec2Client, result.VPCID, zone, publicNetwork.String()) if err != nil { - return nil, err + return nil, nil, err } publicSubnetIDs = append(publicSubnetIDs, publicSubnetID) if !o.PublicOnly && !o.EnableProxy && !o.EnableSecureProxy && ((natGatewayID == "" && o.SingleNATGateway) || !o.SingleNATGateway) { natGatewayID, err = o.CreateNATGateway(ctx, l, ec2Client, publicSubnetID, zone) if err != nil { - return nil, err + return nil, nil, err } } if !o.PublicOnly { privateRouteTable, err := o.CreatePrivateRouteTable(ctx, l, ec2Client, result.VPCID, natGatewayID, privateSubnetID, zone) if err != nil { - return nil, err + return nil, nil, err } endpointRouteTableIds = append(endpointRouteTableIds, privateRouteTable) } @@ -321,79 +367,76 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C Name: zone, SubnetID: zoneSubnetID, }) - // increment each subnet by /20 privateNetwork.IP[2] = privateNetwork.IP[2] + 16 publicNetwork.IP[2] = publicNetwork.IP[2] + 16 } publicRouteTable, err := o.CreatePublicRouteTable(ctx, l, ec2Client, result.VPCID, igwID, publicSubnetIDs) if err != nil { - return nil, err + return nil, nil, err } endpointRouteTableIds = append(endpointRouteTableIds, publicRouteTable) err = o.CreateVPCS3Endpoint(ctx, l, ec2Client, result.VPCID, endpointRouteTableIds) if err != nil { - return nil, err - } - result.PublicZoneID, err = o.LookupPublicZone(ctx, l, route53Client) - if err != nil { - return nil, err - } - - if vpcOwnerSession != nil { - if err := o.shareSubnets(ctx, l, vpcOwnerSession, awsSession, publicSubnetIDs, result); err != nil { - return nil, err - } + return nil, nil, err } + return publicSubnetIDs, endpointRouteTableIds, nil +} +func (o *CreateInfraOptions) createDNSZones(ctx context.Context, l logr.Logger, clusterCreatorEC2Client awsapi.EC2API, route53Client, vpcOwnerRoute53Client awsapi.ROUTE53API, result *CreateInfraOutput) error { privateZoneClient := vpcOwnerRoute53Client var initialVPC string + var err error if o.PrivateZonesInClusterAccount { privateZoneClient = route53Client - - // Create a dummy vpc that we can use to create the private hosted zones if initialVPC, err = o.createVPC(ctx, l, clusterCreatorEC2Client); err != nil { - return nil, err + return err } + defer func() { + if initialVPC != "" { + if cleanupErr := o.deleteVPC(ctx, l, clusterCreatorEC2Client, initialVPC); cleanupErr != nil { + l.Error(cleanupErr, "Failed to clean up temporary VPC", "id", initialVPC) + } + } + }() } result.PrivateZoneID, err = o.CreatePrivateZone(ctx, l, privateZoneClient, ZoneName(o.Name, o.BaseDomainPrefix, o.BaseDomain), result.VPCID, o.PrivateZonesInClusterAccount, vpcOwnerRoute53Client, initialVPC) if err != nil { - return nil, err + return err } result.LocalZoneID, err = o.CreatePrivateZone(ctx, l, privateZoneClient, fmt.Sprintf("%s.%s", o.Name, hypershiftLocalZoneName), result.VPCID, o.PrivateZonesInClusterAccount, vpcOwnerRoute53Client, initialVPC) if err != nil { - return nil, err + return err } - if initialVPC != "" { - if err := o.deleteVPC(ctx, l, clusterCreatorEC2Client, initialVPC); err != nil { - return nil, err - } + return nil +} + +func (o *CreateInfraOptions) createProxyResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput, publicSubnetIDs []string) error { + if !o.EnableProxy && !o.EnableSecureProxy { + return nil + } + sgGroupID, err := o.createProxySecurityGroup(ctx, l, ec2Client, result.VPCID) + if err != nil { + return fmt.Errorf("failed to create security group for proxy: %w", err) } - if o.EnableProxy || o.EnableSecureProxy { - sgGroupID, err := o.createProxySecurityGroup(ctx, l, ec2Client, result.VPCID) + if o.ProxyVPCEndpointServiceName != "" { + result.ProxyAddr, err = o.createProxyVPCEndpoint(ctx, l, ec2Client, result.VPCID, result.Zones[0].SubnetID, sgGroupID) if err != nil { - return nil, fmt.Errorf("failed to create security group for proxy: %w", err) + return err } - - if o.ProxyVPCEndpointServiceName != "" { - result.ProxyAddr, err = o.createProxyVPCEndpoint(ctx, l, ec2Client, result.VPCID, result.Zones[0].SubnetID, sgGroupID) - if err != nil { - return nil, err - } - } else { - proxyResult, err := o.createProxyHost(ctx, l, ec2Client, result.Zones[0].SubnetID, sgGroupID, o.EnableSecureProxy) - if err != nil { - return nil, fmt.Errorf("failed to create proxy host: %w", err) - } - result.ProxyAddr = proxyResult.HTTPProxyURL - result.SecureProxyAddr = proxyResult.HTTPSProxyURL - result.ProxyCA = proxyResult.CA - result.ProxyPrivateSSHKey = proxyResult.PrivateKey + } else { + proxyResult, err := o.createProxyHost(ctx, l, ec2Client, publicSubnetIDs[0], sgGroupID, o.EnableSecureProxy) + if err != nil { + return fmt.Errorf("failed to create proxy host: %w", err) } + result.ProxyAddr = proxyResult.HTTPProxyURL + result.SecureProxyAddr = proxyResult.HTTPSProxyURL + result.ProxyCA = proxyResult.CA + result.ProxyPrivateSSHKey = proxyResult.PrivateKey } - return result, nil + return nil } func (o *CreateInfraOptions) createProxySecurityGroup(ctx context.Context, l logr.Logger, client awsapi.EC2API, vpcID string) (string, error) { diff --git a/cmd/infra/azure/create.go b/cmd/infra/azure/create.go index 4f10a4ef6cf..191ed895423 100644 --- a/cmd/infra/azure/create.go +++ b/cmd/infra/azure/create.go @@ -125,7 +125,6 @@ func BindProductFlags(opts *CreateInfraOptions, flags *pflag.FlagSet) { // Run is the main function responsible for creating the Azure infrastructure resources for a HostedCluster. func (o *CreateInfraOptions) Run(ctx context.Context, l logr.Logger) (*CreateInfraOutput, error) { - // Validate deployment model flags to prevent conflicts between ARO HCP and self-managed Azure if err := o.validateDeploymentModelFlags(); err != nil { return nil, err } @@ -136,180 +135,196 @@ func (o *CreateInfraOptions) Run(ctx context.Context, l logr.Logger) (*CreateInf BaseDomain: o.BaseDomain, } - // Setup subscription ID and Azure credential information subscriptionID, azureCreds, err := util.SetupAzureCredentials(l, o.Credentials, o.CredentialsFile) if err != nil { return nil, fmt.Errorf("failed to setup Azure credentials: %w", err) } - // Initialize managers rgMgr := NewResourceGroupManager(subscriptionID, azureCreds, o.Cloud) netMgr := NewNetworkManager(subscriptionID, azureCreds, o.Cloud) rbacMgr := NewRBACManager(subscriptionID, azureCreds) - // Create main resource group + resourceGroupName, nsgResourceGroupName, vnetResourceGroupName, err := o.createNetworkResources(ctx, l, rgMgr, netMgr, &result) + if err != nil { + return nil, err + } + + if err := o.handleIdentitiesAndRBAC(ctx, rbacMgr, &result, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { + return nil, err + } + + if err := o.createDNSAndLBResources(ctx, l, netMgr, &result, resourceGroupName); err != nil { + return nil, err + } + + if err := o.writeOutputFile(l, result); err != nil { + return nil, err + } + + return &result, nil +} + +func (o *CreateInfraOptions) createNetworkResources(ctx context.Context, l logr.Logger, rgMgr *ResourceGroupManager, netMgr *NetworkManager, result *CreateInfraOutput) (string, string, string, error) { resourceGroupName, msg, err := rgMgr.CreateOrGetResourceGroup(ctx, o, "") if err != nil { - return nil, fmt.Errorf("failed to create a resource group: %w", err) + return "", "", "", fmt.Errorf("failed to create a resource group: %w", err) } result.ResourceGroupName = resourceGroupName l.Info(msg, "name", resourceGroupName) - // Get base DNS zone ID result.PublicZoneID, err = netMgr.GetBaseDomainID(ctx, o.BaseDomain) if err != nil { - return nil, err + return "", "", "", err } - // Handle network security group nsgResourceGroupName := "" if len(o.NetworkSecurityGroupID) > 0 { result.SecurityGroupID = o.NetworkSecurityGroupID _, nsgResourceGroupName, err = azureutil.GetNameAndResourceGroupFromNetworkSecurityGroupID(o.NetworkSecurityGroupID) if err != nil { - return nil, err + return "", "", "", err } l.Info("Using existing network security group", "ID", result.SecurityGroupID) } else { nsgResourceGroupName = o.Name + "-nsg" nsgResourceGroupName, msg, err = rgMgr.CreateOrGetResourceGroup(ctx, o, nsgResourceGroupName) if err != nil { - return nil, fmt.Errorf("failed to create resource group for network security group: %w", err) + return "", "", "", fmt.Errorf("failed to create resource group for network security group: %w", err) } l.Info(msg, "name", nsgResourceGroupName) nsgID, err := netMgr.CreateSecurityGroup(ctx, nsgResourceGroupName, o.Name, o.InfraID, o.Location) if err != nil { - return nil, err + return "", "", "", err } result.SecurityGroupID = nsgID l.Info("Successfully created network security group", "ID", result.SecurityGroupID) } - // Handle subnet if len(o.SubnetID) > 0 { result.SubnetID = o.SubnetID l.Info("Using existing subnet", "ID", result.SubnetID) } - // Handle virtual network vnetResourceGroupName := "" if len(o.VnetID) > 0 { result.VNetID = o.VnetID _, vnetResourceGroupName, err = azureutil.GetVnetNameAndResourceGroupFromVnetID(o.VnetID) if err != nil { - return nil, err + return "", "", "", err } l.Info("Using existing vnet", "ID", result.VNetID) } else { vnetResourceGroupName = o.Name + "-vnet" vnetResourceGroupName, msg, err = rgMgr.CreateOrGetResourceGroup(ctx, o, vnetResourceGroupName) if err != nil { - return nil, fmt.Errorf("failed to create resource group for virtual network: %w", err) + return "", "", "", fmt.Errorf("failed to create resource group for virtual network: %w", err) } l.Info(msg, "name", vnetResourceGroupName) vnet, err := netMgr.CreateVirtualNetwork(ctx, vnetResourceGroupName, o.Name, o.InfraID, o.Location, o.SubnetID, result.SecurityGroupID) if err != nil { - return nil, err + return "", "", "", err } result.SubnetID = *vnet.Properties.Subnets[0].ID result.VNetID = *vnet.ID l.Info("Successfully created vnet", "ID", result.VNetID) } - // Handle managed identities and RBAC + return resourceGroupName, nsgResourceGroupName, vnetResourceGroupName, nil +} + +func (o *CreateInfraOptions) handleIdentitiesAndRBAC(ctx context.Context, rbacMgr *RBACManager, result *CreateInfraOutput, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName string) error { if o.ManagedIdentitiesFile != "" { result.ControlPlaneMIs = &hyperv1.AzureResourceManagedIdentities{} managedIdentitiesRaw, err := os.ReadFile(o.ManagedIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --managed-identities-file %s: %w", o.ManagedIdentitiesFile, err) + return fmt.Errorf("failed to read --managed-identities-file %s: %w", o.ManagedIdentitiesFile, err) } if err := yaml.Unmarshal(managedIdentitiesRaw, &result.ControlPlaneMIs.ControlPlane); err != nil { - return nil, fmt.Errorf("failed to unmarshal --managed-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --managed-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignControlPlaneRoles(ctx, o, result.ControlPlaneMIs, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { - return nil, err + return err } } } - // Handle data plane identities if o.DataPlaneIdentitiesFile != "" { dataPlaneIdentitiesRaw, err := os.ReadFile(o.DataPlaneIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --data-plane-identities-file %s: %w", o.DataPlaneIdentitiesFile, err) + return fmt.Errorf("failed to read --data-plane-identities-file %s: %w", o.DataPlaneIdentitiesFile, err) } if err := yaml.Unmarshal(dataPlaneIdentitiesRaw, &result.DataPlaneIdentities); err != nil { - return nil, fmt.Errorf("failed to unmarshal --data-plane-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --data-plane-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignDataPlaneRoles(ctx, o, result.DataPlaneIdentities, resourceGroupName); err != nil { - return nil, err + return err } } } - // Handle workload identities if o.WorkloadIdentitiesFile != "" { workloadIdentitiesRaw, err := os.ReadFile(o.WorkloadIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --workload-identities-file %s: %w", o.WorkloadIdentitiesFile, err) + return fmt.Errorf("failed to read --workload-identities-file %s: %w", o.WorkloadIdentitiesFile, err) } if err := json.Unmarshal(workloadIdentitiesRaw, &result.WorkloadIdentities); err != nil { - return nil, fmt.Errorf("failed to unmarshal --workload-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --workload-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignWorkloadIdentities(ctx, o, result.WorkloadIdentities, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { - return nil, err + return err } } } - // Create DNS infrastructure + return nil +} + +func (o *CreateInfraOptions) createDNSAndLBResources(ctx context.Context, l logr.Logger, netMgr *NetworkManager, result *CreateInfraOutput, resourceGroupName string) error { privateDNSZoneID, privateDNSZoneName, err := netMgr.CreatePrivateDNSZone(ctx, resourceGroupName, o.Name, o.BaseDomain) if err != nil { - return nil, err + return err } result.PrivateZoneID = privateDNSZoneID l.Info("Successfully created private DNS zone", "name", privateDNSZoneName) err = netMgr.CreatePrivateDNSZoneLink(ctx, resourceGroupName, o.Name, o.InfraID, result.VNetID, privateDNSZoneName) if err != nil { - return nil, err + return err } l.Info("Successfully created private DNS zone link") - // Create load balancer infrastructure publicIPAddress, err := netMgr.CreatePublicIPAddressForLB(ctx, resourceGroupName, o.InfraID, o.Location) if err != nil { - return nil, err + return err } l.Info("Successfully created public IP address for guest cluster egress load balancer") err = netMgr.CreateLoadBalancer(ctx, resourceGroupName, o.InfraID, o.Location, publicIPAddress) if err != nil { - return nil, err + return err } l.Info("Successfully created guest cluster egress load balancer") + return nil +} - // Serialize the result to the output file if it was provided - if o.OutputFile != "" { - resultSerialized, err := yaml.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to serialize result: %w", err) - } - if err := os.WriteFile(o.OutputFile, resultSerialized, 0644); err != nil { - l.Error(err, "Writing output file failed", "Output File", o.OutputFile, "data", string(resultSerialized)) - return nil, fmt.Errorf("failed to write result to --output-file: %w", err) - } +func (o *CreateInfraOptions) writeOutputFile(l logr.Logger, result CreateInfraOutput) error { + if o.OutputFile == "" { + return nil } - - return &result, nil + resultSerialized, err := yaml.Marshal(result) + if err != nil { + return fmt.Errorf("failed to serialize result: %w", err) + } + if err := os.WriteFile(o.OutputFile, resultSerialized, 0600); err != nil { + l.Error(err, "Writing output file failed", "Output File", o.OutputFile) + return fmt.Errorf("failed to write result to --output-file: %w", err) + } + return nil } // Validate validates the CreateInfraOptions before running the command diff --git a/cmd/install/assets/hypershift_operator.go b/cmd/install/assets/hypershift_operator.go index 3d55c46b9c4..592531ff749 100644 --- a/cmd/install/assets/hypershift_operator.go +++ b/cmd/install/assets/hypershift_operator.go @@ -530,345 +530,23 @@ type HyperShiftOperatorDeployment struct { } func (o HyperShiftOperatorDeployment) Build() *appsv1.Deployment { - args := []string{ - "run", - "--namespace=$(MY_NAMESPACE)", - "--pod-name=$(MY_NAME)", - "--metrics-addr=:9000", - fmt.Sprintf("--enable-dedicated-request-serving-isolation=%t", o.EnableDedicatedRequestServingIsolation), - fmt.Sprintf("--enable-ocp-cluster-monitoring=%t", o.EnableOCPClusterMonitoring), - fmt.Sprintf("--enable-ci-debug-output=%t", o.EnableCIDebugOutput), - fmt.Sprintf("--private-platform=%s", o.PrivatePlatform), - } - if o.RegistryOverrides != "" { - args = append(args, fmt.Sprintf("--registry-overrides=%s", o.RegistryOverrides)) - } - + args := o.buildArgs() + envVars := o.buildEnvVars() var volumeMounts []corev1.VolumeMount var initVolumeMounts []corev1.VolumeMount var volumes []corev1.Volume - envVars := []corev1.EnvVar{ - { - Name: "MY_NAMESPACE", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - FieldPath: "metadata.namespace", - }, - }, - }, - { - Name: "MY_NAME", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - FieldPath: "metadata.name", - }, - }, - }, - metrics.MetricsSetToEnv(o.MetricsSet), - { - Name: "CERT_ROTATION_SCALE", - Value: o.CertRotationScale.String(), - }, - { - Name: "KUBE_FEATURE_WatchListClient", - Value: "false", - }, - } - - // Add any additional environment variables specified if they don't already exist. - for key, value := range o.AdditionalOperatorEnvVars { - trimmedKey := strings.TrimSpace(key) - trimmedValue := strings.TrimSpace(value) - - if trimmedKey != "" && - trimmedValue != "" && - !slices.ContainsFunc(envVars, func(e corev1.EnvVar) bool { - return e.Name == trimmedKey - }) { - envVars = append(envVars, corev1.EnvVar{ - Name: trimmedKey, - Value: trimmedValue, - }) - } - } - - // Add the new HYPERSHIFT_FEATURESET env var if TPNU is set. - if o.TechPreviewNoUpgrade { - envVars = append(envVars, corev1.EnvVar{ - Name: "HYPERSHIFT_FEATURESET", - Value: string(configv1.TechPreviewNoUpgrade), - }) - } - - // Add audit log persistence env var if enabled - if o.EnableAuditLogPersistence { - envVars = append(envVars, corev1.EnvVar{ - Name: "ENABLE_AUDIT_LOG_PERSISTENCE", - Value: "true", - }) - } - if o.EnableWebhook { - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "serving-cert", - MountPath: "/var/run/secrets/serving-cert", - }) - volumes = append(volumes, corev1.Volume{ - Name: "serving-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: "manager-serving-cert", - }, - }, - }) - args = append(args, - "--cert-dir=/var/run/secrets/serving-cert", - ) - - if o.EnableValidatingWebhook { - args = append(args, "--enable-validating-webhook=true") - } - } - - if len(o.OIDCBucketName) > 0 && len(o.OIDCBucketRegion) > 0 && len(o.OIDCStorageProviderS3SecretKey) > 0 && - o.OIDCStorageProviderS3Secret != nil && len(o.OIDCStorageProviderS3Secret.Name) > 0 { - args = append(args, - "--oidc-storage-provider-s3-bucket-name="+o.OIDCBucketName, - "--oidc-storage-provider-s3-region="+o.OIDCBucketRegion, - "--oidc-storage-provider-s3-credentials=/etc/oidc-storage-provider-s3-creds/"+o.OIDCStorageProviderS3SecretKey, - ) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "oidc-storage-provider-s3-creds", - MountPath: "/etc/oidc-storage-provider-s3-creds", - }) - volumes = append(volumes, corev1.Volume{ - Name: "oidc-storage-provider-s3-creds", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.OIDCStorageProviderS3Secret.Name, - }, - }, - }) - } - - if o.ScaleFromZeroSecret != nil && len(o.ScaleFromZeroSecret.Name) > 0 && len(o.ScaleFromZeroSecretKey) > 0 && len(o.ScaleFromZeroProvider) > 0 { - args = append(args, - "--scale-from-zero-provider="+o.ScaleFromZeroProvider, - "--scale-from-zero-creds=/etc/scale-from-zero-creds/"+o.ScaleFromZeroSecretKey, - ) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "scale-from-zero-creds", - MountPath: "/etc/scale-from-zero-creds", - ReadOnly: true, - }) - volumes = append(volumes, corev1.Volume{ - Name: "scale-from-zero-creds", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.ScaleFromZeroSecret.Name, - }, - }, - }) - } + o.addWebhookResources(&args, &volumeMounts, &volumes) + o.addOIDCResources(&args, &volumeMounts, &volumes) + o.addScaleFromZeroResources(&args, &volumeMounts, &volumes) if o.UWMTelemetry { args = append(args, "--enable-uwm-telemetry-remote-write") } - if o.EnableCVOManagementClusterMetricsAccess { - envVars = append(envVars, corev1.EnvVar{ - Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, - Value: "1", - }) - } - - if len(o.ManagedService) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: "MANAGED_SERVICE", - Value: o.ManagedService, - }) - } - - if len(o.AROHCPKeyVaultUsersClientID) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: config.AROHCPKeyVaultManagedIdentityClientID, - Value: o.AROHCPKeyVaultUsersClientID, - }) - } - - if o.EnableSizeTagging { - envVars = append(envVars, corev1.EnvVar{ - Name: "ENABLE_SIZE_TAGGING", - Value: "1", - }) - } - - if o.EnableEtcdRecovery { - envVars = append(envVars, corev1.EnvVar{ - Name: config.EnableEtcdRecoveryEnvVar, - Value: "1", - }) - } - - if o.EnableCPOOverrides { - envVars = append(envVars, corev1.EnvVar{ - Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, - Value: "1", - }) - } - - if len(o.PlatformsInstalled) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: "PLATFORMS_INSTALLED", - Value: o.PlatformsInstalled, - }) - } - - image := o.OperatorImage - - if mapImage, ok := o.Images["hypershift-operator"]; ok { - image = mapImage - } - tagMapping := images.TagMapping() - for tag, ref := range o.Images { - if envVar, exists := tagMapping[tag]; exists { - envVars = append(envVars, corev1.EnvVar{ - Name: envVar, - Value: ref, - }) - } - } - - privatePlatformType := hyperv1.PlatformType(o.PrivatePlatform) - if privatePlatformType != hyperv1.NonePlatform { - // Add platform specific settings - switch privatePlatformType { - case hyperv1.AWSPlatform: - // Add AWS credentials secret volume - volumes = append(volumes, corev1.Volume{ - Name: "credentials", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.AWSPrivateSecret.Name, - }, - }, - }) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "credentials", - MountPath: "/etc/provider", - }) - envVars = append(envVars, - corev1.EnvVar{ - Name: "AWS_SHARED_CREDENTIALS_FILE", - Value: "/etc/provider/" + o.AWSPrivateSecretKey, - }, - corev1.EnvVar{ - Name: "AWS_REGION", - Value: o.AWSPrivateRegion, - }, - corev1.EnvVar{ - Name: "AWS_SDK_LOAD_CONFIG", - Value: "1", - }) - case hyperv1.AzurePlatform: - if o.AzurePLSResourceGroup != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_RESOURCE_GROUP", - Value: o.AzurePLSResourceGroup, - }) - } - if o.AzurePLSManagedIdentityClientID != "" { - // Workload identity mode: the SA annotation triggers Azure AD Workload Identity - // webhook to inject federated tokens. Set the client ID as an env var so the - // HO platform controller can construct credentials. - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_PLS_CLIENT_ID", - Value: o.AzurePLSManagedIdentityClientID, - }) - if o.AzurePLSSubscriptionID != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_SUBSCRIPTION_ID", - Value: o.AzurePLSSubscriptionID, - }) - } - } else if o.AzurePrivateSecret != nil && len(o.AzurePrivateSecret.Name) > 0 && len(o.AzurePrivateSecretKey) > 0 { - volumes = append(volumes, corev1.Volume{ - Name: "azure-credentials", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.AzurePrivateSecret.Name, - }, - }, - }) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "azure-credentials", - MountPath: "/etc/azure-provider", - ReadOnly: true, - }) - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_CREDENTIALS_FILE", - Value: "/etc/azure-provider/" + o.AzurePrivateSecretKey, - }) - } - case hyperv1.GCPPlatform: - if o.GCPProject != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "GCP_PROJECT", - Value: o.GCPProject, - }) - } - if o.GCPRegion != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "GCP_REGION", - Value: o.GCPRegion, - }) - } - } - - // Add AWS-specific volumes if AWS platform - if privatePlatformType == hyperv1.AWSPlatform { - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "token", - MountPath: "/var/run/secrets/openshift/serviceaccount", - }) - volumes = append(volumes, corev1.Volume{ - Name: "token", - VolumeSource: corev1.VolumeSource{ - Projected: &corev1.ProjectedVolumeSource{ - Sources: []corev1.VolumeProjection{ - { - ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ - Audience: "openshift", - Path: "token", - }, - }, - }, - }, - }, - }) - } - } - - if o.RHOBSMonitoring { - envVars = append(envVars, corev1.EnvVar{ - Name: rhobsmonitoring.EnvironmentVariable, - Value: "1", - }) - } - - if o.CVOPrometheusURL != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: config.CVOPrometheusURLEnvVar, - Value: o.CVOPrometheusURL, - }) - } - - if o.MonitoringDashboards { - envVars = append(envVars, corev1.EnvVar{ - Name: "MONITORING_DASHBOARDS", - Value: "1", - }) - } + image := o.resolveImage() + o.addImageTagEnvVars(&envVars) + o.addPrivatePlatformResources(&envVars, &volumeMounts, &volumes) deployment := &appsv1.Deployment{ TypeMeta: metav1.TypeMeta{ @@ -1067,6 +745,280 @@ func (o HyperShiftOperatorDeployment) Build() *appsv1.Deployment { return deployment } +func (o HyperShiftOperatorDeployment) buildArgs() []string { + args := []string{ + "run", + "--namespace=$(MY_NAMESPACE)", + "--pod-name=$(MY_NAME)", + "--metrics-addr=:9000", + fmt.Sprintf("--enable-dedicated-request-serving-isolation=%t", o.EnableDedicatedRequestServingIsolation), + fmt.Sprintf("--enable-ocp-cluster-monitoring=%t", o.EnableOCPClusterMonitoring), + fmt.Sprintf("--enable-ci-debug-output=%t", o.EnableCIDebugOutput), + fmt.Sprintf("--private-platform=%s", o.PrivatePlatform), + } + if o.RegistryOverrides != "" { + args = append(args, fmt.Sprintf("--registry-overrides=%s", o.RegistryOverrides)) + } + return args +} + +func (o HyperShiftOperatorDeployment) buildEnvVars() []corev1.EnvVar { + envVars := []corev1.EnvVar{ + { + Name: "MY_NAMESPACE", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.namespace", + }, + }, + }, + { + Name: "MY_NAME", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.name", + }, + }, + }, + metrics.MetricsSetToEnv(o.MetricsSet), + { + Name: "CERT_ROTATION_SCALE", + Value: o.CertRotationScale.String(), + }, + { + Name: "KUBE_FEATURE_WatchListClient", + Value: "false", + }, + } + + for key, value := range o.AdditionalOperatorEnvVars { + trimmedKey := strings.TrimSpace(key) + trimmedValue := strings.TrimSpace(value) + if trimmedKey != "" && + trimmedValue != "" && + !slices.ContainsFunc(envVars, func(e corev1.EnvVar) bool { + return e.Name == trimmedKey + }) { + envVars = append(envVars, corev1.EnvVar{ + Name: trimmedKey, + Value: trimmedValue, + }) + } + } + + if o.TechPreviewNoUpgrade { + envVars = append(envVars, corev1.EnvVar{Name: "HYPERSHIFT_FEATURESET", Value: string(configv1.TechPreviewNoUpgrade)}) + } + if o.EnableAuditLogPersistence { + envVars = append(envVars, corev1.EnvVar{Name: "ENABLE_AUDIT_LOG_PERSISTENCE", Value: "true"}) + } + if o.EnableCVOManagementClusterMetricsAccess { + envVars = append(envVars, corev1.EnvVar{Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, Value: "1"}) + } + if len(o.ManagedService) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: "MANAGED_SERVICE", Value: o.ManagedService}) + } + if len(o.AROHCPKeyVaultUsersClientID) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: config.AROHCPKeyVaultManagedIdentityClientID, Value: o.AROHCPKeyVaultUsersClientID}) + } + if o.EnableSizeTagging { + envVars = append(envVars, corev1.EnvVar{Name: "ENABLE_SIZE_TAGGING", Value: "1"}) + } + if o.EnableEtcdRecovery { + envVars = append(envVars, corev1.EnvVar{Name: config.EnableEtcdRecoveryEnvVar, Value: "1"}) + } + if o.EnableCPOOverrides { + envVars = append(envVars, corev1.EnvVar{Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, Value: "1"}) + } + if len(o.PlatformsInstalled) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: "PLATFORMS_INSTALLED", Value: o.PlatformsInstalled}) + } + if o.RHOBSMonitoring { + envVars = append(envVars, corev1.EnvVar{Name: rhobsmonitoring.EnvironmentVariable, Value: "1"}) + } + if o.CVOPrometheusURL != "" { + envVars = append(envVars, corev1.EnvVar{Name: config.CVOPrometheusURLEnvVar, Value: o.CVOPrometheusURL}) + } + if o.MonitoringDashboards { + envVars = append(envVars, corev1.EnvVar{Name: "MONITORING_DASHBOARDS", Value: "1"}) + } + return envVars +} + +func (o HyperShiftOperatorDeployment) addWebhookResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if !o.EnableWebhook { + return + } + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "serving-cert", + MountPath: "/var/run/secrets/serving-cert", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "serving-cert", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "manager-serving-cert", + }, + }, + }) + *args = append(*args, "--cert-dir=/var/run/secrets/serving-cert") + if o.EnableValidatingWebhook { + *args = append(*args, "--enable-validating-webhook=true") + } +} + +func (o HyperShiftOperatorDeployment) addOIDCResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if len(o.OIDCBucketName) == 0 || len(o.OIDCBucketRegion) == 0 || len(o.OIDCStorageProviderS3SecretKey) == 0 || + o.OIDCStorageProviderS3Secret == nil || len(o.OIDCStorageProviderS3Secret.Name) == 0 { + return + } + *args = append(*args, + "--oidc-storage-provider-s3-bucket-name="+o.OIDCBucketName, + "--oidc-storage-provider-s3-region="+o.OIDCBucketRegion, + "--oidc-storage-provider-s3-credentials=/etc/oidc-storage-provider-s3-creds/"+o.OIDCStorageProviderS3SecretKey, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "oidc-storage-provider-s3-creds", + MountPath: "/etc/oidc-storage-provider-s3-creds", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "oidc-storage-provider-s3-creds", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.OIDCStorageProviderS3Secret.Name, + }, + }, + }) +} + +func (o HyperShiftOperatorDeployment) addScaleFromZeroResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if o.ScaleFromZeroSecret == nil || len(o.ScaleFromZeroSecret.Name) == 0 || len(o.ScaleFromZeroSecretKey) == 0 || len(o.ScaleFromZeroProvider) == 0 { + return + } + *args = append(*args, + "--scale-from-zero-provider="+o.ScaleFromZeroProvider, + "--scale-from-zero-creds=/etc/scale-from-zero-creds/"+o.ScaleFromZeroSecretKey, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "scale-from-zero-creds", + MountPath: "/etc/scale-from-zero-creds", + ReadOnly: true, + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "scale-from-zero-creds", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.ScaleFromZeroSecret.Name, + }, + }, + }) +} + +func (o HyperShiftOperatorDeployment) resolveImage() string { + image := o.OperatorImage + if mapImage, ok := o.Images["hypershift-operator"]; ok { + image = mapImage + } + return image +} + +func (o HyperShiftOperatorDeployment) addImageTagEnvVars(envVars *[]corev1.EnvVar) { + tagMapping := images.TagMapping() + for tag, ref := range o.Images { + if envVar, exists := tagMapping[tag]; exists { + *envVars = append(*envVars, corev1.EnvVar{ + Name: envVar, + Value: ref, + }) + } + } +} + +func (o HyperShiftOperatorDeployment) addPrivatePlatformResources(envVars *[]corev1.EnvVar, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + privatePlatformType := hyperv1.PlatformType(o.PrivatePlatform) + if privatePlatformType == hyperv1.NonePlatform { + return + } + switch privatePlatformType { + case hyperv1.AWSPlatform: + *volumes = append(*volumes, corev1.Volume{ + Name: "credentials", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.AWSPrivateSecret.Name, + }, + }, + }) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "credentials", + MountPath: "/etc/provider", + }) + *envVars = append(*envVars, + corev1.EnvVar{Name: "AWS_SHARED_CREDENTIALS_FILE", Value: "/etc/provider/" + o.AWSPrivateSecretKey}, + corev1.EnvVar{Name: "AWS_REGION", Value: o.AWSPrivateRegion}, + corev1.EnvVar{Name: "AWS_SDK_LOAD_CONFIG", Value: "1"}, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "token", + MountPath: "/var/run/secrets/openshift/serviceaccount", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "token", + VolumeSource: corev1.VolumeSource{ + Projected: &corev1.ProjectedVolumeSource{ + Sources: []corev1.VolumeProjection{ + { + ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ + Audience: "openshift", + Path: "token", + }, + }, + }, + }, + }, + }) + case hyperv1.AzurePlatform: + o.addAzurePlatformResources(envVars, volumeMounts, volumes) + case hyperv1.GCPPlatform: + if o.GCPProject != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "GCP_PROJECT", Value: o.GCPProject}) + } + if o.GCPRegion != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "GCP_REGION", Value: o.GCPRegion}) + } + } +} + +func (o HyperShiftOperatorDeployment) addAzurePlatformResources(envVars *[]corev1.EnvVar, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if o.AzurePLSResourceGroup != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_RESOURCE_GROUP", Value: o.AzurePLSResourceGroup}) + } + if o.AzurePLSManagedIdentityClientID != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_PLS_CLIENT_ID", Value: o.AzurePLSManagedIdentityClientID}) + if o.AzurePLSSubscriptionID != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_SUBSCRIPTION_ID", Value: o.AzurePLSSubscriptionID}) + } + } else if o.AzurePrivateSecret != nil && len(o.AzurePrivateSecret.Name) > 0 && len(o.AzurePrivateSecretKey) > 0 { + *volumes = append(*volumes, corev1.Volume{ + Name: "azure-credentials", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.AzurePrivateSecret.Name, + }, + }, + }) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "azure-credentials", + MountPath: "/etc/azure-provider", + ReadOnly: true, + }) + *envVars = append(*envVars, corev1.EnvVar{ + Name: "AZURE_CREDENTIALS_FILE", + Value: "/etc/azure-provider/" + o.AzurePrivateSecretKey, + }) + } +} + type HyperShiftOperatorService struct { Namespace *corev1.Namespace } diff --git a/cmd/install/assets/hypershift_operator_test.go b/cmd/install/assets/hypershift_operator_test.go index 5ca4665d1a9..7dcca5ccf9a 100644 --- a/cmd/install/assets/hypershift_operator_test.go +++ b/cmd/install/assets/hypershift_operator_test.go @@ -3,10 +3,16 @@ package assets import ( "fmt" "testing" + "time" . "github.com/onsi/gomega" hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + controlplaneoperatoroverrides "github.com/openshift/hypershift/hypershift-operator/controlplaneoperator-overrides" + "github.com/openshift/hypershift/support/config" + "github.com/openshift/hypershift/support/rhobsmonitoring" + + configv1 "github.com/openshift/api/config/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -799,3 +805,540 @@ func TestHyperShiftOperatorClusterRole_WebhookRBAC(t *testing.T) { }))) }) } + +func TestBuildArgs(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectContains []string + expectNotContains []string + }{ + { + name: "When registry overrides is set, it should include the flag in args", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + RegistryOverrides: "quay.io=mirror.example.com", + }, + expectContains: []string{"--registry-overrides=quay.io=mirror.example.com"}, + }, + { + name: "When registry overrides is empty, it should not include the flag", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + }, + expectNotContains: []string{"--registry-overrides"}, + }, + { + name: "When all standard options are set, it should include them in args", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.AWSPlatform), + EnableOCPClusterMonitoring: true, + EnableCIDebugOutput: true, + EnableDedicatedRequestServingIsolation: true, + }, + expectContains: []string{ + "run", + "--namespace=$(MY_NAMESPACE)", + "--pod-name=$(MY_NAME)", + "--metrics-addr=:9000", + "--enable-dedicated-request-serving-isolation=true", + "--enable-ocp-cluster-monitoring=true", + "--enable-ci-debug-output=true", + fmt.Sprintf("--private-platform=%s", hyperv1.AWSPlatform), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + args := tc.deployment.buildArgs() + for _, expected := range tc.expectContains { + g.Expect(args).To(ContainElement(expected)) + } + for _, notExpected := range tc.expectNotContains { + for _, arg := range args { + g.Expect(arg).NotTo(HavePrefix(notExpected)) + } + } + }) + } +} + +func TestBuildEnvVars(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectContains []corev1.EnvVar + expectAbsent []string + }{ + { + name: "When TechPreviewNoUpgrade is enabled, it should include HYPERSHIFT_FEATURESET env var", + deployment: HyperShiftOperatorDeployment{ + TechPreviewNoUpgrade: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "HYPERSHIFT_FEATURESET", Value: string(configv1.TechPreviewNoUpgrade)}, + }, + }, + { + name: "When EnableAuditLogPersistence is enabled, it should include ENABLE_AUDIT_LOG_PERSISTENCE env var", + deployment: HyperShiftOperatorDeployment{ + EnableAuditLogPersistence: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "ENABLE_AUDIT_LOG_PERSISTENCE", Value: "true"}, + }, + }, + { + name: "When EnableCVOManagementClusterMetricsAccess is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableCVOManagementClusterMetricsAccess: true, + }, + expectContains: []corev1.EnvVar{ + {Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, Value: "1"}, + }, + }, + { + name: "When ManagedService is set, it should include MANAGED_SERVICE env var", + deployment: HyperShiftOperatorDeployment{ + ManagedService: hyperv1.AroHCP, + }, + expectContains: []corev1.EnvVar{ + {Name: "MANAGED_SERVICE", Value: hyperv1.AroHCP}, + }, + }, + { + name: "When EnableSizeTagging is enabled, it should include ENABLE_SIZE_TAGGING env var", + deployment: HyperShiftOperatorDeployment{ + EnableSizeTagging: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "ENABLE_SIZE_TAGGING", Value: "1"}, + }, + }, + { + name: "When EnableEtcdRecovery is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableEtcdRecovery: true, + }, + expectContains: []corev1.EnvVar{ + {Name: config.EnableEtcdRecoveryEnvVar, Value: "1"}, + }, + }, + { + name: "When EnableCPOOverrides is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableCPOOverrides: true, + }, + expectContains: []corev1.EnvVar{ + {Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, Value: "1"}, + }, + }, + { + name: "When PlatformsInstalled is set, it should include PLATFORMS_INSTALLED env var", + deployment: HyperShiftOperatorDeployment{ + PlatformsInstalled: "aws,azure", + }, + expectContains: []corev1.EnvVar{ + {Name: "PLATFORMS_INSTALLED", Value: "aws,azure"}, + }, + }, + { + name: "When RHOBSMonitoring is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + RHOBSMonitoring: true, + }, + expectContains: []corev1.EnvVar{ + {Name: rhobsmonitoring.EnvironmentVariable, Value: "1"}, + }, + }, + { + name: "When CVOPrometheusURL is set, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + CVOPrometheusURL: "https://prometheus.example.com", + }, + expectContains: []corev1.EnvVar{ + {Name: config.CVOPrometheusURLEnvVar, Value: "https://prometheus.example.com"}, + }, + }, + { + name: "When MonitoringDashboards is enabled, it should include MONITORING_DASHBOARDS env var", + deployment: HyperShiftOperatorDeployment{ + MonitoringDashboards: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "MONITORING_DASHBOARDS", Value: "1"}, + }, + }, + { + name: "When no optional features are enabled, it should not include optional env vars", + deployment: HyperShiftOperatorDeployment{ + CertRotationScale: 24 * time.Hour, + }, + expectAbsent: []string{ + "HYPERSHIFT_FEATURESET", + "ENABLE_AUDIT_LOG_PERSISTENCE", + "MANAGED_SERVICE", + "ENABLE_SIZE_TAGGING", + "MONITORING_DASHBOARDS", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + envVars := tc.deployment.buildEnvVars() + for _, expected := range tc.expectContains { + g.Expect(envVars).To(ContainElement(expected)) + } + for _, absent := range tc.expectAbsent { + for _, env := range envVars { + g.Expect(env.Name).NotTo(Equal(absent)) + } + } + }) + } +} + +func TestAddWebhookResources(t *testing.T) { + tests := []struct { + name string + enableWebhook bool + enableValidatingWebhook bool + expectArgs []string + expectVolumeMountCount int + expectVolumeCount int + }{ + { + name: "When webhook is disabled, it should not add any resources", + enableWebhook: false, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + }, + { + name: "When webhook is enabled without validating webhook, it should add serving-cert resources and cert-dir arg", + enableWebhook: true, + expectArgs: []string{"--cert-dir=/var/run/secrets/serving-cert"}, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + }, + { + name: "When webhook and validating webhook are both enabled, it should add cert-dir and enable-validating-webhook args", + enableWebhook: true, + enableValidatingWebhook: true, + expectArgs: []string{"--cert-dir=/var/run/secrets/serving-cert", "--enable-validating-webhook=true"}, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + d := HyperShiftOperatorDeployment{ + EnableWebhook: tc.enableWebhook, + EnableValidatingWebhook: tc.enableValidatingWebhook, + } + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + d.addWebhookResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(volumes).To(HaveLen(tc.expectVolumeCount)) + for _, expected := range tc.expectArgs { + g.Expect(args).To(ContainElement(expected)) + } + }) + } +} + +func TestAddOIDCResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectVolumeMountCount int + expectVolumeCount int + expectArgCount int + }{ + { + name: "When all OIDC parameters are set, it should add OIDC resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketName: "my-bucket", + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + OIDCStorageProviderS3Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "oidc-secret"}, + }, + }, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + expectArgCount: 3, + }, + { + name: "When OIDC bucket name is empty, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + OIDCStorageProviderS3Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "oidc-secret"}, + }, + }, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + expectArgCount: 0, + }, + { + name: "When OIDC secret is nil, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketName: "my-bucket", + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + }, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + expectArgCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addOIDCResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(volumes).To(HaveLen(tc.expectVolumeCount)) + g.Expect(args).To(HaveLen(tc.expectArgCount)) + }) + } +} + +func TestAddScaleFromZeroResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectVolumeMountCount int + expectArgCount int + }{ + { + name: "When all scale-from-zero parameters are set, it should add resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "sfz-secret"}, + }, + ScaleFromZeroSecretKey: "credentials", + ScaleFromZeroProvider: "aws", + }, + expectVolumeMountCount: 1, + expectArgCount: 2, + }, + { + name: "When scale-from-zero secret is nil, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecretKey: "credentials", + ScaleFromZeroProvider: "aws", + }, + expectVolumeMountCount: 0, + expectArgCount: 0, + }, + { + name: "When scale-from-zero provider is empty, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "sfz-secret"}, + }, + ScaleFromZeroSecretKey: "credentials", + }, + expectVolumeMountCount: 0, + expectArgCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addScaleFromZeroResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(args).To(HaveLen(tc.expectArgCount)) + }) + } +} + +func TestResolveImage(t *testing.T) { + tests := []struct { + name string + operatorImage string + images map[string]string + expected string + }{ + { + name: "When no image override in map, it should use OperatorImage", + operatorImage: "default-image:latest", + images: map[string]string{}, + expected: "default-image:latest", + }, + { + name: "When image override exists in map, it should use the override", + operatorImage: "default-image:latest", + images: map[string]string{"hypershift-operator": "override-image:v1"}, + expected: "override-image:v1", + }, + { + name: "When images map is nil, it should use OperatorImage", + operatorImage: "default-image:latest", + images: nil, + expected: "default-image:latest", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + d := HyperShiftOperatorDeployment{ + OperatorImage: tc.operatorImage, + Images: tc.images, + } + result := d.resolveImage() + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestAddPrivatePlatformResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectEnvNames []string + expectAbsentEnvs []string + }{ + { + name: "When private platform is None, it should not add any platform resources", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + }, + expectAbsentEnvs: []string{"AWS_SHARED_CREDENTIALS_FILE", "AZURE_CREDENTIALS_FILE", "GCP_PROJECT"}, + }, + { + name: "When private platform is AWS, it should add AWS env vars and volumes", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.AWSPlatform), + AWSPrivateRegion: "us-east-1", + AWSPrivateSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "aws-creds"}, + }, + AWSPrivateSecretKey: "credentials", + }, + expectEnvNames: []string{"AWS_SHARED_CREDENTIALS_FILE", "AWS_REGION", "AWS_SDK_LOAD_CONFIG"}, + }, + { + name: "When private platform is GCP with project and region, it should add GCP env vars", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + GCPRegion: "us-central1", + }, + expectEnvNames: []string{"GCP_PROJECT", "GCP_REGION"}, + }, + { + name: "When private platform is GCP without project, it should not add GCP_PROJECT env var", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.GCPPlatform), + }, + expectAbsentEnvs: []string{"GCP_PROJECT", "GCP_REGION"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var envVars []corev1.EnvVar + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addPrivatePlatformResources(&envVars, &volumeMounts, &volumes) + + envNames := make([]string, 0, len(envVars)) + for _, e := range envVars { + envNames = append(envNames, e.Name) + } + for _, expected := range tc.expectEnvNames { + g.Expect(envNames).To(ContainElement(expected)) + } + for _, absent := range tc.expectAbsentEnvs { + g.Expect(envNames).NotTo(ContainElement(absent)) + } + }) + } +} + +func TestAddAzurePlatformResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectEnvNames []string + expectVolCount int + }{ + { + name: "When Azure managed identity is provided, it should set PLS client ID and subscription ID env vars", + deployment: HyperShiftOperatorDeployment{ + AzurePLSManagedIdentityClientID: "client-id", + AzurePLSSubscriptionID: "sub-id", + AzurePLSResourceGroup: "rg-mgmt", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP", "AZURE_PLS_CLIENT_ID", "AZURE_SUBSCRIPTION_ID"}, + expectVolCount: 0, + }, + { + name: "When Azure credentials file is provided, it should mount the credentials volume", + deployment: HyperShiftOperatorDeployment{ + AzurePLSResourceGroup: "rg-mgmt", + AzurePrivateSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "azure-creds"}, + }, + AzurePrivateSecretKey: "credentials", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP", "AZURE_CREDENTIALS_FILE"}, + expectVolCount: 1, + }, + { + name: "When neither managed identity nor credentials file is provided, it should not add env vars or volumes", + deployment: HyperShiftOperatorDeployment{ + AzurePLSResourceGroup: "rg-mgmt", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP"}, + expectVolCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var envVars []corev1.EnvVar + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addAzurePlatformResources(&envVars, &volumeMounts, &volumes) + + envNames := make([]string, 0, len(envVars)) + for _, e := range envVars { + envNames = append(envNames, e.Name) + } + for _, expected := range tc.expectEnvNames { + g.Expect(envNames).To(ContainElement(expected)) + } + g.Expect(volumes).To(HaveLen(tc.expectVolCount)) + }) + } +} diff --git a/cmd/install/install.go b/cmd/install/install.go index e022f1ecb7b..ed798ea8f4f 100644 --- a/cmd/install/install.go +++ b/cmd/install/install.go @@ -166,46 +166,67 @@ func (o *Options) Validate() error { o.ScaleFromZeroProvider = strings.ToLower(o.ScaleFromZeroProvider) } + errs = append(errs, o.validatePlatformConfig()...) + errs = append(errs, o.validateOIDCConfig()...) + errs = append(errs, o.validateExternalDNSConfig()...) + errs = append(errs, o.validateImageConfig()...) + errs = append(errs, o.validateScaleFromZeroConfig()...) + errs = append(errs, o.validateMonitoringConfig()...) + errs = append(errs, o.validateMiscConfig()...) + + return errors.NewAggregate(errs) +} + +func (o *Options) validatePlatformConfig() []error { + var errs []error switch hyperv1.PlatformType(o.PrivatePlatform) { case hyperv1.AWSPlatform: if (len(o.AWSPrivateCreds) == 0 && len(o.AWSPrivateCredentialsSecret) == 0) || len(o.AWSPrivateRegion) == 0 { errs = append(errs, fmt.Errorf("--aws-private-region and --aws-private-creds or --aws-private-secret are required with --private-platform=%s", hyperv1.AWSPlatform)) } case hyperv1.GCPPlatform: - // GCP uses Workload Identity Federation, no credentials required. - // However, --gcp-project and --gcp-region must be set together. if (o.GCPProject == "") != (o.GCPRegion == "") { errs = append(errs, fmt.Errorf("--gcp-project and --gcp-region must be set together when --private-platform=%s", hyperv1.GCPPlatform)) } case hyperv1.AzurePlatform: - if o.ManagedService != hyperv1.AroHCP { - hasCredFile := len(o.AzurePrivateCreds) != 0 || len(o.AzurePrivateCredentialsSecret) != 0 - hasManagedIdentity := len(o.AzurePLSManagedIdentityClientID) != 0 - if !hasCredFile && !hasManagedIdentity { - errs = append(errs, fmt.Errorf("--azure-private-creds, --azure-private-secret, or --azure-pls-managed-identity-client-id is required with --private-platform=%s", hyperv1.AzurePlatform)) - } - if hasCredFile && hasManagedIdentity { - errs = append(errs, fmt.Errorf("--azure-pls-managed-identity-client-id cannot be used with --azure-private-creds or --azure-private-secret")) - } - if hasManagedIdentity && len(o.AzurePLSSubscriptionID) == 0 { - errs = append(errs, fmt.Errorf("--azure-pls-subscription-id is required when using --azure-pls-managed-identity-client-id")) - } - if len(o.AzurePrivateCreds) != 0 && len(o.AzurePrivateCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --azure-private-creds or --azure-private-secret is supported")) - } - if len(o.AzurePLSResourceGroup) == 0 { - errs = append(errs, fmt.Errorf("--azure-pls-resource-group is required with --private-platform=%s", hyperv1.AzurePlatform)) - } - } + errs = append(errs, o.validateAzurePlatformConfig()...) case hyperv1.NonePlatform: default: errs = append(errs, fmt.Errorf("--private-platform must be either %s, %s, %s, or %s", hyperv1.AWSPlatform, hyperv1.AzurePlatform, hyperv1.GCPPlatform, hyperv1.NonePlatform)) } + return errs +} +func (o *Options) validateAzurePlatformConfig() []error { + if o.ManagedService == hyperv1.AroHCP { + return nil + } + var errs []error + hasCredFile := len(o.AzurePrivateCreds) != 0 || len(o.AzurePrivateCredentialsSecret) != 0 + hasManagedIdentity := len(o.AzurePLSManagedIdentityClientID) != 0 + if !hasCredFile && !hasManagedIdentity { + errs = append(errs, fmt.Errorf("--azure-private-creds, --azure-private-secret, or --azure-pls-managed-identity-client-id is required with --private-platform=%s", hyperv1.AzurePlatform)) + } + if hasCredFile && hasManagedIdentity { + errs = append(errs, fmt.Errorf("--azure-pls-managed-identity-client-id cannot be used with --azure-private-creds or --azure-private-secret")) + } + if hasManagedIdentity && len(o.AzurePLSSubscriptionID) == 0 { + errs = append(errs, fmt.Errorf("--azure-pls-subscription-id is required when using --azure-pls-managed-identity-client-id")) + } + if len(o.AzurePrivateCreds) != 0 && len(o.AzurePrivateCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --azure-private-creds or --azure-private-secret is supported")) + } + if len(o.AzurePLSResourceGroup) == 0 { + errs = append(errs, fmt.Errorf("--azure-pls-resource-group is required with --private-platform=%s", hyperv1.AzurePlatform)) + } + return errs +} + +func (o *Options) validateOIDCConfig() []error { + var errs []error if len(o.OIDCStorageProviderS3CredentialsSecret) > 0 && len(o.OIDCStorageProviderS3Credentials) > 0 { errs = append(errs, fmt.Errorf("only one of --oidc-storage-provider-s3-secret or --oidc-storage-provider-s3-credentials is supported")) } - if (len(o.OIDCStorageProviderS3CredentialsSecret) > 0 || len(o.OIDCStorageProviderS3Credentials) > 0) && (len(o.OIDCStorageProviderS3BucketName) == 0 || len(o.OIDCStorageProviderS3Region) == 0 || len(o.OIDCStorageProviderS3CredentialsSecretKey) == 0) { errs = append(errs, fmt.Errorf("all required oidc information is not set")) @@ -213,91 +234,102 @@ func (o *Options) Validate() error { if strings.Contains(o.OIDCStorageProviderS3BucketName, ".") { errs = append(errs, fmt.Errorf("oidc bucket name must not contain dots (.); see the notes on HTTPS at https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html")) } + return errs +} - if len(o.ExternalDNSProvider) > 0 { - // Credentials are optional for GCP when using Workload Identity - credentialsRequired := o.ExternalDNSProvider != "google" - if credentialsRequired && len(o.ExternalDNSCredentials) == 0 && len(o.ExternalDNSCredentialsSecret) == 0 { - errs = append(errs, fmt.Errorf("--external-dns-credentials or --external-dns-credentials-secret are required with --external-dns-provider")) - } - if len(o.ExternalDNSCredentials) != 0 && len(o.ExternalDNSCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --external-dns-credentials or --external-dns-credentials-secret is supported")) - } - if len(o.ExternalDNSDomainFilter) == 0 { - errs = append(errs, fmt.Errorf("--external-dns-domain-filter is required with --external-dns-provider")) +func (o *Options) validateExternalDNSConfig() []error { + if len(o.ExternalDNSProvider) == 0 { + return nil + } + var errs []error + credentialsRequired := o.ExternalDNSProvider != "google" + if credentialsRequired && len(o.ExternalDNSCredentials) == 0 && len(o.ExternalDNSCredentialsSecret) == 0 { + errs = append(errs, fmt.Errorf("--external-dns-credentials or --external-dns-credentials-secret are required with --external-dns-provider")) + } + if len(o.ExternalDNSCredentials) != 0 && len(o.ExternalDNSCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --external-dns-credentials or --external-dns-credentials-secret is supported")) + } + if len(o.ExternalDNSDomainFilter) == 0 { + errs = append(errs, fmt.Errorf("--external-dns-domain-filter is required with --external-dns-provider")) + } + if len(o.ExternalDNSInterval) > 0 { + if _, err := time.ParseDuration(o.ExternalDNSInterval); err != nil { + errs = append(errs, fmt.Errorf("--external-dns-interval is not a valid duration: %w", err)) } - if len(o.ExternalDNSInterval) > 0 { - if _, err := time.ParseDuration(o.ExternalDNSInterval); err != nil { - errs = append(errs, fmt.Errorf("--external-dns-interval is not a valid duration: %w", err)) - } + } + if len(o.ExternalDNSAWSZonesCacheDuration) > 0 { + if _, err := time.ParseDuration(o.ExternalDNSAWSZonesCacheDuration); err != nil { + errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is not a valid duration: %w", err)) } - if len(o.ExternalDNSAWSZonesCacheDuration) > 0 { - if _, err := time.ParseDuration(o.ExternalDNSAWSZonesCacheDuration); err != nil { - errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is not a valid duration: %w", err)) - } - if o.ExternalDNSProvider != "aws" { - errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is only effective with --external-dns-provider=aws")) - } + if o.ExternalDNSProvider != "aws" { + errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is only effective with --external-dns-provider=aws")) } } + return errs +} + +func (o *Options) validateImageConfig() []error { + var errs []error if o.HyperShiftImage != HyperShiftImage && len(o.ImageRefsFile) > 0 { errs = append(errs, fmt.Errorf("only one of --hypershift-image or --image-refs-file should be specified")) } if o.RHOBSMonitoring && os.Getenv(rhobsmonitoring.EnvironmentVariable) != "1" { errs = append(errs, fmt.Errorf("when invoking this command with the --rhobs-monitoring flag, the RHOBS_MONITORING environment variable must be set to \"1\"")) } - if o.CertRotationScale > 24*time.Hour { errs = append(errs, fmt.Errorf("cannot set --cert-rotation-scale longer than 24h, invalid value: %s", o.CertRotationScale.String())) } + return errs +} - // Validate scale-from-zero credentials +func (o *Options) validateScaleFromZeroConfig() []error { + if len(o.ScaleFromZeroCreds) == 0 && len(o.ScaleFromZeroCredentialsSecret) == 0 { + return nil + } + var errs []error supportedProviders := set.New("aws") - if len(o.ScaleFromZeroCreds) != 0 || len(o.ScaleFromZeroCredentialsSecret) != 0 { - // Check mutual exclusivity - only one of file or secret should be provided - if len(o.ScaleFromZeroCreds) != 0 && len(o.ScaleFromZeroCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --scale-from-zero-creds or --scale-from-zero-secret is supported")) - } - - // Provider is required when using scale-from-zero credentials - if len(o.ScaleFromZeroProvider) == 0 { - errs = append(errs, fmt.Errorf("--scale-from-zero-provider is required when using scale-from-zero credentials")) - } else if !supportedProviders.Has(o.ScaleFromZeroProvider) { - errs = append(errs, fmt.Errorf("invalid --scale-from-zero-provider: %s (must be one of: %v)", o.ScaleFromZeroProvider, supportedProviders.UnsortedList())) - } - - // Validate credentials file exists and is accessible if provided - if len(o.ScaleFromZeroCreds) > 0 { - if _, err := os.Stat(o.ScaleFromZeroCreds); err != nil { - if os.IsNotExist(err) { - errs = append(errs, fmt.Errorf("--scale-from-zero-creds file does not exist: %s", o.ScaleFromZeroCreds)) - } else { - errs = append(errs, fmt.Errorf("--scale-from-zero-creds file is not accessible: %w", err)) - } + if len(o.ScaleFromZeroCreds) != 0 && len(o.ScaleFromZeroCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --scale-from-zero-creds or --scale-from-zero-secret is supported")) + } + if len(o.ScaleFromZeroProvider) == 0 { + errs = append(errs, fmt.Errorf("--scale-from-zero-provider is required when using scale-from-zero credentials")) + } else if !supportedProviders.Has(o.ScaleFromZeroProvider) { + errs = append(errs, fmt.Errorf("invalid --scale-from-zero-provider: %s (must be one of: %v)", o.ScaleFromZeroProvider, supportedProviders.UnsortedList())) + } + if len(o.ScaleFromZeroCreds) > 0 { + if _, err := os.Stat(o.ScaleFromZeroCreds); err != nil { + if os.IsNotExist(err) { + errs = append(errs, fmt.Errorf("--scale-from-zero-creds file does not exist: %s", o.ScaleFromZeroCreds)) + } else { + errs = append(errs, fmt.Errorf("--scale-from-zero-creds file is not accessible: %w", err)) } } } + return errs +} +func (o *Options) validateMonitoringConfig() []error { + var errs []error if o.RHOBSMonitoring && o.EnableCVOManagementClusterMetricsAccess { errs = append(errs, fmt.Errorf("when invoking this command with the --rhobs-monitoring flag, the --enable-cvo-management-cluster-metrics-access flag is not supported ")) } - if len(o.CVOPrometheusURL) > 0 && !o.RHOBSMonitoring && !o.EnableCVOManagementClusterMetricsAccess { errs = append(errs, fmt.Errorf("--cvo-prometheus-url requires either --rhobs-monitoring or --enable-cvo-management-cluster-metrics-access to be enabled")) } + return errs +} +func (o *Options) validateMiscConfig() []error { + var errs []error if len(o.ManagedService) > 0 && o.ManagedService != hyperv1.AroHCP { errs = append(errs, fmt.Errorf("not a valid managed service type: %s", o.ManagedService)) } - - // Validate all the platforms in the list are valid for _, platform := range o.PlatformsToInstall { platformToCheck := strings.ToLower(platform) if !ValidPlatforms.Has(platformToCheck) { errs = append(errs, fmt.Errorf("not a valid platform type: %s", platform)) } } - if len(o.ImagePullPolicy) > 0 { normalized := strings.ToLower(o.ImagePullPolicy) switch normalized { @@ -306,8 +338,7 @@ func (o *Options) Validate() error { errs = append(errs, fmt.Errorf("invalid --image-pull-policy: %s (want Always|Never|IfNotPresent)", o.ImagePullPolicy)) } } - - return errors.NewAggregate(errs) + return errs } func (o *Options) ApplyDefaults() { diff --git a/cmd/install/install_test.go b/cmd/install/install_test.go index b01725f0b92..da72c2ea3cc 100644 --- a/cmd/install/install_test.go +++ b/cmd/install/install_test.go @@ -13,12 +13,15 @@ import ( . "github.com/onsi/gomega" hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + "github.com/openshift/hypershift/cmd/install/assets" crdassets "github.com/openshift/hypershift/cmd/install/assets/crds" "github.com/openshift/hypershift/hypershift-operator/controllers/sharedingress" + "github.com/openshift/hypershift/support/metrics" hyperapi "github.com/openshift/hypershift/support/api" operatorv1alpha1 "github.com/openshift/api/operator/v1alpha1" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" @@ -30,6 +33,8 @@ import ( crclient "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + + prometheusoperatorv1 "github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring/v1" ) func TestOptions_Validate(t *testing.T) { @@ -968,3 +973,785 @@ func TestWaitForCAPIOperatorSync(t *testing.T) { }) } } +func TestApplyDefaults(t *testing.T) { + tests := []struct { + name string + opts Options + expectedReplicas int32 + }{ + { + name: "When Development mode is enabled, it should set replicas to 0", + opts: Options{Development: true}, + expectedReplicas: 0, + }, + { + name: "When defaulting webhook is enabled, it should set replicas to 2", + opts: Options{EnableDefaultingWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When conversion webhook is enabled, it should set replicas to 2", + opts: Options{EnableConversionWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When validating webhook is enabled, it should set replicas to 2", + opts: Options{EnableValidatingWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When no special options are set, it should default replicas to 1", + opts: Options{}, + expectedReplicas: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + tc.opts.ApplyDefaults() + g.Expect(tc.opts.HyperShiftOperatorReplicas).To(Equal(tc.expectedReplicas)) + g.Expect(tc.opts.RenderNamespace).To(BeTrue()) + }) + } +} + +func TestIsAWSPlatformEnabled(t *testing.T) { + tests := []struct { + name string + platformsToInstall []string + expected bool + }{ + { + name: "When no platforms are specified, it should return true (all platforms enabled)", + platformsToInstall: nil, + expected: true, + }, + { + name: "When empty platforms list is specified, it should return true", + platformsToInstall: []string{}, + expected: true, + }, + { + name: "When AWS is in the list, it should return true", + platformsToInstall: []string{"AWS", "Azure"}, + expected: true, + }, + { + name: "When aws (lowercase) is in the list, it should return true", + platformsToInstall: []string{"aws"}, + expected: true, + }, + { + name: "When AWS is not in the list, it should return false", + platformsToInstall: []string{"Azure", "GCP"}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := isAWSPlatformEnabled(tc.platformsToInstall) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestIsAzurePlatformEnabled(t *testing.T) { + tests := []struct { + name string + platformsToInstall []string + expected bool + }{ + { + name: "When no platforms are specified, it should return true (all platforms enabled)", + platformsToInstall: nil, + expected: true, + }, + { + name: "When Azure is in the list, it should return true", + platformsToInstall: []string{"Azure", "AWS"}, + expected: true, + }, + { + name: "When azure (lowercase) is in the list, it should return true", + platformsToInstall: []string{"azure"}, + expected: true, + }, + { + name: "When Azure is not in the list, it should return false", + platformsToInstall: []string{"AWS", "GCP"}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := isAzurePlatformEnabled(tc.platformsToInstall) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestValidatePlatformConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When private platform is None, it should pass", + opts: Options{PrivatePlatform: string(hyperv1.NonePlatform)}, + expectError: false, + }, + { + name: "When private platform is an unsupported value, it should fail", + opts: Options{PrivatePlatform: "Unsupported"}, + expectError: true, + }, + { + name: "When private platform is AWS with creds and region, it should pass", + opts: Options{ + PrivatePlatform: string(hyperv1.AWSPlatform), + AWSPrivateCreds: "/path/to/creds", + AWSPrivateRegion: "us-east-1", + }, + expectError: false, + }, + { + name: "When private platform is AWS without creds or region, it should fail", + opts: Options{ + PrivatePlatform: string(hyperv1.AWSPlatform), + }, + expectError: true, + }, + { + name: "When private platform is GCP with only project set, it should fail", + opts: Options{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + }, + expectError: true, + }, + { + name: "When private platform is GCP with both project and region, it should pass", + opts: Options{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + GCPRegion: "us-central1", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validatePlatformConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateOIDCConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no OIDC options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When both OIDC secret and credentials are set, it should fail", + opts: Options{ + OIDCStorageProviderS3CredentialsSecret: "my-secret", + OIDCStorageProviderS3Credentials: "my-creds", + }, + expectError: true, + }, + { + name: "When OIDC credentials are set without bucket name, it should fail", + opts: Options{ + OIDCStorageProviderS3Credentials: "my-creds", + }, + expectError: true, + }, + { + name: "When OIDC bucket name contains dots, it should fail", + opts: Options{ + OIDCStorageProviderS3BucketName: "my.bucket.name", + }, + expectError: true, + }, + { + name: "When all OIDC parameters are provided correctly, it should pass", + opts: Options{ + OIDCStorageProviderS3Credentials: "my-creds", + OIDCStorageProviderS3BucketName: "mybucket", + OIDCStorageProviderS3Region: "us-east-1", + OIDCStorageProviderS3CredentialsSecretKey: "mykey", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateOIDCConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateExternalDNSConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no external DNS provider is set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When external DNS provider is set without credentials or domain filter, it should fail", + opts: Options{ + ExternalDNSProvider: "aws", + }, + expectError: true, + }, + { + name: "When external DNS provider is set with credentials and domain filter, it should pass", + opts: Options{ + ExternalDNSProvider: "aws", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + }, + expectError: false, + }, + { + name: "When external DNS interval is an invalid duration, it should fail", + opts: Options{ + ExternalDNSProvider: "aws", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + ExternalDNSInterval: "not-a-duration", + }, + expectError: true, + }, + { + name: "When AWS zones cache duration is set with non-AWS provider, it should fail", + opts: Options{ + ExternalDNSProvider: "azure", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + ExternalDNSAWSZonesCacheDuration: "1h", + }, + expectError: true, + }, + { + name: "When google provider is set without credentials, it should pass (Workload Identity)", + opts: Options{ + ExternalDNSProvider: "google", + ExternalDNSDomainFilter: "example.com", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateExternalDNSConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateMonitoringConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no monitoring options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When both RHOBS monitoring and CVO management cluster metrics access are set, it should fail", + opts: Options{ + RHOBSMonitoring: true, + EnableCVOManagementClusterMetricsAccess: true, + }, + expectError: true, + }, + { + name: "When CVO prometheus URL is set without RHOBS or CVO metrics, it should fail", + opts: Options{ + CVOPrometheusURL: "https://prometheus.example.com", + }, + expectError: true, + }, + { + name: "When CVO prometheus URL is set with RHOBS monitoring, it should pass", + opts: Options{ + CVOPrometheusURL: "https://prometheus.example.com", + RHOBSMonitoring: true, + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateMonitoringConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateMiscConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no misc options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When managed service is an invalid value, it should fail", + opts: Options{ + ManagedService: "INVALID", + }, + expectError: true, + }, + { + name: "When managed service is ARO-HCP, it should pass", + opts: Options{ + ManagedService: hyperv1.AroHCP, + }, + expectError: false, + }, + { + name: "When an invalid platform is specified, it should fail", + opts: Options{ + PlatformsToInstall: []string{"invalid-platform"}, + }, + expectError: true, + }, + { + name: "When valid platforms are specified, it should pass", + opts: Options{ + PlatformsToInstall: []string{"aws", "Azure"}, + }, + expectError: false, + }, + { + name: "When invalid image pull policy is specified, it should fail", + opts: Options{ + ImagePullPolicy: "WheneverYouFeel", + }, + expectError: true, + }, + { + name: "When Always image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "Always", + }, + expectError: false, + }, + { + name: "When Never image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "Never", + }, + expectError: false, + }, + { + name: "When IfNotPresent image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "IfNotPresent", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateMiscConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestSetupMonitoring(t *testing.T) { + tests := []struct { + name string + opts Options + expectSLOAlerts bool + expectDashboards bool + minResourceCount int + }{ + { + name: "When SLOs alerts and monitoring dashboards are disabled, it should return base monitoring resources", + opts: Options{ + PlatformMonitoring: metrics.PlatformMonitoringAll, + }, + expectSLOAlerts: false, + expectDashboards: false, + minResourceCount: 4, + }, + { + name: "When SLOs alerts are enabled, it should include alerting rule", + opts: Options{ + SLOsAlerts: true, + }, + expectSLOAlerts: true, + minResourceCount: 5, + }, + { + name: "When monitoring dashboards are enabled, it should include dashboard template", + opts: Options{ + Namespace: "hypershift", + MonitoringDashboards: true, + }, + expectDashboards: true, + minResourceCount: 5, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + objects := setupMonitoring(tc.opts, ns) + + g.Expect(len(objects)).To(BeNumerically(">=", tc.minResourceCount)) + + // Check SLO alerts + foundAlertingRule := false + for _, obj := range objects { + if rule, ok := obj.(*prometheusoperatorv1.PrometheusRule); ok && rule.Namespace == "openshift-monitoring" { + foundAlertingRule = true + break + } + } + if tc.expectSLOAlerts { + g.Expect(foundAlertingRule).To(BeTrue(), "expected SLO alerting rule to be present when SLOsAlerts is enabled") + } else { + g.Expect(foundAlertingRule).To(BeFalse(), "expected no SLO alerting rule when SLOsAlerts is disabled") + } + + // Check dashboards + foundDashboard := false + for _, obj := range objects { + if cm, ok := obj.(*corev1.ConfigMap); ok && cm.Name == "monitoring-dashboard-template" { + foundDashboard = true + break + } + } + if tc.expectDashboards { + g.Expect(foundDashboard).To(BeTrue(), "expected monitoring dashboard to be present when MonitoringDashboards is enabled") + } else { + g.Expect(foundDashboard).To(BeFalse(), "expected no monitoring dashboard when MonitoringDashboards is disabled") + } + }) + } +} + +func TestSetupRBAC(t *testing.T) { + tests := []struct { + name string + enableAdminRBAC bool + azureManagedIdentityClientID string + expectSAAnnotation bool + minObjectCount int + }{ + { + name: "When admin RBAC is disabled, it should return base RBAC resources", + minObjectCount: 6, + }, + { + name: "When admin RBAC is enabled, it should include client and reader RBAC resources", + enableAdminRBAC: true, + minObjectCount: 11, + }, + { + name: "When Azure managed identity is set, it should annotate the service account", + azureManagedIdentityClientID: "test-client-id", + expectSAAnnotation: true, + minObjectCount: 6, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + opts := Options{ + EnableAdminRBACGeneration: tc.enableAdminRBAC, + AzurePLSManagedIdentityClientID: tc.azureManagedIdentityClientID, + } + + sa, objects := setupRBAC(opts, ns) + g.Expect(len(objects)).To(BeNumerically(">=", tc.minObjectCount)) + g.Expect(sa).NotTo(BeNil()) + if tc.expectSAAnnotation { + g.Expect(sa.Annotations).To(HaveKeyWithValue("azure.workload.identity/client-id", tc.azureManagedIdentityClientID)) + } + + // When admin RBAC is enabled, verify specific admin RBAC objects exist + if tc.enableAdminRBAC { + foundClientClusterRole := false + foundReaderClusterRole := false + for _, obj := range objects { + if obj.GetObjectKind().GroupVersionKind().Kind == "ClusterRole" { + if obj.GetName() == "hypershift-client" { + foundClientClusterRole = true + } + if obj.GetName() == "hypershift-readers" { + foundReaderClusterRole = true + } + } + } + g.Expect(foundClientClusterRole).To(BeTrue(), "expected hypershift-client ClusterRole to be present when admin RBAC is enabled") + g.Expect(foundReaderClusterRole).To(BeTrue(), "expected hypershift-readers ClusterRole to be present when admin RBAC is enabled") + } + }) + } +} + +func TestSetupCA(t *testing.T) { + t.Run("When no additional trust bundle is provided, it should return only the managed trust bundle", func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + userCA, trustedCA, objects, err := setupCA(Options{}, ns) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(userCA).To(BeNil()) + g.Expect(trustedCA).NotTo(BeNil()) + g.Expect(trustedCA.Name).To(Equal("openshift-config-managed-trusted-ca-bundle")) + g.Expect(objects).To(HaveLen(1)) + }) +} + +func TestValidateImageConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When default image and no refs file are set, it should pass", + opts: Options{HyperShiftImage: HyperShiftImage}, + expectError: false, + }, + { + name: "When both custom image and refs file are set, it should fail", + opts: Options{ + HyperShiftImage: "custom-image:latest", + ImageRefsFile: "/path/to/refs", + }, + expectError: true, + }, + { + name: "When cert rotation scale is longer than 24h, it should fail", + opts: Options{ + HyperShiftImage: HyperShiftImage, + CertRotationScale: 48 * time.Hour, + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateImageConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestSetupSharedIngress(t *testing.T) { + t.Run("When setupSharedIngress is called, it should return namespace, ClusterRole, and ClusterRoleBinding", func(t *testing.T) { + g := NewGomegaWithT(t) + objects := setupSharedIngress() + g.Expect(objects).To(HaveLen(3)) + + g.Expect(objects[0].GetName()).To(Equal(sharedingress.RouterNamespace)) + g.Expect(objects[0].GetLabels()).To(HaveKeyWithValue("hypershift.openshift.io/component", "shared-ingress")) + + g.Expect(objects[1].GetName()).To(Equal(sharedingress.ConfigGeneratorName)) + g.Expect(objects[2].GetName()).To(Equal(sharedingress.ConfigGeneratorName)) + + crb := objects[2].(*rbacv1.ClusterRoleBinding) + g.Expect(crb.RoleRef.Name).To(Equal(sharedingress.ConfigGeneratorName)) + g.Expect(crb.Subjects).To(HaveLen(1)) + g.Expect(crb.Subjects[0].Name).To(Equal("router")) + g.Expect(crb.Subjects[0].Namespace).To(Equal(sharedingress.RouterNamespace)) + }) +} + +func TestSetupAdminRBAC(t *testing.T) { + t.Run("When setupAdminRBAC is called, it should return client and reader RBAC resources", func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + objects := setupAdminRBAC(ns) + + g.Expect(objects).To(HaveLen(5)) + + objectNames := make([]string, len(objects)) + for i, obj := range objects { + objectNames[i] = obj.GetName() + } + g.Expect(objectNames).To(ContainElement("hypershift-client")) + g.Expect(objectNames).To(ContainElement("hypershift-readers")) + }) +} + +func TestGetDeploymentCondition(t *testing.T) { + tests := []struct { + name string + deployConditions []appsv1.DeploymentCondition + condType string + expectFound bool + }{ + { + name: "When the condition exists, it should return it", + deployConditions: []appsv1.DeploymentCondition{ + {Type: "Progressing", Reason: "NewReplicaSetAvailable"}, + {Type: "Available", Reason: "MinimumReplicasAvailable"}, + }, + condType: "Available", + expectFound: true, + }, + { + name: "When the condition does not exist, it should return nil", + deployConditions: []appsv1.DeploymentCondition{ + {Type: "Progressing", Reason: "NewReplicaSetAvailable"}, + }, + condType: "Available", + expectFound: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + status := appsv1.DeploymentStatus{} + for _, c := range tc.deployConditions { + status.Conditions = append(status.Conditions, appsv1.DeploymentCondition{ + Type: appsv1.DeploymentConditionType(c.Type), + Reason: c.Reason, + }) + } + cond := GetDeploymentCondition(status, appsv1.DeploymentConditionType(tc.condType)) + if tc.expectFound { + g.Expect(cond).NotTo(BeNil()) + } else { + g.Expect(cond).To(BeNil()) + } + }) + } +} + +func TestNewInstallOptionsWithDefaults(t *testing.T) { + t.Run("When NewInstallOptionsWithDefaults is called, it should set all expected defaults", func(t *testing.T) { + g := NewGomegaWithT(t) + opts := NewInstallOptionsWithDefaults() + + g.Expect(opts.Namespace).To(Equal("hypershift")) + g.Expect(opts.PrivatePlatform).To(Equal(string(hyperv1.NonePlatform))) + g.Expect(opts.HyperShiftImage).To(Equal(HyperShiftImage)) + g.Expect(opts.ExternalDNSImage).To(Equal(ExternalDNSImage)) + g.Expect(opts.CertRotationScale).To(Equal(24 * time.Hour)) + g.Expect(opts.ImagePullPolicy).To(Equal("IfNotPresent")) + g.Expect(opts.EnableConversionWebhook).To(BeTrue()) + g.Expect(opts.EnableDedicatedRequestServingIsolation).To(BeTrue()) + g.Expect(opts.EnableEtcdRecovery).To(BeTrue()) + g.Expect(opts.MetricsSet).To(Equal(metrics.DefaultMetricsSet)) + g.Expect(opts.AWSPrivateCredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.ScaleFromZeroCredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.OIDCStorageProviderS3CredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.Development).To(BeFalse()) + g.Expect(opts.EnableAdminRBACGeneration).To(BeFalse()) + g.Expect(opts.AdditionalOperatorEnvVars).NotTo(BeNil()) + }) +} + +func TestHyperShiftNamespaceBuild(t *testing.T) { + tests := []struct { + name string + nsConfig assets.HyperShiftNamespace + expectClusterMonitoringLabel bool + }{ + { + name: "When OCP cluster monitoring is enabled, it should include the monitoring label", + nsConfig: assets.HyperShiftNamespace{ + Name: "hypershift", + EnableOCPClusterMonitoring: true, + }, + expectClusterMonitoringLabel: true, + }, + { + name: "When OCP cluster monitoring is disabled, it should not include the monitoring label", + nsConfig: assets.HyperShiftNamespace{ + Name: "hypershift", + EnableOCPClusterMonitoring: false, + }, + expectClusterMonitoringLabel: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := tc.nsConfig.Build() + g.Expect(ns.Name).To(Equal(tc.nsConfig.Name)) + g.Expect(ns.Labels).To(HaveKeyWithValue("hypershift.openshift.io/component", "operator")) + if tc.expectClusterMonitoringLabel { + g.Expect(ns.Labels).To(HaveKeyWithValue("openshift.io/cluster-monitoring", "true")) + } else { + g.Expect(ns.Labels).NotTo(HaveKey("openshift.io/cluster-monitoring")) + } + }) + } +} From 615bc5bc5a4734f21df84c08c953efd2a029614c Mon Sep 17 00:00:00 2001 From: Bryan Cox Date: Thu, 23 Apr 2026 10:21:25 -0400 Subject: [PATCH 2/7] refactor(hypershift-operator): reduce cyclomatic complexity and enable gocyclo linter - Extract helper functions across HO controllers to reduce function complexity below gocyclo threshold of 30 - Fix etcd recovery short-circuit: handleExistingEtcdRecoveryJob now returns done flag to prevent fall-through into detectAndTrigger - Fix operator crash on non-OpenShift clusters (AKS): tolerate NoMatchError for IngressController CRD in addition to NotFound - Add RBAC/NetworkPolicy cleanup on terminal backup failure in etcd backup controller - Return early when listing HostedClusters fails in metrics collector - Add nil guard for managementClusterNetwork in network policy reconciliation - Fix ignoreError type in sizing controller to use concrete struct instead of type alias - Add nil guard for lastTransitionTime and preserve timed requeue in sizing controller - Add type checks in network policy test createOrUpdate stubs - Use rbacv1.GroupName constant in etcd recovery test assertions - Add behavior-driven unit tests for all extracted functions Signed-off-by: Bryan Cox Commit-Message-Assisted-by: Claude (via Claude Code) --- codecov.yml | 9 + .../snapshot_controller.go | 166 ++- .../snapshot_controller_test.go | 254 +++++ .../controllers/etcdbackup/reconciler.go | 203 ++-- .../controllers/etcdbackup/reconciler_test.go | 288 +++++ .../hostedcluster/etcd_recovery.go | 163 +-- .../hostedcluster/etcd_recovery_test.go | 563 ++++++++++ .../hostedcluster/hostedcluster_controller.go | 2 + .../internal/platform/platform.go | 2 + .../hostedcluster/metrics/metrics.go | 597 ++++++----- .../hostedcluster/network_policies.go | 414 ++++---- .../hostedcluster/network_policies_test.go | 699 +++++++++++++ .../hostedclustersizing_controller.go | 370 ++++--- .../hostedclustersizing_controller_test.go | 613 ++++++++++- .../controllers/nodepool/aws.go | 229 ++-- .../controllers/nodepool/aws_test.go | 629 +++++++++++ .../controllers/nodepool/capi.go | 152 ++- .../controllers/nodepool/capi_test.go | 612 +++++++++++ .../controllers/nodepool/metrics/metrics.go | 370 +++---- .../nodepool/nodepool_controller.go | 1 + .../controllers/platform/aws/controller.go | 239 +++-- .../platform/aws/controller_test.go | 251 +++++ .../controllers/scheduler/aws/autoscaler.go | 165 ++- .../scheduler/aws/autoscaler_test.go | 363 +++++++ .../aws/dedicated_request_serving_nodes.go | 410 ++++---- .../dedicated_request_serving_nodes_test.go | 987 ++++++++++++++++++ hypershift-operator/main.go | 585 ++++++----- 27 files changed, 7402 insertions(+), 1934 deletions(-) create mode 100644 hypershift-operator/controllers/hostedcluster/etcd_recovery_test.go diff --git a/codecov.yml b/codecov.yml index 31d7ec7e77f..71f1202bc9a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -44,6 +44,15 @@ ignore: - "hypershift-operator/controllers/nodepool/instancetype/instancetype.go" - "api/ibmcapi/types.go" +coverage: + status: + project: + default: + informational: true + patch: + default: + informational: true + comment: layout: "condensed_header, diff, files, flags, components" behavior: default diff --git a/hypershift-operator/controllers/auditlogpersistence/snapshot_controller.go b/hypershift-operator/controllers/auditlogpersistence/snapshot_controller.go index 401c44b5f68..9bb17347941 100644 --- a/hypershift-operator/controllers/auditlogpersistence/snapshot_controller.go +++ b/hypershift-operator/controllers/auditlogpersistence/snapshot_controller.go @@ -67,10 +67,77 @@ func SetupSnapshotController(mgr ctrl.Manager) error { return nil } +func (r *SnapshotReconciler) getSnapshotConfig(ctx context.Context) (*auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec, error) { + config := &auditlogpersistencev1alpha1.AuditLogPersistenceConfig{} + if err := r.client.Get(ctx, types.NamespacedName{Name: "cluster"}, config); err != nil { + if apierrors.IsNotFound(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to get AuditLogPersistenceConfig: %w", err) + } + + spec := config.Spec.DeepCopy() + ApplyDefaults(spec) + + if !IsEnabled(spec) || !IsSnapshotsEnabled(spec) { + return nil, nil + } + return spec, nil +} + +func (r *SnapshotReconciler) getLastObservedRestartCount(ctx context.Context, pod *corev1.Pod, log logr.Logger) int32 { + val, ok := pod.Annotations[lastObservedRestartCountAnnotation] + if !ok { + return 0 + } + count, err := parseInt32(val) + if err != nil { + log.V(1).Info("Failed to parse last observed restart count annotation, resetting to 0", "annotationValue", val, "error", err) + podCopy := pod.DeepCopy() + if podCopy.Annotations == nil { + podCopy.Annotations = make(map[string]string) + } + podCopy.Annotations[lastObservedRestartCountAnnotation] = "0" + if patchErr := r.client.Patch(ctx, podCopy, client.MergeFrom(pod)); patchErr != nil { + log.Error(patchErr, "Failed to reset corrupted annotation") + } + return 0 + } + return count +} + +func (r *SnapshotReconciler) checkSnapshotInterval(ctx context.Context, pod *corev1.Pod, spec *auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec, log logr.Logger) (shouldSnapshot bool, skipReconcile bool) { + lastSnapshotTimeStr, ok := pod.Annotations[lastSnapshotTimeAnnotation] + if !ok { + return true, false + } + lastSnapshotTime, err := time.Parse(time.RFC3339, lastSnapshotTimeStr) + if err != nil { + log.V(1).Info("Failed to parse last snapshot time annotation, will create snapshot", "annotationValue", lastSnapshotTimeStr, "error", err) + podCopy := pod.DeepCopy() + if podCopy.Annotations == nil { + podCopy.Annotations = make(map[string]string) + } + delete(podCopy.Annotations, lastSnapshotTimeAnnotation) + if patchErr := r.client.Patch(ctx, podCopy, client.MergeFrom(pod)); patchErr != nil { + log.Error(patchErr, "Failed to remove corrupted last snapshot time annotation") + } + return true, false + } + minInterval, err := time.ParseDuration(spec.Snapshots.MinInterval) + if err != nil { + log.Error(err, "Failed to parse minimum interval from config, will create snapshot", "minInterval", spec.Snapshots.MinInterval) + return true, false + } + if time.Since(lastSnapshotTime) >= minInterval { + return true, false + } + return false, true +} + func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { log := r.log.WithValues("pod", req.NamespacedName) - // Get the pod pod := &corev1.Pod{} if err := r.client.Get(ctx, req.NamespacedName, pod); err != nil { if apierrors.IsNotFound(err) { @@ -79,12 +146,10 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, fmt.Errorf("failed to get pod: %w", err) } - // Check if this is a kube-apiserver pod if !isKubeAPIServerPod(pod) { return ctrl.Result{}, nil } - // Check if namespace is a control plane namespace ns := &corev1.Namespace{} if err := r.client.Get(ctx, types.NamespacedName{Name: pod.Namespace}, ns); err != nil { return ctrl.Result{}, fmt.Errorf("failed to get namespace: %w", err) @@ -94,30 +159,14 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, nil } - // Get the AuditLogPersistenceConfig - config := &auditlogpersistencev1alpha1.AuditLogPersistenceConfig{} - if err := r.client.Get(ctx, types.NamespacedName{Name: "cluster"}, config); err != nil { - if apierrors.IsNotFound(err) { - return ctrl.Result{}, nil - } - return ctrl.Result{}, fmt.Errorf("failed to get AuditLogPersistenceConfig: %w", err) - } - - // Apply defaults to a copy of the spec to avoid modifying the original - spec := config.Spec.DeepCopy() - ApplyDefaults(spec) - - // Check if feature is enabled - if !IsEnabled(spec) { - return ctrl.Result{}, nil + spec, err := r.getSnapshotConfig(ctx) + if err != nil { + return ctrl.Result{}, err } - - // Check if snapshots are enabled - if !IsSnapshotsEnabled(spec) { + if spec == nil { return ctrl.Result{}, nil } - // Get the kube-apiserver container restart count var restartCount int32 for _, containerStatus := range pod.Status.ContainerStatuses { if containerStatus.Name == "kube-apiserver" { @@ -126,33 +175,12 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c } } - // Get the last observed restart count from annotation - lastObservedRestartCount := int32(0) - if val, ok := pod.Annotations[lastObservedRestartCountAnnotation]; ok { - var err error - lastObservedRestartCount, err = parseInt32(val) - if err != nil { - log.V(1).Info("Failed to parse last observed restart count annotation, resetting to 0", "annotationValue", val, "error", err) - // Reset corrupted annotation to 0 - podCopy := pod.DeepCopy() - if podCopy.Annotations == nil { - podCopy.Annotations = make(map[string]string) - } - podCopy.Annotations[lastObservedRestartCountAnnotation] = "0" - if patchErr := r.client.Patch(ctx, podCopy, client.MergeFrom(pod)); patchErr != nil { - log.Error(patchErr, "Failed to reset corrupted annotation") - // Continue anyway - the annotation will be fixed on next reconciliation - } - lastObservedRestartCount = 0 - } - } + lastObservedRestartCount := r.getLastObservedRestartCount(ctx, pod, log) - // Check if restart count increased (indicating a crash) if restartCount <= lastObservedRestartCount { return ctrl.Result{}, nil } - // Always update the last observed restart count when we see a new restart podCopy := pod.DeepCopy() if podCopy.Annotations == nil { podCopy.Annotations = make(map[string]string) @@ -160,51 +188,17 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c podCopy.Annotations[lastObservedRestartCountAnnotation] = fmt.Sprintf("%d", restartCount) if patchErr := r.client.Patch(ctx, podCopy, client.MergeFrom(pod)); patchErr != nil { log.Error(patchErr, "Failed to update last observed restart count annotation") - // Continue anyway - we'll try again on next reconciliation - } - - // Check if we should create a snapshot based on minimum interval - shouldSnapshot := false - if lastSnapshotTimeStr, ok := pod.Annotations[lastSnapshotTimeAnnotation]; ok { - lastSnapshotTime, err := time.Parse(time.RFC3339, lastSnapshotTimeStr) - if err != nil { - log.V(1).Info("Failed to parse last snapshot time annotation, will create snapshot", "annotationValue", lastSnapshotTimeStr, "error", err) - // Remove corrupted annotation - it will be set correctly after snapshot creation - podCopy := pod.DeepCopy() - if podCopy.Annotations == nil { - podCopy.Annotations = make(map[string]string) - } - delete(podCopy.Annotations, lastSnapshotTimeAnnotation) - if patchErr := r.client.Patch(ctx, podCopy, client.MergeFrom(pod)); patchErr != nil { - log.Error(patchErr, "Failed to remove corrupted last snapshot time annotation") - // Continue anyway - the annotation will be fixed on next reconciliation - } - shouldSnapshot = true - } else { - minInterval, err := time.ParseDuration(spec.Snapshots.MinInterval) - if err != nil { - log.Error(err, "Failed to parse minimum interval from config, will create snapshot", "minInterval", spec.Snapshots.MinInterval) - shouldSnapshot = true - } else { - if time.Since(lastSnapshotTime) >= minInterval { - shouldSnapshot = true - } else { - log.V(1).Info("Skipping snapshot due to minimum interval", "timeSinceLastSnapshot", time.Since(lastSnapshotTime), "minInterval", minInterval, "restartCount", restartCount) - return ctrl.Result{}, nil - } - } - } - } else { - // No previous snapshot, create one - shouldSnapshot = true } - // If we shouldn't snapshot, return early (we've already updated lastObservedRestartCount) + shouldSnapshot, skipReconcile := r.checkSnapshotInterval(ctx, pod, spec, log) + if skipReconcile { + log.V(1).Info("Skipping snapshot due to minimum interval", "restartCount", restartCount) + return ctrl.Result{}, nil + } if !shouldSnapshot { return ctrl.Result{}, nil } - // Find the PVC for this pod pvcName := pvcNamePrefix + pod.Name pvc := &corev1.PersistentVolumeClaim{} if err := r.client.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: pod.Namespace}, pvc); err != nil { @@ -215,12 +209,10 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, fmt.Errorf("failed to get PVC: %w", err) } - // Create snapshot if err := r.createSnapshot(ctx, pod, pvc, spec); err != nil { return ctrl.Result{}, fmt.Errorf("failed to create snapshot: %w", err) } - // Update pod annotation with snapshot time (lastObservedRestartCount was already updated above) podCopy = pod.DeepCopy() if podCopy.Annotations == nil { podCopy.Annotations = make(map[string]string) @@ -230,10 +222,8 @@ func (r *SnapshotReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, fmt.Errorf("failed to update pod annotation: %w", err) } - // Manage retention if err := r.manageRetention(ctx, pod, pvc, spec); err != nil { log.Error(err, "Failed to manage snapshot retention") - // Don't fail reconciliation on retention errors } log.Info("Successfully created snapshot for pod crash", "restartCount", restartCount, "previousObservedRestartCount", lastObservedRestartCount) diff --git a/hypershift-operator/controllers/auditlogpersistence/snapshot_controller_test.go b/hypershift-operator/controllers/auditlogpersistence/snapshot_controller_test.go index bdef8fa8b0f..35835b5e73a 100644 --- a/hypershift-operator/controllers/auditlogpersistence/snapshot_controller_test.go +++ b/hypershift-operator/controllers/auditlogpersistence/snapshot_controller_test.go @@ -1308,6 +1308,260 @@ func TestIsKubeAPIServerPod(t *testing.T) { } } +func TestGetLastObservedRestartCount(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expected int32 + }{ + { + name: "When annotation is missing, it should return 0", + annotations: map[string]string{}, + expected: 0, + }, + { + name: "When annotation has a valid integer, it should return the parsed value", + annotations: map[string]string{lastObservedRestartCountAnnotation: "5"}, + expected: 5, + }, + { + name: "When annotation is zero, it should return 0", + annotations: map[string]string{lastObservedRestartCountAnnotation: "0"}, + expected: 0, + }, + { + name: "When annotation is corrupted, it should reset to 0 and return 0", + annotations: map[string]string{lastObservedRestartCountAnnotation: "not-a-number"}, + expected: 0, + }, + { + name: "When annotations map is nil, it should return 0", + annotations: nil, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver-test", + Namespace: "hcp-namespace", + Annotations: tt.annotations, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(pod). + Build() + + reconciler := &SnapshotReconciler{ + client: fakeClient, + log: logr.Discard(), + } + + // Save original annotation value before calling function to avoid checking mutated map + var originalAnnotationValue string + if tt.annotations != nil { + originalAnnotationValue = tt.annotations[lastObservedRestartCountAnnotation] + } + + result := reconciler.getLastObservedRestartCount(context.Background(), pod, logr.Discard()) + g.Expect(result).To(Equal(tt.expected)) + + // When annotation was corrupted, verify it was reset on the pod + if originalAnnotationValue == "not-a-number" { + updatedPod := &corev1.Pod{} + err := fakeClient.Get(context.Background(), types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace}, updatedPod) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(updatedPod.Annotations[lastObservedRestartCountAnnotation]).To(Equal("0")) + } + }) + } +} + +func TestCheckSnapshotInterval(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + minInterval string + expectedShouldCreate bool + expectedSkip bool + }{ + { + name: "When no last snapshot time annotation exists, it should allow snapshot creation", + annotations: map[string]string{}, + minInterval: "1h", + expectedShouldCreate: true, + expectedSkip: false, + }, + { + name: "When last snapshot time is older than min interval, it should allow snapshot creation", + annotations: map[string]string{ + lastSnapshotTimeAnnotation: time.Now().Add(-2 * time.Hour).Format(time.RFC3339), + }, + minInterval: "1h", + expectedShouldCreate: true, + expectedSkip: false, + }, + { + name: "When last snapshot time is within min interval, it should skip reconciliation", + annotations: map[string]string{ + lastSnapshotTimeAnnotation: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), + }, + minInterval: "1h", + expectedShouldCreate: false, + expectedSkip: true, + }, + { + name: "When last snapshot time annotation is corrupted, it should allow snapshot creation", + annotations: map[string]string{ + lastSnapshotTimeAnnotation: "invalid-time", + }, + minInterval: "1h", + expectedShouldCreate: true, + expectedSkip: false, + }, + { + name: "When min interval is unparsable, it should allow snapshot creation", + annotations: map[string]string{ + lastSnapshotTimeAnnotation: time.Now().Format(time.RFC3339), + }, + minInterval: "not-a-duration", + expectedShouldCreate: true, + expectedSkip: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver-test", + Namespace: "hcp-namespace", + Annotations: tt.annotations, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(pod). + Build() + + reconciler := &SnapshotReconciler{ + client: fakeClient, + log: logr.Discard(), + } + + spec := &auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec{ + Snapshots: auditlogpersistencev1alpha1.SnapshotConfig{ + MinInterval: tt.minInterval, + }, + } + + shouldSnapshot, skipReconcile := reconciler.checkSnapshotInterval(context.Background(), pod, spec, logr.Discard()) + g.Expect(shouldSnapshot).To(Equal(tt.expectedShouldCreate)) + g.Expect(skipReconcile).To(Equal(tt.expectedSkip)) + }) + } +} + +func TestGetSnapshotConfig(t *testing.T) { + tests := []struct { + name string + config *auditlogpersistencev1alpha1.AuditLogPersistenceConfig + expectNil bool + expectError bool + }{ + { + name: "When config does not exist, it should return nil without error", + config: nil, + expectNil: true, + }, + { + name: "When feature is disabled, it should return nil", + config: &auditlogpersistencev1alpha1.AuditLogPersistenceConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec{ + Enabled: false, + Snapshots: auditlogpersistencev1alpha1.SnapshotConfig{ + Enabled: true, + }, + }, + }, + expectNil: true, + }, + { + name: "When snapshots are disabled, it should return nil", + config: &auditlogpersistencev1alpha1.AuditLogPersistenceConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec{ + Enabled: true, + Snapshots: auditlogpersistencev1alpha1.SnapshotConfig{ + Enabled: false, + }, + }, + }, + expectNil: true, + }, + { + name: "When both feature and snapshots are enabled, it should return spec with defaults applied", + config: &auditlogpersistencev1alpha1.AuditLogPersistenceConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: auditlogpersistencev1alpha1.AuditLogPersistenceConfigSpec{ + Enabled: true, + Snapshots: auditlogpersistencev1alpha1.SnapshotConfig{ + Enabled: true, + }, + }, + }, + expectNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + var objects []client.Object + if tt.config != nil { + objects = append(objects, tt.config) + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(objects...). + Build() + + reconciler := &SnapshotReconciler{ + client: fakeClient, + log: logr.Discard(), + } + + spec, err := reconciler.getSnapshotConfig(context.Background()) + + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + + if tt.expectNil { + g.Expect(spec).To(BeNil()) + } else { + g.Expect(spec).ToNot(BeNil()) + // Verify defaults were applied (MinInterval should have a default) + g.Expect(spec.Snapshots.MinInterval).ToNot(BeEmpty()) + } + }) + } +} + func TestSnapshotReconciler_createSnapshot(t *testing.T) { tests := []struct { name string diff --git a/hypershift-operator/controllers/etcdbackup/reconciler.go b/hypershift-operator/controllers/etcdbackup/reconciler.go index 89ed91b4600..bd665bdda2d 100644 --- a/hypershift-operator/controllers/etcdbackup/reconciler.go +++ b/hypershift-operator/controllers/etcdbackup/reconciler.go @@ -109,15 +109,100 @@ func (r *HCPEtcdBackupReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } +func (r *HCPEtcdBackupReconciler) setFailedConditionAndUpdate(ctx context.Context, backup *hyperv1.HCPEtcdBackup, reason, message string) (ctrl.Result, error) { + r.setCondition(backup, metav1.Condition{ + Type: string(hyperv1.BackupCompleted), + Status: metav1.ConditionFalse, + Reason: reason, + Message: message, + }) + if err := r.Status().Update(ctx, backup); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status: %w", err) + } + return ctrl.Result{}, nil +} + +func (r *HCPEtcdBackupReconciler) validatePrerequisites(ctx context.Context, backup *hyperv1.HCPEtcdBackup) (ctrl.Result, bool, error) { + credentialSecretName, err := r.getCredentialSecretName(backup) + if err != nil { + result, updateErr := r.setFailedConditionAndUpdate(ctx, backup, hyperv1.BackupFailedReason, err.Error()) + if updateErr != nil { + return result, true, updateErr + } + return result, true, nil + } + + credSecret := &corev1.Secret{} + if err := r.Get(ctx, types.NamespacedName{Name: credentialSecretName, Namespace: r.OperatorNamespace}, credSecret); err != nil { + if apierrors.IsNotFound(err) { + result, updateErr := r.setFailedConditionAndUpdate(ctx, backup, hyperv1.BackupFailedReason, + fmt.Sprintf("credential Secret %q not found in namespace %q", credentialSecretName, r.OperatorNamespace)) + if updateErr != nil { + return result, true, updateErr + } + return result, true, nil + } + return ctrl.Result{}, true, fmt.Errorf("failed to get credential Secret: %w", err) + } + return ctrl.Result{}, false, nil +} + +func (r *HCPEtcdBackupReconciler) createResourcesAndJob(ctx context.Context, backup *hyperv1.HCPEtcdBackup, hcp *hyperv1.HostedControlPlane) (ctrl.Result, error) { + logger := log.FromContext(ctx) + logger.Info("creating backup resources", "backup", backup.Name, "namespace", backup.Namespace) + + if err := r.ensureServiceAccount(ctx); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to ensure ServiceAccount: %w", err) + } + + if err := r.ensureRBAC(ctx, backup); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to ensure RBAC: %w", err) + } + + if err := r.ensureNetworkPolicy(ctx, backup); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to ensure NetworkPolicy: %w", err) + } + + if err := r.createBackupJob(ctx, backup, hcp); err != nil { + if apierrors.IsNotFound(err) { + // Clean up RBAC and NetworkPolicy created above before marking terminal. + if cleanupErr := r.cleanupResources(ctx, backup); cleanupErr != nil { + logger.Error(cleanupErr, "failed to cleanup resources after terminal backup failure") + } + return r.setFailedConditionAndUpdate(ctx, backup, hyperv1.BackupFailedReason, err.Error()) + } + return ctrl.Result{}, fmt.Errorf("failed to create backup Job: %w", err) + } + + r.setCondition(backup, metav1.Condition{ + Type: string(hyperv1.BackupCompleted), + Status: metav1.ConditionFalse, + Reason: hyperv1.BackupInProgressReason, + Message: "Backup Job created, waiting for completion", + }) + if err := r.Status().Update(ctx, backup); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status: %w", err) + } + + if err := r.updateHCPBackupCondition(ctx, hcp, metav1.Condition{ + Type: string(hyperv1.EtcdBackupSucceeded), + Status: metav1.ConditionFalse, + Reason: hyperv1.BackupInProgressReason, + Message: fmt.Sprintf("Backup %q is in progress", backup.Name), + }); err != nil { + logger.Error(err, "failed to update HCP backup condition") + } + + return ctrl.Result{RequeueAfter: requeueInterval}, nil +} + func (r *HCPEtcdBackupReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - // Feature gate check if !featuregate.Gate().Enabled(featuregate.HCPEtcdBackup) { return ctrl.Result{}, nil } - // Fetch the HCPEtcdBackup CR backup := &hyperv1.HCPEtcdBackup{} if err := r.Get(ctx, req.NamespacedName, backup); err != nil { if apierrors.IsNotFound(err) { @@ -139,25 +224,15 @@ func (r *HCPEtcdBackupReconciler) Reconcile(ctx context.Context, req ctrl.Reques return ctrl.Result{}, nil } - // Look up the HostedControlPlane in the same namespace hcp, err := r.getHostedControlPlane(ctx, backup.Namespace) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to look up HostedControlPlane: %w", err) } if hcp == nil { - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupFailedReason, - Message: "HostedControlPlane not found in namespace " + backup.Namespace, - }) - if err := r.Status().Update(ctx, backup); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", err) - } - return ctrl.Result{}, nil + return r.setFailedConditionAndUpdate(ctx, backup, hyperv1.BackupFailedReason, + "HostedControlPlane not found in namespace "+backup.Namespace) } - // Phase 1 health check: etcd StatefulSet readiness healthy, msg, err := r.checkEtcdHealth(ctx, backup.Namespace) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to check etcd health: %w", err) @@ -175,14 +250,11 @@ func (r *HCPEtcdBackupReconciler) Reconcile(ctx context.Context, req ctrl.Reques return ctrl.Result{RequeueAfter: requeueInterval}, nil } - // Check if we already created a Job for this backup existingJob, err := r.findJobForBackup(ctx, backup) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to find job for backup: %w", err) } - if existingJob != nil { - // Monitor existing Job status return r.handleJobStatus(ctx, backup, existingJob, hcp) } @@ -194,104 +266,17 @@ func (r *HCPEtcdBackupReconciler) Reconcile(ctx context.Context, req ctrl.Reques } if activeJob != nil { logger.Info("rejecting backup: another backup Job is already active", "activeJob", activeJob.Name) - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupRejectedReason, - Message: fmt.Sprintf("rejected: backup Job %q is already running for this HCP; delete this CR and retry after the active backup completes", activeJob.Name), - }) - if err := r.Status().Update(ctx, backup); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", err) - } - return ctrl.Result{}, nil + return r.setFailedConditionAndUpdate(ctx, backup, hyperv1.BackupRejectedReason, + fmt.Sprintf("rejected: backup Job %q is already running for this HCP; delete this CR and retry after the active backup completes", activeJob.Name)) } // Validate prerequisites before creating any resources. // Check credential Secret early so we don't create RBAC/NetworkPolicy unnecessarily. - credentialSecretName, err := r.getCredentialSecretName(backup) - if err != nil { - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupFailedReason, - Message: err.Error(), - }) - if statusErr := r.Status().Update(ctx, backup); statusErr != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", statusErr) - } - return ctrl.Result{}, nil - } - - credSecret := &corev1.Secret{} - if err := r.Get(ctx, types.NamespacedName{Name: credentialSecretName, Namespace: r.OperatorNamespace}, credSecret); err != nil { - if apierrors.IsNotFound(err) { - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupFailedReason, - Message: fmt.Sprintf("credential Secret %q not found in namespace %q", credentialSecretName, r.OperatorNamespace), - }) - if statusErr := r.Status().Update(ctx, backup); statusErr != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", statusErr) - } - return ctrl.Result{}, nil - } - return ctrl.Result{}, fmt.Errorf("failed to get credential Secret: %w", err) - } - - // Create resources and Job - logger.Info("creating backup resources", "backup", backup.Name, "namespace", backup.Namespace) - - if err := r.ensureServiceAccount(ctx); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to ensure ServiceAccount: %w", err) - } - - if err := r.ensureRBAC(ctx, backup); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to ensure RBAC: %w", err) + if result, done, err := r.validatePrerequisites(ctx, backup); done || err != nil { + return result, err } - if err := r.ensureNetworkPolicy(ctx, backup); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to ensure NetworkPolicy: %w", err) - } - - if err := r.createBackupJob(ctx, backup, hcp); err != nil { - if apierrors.IsNotFound(err) { - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupFailedReason, - Message: err.Error(), - }) - if statusErr := r.Status().Update(ctx, backup); statusErr != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", statusErr) - } - return ctrl.Result{}, nil - } - return ctrl.Result{}, fmt.Errorf("failed to create backup Job: %w", err) - } - - // Set status to indicate backup is in progress - r.setCondition(backup, metav1.Condition{ - Type: string(hyperv1.BackupCompleted), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupInProgressReason, - Message: "Backup Job created, waiting for completion", - }) - if err := r.Status().Update(ctx, backup); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status: %w", err) - } - - // Bubble up to HCP - if err := r.updateHCPBackupCondition(ctx, hcp, metav1.Condition{ - Type: string(hyperv1.EtcdBackupSucceeded), - Status: metav1.ConditionFalse, - Reason: hyperv1.BackupInProgressReason, - Message: fmt.Sprintf("Backup %q is in progress", backup.Name), - }); err != nil { - logger.Error(err, "failed to update HCP backup condition") - } - - return ctrl.Result{RequeueAfter: requeueInterval}, nil + return r.createResourcesAndJob(ctx, backup, hcp) } // isTerminal returns true if the backup is in a terminal state. diff --git a/hypershift-operator/controllers/etcdbackup/reconciler_test.go b/hypershift-operator/controllers/etcdbackup/reconciler_test.go index 19003bc949a..d585b24c524 100644 --- a/hypershift-operator/controllers/etcdbackup/reconciler_test.go +++ b/hypershift-operator/controllers/etcdbackup/reconciler_test.go @@ -1498,3 +1498,291 @@ func TestUpdateHostedClusterBackupURL(t *testing.T) { g.Expect(updatedHC.Status.LastSuccessfulEtcdBackupURL).To(Equal("s3://bucket/snapshot.db")) }) } + +func TestGetCredentialSecretName(t *testing.T) { + tests := []struct { + name string + backup *hyperv1.HCPEtcdBackup + expected string + expectError bool + }{ + { + name: "When storage type is S3, it should return S3 credential secret name", + backup: &hyperv1.HCPEtcdBackup{ + Spec: hyperv1.HCPEtcdBackupSpec{ + Storage: hyperv1.HCPEtcdBackupStorage{ + StorageType: hyperv1.S3BackupStorage, + S3: hyperv1.HCPEtcdBackupS3{ + Credentials: hyperv1.SecretReference{Name: "s3-creds"}, + }, + }, + }, + }, + expected: "s3-creds", + }, + { + name: "When storage type is AzureBlob, it should return Azure credential secret name", + backup: &hyperv1.HCPEtcdBackup{ + Spec: hyperv1.HCPEtcdBackupSpec{ + Storage: hyperv1.HCPEtcdBackupStorage{ + StorageType: hyperv1.AzureBlobBackupStorage, + AzureBlob: hyperv1.HCPEtcdBackupAzureBlob{ + Credentials: hyperv1.SecretReference{Name: "azure-creds"}, + }, + }, + }, + }, + expected: "azure-creds", + }, + { + name: "When storage type is unsupported, it should return an error", + backup: &hyperv1.HCPEtcdBackup{ + Spec: hyperv1.HCPEtcdBackupSpec{ + Storage: hyperv1.HCPEtcdBackupStorage{ + StorageType: "UnsupportedType", + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + r := newReconciler() + + name, err := r.getCredentialSecretName(tt.backup) + + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("unsupported")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(name).To(Equal(tt.expected)) + } + }) + } +} + +func TestGetHostedControlPlane(t *testing.T) { + tests := []struct { + name string + objects []client.Object + namespace string + expectNil bool + expectErr bool + }{ + { + name: "When no HCPs exist in the namespace, it should return nil", + objects: []client.Object{}, + namespace: testHCPNamespace, + expectNil: true, + }, + { + name: "When one HCP exists in the namespace, it should return it", + objects: []client.Object{ + newHostedControlPlane(), + }, + namespace: testHCPNamespace, + expectNil: false, + }, + { + name: "When HCP exists in a different namespace, it should return nil", + objects: []client.Object{ + newHostedControlPlane(), + }, + namespace: "other-namespace", + expectNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + r := newReconciler(tt.objects...) + + result, err := r.getHostedControlPlane(t.Context(), tt.namespace) + + if tt.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + if tt.expectNil { + g.Expect(result).To(BeNil()) + } else { + g.Expect(result).ToNot(BeNil()) + } + } + }) + } +} + +func TestValidatePrerequisites(t *testing.T) { + tests := []struct { + name string + backup *hyperv1.HCPEtcdBackup + objects []client.Object + expectDone bool + expectError bool + expectFail bool + }{ + { + name: "When credential secret exists, it should return done=false (proceed)", + backup: newHCPEtcdBackup(), + objects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "aws-creds", + Namespace: testHONamespace, + }, + }, + }, + expectDone: false, + }, + { + name: "When credential secret is missing, it should set BackupFailed and return done=true", + backup: newHCPEtcdBackup(), + objects: []client.Object{}, + expectDone: true, + expectFail: true, + }, + { + name: "When storage type is unsupported, it should set BackupFailed and return done=true", + backup: &hyperv1.HCPEtcdBackup{ + ObjectMeta: metav1.ObjectMeta{ + Name: testBackupName, + Namespace: testHCPNamespace, + }, + Spec: hyperv1.HCPEtcdBackupSpec{ + Storage: hyperv1.HCPEtcdBackupStorage{ + StorageType: "UnsupportedType", + }, + }, + }, + objects: []client.Object{}, + expectDone: true, + expectFail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + allObjects := append([]client.Object{tt.backup}, tt.objects...) + r := newReconciler(allObjects...) + + _, done, err := r.validatePrerequisites(t.Context(), tt.backup) + + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + g.Expect(done).To(Equal(tt.expectDone)) + + if tt.expectFail { + updated := &hyperv1.HCPEtcdBackup{} + g.Expect(r.Get(t.Context(), types.NamespacedName{Name: tt.backup.Name, Namespace: tt.backup.Namespace}, updated)).To(Succeed()) + cond := meta.FindStatusCondition(updated.Status.Conditions, string(hyperv1.BackupCompleted)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Reason).To(Equal(hyperv1.BackupFailedReason)) + } + }) + } +} + +func TestGetSnapshotURLFromPod(t *testing.T) { + tests := []struct { + name string + pods []client.Object + jobName string + expected string + }{ + { + name: "When upload container has termination message, it should return the URL", + pods: []client.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "backup-pod", + Namespace: testHONamespace, + Labels: map[string]string{"batch.kubernetes.io/job-name": "my-job"}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "upload", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "upload", + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + Message: " s3://bucket/path/snapshot.db ", + }, + }, + }, + }, + }, + }, + }, + jobName: "my-job", + expected: "s3://bucket/path/snapshot.db", + }, + { + name: "When no pods exist for the job, it should return empty string", + pods: []client.Object{}, + jobName: "my-job", + expected: "", + }, + { + name: "When upload container has no termination message, it should return empty string", + pods: []client.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "backup-pod", + Namespace: testHONamespace, + Labels: map[string]string{"batch.kubernetes.io/job-name": "my-job"}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "upload", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "upload", + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: 0, + }, + }, + }, + }, + }, + }, + }, + jobName: "my-job", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + r := newReconciler(tt.pods...) + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: tt.jobName, + Namespace: testHONamespace, + }, + } + + url, err := r.getSnapshotURLFromPod(t.Context(), job) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(url).To(Equal(tt.expected)) + }) + } +} diff --git a/hypershift-operator/controllers/hostedcluster/etcd_recovery.go b/hypershift-operator/controllers/hostedcluster/etcd_recovery.go index e7c25b134ff..d4196de80b6 100644 --- a/hypershift-operator/controllers/hostedcluster/etcd_recovery.go +++ b/hypershift-operator/controllers/hostedcluster/etcd_recovery.go @@ -25,6 +25,8 @@ import ( ctrl "sigs.k8s.io/controller-runtime" crclient "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/go-logr/logr" ) type etcdJobStatus struct { @@ -37,64 +39,76 @@ func (r *HostedClusterReconciler) reconcileETCDMemberRecovery(ctx context.Contex log := ctrl.LoggerFrom(ctx) hcpNS := manifests.HostedControlPlaneNamespace(hcluster.Namespace, hcluster.Name) - // Check the recovery job recoveryJob := etcdrecoverymanifests.EtcdRecoveryJob(hcpNS) jobStatus, err := r.etcdRecoveryJobStatus(ctx, recoveryJob) if err != nil { return nil, err } - etcdRecoveryActiveCondition := metav1.Condition{ - Type: string(hyperv1.EtcdRecoveryActive), - ObservedGeneration: hcluster.Generation, - } - if jobStatus.exists { - if !jobStatus.finished { - log.Info("waiting for etcd recovery job to complete") + done, err := r.handleExistingEtcdRecoveryJob(ctx, log, hcluster, recoveryJob, jobStatus) + if err != nil { + return nil, err + } + if done { return nil, nil } + } - if !jobStatus.successful { - etcdRecoveryActiveCondition.Status = metav1.ConditionFalse - etcdRecoveryActiveCondition.Reason = hyperv1.EtcdRecoveryJobFailedReason - etcdRecoveryActiveCondition.Message = "Error in Etcd Recovery job: the Etcd cluster requires manual intervention." - etcdRecoveryActiveCondition.LastTransitionTime = r.now() - - oldCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + return r.detectAndTriggerEtcdRecovery(ctx, log, hcluster, hcpNS, recoveryJob, createOrUpdate) +} - if oldCondition == nil || oldCondition.Status != etcdRecoveryActiveCondition.Status { - meta.SetStatusCondition(&hcluster.Status.Conditions, etcdRecoveryActiveCondition) - if err := r.Client.Status().Update(ctx, hcluster); err != nil { - return nil, fmt.Errorf("failed to update etcd recovery job condition: %w", err) - } - } +// handleExistingEtcdRecoveryJob processes an existing recovery job. +// It returns (done, err) where done=true means the caller should return immediately +// without falling through to detectAndTriggerEtcdRecovery. +func (r *HostedClusterReconciler) handleExistingEtcdRecoveryJob(ctx context.Context, log logr.Logger, hcluster *hyperv1.HostedCluster, recoveryJob *batchv1.Job, jobStatus *etcdJobStatus) (bool, error) { + if !jobStatus.finished { + log.Info("waiting for etcd recovery job to complete") + return true, nil + } - // There is no benefit in requeuing, since the cluster needs manual intervention - log.Error(errors.New("etcd recovery failed"), "failed recovery job exists", "job", crclient.ObjectKeyFromObject(recoveryJob).String()) - return nil, nil + if !jobStatus.successful { + if err := r.setEtcdRecoveryCondition(ctx, hcluster, metav1.ConditionFalse, hyperv1.EtcdRecoveryJobFailedReason, "Error in Etcd Recovery job: the Etcd cluster requires manual intervention."); err != nil { + return false, err } + // There is no benefit in requeuing, since the cluster needs manual intervention + log.Error(errors.New("etcd recovery failed"), "failed recovery job exists", "job", crclient.ObjectKeyFromObject(recoveryJob).String()) + return true, nil + } - // Cleanup ETCD Recovery objects - if err := r.cleanupEtcdRecoveryObjects(ctx, hcluster); err != nil { - return nil, fmt.Errorf("failed to cleanup etcd recovery job: %w", err) - } + if err := r.cleanupEtcdRecoveryObjects(ctx, hcluster); err != nil { + return false, fmt.Errorf("failed to cleanup etcd recovery job: %w", err) + } - etcdRecoveryActiveCondition.Status = metav1.ConditionFalse - etcdRecoveryActiveCondition.Reason = hyperv1.AsExpectedReason - etcdRecoveryActiveCondition.Message = "ETCD Recovery job succeeded." - etcdRecoveryActiveCondition.LastTransitionTime = r.now() + if err := r.setEtcdRecoveryCondition(ctx, hcluster, metav1.ConditionFalse, hyperv1.AsExpectedReason, "ETCD Recovery job succeeded."); err != nil { + return false, err + } - oldCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + // After successful cleanup, fall through to detectAndTriggerEtcdRecovery + return false, nil +} - if oldCondition == nil || oldCondition.Status != etcdRecoveryActiveCondition.Status { - meta.SetStatusCondition(&hcluster.Status.Conditions, etcdRecoveryActiveCondition) - if err := r.Client.Status().Update(ctx, hcluster); err != nil { - return nil, fmt.Errorf("failed to update etcd recovery job condition: %w", err) - } +func (r *HostedClusterReconciler) setEtcdRecoveryCondition(ctx context.Context, hcluster *hyperv1.HostedCluster, status metav1.ConditionStatus, reason, message string) error { + condition := metav1.Condition{ + Type: string(hyperv1.EtcdRecoveryActive), + ObservedGeneration: hcluster.Generation, + Status: status, + Reason: reason, + Message: message, + LastTransitionTime: r.now(), + } + + oldCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + if oldCondition == nil || oldCondition.Status != condition.Status { + meta.SetStatusCondition(&hcluster.Status.Conditions, condition) + if err := r.Client.Status().Update(ctx, hcluster); err != nil { + return fmt.Errorf("failed to update etcd recovery job condition: %w", err) } } + return nil +} +func (r *HostedClusterReconciler) detectAndTriggerEtcdRecovery(ctx context.Context, log logr.Logger, hcluster *hyperv1.HostedCluster, hcpNS string, recoveryJob *batchv1.Job, createOrUpdate upsert.CreateOrUpdateFN) (*time.Duration, error) { etcdStatefulSet := etcdrecoverymanifests.EtcdStatefulSet(hcpNS) if err := r.Get(ctx, crclient.ObjectKeyFromObject(etcdStatefulSet), etcdStatefulSet); err != nil { if apierrors.IsNotFound(err) { @@ -108,6 +122,32 @@ func (r *HostedClusterReconciler) reconcileETCDMemberRecovery(ctx context.Contex } requeueAfter := etcdCheckRequeueInterval + failingEtcdPod, err := r.findFailingEtcdPod(ctx, log, hcpNS) + if err != nil { + return nil, err + } + + if failingEtcdPod == nil { + if !fullyAvailable { + return &requeueAfter, nil + } + return nil, nil + } + + log.Info("there are symptoms of etcd cluster degradation, triggering recovery job") + + if err := r.createEtcdRecoveryResources(ctx, hcluster, hcpNS, recoveryJob, createOrUpdate); err != nil { + return nil, err + } + + if err := r.setEtcdRecoveryCondition(ctx, hcluster, metav1.ConditionTrue, hyperv1.AsExpectedReason, "ETCD Recovery job in progress."); err != nil { + return nil, err + } + + return nil, nil +} + +func (r *HostedClusterReconciler) findFailingEtcdPod(ctx context.Context, log logr.Logger, hcpNS string) (*corev1.Pod, error) { etcdPodList := &corev1.PodList{} if err := r.List(ctx, etcdPodList, crclient.InNamespace(hcpNS), crclient.MatchingLabels{ "app": "etcd", @@ -116,38 +156,27 @@ func (r *HostedClusterReconciler) reconcileETCDMemberRecovery(ctx context.Contex } if len(etcdPodList.Items) < 3 { - // Cannot initiate recovery without all etcd pods, let's requeue - return &requeueAfter, nil + return nil, nil } - var failingEtcdPod *corev1.Pod for _, pod := range etcdPodList.Items { for _, containerStatus := range pod.Status.ContainerStatuses { if containerStatus.State.Waiting != nil && containerStatus.RestartCount > 0 && containerStatus.Name == "etcd" { - failingEtcdPod = &pod log.Info("detected etcd failing pod", "name", pod.Name, "namespace", pod.Namespace) - break + return &pod, nil } } } + return nil, nil +} - if failingEtcdPod == nil { - // No failing etcd pods detected - // However, if the statefulset is not reporting fully available, check later - if !fullyAvailable { - return &requeueAfter, nil - } - return nil, nil - } - - log.Info("there are symptoms of etcd cluster degradation, triggering recovery job") - +func (r *HostedClusterReconciler) createEtcdRecoveryResources(ctx context.Context, hcluster *hyperv1.HostedCluster, hcpNS string, recoveryJob *batchv1.Job, createOrUpdate upsert.CreateOrUpdateFN) error { recoveryRole := etcdrecoverymanifests.EtcdRecoveryRole(hcpNS) if _, err := createOrUpdate(ctx, r.Client, recoveryRole, func() error { r.reconcileEtcdRecoveryRole(recoveryRole) return nil }); err != nil { - return nil, fmt.Errorf("failed to reconcile etcd recovery role: %w", err) + return fmt.Errorf("failed to reconcile etcd recovery role: %w", err) } recoverySA := etcdrecoverymanifests.EtcdRecoveryServiceAccount(hcpNS) @@ -155,7 +184,7 @@ func (r *HostedClusterReconciler) reconcileETCDMemberRecovery(ctx context.Contex k8sutil.EnsurePullSecret(recoverySA, common.PullSecret("").Name) return nil }); err != nil { - return nil, fmt.Errorf("failed to reconcile etcd-recovery job service account: %w", err) + return fmt.Errorf("failed to reconcile etcd-recovery job service account: %w", err) } recoveryRoleBinding := etcdrecoverymanifests.EtcdRecoveryRoleBinding(hcpNS) @@ -163,32 +192,16 @@ func (r *HostedClusterReconciler) reconcileETCDMemberRecovery(ctx context.Contex r.reconcileEtcdRecoveryRoleBinding(recoveryRoleBinding, recoveryRole, recoverySA) return nil }); err != nil { - return nil, fmt.Errorf("failed to reconcile etcd recovery role binding: %w", err) + return fmt.Errorf("failed to reconcile etcd recovery role binding: %w", err) } if _, err := createOrUpdate(ctx, r.Client, recoveryJob, func() error { return r.reconcileEtcdRecoveryJob(recoveryJob, hcluster) }); err != nil { - return nil, fmt.Errorf("failed to reconcile etcd recovery job: %w", err) + return fmt.Errorf("failed to reconcile etcd recovery job: %w", err) } - // Creating the condition for the first time or in the case of the ETCD fails intermitently - etcdRecoveryActiveCondition.Status = metav1.ConditionTrue - etcdRecoveryActiveCondition.Reason = hyperv1.AsExpectedReason - etcdRecoveryActiveCondition.Message = "ETCD Recovery job in progress." - etcdRecoveryActiveCondition.LastTransitionTime = r.now() - - // If the ETCD keeps failing and recovering, we can see the hcluster.Generation increasing indefinitely. - oldCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) - - if oldCondition == nil || oldCondition.Status != etcdRecoveryActiveCondition.Status { - meta.SetStatusCondition(&hcluster.Status.Conditions, etcdRecoveryActiveCondition) - if err := r.Client.Status().Update(ctx, hcluster); err != nil { - return nil, fmt.Errorf("failed to update etcd recovery job condition: %w", err) - } - } - - return nil, nil + return nil } // etcdRecoveryJobStatus checks the status of the ETCD recovery job and returns a status diff --git a/hypershift-operator/controllers/hostedcluster/etcd_recovery_test.go b/hypershift-operator/controllers/hostedcluster/etcd_recovery_test.go new file mode 100644 index 00000000000..39066f480e2 --- /dev/null +++ b/hypershift-operator/controllers/hostedcluster/etcd_recovery_test.go @@ -0,0 +1,563 @@ +package hostedcluster + +import ( + "testing" + + . "github.com/onsi/gomega" + + hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + "github.com/openshift/hypershift/support/api" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/log/zap" +) + +func TestEtcdRecoveryJobStatus(t *testing.T) { + tests := []struct { + name string + job *batchv1.Job + jobExists bool + expectedExists bool + expectedFinished bool + expectedSuccessful bool + expectError bool + }{ + { + name: "When job does not exist, it should return exists=false", + jobExists: false, + expectedExists: false, + }, + { + name: "When job exists with no conditions, it should return exists=true, finished=false", + jobExists: true, + job: &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test", + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "etcd-recovery", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + }, + }, + Status: batchv1.JobStatus{Active: 1}, + }, + expectedExists: true, + expectedFinished: false, + }, + { + name: "When job completed successfully, it should return finished=true, successful=true", + jobExists: true, + job: &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test", + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "etcd-recovery", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + }, + }, + Status: batchv1.JobStatus{ + Conditions: []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + expectedExists: true, + expectedFinished: true, + expectedSuccessful: true, + }, + { + name: "When job failed, it should return finished=true, successful=false", + jobExists: true, + job: &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test", + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "etcd-recovery", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + }, + }, + Status: batchv1.JobStatus{ + Conditions: []batchv1.JobCondition{ + { + Type: batchv1.JobFailed, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + expectedExists: true, + expectedFinished: true, + expectedSuccessful: false, + }, + { + name: "When job has both Complete and Failed conditions with False status, it should return finished=false", + jobExists: true, + job: &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test", + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "etcd-recovery", Image: "test:latest"}}, + RestartPolicy: corev1.RestartPolicyNever, + }, + }, + }, + Status: batchv1.JobStatus{ + Conditions: []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionFalse, + }, + { + Type: batchv1.JobFailed, + Status: corev1.ConditionFalse, + }, + }, + }, + }, + expectedExists: true, + expectedFinished: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + var objects []crclient.Object + if tt.jobExists && tt.job != nil { + objects = append(objects, tt.job) + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(objects...). + Build() + + r := &HostedClusterReconciler{ + Client: fakeClient, + } + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test", + }, + } + + status, err := r.etcdRecoveryJobStatus(t.Context(), job) + + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(status).ToNot(BeNil()) + g.Expect(status.exists).To(Equal(tt.expectedExists)) + g.Expect(status.finished).To(Equal(tt.expectedFinished)) + g.Expect(status.successful).To(Equal(tt.expectedSuccessful)) + } + }) + } +} + +func TestFindFailingEtcdPod(t *testing.T) { + tests := []struct { + name string + pods []crclient.Object + namespace string + expectFound bool + expectedName string + }{ + { + name: "When fewer than 3 etcd pods exist, it should return nil", + namespace: "clusters-test", + pods: []crclient.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-1", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + }, + }, + expectFound: false, + }, + { + name: "When 3 healthy etcd pods exist, it should return nil", + namespace: "clusters-test", + pods: []crclient.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-1", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-2", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + }, + expectFound: false, + }, + { + name: "When one etcd pod is in Waiting state with restarts, it should return that pod", + namespace: "clusters-test", + pods: []crclient.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-1", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "etcd", + State: corev1.ContainerState{Waiting: &corev1.ContainerStateWaiting{Reason: "CrashLoopBackOff"}}, + RestartCount: 5, + }, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-2", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + }, + expectFound: true, + expectedName: "etcd-1", + }, + { + name: "When etcd pod has waiting state but zero restarts, it should not detect it as failing", + namespace: "clusters-test", + pods: []crclient.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-1", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: "etcd", + State: corev1.ContainerState{Waiting: &corev1.ContainerStateWaiting{Reason: "ContainerCreating"}}, + RestartCount: 0, + }, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-2", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + }, + }, + }, + }, + expectFound: false, + }, + { + name: "When a non-etcd container is failing, it should not detect the pod as failing", + namespace: "clusters-test", + pods: []crclient.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + {Name: "etcd", State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}}, RestartCount: 0}, + { + Name: "sidecar", + State: corev1.ContainerState{Waiting: &corev1.ContainerStateWaiting{Reason: "CrashLoopBackOff"}}, + RestartCount: 5, + }, + }, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-1", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-2", Namespace: "clusters-test", + Labels: map[string]string{"app": "etcd"}, + }, + }, + }, + expectFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + fakeClient := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tt.pods...). + Build() + + r := &HostedClusterReconciler{ + Client: fakeClient, + } + + log := zap.New(zap.UseDevMode(true)) + pod, err := r.findFailingEtcdPod(t.Context(), log, tt.namespace) + + g.Expect(err).ToNot(HaveOccurred()) + if tt.expectFound { + g.Expect(pod).ToNot(BeNil()) + g.Expect(pod.Name).To(Equal(tt.expectedName)) + } else { + g.Expect(pod).To(BeNil()) + } + }) + } +} + +func TestHandleExistingEtcdRecoveryJob(t *testing.T) { + tests := []struct { + name string + jobStatus *etcdJobStatus + expectDone bool + expectCondition bool + expectedReason string + expectedStatus metav1.ConditionStatus + }{ + { + name: "When job is not finished, it should return done=true and not set any condition", + jobStatus: &etcdJobStatus{ + exists: true, + finished: false, + }, + expectDone: true, + expectCondition: false, + }, + { + name: "When job failed, it should return done=true and set EtcdRecoveryActive with failure reason", + jobStatus: &etcdJobStatus{ + exists: true, + finished: true, + successful: false, + }, + expectDone: true, + expectCondition: true, + expectedReason: hyperv1.EtcdRecoveryJobFailedReason, + expectedStatus: metav1.ConditionFalse, + }, + { + name: "When job succeeded, it should return done=false and set EtcdRecoveryActive with success reason", + jobStatus: &etcdJobStatus{ + exists: true, + finished: true, + successful: true, + }, + expectDone: false, + expectCondition: true, + expectedReason: hyperv1.AsExpectedReason, + expectedStatus: metav1.ConditionFalse, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + hcluster := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "clusters", + }, + } + + recoveryJob := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-recovery", + Namespace: "clusters-test-cluster", + }, + } + + var objects []crclient.Object + objects = append(objects, hcluster) + if tt.jobStatus.finished && tt.jobStatus.successful { + // For cleanup, the job must exist in the client + objects = append(objects, recoveryJob) + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(objects...). + WithStatusSubresource(&hyperv1.HostedCluster{}). + Build() + + r := &HostedClusterReconciler{ + Client: fakeClient, + now: metav1.Now, + } + + log := zap.New(zap.UseDevMode(true)) + done, err := r.handleExistingEtcdRecoveryJob(t.Context(), log, hcluster, recoveryJob, tt.jobStatus) + + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(done).To(Equal(tt.expectDone)) + + if tt.expectCondition { + updatedHC := &hyperv1.HostedCluster{} + g.Expect(fakeClient.Get(t.Context(), crclient.ObjectKeyFromObject(hcluster), updatedHC)).To(Succeed()) + cond := findCondition(updatedHC.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Reason).To(Equal(tt.expectedReason)) + g.Expect(cond.Status).To(Equal(tt.expectedStatus)) + } else { + updatedHC := &hyperv1.HostedCluster{} + g.Expect(fakeClient.Get(t.Context(), crclient.ObjectKeyFromObject(hcluster), updatedHC)).To(Succeed()) + cond := findCondition(updatedHC.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + g.Expect(cond).To(BeNil(), "EtcdRecoveryActive condition should not be set") + } + }) + } +} + +func TestReconcileEtcdRecoveryRole(t *testing.T) { + t.Run("When called, it should set expected RBAC rules", func(t *testing.T) { + g := NewWithT(t) + + r := &HostedClusterReconciler{} + role := &rbacv1.Role{} + r.reconcileEtcdRecoveryRole(role) + + g.Expect(role.Rules).To(HaveLen(2)) + + // First rule: pods and PVCs + g.Expect(role.Rules[0].APIGroups).To(ConsistOf("")) + g.Expect(role.Rules[0].Resources).To(ConsistOf("pods", "persistentvolumeclaims")) + g.Expect(role.Rules[0].Verbs).To(ConsistOf("get", "list", "delete")) + + // Second rule: statefulsets + g.Expect(role.Rules[1].APIGroups).To(ConsistOf("apps")) + g.Expect(role.Rules[1].Resources).To(ConsistOf("statefulsets")) + g.Expect(role.Rules[1].Verbs).To(ConsistOf("get", "list")) + }) +} + +func TestReconcileEtcdRecoveryRoleBinding(t *testing.T) { + t.Run("When called, it should bind the role to the service account", func(t *testing.T) { + g := NewWithT(t) + + r := &HostedClusterReconciler{} + roleBinding := &rbacv1.RoleBinding{} + role := &rbacv1.Role{ObjectMeta: metav1.ObjectMeta{Name: "etcd-recovery"}} + sa := &corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: "etcd-recovery-sa", Namespace: "clusters-test"}} + + r.reconcileEtcdRecoveryRoleBinding(roleBinding, role, sa) + + g.Expect(roleBinding.RoleRef.Kind).To(Equal("Role")) + g.Expect(roleBinding.RoleRef.Name).To(Equal("etcd-recovery")) + g.Expect(roleBinding.RoleRef.APIGroup).To(Equal(rbacv1.GroupName)) + g.Expect(roleBinding.Subjects).To(HaveLen(1)) + g.Expect(roleBinding.Subjects[0].Kind).To(Equal("ServiceAccount")) + g.Expect(roleBinding.Subjects[0].Name).To(Equal("etcd-recovery-sa")) + g.Expect(roleBinding.Subjects[0].Namespace).To(Equal("clusters-test")) + }) +} + +// findCondition is a test helper that looks up a condition by type. +func findCondition(conditions []metav1.Condition, condType string) *metav1.Condition { + for i := range conditions { + if conditions[i].Type == condType { + return &conditions[i] + } + } + return nil +} diff --git a/hypershift-operator/controllers/hostedcluster/hostedcluster_controller.go b/hypershift-operator/controllers/hostedcluster/hostedcluster_controller.go index 022b442d182..a41acb79a83 100644 --- a/hypershift-operator/controllers/hostedcluster/hostedcluster_controller.go +++ b/hypershift-operator/controllers/hostedcluster/hostedcluster_controller.go @@ -383,6 +383,7 @@ func (r *HostedClusterReconciler) Reconcile(ctx context.Context, req ctrl.Reques return res, err } +//nolint:gocyclo func (r *HostedClusterReconciler) reconcile(ctx context.Context, req ctrl.Request, log logr.Logger, hcluster *hyperv1.HostedCluster) (ctrl.Result, error) { controlPlaneNamespace := manifests.HostedControlPlaneNamespaceObject(hcluster.Namespace, hcluster.Name) hcp := controlplaneoperator.HostedControlPlane(controlPlaneNamespace.Name, hcluster.Name) @@ -3334,6 +3335,7 @@ func deleteControlPlaneOperatorRBAC(ctx context.Context, c client.Client, rbacNa return nil } +//nolint:gocyclo func (r *HostedClusterReconciler) delete(ctx context.Context, hc *hyperv1.HostedCluster) (bool, error) { controlPlaneNamespace := manifests.HostedControlPlaneNamespace(hc.Namespace, hc.Name) log := ctrl.LoggerFrom(ctx) diff --git a/hypershift-operator/controllers/hostedcluster/internal/platform/platform.go b/hypershift-operator/controllers/hostedcluster/internal/platform/platform.go index c58ff77be3d..2ca701d7478 100644 --- a/hypershift-operator/controllers/hostedcluster/internal/platform/platform.go +++ b/hypershift-operator/controllers/hostedcluster/internal/platform/platform.go @@ -88,6 +88,8 @@ type OrphanDeleter interface { } // GetPlatform gets and initializes the cloud platform the hosted cluster was created on +// +//nolint:gocyclo func GetPlatform(ctx context.Context, hcluster *hyperv1.HostedCluster, releaseProvider releaseinfo.Provider, utilitiesImage string, pullSecretBytes []byte) (Platform, error) { var ( platform Platform diff --git a/hypershift-operator/controllers/hostedcluster/metrics/metrics.go b/hypershift-operator/controllers/hostedcluster/metrics/metrics.go index e68ab048950..8c26a7bb0f4 100644 --- a/hypershift-operator/controllers/hostedcluster/metrics/metrics.go +++ b/hypershift-operator/controllers/hostedcluster/metrics/metrics.go @@ -260,25 +260,54 @@ func (c *hostedClustersMetricsCollector) Collect(ch chan<- prometheus.Metric) { currentCollectTime := c.clock.Now() log := ctrllog.Log - // countByIdentityProviderMetric - init - identityProviderToHClustersCount := make(map[configv1.IdentityProviderType]int) + identityProviderToHClustersCount := initIdentityProviderCounts() + platformToHClustersCount := initPlatformCounts() + platformToFailureConditionToHClustersCount := initPlatformFailureConditionCounts() + + hclusters := &hyperv1.HostedClusterList{} + if err := c.List(context.Background(), hclusters); err != nil { + log.Error(err, "failed to list hosted clusters while collecting metrics") + return + } - for k := range knownIdentityProviders { - identityProviderToHClustersCount[knownIdentityProviders[k]] = 0 + for k := range hclusters.Items { + hcluster := &hclusters.Items[k] + + collectIdentityProviderCounts(hcluster, identityProviderToHClustersCount) + platform := hcluster.Spec.Platform.Type + platformToHClustersCount[platform] = platformToHClustersCount[platform] + 1 + collectFailureConditionCounts(hcluster, platform, platformToFailureConditionToHClustersCount) + c.collectTransitionDurationMetrics(hcluster, currentCollectTime) + + hclusterLabelValues := []string{hcluster.Namespace, hcluster.Name, hcluster.Spec.ClusterID} + c.collectPerClusterMetrics(ch, hcluster, hclusterLabelValues) } - // countByPlatformMetric - init - platformToHClustersCount := make(map[hyperv1.PlatformType]int) + emitAggregatedMetrics(ch, identityProviderToHClustersCount, platformToHClustersCount, platformToFailureConditionToHClustersCount) + c.transitionDurationMetric.Collect(ch) + c.lastCollectTime = currentCollectTime +} - for k := range knownPlatforms { - platformToHClustersCount[knownPlatforms[k]] = 0 +func initIdentityProviderCounts() map[configv1.IdentityProviderType]int { + counts := make(map[configv1.IdentityProviderType]int) + for k := range knownIdentityProviders { + counts[knownIdentityProviders[k]] = 0 } + return counts +} - // countByPlatformAndFailureConditionMetric - init - platformToFailureConditionToHClustersCount := make(map[hyperv1.PlatformType]*map[string]int) +func initPlatformCounts() map[hyperv1.PlatformType]int { + counts := make(map[hyperv1.PlatformType]int) + for k := range knownPlatforms { + counts[knownPlatforms[k]] = 0 + } + return counts +} +func initPlatformFailureConditionCounts() map[hyperv1.PlatformType]*map[string]int { + counts := make(map[hyperv1.PlatformType]*map[string]int) for k := range knownPlatforms { - platformToFailureConditionToHClustersCount[knownPlatforms[k]] = createFailureConditionToHClustersCountMap(conditions.ExpectedHCConditions(&hyperv1.HostedCluster{ + counts[knownPlatforms[k]] = createFailureConditionToHClustersCountMap(conditions.ExpectedHCConditions(&hyperv1.HostedCluster{ Spec: hyperv1.HostedClusterSpec{ Platform: hyperv1.PlatformSpec{ Type: knownPlatforms[k], @@ -286,319 +315,284 @@ func (c *hostedClustersMetricsCollector) Collect(ch chan<- prometheus.Metric) { }, })) } + return counts +} - // MAIN LOOP - Hosted clusters loop - { - hclusters := &hyperv1.HostedClusterList{} - - if err := c.List(context.Background(), hclusters); err != nil { - log.Error(err, "failed to list hosted clusters while collecting metrics") +func collectIdentityProviderCounts(hcluster *hyperv1.HostedCluster, counts map[configv1.IdentityProviderType]int) { + if hcluster.Spec.Configuration != nil && hcluster.Spec.Configuration.OAuth != nil { + for _, identityProvider := range hcluster.Spec.Configuration.OAuth.IdentityProviders { + counts[identityProvider.Type] = counts[identityProvider.Type] + 1 } + } +} - for k := range hclusters.Items { - hcluster := &hclusters.Items[k] +func collectFailureConditionCounts(hcluster *hyperv1.HostedCluster, platform hyperv1.PlatformType, platformToFailureConditionToHClustersCount map[hyperv1.PlatformType]*map[string]int) { + expectedConditions := conditions.ExpectedHCConditions(hcluster) + _, isKnownPlatform := platformToFailureConditionToHClustersCount[platform] + if !isKnownPlatform { + platformToFailureConditionToHClustersCount[platform] = createFailureConditionToHClustersCountMap(expectedConditions) + } - // countByIdentityProviderMetric - aggregation - if hcluster.Spec.Configuration != nil && hcluster.Spec.Configuration.OAuth != nil { - for _, identityProvider := range hcluster.Spec.Configuration.OAuth.IdentityProviders { - identityProviderToHClustersCount[identityProvider.Type] = identityProviderToHClustersCount[identityProvider.Type] + 1 - } + failureConditionToHClustersCount := platformToFailureConditionToHClustersCount[platform] + for _, condition := range hcluster.Status.Conditions { + expectedStatus, isKnownCondition := expectedConditions[hyperv1.ConditionType(condition.Type)] + if isKnownCondition && condition.Status != expectedStatus { + failureCondPrefix := "" + if expectedStatus == metav1.ConditionTrue { + failureCondPrefix = "not_" } + failureCondition := failureCondPrefix + condition.Type + (*failureConditionToHClustersCount)[failureCondition] = (*failureConditionToHClustersCount)[failureCondition] + 1 + } + } +} - // countByPlatformMetric - aggregation - platform := hcluster.Spec.Platform.Type - platformToHClustersCount[platform] = platformToHClustersCount[platform] + 1 - - // countByPlatformAndFailureConditionMetric - aggregation - { - expectedConditions := conditions.ExpectedHCConditions(hcluster) - _, isKnownPlatform := platformToFailureConditionToHClustersCount[platform] - - if !isKnownPlatform { - platformToFailureConditionToHClustersCount[platform] = createFailureConditionToHClustersCountMap(expectedConditions) - } - - failureConditionToHClustersCount := platformToFailureConditionToHClustersCount[platform] - - for _, condition := range hcluster.Status.Conditions { - expectedStatus, isKnownCondition := expectedConditions[hyperv1.ConditionType(condition.Type)] - - if isKnownCondition && condition.Status != expectedStatus { - failureCondPrefix := "" - - if expectedStatus == metav1.ConditionTrue { - failureCondPrefix = "not_" - } - - failureCondition := failureCondPrefix + condition.Type - - (*failureConditionToHClustersCount)[failureCondition] = (*failureConditionToHClustersCount)[failureCondition] + 1 - } - } +func (c *hostedClustersMetricsCollector) collectTransitionDurationMetrics(hcluster *hyperv1.HostedCluster, currentCollectTime time.Time) { + for _, conditionType := range []hyperv1.ConditionType{hyperv1.EtcdAvailable, hyperv1.InfrastructureReady, hyperv1.ExternalDNSReachable, hyperv1.AWSEndpointServiceAvailable, hyperv1.AWSEndpointAvailable} { + condition := meta.FindStatusCondition(hcluster.Status.Conditions, string(conditionType)) + if condition != nil && condition.Status == metav1.ConditionTrue { + t := condition.LastTransitionTime.Time + if c.lastCollectTime.Before(t) && (t.Before(currentCollectTime) || t.Equal(currentCollectTime)) { + c.transitionDurationMetric.With(map[string]string{"condition": string(conditionType)}).Observe(t.Sub(hcluster.CreationTimestamp.Time).Seconds()) } + } + } +} - // transitionDurationMetric - aggregation - for _, conditionType := range []hyperv1.ConditionType{hyperv1.EtcdAvailable, hyperv1.InfrastructureReady, hyperv1.ExternalDNSReachable, hyperv1.AWSEndpointServiceAvailable, hyperv1.AWSEndpointAvailable} { - condition := meta.FindStatusCondition(hcluster.Status.Conditions, string(conditionType)) - - if condition != nil && condition.Status == metav1.ConditionTrue { - t := condition.LastTransitionTime.Time +func (c *hostedClustersMetricsCollector) collectPerClusterMetrics(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + collectInitialAvailabilityMetric(ch, c.clock, hcluster, hclusterLabelValues) + collectInitialRollingOutMetric(ch, c.clock, hcluster, hclusterLabelValues) + collectUpgradingDurationMetric(ch, c.clock, hcluster, hclusterLabelValues) + collectLimitedSupportMetric(ch, hcluster, hclusterLabelValues) + collectSilenceAlertsMetric(ch, hcluster, hclusterLabelValues) + c.collectProxyMetrics(ch, hcluster, hclusterLabelValues) + collectRosaMetrics(ch, hcluster, hclusterLabelValues) + collectAzureInfoMetrics(ch, hcluster, hclusterLabelValues) + collectAwsCredsMetric(ch, hcluster, hclusterLabelValues) + collectDeletingMetrics(ch, c.clock, hcluster, hclusterLabelValues) +} - if c.lastCollectTime.Before(t) && (t.Before(currentCollectTime) || t.Equal(currentCollectTime)) { - c.transitionDurationMetric.With(map[string]string{"condition": string(conditionType)}).Observe(t.Sub(hcluster.CreationTimestamp.Time).Seconds()) - } - } - } +func collectInitialAvailabilityMetric(ch chan<- prometheus.Metric, clk clock.Clock, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + if _, hasBeenAvailable := hcluster.Annotations[HasBeenAvailableAnnotation]; !hasBeenAvailable { + ch <- prometheus.MustNewConstMetric( + waitingInitialAvailabilityDurationMetricDesc, + prometheus.GaugeValue, + clk.Since(hcluster.CreationTimestamp.Time).Seconds(), + hclusterLabelValues..., + ) + } +} - hclusterLabelValues := []string{hcluster.Namespace, hcluster.Name, hcluster.Spec.ClusterID} +func collectInitialRollingOutMetric(ch chan<- prometheus.Metric, clk clock.Clock, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + if hcluster.Status.Version == nil || len(hcluster.Status.Version.History) == 0 || hcluster.Status.Version.History[0].CompletionTime == nil { + ch <- prometheus.MustNewConstMetric( + initialRollingOutDurationMetricDesc, + prometheus.GaugeValue, + clk.Since(hcluster.CreationTimestamp.Time).Seconds(), + hclusterLabelValues..., + ) + } +} - // waitingInitialAvailabilityDurationMetric - { - _, hasBeenAvailable := hcluster.Annotations[HasBeenAvailableAnnotation] +func collectUpgradingDurationMetric(ch chan<- prometheus.Metric, clk clock.Clock, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + if hcluster.Status.Version == nil || len(hcluster.Status.Version.History) <= 1 { + return + } + newVersionEntry := hcluster.Status.Version.History[len(hcluster.Status.Version.History)-1] + if newVersionEntry.CompletionTime != nil { + return + } + previousVersionEntry := hcluster.Status.Version.History[len(hcluster.Status.Version.History)-2] + ch <- prometheus.MustNewConstMetric( + upgradingDurationMetricDesc, + prometheus.GaugeValue, + clk.Since(newVersionEntry.StartedTime.Time).Seconds(), + append(hclusterLabelValues, previousVersionEntry.Version, newVersionEntry.Version)..., + ) +} - if !hasBeenAvailable { - initializingDuration := c.clock.Since(hcluster.CreationTimestamp.Time).Seconds() +func collectLimitedSupportMetric(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + limitedSupportValue := 0.0 + if _, ok := hcluster.Labels[hyperv1.LimitedSupportLabel]; ok { + limitedSupportValue = 1.0 + } + ch <- prometheus.MustNewConstMetric( + limitedSupportEnabledMetricDesc, + prometheus.GaugeValue, + limitedSupportValue, + hclusterLabelValues..., + ) +} - ch <- prometheus.MustNewConstMetric( - waitingInitialAvailabilityDurationMetricDesc, - prometheus.GaugeValue, - initializingDuration, - hclusterLabelValues..., - ) - } - } +func collectSilenceAlertsMetric(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + silenceAlertsValue := 0.0 + if _, ok := hcluster.Labels[hyperv1.SilenceClusterAlertsLabel]; ok { + silenceAlertsValue = 1.0 + } + ch <- prometheus.MustNewConstMetric( + silenceAlertsMetricDesc, + prometheus.GaugeValue, + silenceAlertsValue, + hclusterLabelValues..., + ) +} - // initialRollingOutDurationMetric - if hcluster.Status.Version == nil || len(hcluster.Status.Version.History) == 0 || hcluster.Status.Version.History[0].CompletionTime == nil { - initializingDuration := c.clock.Since(hcluster.CreationTimestamp.Time).Seconds() +func (c *hostedClustersMetricsCollector) collectProxyMetrics(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + var proxyHTTP, proxyHTTPS, proxyTrustedCA string + proxyValue := 0.0 + if hcluster.Spec.Configuration != nil && hcluster.Spec.Configuration.Proxy != nil { + if hcluster.Spec.Configuration.Proxy.HTTPProxy != "" { + proxyHTTP = "1" + } + if hcluster.Spec.Configuration.Proxy.HTTPSProxy != "" { + proxyHTTPS = "1" + } + if hcluster.Spec.Configuration.Proxy.TrustedCA.Name != "" { + proxyTrustedCA = "1" + c.collectProxyCAMetrics(ch, hcluster, hclusterLabelValues) + } + proxyValue = 1.0 + } - ch <- prometheus.MustNewConstMetric( - initialRollingOutDurationMetricDesc, - prometheus.GaugeValue, - initializingDuration, - hclusterLabelValues..., - ) - } + ch <- prometheus.MustNewConstMetric( + proxyMetricDesc, + prometheus.GaugeValue, + proxyValue, + append(hclusterLabelValues, proxyHTTP, proxyHTTPS, proxyTrustedCA)..., + ) +} - // upgradingDurationMetric - // The upgrade is adding a new entry in the history on top of the initial rollout. - if hcluster.Status.Version != nil && len(hcluster.Status.Version.History) > 1 { - newVersionEntry := hcluster.Status.Version.History[len(hcluster.Status.Version.History)-1] - - if newVersionEntry.CompletionTime == nil { - previousVersionEntry := hcluster.Status.Version.History[len(hcluster.Status.Version.History)-2] - upgradingDuration := c.clock.Since(newVersionEntry.StartedTime.Time).Seconds() - - ch <- prometheus.MustNewConstMetric( - upgradingDurationMetricDesc, - prometheus.GaugeValue, - upgradingDuration, - append(hclusterLabelValues, previousVersionEntry.Version, newVersionEntry.Version)..., - ) - } - } +func (c *hostedClustersMetricsCollector) collectProxyCAMetrics(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + proxyCAValid := 0.0 + validProxyCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.ValidProxyConfiguration)) + if validProxyCondition != nil && validProxyCondition.Status == metav1.ConditionTrue { + proxyCAValid = 1.0 + } + ch <- prometheus.MustNewConstMetric( + proxyCAMetricDesc, + prometheus.GaugeValue, + proxyCAValid, + hclusterLabelValues..., + ) + + proxyExpiryTime := 0.0 + expiryTime, err := c.expiryTimeProxyCA(hcluster) + if err == nil { + proxyExpiryTime = float64(expiryTime.Unix()) + } + ch <- prometheus.MustNewConstMetric( + proxyCAExpiryMetricDesc, + prometheus.GaugeValue, + proxyExpiryTime, + hclusterLabelValues..., + ) +} - // limitedSupportEnabledMetric - { - limitedSupportValue := 0.0 - if _, ok := hcluster.Labels[hyperv1.LimitedSupportLabel]; ok { - limitedSupportValue = 1.0 - } - - ch <- prometheus.MustNewConstMetric( - limitedSupportEnabledMetricDesc, - prometheus.GaugeValue, - limitedSupportValue, - hclusterLabelValues..., - ) +func collectRosaMetrics(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + metricLabels := make(map[string]string, 0) + if hcluster.Spec.Platform.Type == hyperv1.AWSPlatform && hcluster.Spec.Platform.AWS != nil && hcluster.Spec.Platform.AWS.ResourceTags != nil { + for _, resourceTag := range hcluster.Spec.Platform.AWS.ResourceTags { + switch resourceTag.Key { + case "api.openshift.com/environment": + metricLabels["environment"] = resourceTag.Value + case "api.openshift.com/id": + metricLabels["internal_id"] = resourceTag.Value + case "red-hat-clustertype": + metricLabels["cluster_type"] = resourceTag.Value } + } + } - // silenceAlertsMetric - { - silenceAlertsValue := 0.0 - if _, ok := hcluster.Labels[hyperv1.SilenceClusterAlertsLabel]; ok { - silenceAlertsValue = 1.0 - } - - ch <- prometheus.MustNewConstMetric( - silenceAlertsMetricDesc, - prometheus.GaugeValue, - silenceAlertsValue, - hclusterLabelValues..., - ) - } + if metricLabels["cluster_type"] != "rosa" { + return + } - // proxyMetric - { - var proxyHTTP, proxyHTTPS, proxyTrustedCA string - proxyValue := 0.0 - proxyCAValid := 0.0 - proxyExpiryTime := 0.0 - if hcluster.Spec.Configuration != nil && hcluster.Spec.Configuration.Proxy != nil { - if hcluster.Spec.Configuration.Proxy.HTTPProxy != "" { - proxyHTTP = "1" - } - if hcluster.Spec.Configuration.Proxy.HTTPSProxy != "" { - proxyHTTPS = "1" - } - if hcluster.Spec.Configuration.Proxy.TrustedCA.Name != "" { - proxyTrustedCA = "1" - - // Read validation result from the ValidProxyConfiguration condition - validProxyCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.ValidProxyConfiguration)) - if validProxyCondition != nil && validProxyCondition.Status == metav1.ConditionTrue { - proxyCAValid = 1.0 - } else { - proxyCAValid = 0.0 - } - - // Only report CA validity if a CA is actually configured - ch <- prometheus.MustNewConstMetric( - proxyCAMetricDesc, - prometheus.GaugeValue, - proxyCAValid, - hclusterLabelValues..., - ) - - expiryTime, err := c.expiryTimeProxyCA(hcluster) - if err != nil { - // Silently skip expiry time if we can't fetch it - proxyExpiryTime = 0.0 - } else { - proxyExpiryTime = float64(expiryTime.Unix()) - } - ch <- prometheus.MustNewConstMetric( - proxyCAExpiryMetricDesc, - prometheus.GaugeValue, - proxyExpiryTime, - hclusterLabelValues..., - ) - } - proxyValue = 1.0 - } - - ch <- prometheus.MustNewConstMetric( - proxyMetricDesc, - prometheus.GaugeValue, - proxyValue, - append(hclusterLabelValues, proxyHTTP, proxyHTTPS, proxyTrustedCA)..., - ) - } + etcdRecoveryActiveCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) + if etcdRecoveryActiveCondition != nil && etcdRecoveryActiveCondition.Status == metav1.ConditionFalse && etcdRecoveryActiveCondition.Reason == hyperv1.EtcdRecoveryJobFailedReason { + ch <- prometheus.MustNewConstMetric( + etcdManualInterventionRequiredMetricDesc, + prometheus.GaugeValue, + 1.0, + append(hclusterLabelValues, metricLabels["environment"], metricLabels["internal_id"])..., + ) + } - // etcdManualInterventionRequiredMetric - // clusterSizeOverrideMetric - { - metricLabels := make(map[string]string, 0) - if hcluster.Spec.Platform.Type == hyperv1.AWSPlatform && hcluster.Spec.Platform.AWS != nil && hcluster.Spec.Platform.AWS.ResourceTags != nil { - for _, resourceTag := range hcluster.Spec.Platform.AWS.ResourceTags { - switch resourceTag.Key { - case "api.openshift.com/environment": - metricLabels["environment"] = resourceTag.Value - case "api.openshift.com/id": - metricLabels["internal_id"] = resourceTag.Value - case "red-hat-clustertype": - metricLabels["cluster_type"] = resourceTag.Value - } - } - } - - if metricLabels["cluster_type"] == "rosa" { - etcdRecoveryActiveCondition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.EtcdRecoveryActive)) - if etcdRecoveryActiveCondition != nil && etcdRecoveryActiveCondition.Status == metav1.ConditionFalse && etcdRecoveryActiveCondition.Reason == hyperv1.EtcdRecoveryJobFailedReason { - etcdManualInterventionRequiredValue := 1.0 - ch <- prometheus.MustNewConstMetric( - etcdManualInterventionRequiredMetricDesc, - prometheus.GaugeValue, - etcdManualInterventionRequiredValue, - append(hclusterLabelValues, metricLabels["environment"], metricLabels["internal_id"])..., - ) - - } - - if sizeOverride := hcluster.Annotations[hyperv1.ClusterSizeOverrideAnnotation]; sizeOverride != "" { - ch <- prometheus.MustNewConstMetric( - clusterSizeOverrideMetricDesc, - prometheus.GaugeValue, - 1.0, - append(hclusterLabelValues, metricLabels["environment"], metricLabels["internal_id"], sizeOverride)..., - ) - } - } - } + if sizeOverride := hcluster.Annotations[hyperv1.ClusterSizeOverrideAnnotation]; sizeOverride != "" { + ch <- prometheus.MustNewConstMetric( + clusterSizeOverrideMetricDesc, + prometheus.GaugeValue, + 1.0, + append(hclusterLabelValues, metricLabels["environment"], metricLabels["internal_id"], sizeOverride)..., + ) + } +} - if hcluster.Spec.Platform.Azure != nil { - azInfo := hcluster.Spec.Platform.Azure - subID := azInfo.SubscriptionID - resGroup := azInfo.ResourceGroupName - if azureutil.IsAroHCP() { - // see https://github.com/Azure/ARO-HCP/blob/4134b5bb53782858047a0493f31b250c811eb84c/api/redhatopenshift/resource-manager/Microsoft.RedHatOpenShift/hcpclusters/preview/2024-06-10-preview/openapi.json#L131 - resourceID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.RedHatOpenshift/hcpOpenShiftClusters/%s", - subID, resGroup, hcluster.Name) - ch <- prometheus.MustNewConstMetric( - managedAzureHostedClusterInfoDesc, - prometheus.GaugeValue, - 1.0, - append(hclusterLabelValues, - azInfo.Location, - subID, - resGroup, - HostedClusterManagedAzureResourceType, - resourceID)...) - } else { - ch <- prometheus.MustNewConstMetric( - azureHostedClusterInfoDesc, - prometheus.GaugeValue, - 1.0, - append(hclusterLabelValues, - azInfo.Location, - subID, - resGroup)...) - } - } +func collectAzureInfoMetrics(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + if hcluster.Spec.Platform.Azure == nil { + return + } + azInfo := hcluster.Spec.Platform.Azure + subID := azInfo.SubscriptionID + resGroup := azInfo.ResourceGroupName + if azureutil.IsAroHCP() { + resourceID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.RedHatOpenshift/hcpOpenShiftClusters/%s", + subID, resGroup, hcluster.Name) + ch <- prometheus.MustNewConstMetric( + managedAzureHostedClusterInfoDesc, + prometheus.GaugeValue, + 1.0, + append(hclusterLabelValues, + azInfo.Location, + subID, + resGroup, + HostedClusterManagedAzureResourceType, + resourceID)...) + } else { + ch <- prometheus.MustNewConstMetric( + azureHostedClusterInfoDesc, + prometheus.GaugeValue, + 1.0, + append(hclusterLabelValues, + azInfo.Location, + subID, + resGroup)...) + } +} - // invalidAwsCredsMetric - { - // Use detailed credential status: 0=valid, 1=invalid, 2=unknown - credStatus := platformaws.GetCredentialStatus(hcluster) - invalidAwsCredsValue := float64(credStatus) - - ch <- prometheus.MustNewConstMetric( - invalidAwsCredsMetricDesc, - prometheus.GaugeValue, - invalidAwsCredsValue, - hclusterLabelValues..., - ) - } +func collectAwsCredsMetric(ch chan<- prometheus.Metric, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + credStatus := platformaws.GetCredentialStatus(hcluster) + ch <- prometheus.MustNewConstMetric( + invalidAwsCredsMetricDesc, + prometheus.GaugeValue, + float64(credStatus), + hclusterLabelValues..., + ) +} - if !hcluster.DeletionTimestamp.IsZero() { - // deletingDurationMetric - deletingDuration := c.clock.Since(hcluster.DeletionTimestamp.Time).Seconds() - - ch <- prometheus.MustNewConstMetric( - deletingDurationMetricDesc, - prometheus.GaugeValue, - deletingDuration, - hclusterLabelValues..., - ) - - // guestCloudResourcesDeletingDurationMetric - condition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.CloudResourcesDestroyed)) - - if condition == nil || condition.Status != metav1.ConditionTrue { - ch <- prometheus.MustNewConstMetric( - guestCloudResourcesDeletingDurationMetricDesc, - prometheus.GaugeValue, - deletingDuration, - hclusterLabelValues..., - ) - } - } - } +func collectDeletingMetrics(ch chan<- prometheus.Metric, clk clock.Clock, hcluster *hyperv1.HostedCluster, hclusterLabelValues []string) { + if hcluster.DeletionTimestamp.IsZero() { + return } + deletingDuration := clk.Since(hcluster.DeletionTimestamp.Time).Seconds() + ch <- prometheus.MustNewConstMetric( + deletingDurationMetricDesc, + prometheus.GaugeValue, + deletingDuration, + hclusterLabelValues..., + ) + + condition := meta.FindStatusCondition(hcluster.Status.Conditions, string(hyperv1.CloudResourcesDestroyed)) + if condition == nil || condition.Status != metav1.ConditionTrue { + ch <- prometheus.MustNewConstMetric( + guestCloudResourcesDeletingDurationMetricDesc, + prometheus.GaugeValue, + deletingDuration, + hclusterLabelValues..., + ) + } +} - // AGGREGATED METRICS - - // countByIdentityProviderMetric +func emitAggregatedMetrics(ch chan<- prometheus.Metric, identityProviderToHClustersCount map[configv1.IdentityProviderType]int, platformToHClustersCount map[hyperv1.PlatformType]int, platformToFailureConditionToHClustersCount map[hyperv1.PlatformType]*map[string]int) { for identityProvider, hclustersCount := range identityProviderToHClustersCount { ch <- prometheus.MustNewConstMetric( countByIdentityProviderMetricDesc, @@ -608,7 +602,6 @@ func (c *hostedClustersMetricsCollector) Collect(ch chan<- prometheus.Metric) { ) } - // countByPlatformMetric for platform, hclustersCount := range platformToHClustersCount { ch <- prometheus.MustNewConstMetric( countByPlatformMetricDesc, @@ -618,7 +611,6 @@ func (c *hostedClustersMetricsCollector) Collect(ch chan<- prometheus.Metric) { ) } - // countByPlatformAndFailureConditionMetric for platform, failureConditionToHClustersCount := range platformToFailureConditionToHClustersCount { for failureCondition, hclustersCount := range *failureConditionToHClustersCount { ch <- prometheus.MustNewConstMetric( @@ -630,11 +622,6 @@ func (c *hostedClustersMetricsCollector) Collect(ch chan<- prometheus.Metric) { ) } } - - // transitionDurationMetric - c.transitionDurationMetric.Collect(ch) - - c.lastCollectTime = currentCollectTime } // Load the CA bundle for the hosted cluster and find the earliest expiring certificate time. diff --git a/hypershift-operator/controllers/hostedcluster/network_policies.go b/hypershift-operator/controllers/hostedcluster/network_policies.go index 544296fefef..7c0c5e897dc 100644 --- a/hypershift-operator/controllers/hostedcluster/network_policies.go +++ b/hypershift-operator/controllers/hostedcluster/network_policies.go @@ -44,38 +44,22 @@ const ( func (r *HostedClusterReconciler) reconcileNetworkPolicies(ctx context.Context, log logr.Logger, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, hcp *hyperv1.HostedControlPlane, version semver.Version, controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel bool) error { controlPlaneNamespaceName := manifests.HostedControlPlaneNamespace(hcluster.Namespace, hcluster.Name) - // Reconcile openshift-ingress Network Policy. - // Only needed when routes are served by the management cluster's default ingress controller, - // i.e., when routes are NOT labeled for the HCP router. - policy := networkpolicy.OpenshiftIngressNetworkPolicy(controlPlaneNamespaceName) - if !netutil.LabelHCPRoutes(hcp) { - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileOpenshiftIngressNetworkPolicy(policy) - }); err != nil { - return fmt.Errorf("failed to reconcile ingress network policy: %w", err) - } - } else { - if _, err := k8sutil.DeleteIfNeeded(ctx, r.Client, policy); err != nil { - return fmt.Errorf("failed to delete ingress network policy: %w", err) - } + if err := r.reconcileIngressNetworkPolicy(ctx, createOrUpdate, hcp, controlPlaneNamespaceName); err != nil { + return err } - // Reconcile same-namespace Network Policy - policy = networkpolicy.SameNamespaceNetworkPolicy(controlPlaneNamespaceName) + policy := networkpolicy.SameNamespaceNetworkPolicy(controlPlaneNamespaceName) if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcileSameNamespaceNetworkPolicy(policy) }); err != nil { return fmt.Errorf("failed to reconcile same namespace network policy: %w", err) } - // Reconcile KAS Network Policy - var managementClusterNetwork *configv1.Network - if r.ManagementClusterCapabilities.Has(capabilities.CapabilityNetworks) { - managementClusterNetwork = &configv1.Network{ObjectMeta: metav1.ObjectMeta{Name: "cluster"}} - if err := r.Get(ctx, client.ObjectKeyFromObject(managementClusterNetwork), managementClusterNetwork); err != nil { - return fmt.Errorf("failed to get management cluster network config: %w", err) - } + managementClusterNetwork, err := r.getManagementClusterNetwork(ctx) + if err != nil { + return err } + policy = networkpolicy.KASNetworkPolicy(controlPlaneNamespaceName) if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcileKASNetworkPolicy(policy, hcluster, r.ManagementClusterCapabilities.Has(capabilities.CapabilityDNS), managementClusterNetwork) @@ -83,38 +67,17 @@ func (r *HostedClusterReconciler) reconcileNetworkPolicies(ctx context.Context, return fmt.Errorf("failed to reconcile kube-apiserver network policy: %w", err) } - // Reconcile management KAS network policy //nolint:staticcheck // SA1019: corev1.Endpoints is intentionally used for backward compatibility kubernetesEndpoint := &corev1.Endpoints{ObjectMeta: metav1.ObjectMeta{Name: "kubernetes", Namespace: "default"}} if err := r.Get(ctx, client.ObjectKeyFromObject(kubernetesEndpoint), kubernetesEndpoint); err != nil { return fmt.Errorf("failed to get management cluster network config: %w", err) } - // ManagementKASNetworkPolicy restricts traffic for pods unless they have a known annotation. - if controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel && hcluster.Spec.Platform.Type == hyperv1.AWSPlatform { - policy = networkpolicy.ManagementKASNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileManagementKASNetworkPolicy(policy, managementClusterNetwork, kubernetesEndpoint, r.ManagementClusterCapabilities.Has(capabilities.CapabilityDNS)) - }); err != nil { - return fmt.Errorf("failed to reconcile kube-apiserver network policy: %w", err) - } - - // Allow egress communication to the HCP metrics server for pods that have a known annotation. - // Enable if either RHOBS monitoring is enabled for ROSA HCP or the explicit flag is set. - enableMetricsAccess := r.EnableCVOManagementClusterMetricsAccess || (os.Getenv(rhobsmonitoring.EnvironmentVariable) == "1" && awsutil.IsROSAHCP(hcp)) - if enableMetricsAccess { - policy = networkpolicy.MetricsServerNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileMetricsServerNetworkPolicy(policy, hcp) - }); err != nil { - return fmt.Errorf("failed to reconcile metrics server network policy: %w", err) - } - } + if err := r.reconcileManagementKASPolicies(ctx, createOrUpdate, hcluster, hcp, controlPlaneNamespaceName, managementClusterNetwork, kubernetesEndpoint, controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel); err != nil { + return err } if sharedingress.UseSharedIngress() { - // Reconcile shared-ingress Network Policy. - // Let all ingress from shared-ingress namespace. policy := networkpolicy.SharedIngressNetworkPolicy(controlPlaneNamespaceName) if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcileSharedIngressNetworkPolicy(policy, hcluster) @@ -123,7 +86,6 @@ func (r *HostedClusterReconciler) reconcileNetworkPolicies(ctx context.Context, } } - // Reconcile openshift-monitoring Network Policy policy = networkpolicy.OpenshiftMonitoringNetworkPolicy(controlPlaneNamespaceName) if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcileOpenshiftMonitoringNetworkPolicy(policy, hcluster) @@ -131,12 +93,70 @@ func (r *HostedClusterReconciler) reconcileNetworkPolicies(ctx context.Context, return fmt.Errorf("failed to reconcile monitoring network policy: %w", err) } - // Reconcile private-router Network Policy + if err := r.reconcilePlatformNetworkPolicies(ctx, log, createOrUpdate, hcluster, kubernetesEndpoint, managementClusterNetwork, version, controlPlaneNamespaceName); err != nil { + return err + } + + return r.reconcileServiceNetworkPolicies(ctx, createOrUpdate, hcluster, controlPlaneNamespaceName) +} + +func (r *HostedClusterReconciler) reconcileIngressNetworkPolicy(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcp *hyperv1.HostedControlPlane, controlPlaneNamespaceName string) error { + policy := networkpolicy.OpenshiftIngressNetworkPolicy(controlPlaneNamespaceName) + if !netutil.LabelHCPRoutes(hcp) { + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileOpenshiftIngressNetworkPolicy(policy) + }); err != nil { + return fmt.Errorf("failed to reconcile ingress network policy: %w", err) + } + } else { + if _, err := k8sutil.DeleteIfNeeded(ctx, r.Client, policy); err != nil { + return fmt.Errorf("failed to delete ingress network policy: %w", err) + } + } + return nil +} + +func (r *HostedClusterReconciler) getManagementClusterNetwork(ctx context.Context) (*configv1.Network, error) { + if !r.ManagementClusterCapabilities.Has(capabilities.CapabilityNetworks) { + return nil, nil + } + managementClusterNetwork := &configv1.Network{ObjectMeta: metav1.ObjectMeta{Name: "cluster"}} + if err := r.Get(ctx, client.ObjectKeyFromObject(managementClusterNetwork), managementClusterNetwork); err != nil { + return nil, fmt.Errorf("failed to get management cluster network config: %w", err) + } + return managementClusterNetwork, nil +} + +//nolint:staticcheck // SA1019: corev1.Endpoints is intentionally used for backward compatibility +func (r *HostedClusterReconciler) reconcileManagementKASPolicies(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, hcp *hyperv1.HostedControlPlane, controlPlaneNamespaceName string, managementClusterNetwork *configv1.Network, kubernetesEndpoint *corev1.Endpoints, controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel bool) error { + if !controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel || hcluster.Spec.Platform.Type != hyperv1.AWSPlatform { + return nil + } + + policy := networkpolicy.ManagementKASNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileManagementKASNetworkPolicy(policy, managementClusterNetwork, kubernetesEndpoint, r.ManagementClusterCapabilities.Has(capabilities.CapabilityDNS)) + }); err != nil { + return fmt.Errorf("failed to reconcile kube-apiserver network policy: %w", err) + } + + enableMetricsAccess := r.EnableCVOManagementClusterMetricsAccess || (os.Getenv(rhobsmonitoring.EnvironmentVariable) == "1" && awsutil.IsROSAHCP(hcp)) + if enableMetricsAccess { + policy = networkpolicy.MetricsServerNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileMetricsServerNetworkPolicy(policy, hcp) + }); err != nil { + return fmt.Errorf("failed to reconcile metrics server network policy: %w", err) + } + } + return nil +} + +//nolint:staticcheck // SA1019: corev1.Endpoints is intentionally used for backward compatibility +func (r *HostedClusterReconciler) reconcilePlatformNetworkPolicies(ctx context.Context, log logr.Logger, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, kubernetesEndpoint *corev1.Endpoints, managementClusterNetwork *configv1.Network, version semver.Version, controlPlaneNamespaceName string) error { switch hcluster.Spec.Platform.Type { case hyperv1.AWSPlatform, hyperv1.AzurePlatform, hyperv1.GCPPlatform: - policy = networkpolicy.PrivateRouterNetworkPolicy(controlPlaneNamespaceName) - // TODO: Network policy code should move to the control plane operator. For now, - // only setup ingress rules (and not egress rules) when version is < 4.14 + policy := networkpolicy.PrivateRouterNetworkPolicy(controlPlaneNamespaceName) ingressOnly := version.Major == 4 && version.Minor < 14 if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcilePrivateRouterNetworkPolicy(policy, hcluster, kubernetesEndpoint, r.ManagementClusterCapabilities.Has(capabilities.CapabilityDNS), managementClusterNetwork, ingressOnly) @@ -146,152 +166,97 @@ func (r *HostedClusterReconciler) reconcileNetworkPolicies(ctx context.Context, case hyperv1.KubevirtPlatform: if hcluster.Spec.Platform.Kubevirt.Credentials == nil { // Centralized infra: policy targets the control plane namespace on the management cluster - policy = networkpolicy.VirtLauncherNetworkPolicy(controlPlaneNamespaceName) + policy := networkpolicy.VirtLauncherNetworkPolicy(controlPlaneNamespaceName) if _, err := createOrUpdate(ctx, r.Client, policy, func() error { return reconcileVirtLauncherNetworkPolicy(log, policy, hcluster, managementClusterNetwork) }); err != nil { return fmt.Errorf("failed to reconcile virt launcher policy: %w", err) } } else { - // External infra: policy targets the infra namespace on the infrastructure cluster - kvInfraClient, err := r.KubevirtInfraClients.DiscoverKubevirtClusterClient(ctx, - r.Client, - hcluster.Spec.InfraID, - hcluster.Spec.Platform.Kubevirt.Credentials, - hcluster.Namespace, - hcluster.Namespace) - if err != nil { - return fmt.Errorf("failed to get kubevirt infra client for network policy: %w", err) - } - - infraClient := kvInfraClient.GetInfraClient() - infraNamespace := kvInfraClient.GetInfraNamespace() - - // networks.config.openshift.io is cluster-scoped, so the infra - // kubeconfig needs a ClusterRole with get permission on that - // resource. When the permission is missing we still create the - // NetworkPolicy but without CIDR-based egress blocking, and - // surface the RBAC gap as a condition on the HostedCluster. - var infraClusterNetwork *configv1.Network - networkObj := &configv1.Network{ObjectMeta: metav1.ObjectMeta{Name: "cluster"}} - if err := infraClient.Get(ctx, client.ObjectKeyFromObject(networkObj), networkObj); err != nil { - if apierrors.IsForbidden(err) || apierrors.IsNotFound(err) || meta.IsNoMatchError(err) { - rbacMsg := fmt.Sprintf( - "The external infrastructure kubeconfig lacks permission to read "+ - "networks.config.openshift.io/cluster. The virt-launcher NetworkPolicy "+ - "has been created without CIDR-based egress restrictions, resulting in "+ - "weaker tenant isolation. Grant a ClusterRole with get on "+ - "networks.config.openshift.io to the infra service account for full isolation. "+ - "Error: %v", err) - log.Info(rbacMsg) - meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ - Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), - Status: metav1.ConditionFalse, - Reason: hyperv1.InfraClusterNetworkReadFailedReason, - ObservedGeneration: hcluster.Generation, - Message: rbacMsg, - }) - emitInfraClusterWarningEvent(ctx, infraClient, infraNamespace, hcluster.Spec.InfraID, rbacMsg, log) - } else { - return fmt.Errorf("failed to get infrastructure cluster network config: %w", err) - } - } else { - infraClusterNetwork = networkObj - meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ - Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), - Status: metav1.ConditionTrue, - Reason: hyperv1.AsExpectedReason, - ObservedGeneration: hcluster.Generation, - Message: "Infrastructure cluster network configuration is readable; CIDR-based egress restrictions are active.", - }) - } - - policy = networkpolicy.VirtLauncherNetworkPolicy(infraNamespace) - if _, err := createOrUpdate(ctx, infraClient, policy, func() error { - return reconcileVirtLauncherNetworkPolicyExternalInfra(log, policy, hcluster, infraClusterNetwork) - }); err != nil { - if apierrors.IsForbidden(err) { - rbacMsg := fmt.Sprintf( - "Unable to create/update virt-launcher NetworkPolicy on the infrastructure cluster: "+ - "the external infra kubeconfig lacks networking.k8s.io/networkpolicies permissions. "+ - "Grant create/update/get/list/watch on networkpolicies in the infra namespace. "+ - "Error: %v", err) - log.Info(rbacMsg) - meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ - Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), - Status: metav1.ConditionFalse, - Reason: hyperv1.InfraClusterNetworkPolicyCreateFailedReason, - ObservedGeneration: hcluster.Generation, - Message: rbacMsg, - }) - emitInfraClusterWarningEvent(ctx, infraClient, infraNamespace, hcluster.Spec.InfraID, rbacMsg, log) - } else { - return fmt.Errorf("failed to reconcile virt launcher policy on external infra: %w", err) - } + // External infra (credentials != nil): policy targets the infra namespace on the infrastructure cluster + if err := r.reconcileExternalInfraVirtLauncherPolicy(ctx, log, createOrUpdate, hcluster); err != nil { + return err } } } + return nil +} +func (r *HostedClusterReconciler) reconcileServiceNetworkPolicies(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, controlPlaneNamespaceName string) error { for _, svc := range hcluster.Spec.Services { switch svc.Service { case hyperv1.OAuthServer: - if svc.ServicePublishingStrategy.Type == hyperv1.NodePort { - // Reconcile nodeport-oauth Network Policy - policy = networkpolicy.NodePortOauthNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileNodePortOauthNetworkPolicy(policy, hcluster) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth server nodeport network policy: %w", err) - } - } - if svc.ServicePublishingStrategy.Type == hyperv1.LoadBalancer { - // Reconcile loadbalancer-oauth Network Policy - policy = networkpolicy.LoadBalancerOauthNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileLoadBalancerOauthNetworkPolicy(policy) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth server loadbalancer network policy: %w", err) - } + if err := r.reconcileOAuthNetworkPolicies(ctx, createOrUpdate, hcluster, svc, controlPlaneNamespaceName); err != nil { + return err } case hyperv1.Ignition: - if svc.ServicePublishingStrategy.Type == hyperv1.NodePort { - // Reconcile nodeport-ignition Network Policy - policy = networkpolicy.NodePortIgnitionNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileNodePortIgnitionNetworkPolicy(policy, hcluster) - }); err != nil { - return fmt.Errorf("failed to reconcile ignition nodeport network policy: %w", err) - } - // Reconcile nodeport-ignition-proxy Network Policy - policy = networkpolicy.NodePortIgnitionProxyNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileNodePortIgnitionProxyNetworkPolicy(policy, hcluster) - }); err != nil { - return fmt.Errorf("failed to reconcile ignition proxy nodeport network policy: %w", err) - } + if err := r.reconcileIgnitionNetworkPolicies(ctx, createOrUpdate, hcluster, svc, controlPlaneNamespaceName); err != nil { + return err } case hyperv1.Konnectivity: - if svc.ServicePublishingStrategy.Type == hyperv1.NodePort { - // Reconcile nodeport-konnectivity Network Policy - policy = networkpolicy.NodePortKonnectivityNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileNodePortKonnectivityNetworkPolicy(policy, hcluster) - }); err != nil { - return fmt.Errorf("failed to reconcile konnectivity nodeport network policy: %w", err) - } - - // Reconcile nodeport-konnectivity Network Policy when konnectivity is hosted in the kas pod - policy = networkpolicy.NodePortKonnectivityKASNetworkPolicy(controlPlaneNamespaceName) - if _, err := createOrUpdate(ctx, r.Client, policy, func() error { - return reconcileNodePortKonnectivityKASNetworkPolicy(policy, hcluster) - }); err != nil { - return fmt.Errorf("failed to reconcile konnectivity nodeport network policy: %w", err) - } - + if err := r.reconcileKonnectivityNetworkPolicies(ctx, createOrUpdate, hcluster, svc, controlPlaneNamespaceName); err != nil { + return err } } } + return nil +} +func (r *HostedClusterReconciler) reconcileOAuthNetworkPolicies(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, svc hyperv1.ServicePublishingStrategyMapping, controlPlaneNamespaceName string) error { + if svc.ServicePublishingStrategy.Type == hyperv1.NodePort { + policy := networkpolicy.NodePortOauthNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileNodePortOauthNetworkPolicy(policy, hcluster) + }); err != nil { + return fmt.Errorf("failed to reconcile oauth server nodeport network policy: %w", err) + } + } + if svc.ServicePublishingStrategy.Type == hyperv1.LoadBalancer { + policy := networkpolicy.LoadBalancerOauthNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileLoadBalancerOauthNetworkPolicy(policy) + }); err != nil { + return fmt.Errorf("failed to reconcile oauth server loadbalancer network policy: %w", err) + } + } + return nil +} + +func (r *HostedClusterReconciler) reconcileIgnitionNetworkPolicies(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, svc hyperv1.ServicePublishingStrategyMapping, controlPlaneNamespaceName string) error { + if svc.ServicePublishingStrategy.Type != hyperv1.NodePort { + return nil + } + policy := networkpolicy.NodePortIgnitionNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileNodePortIgnitionNetworkPolicy(policy, hcluster) + }); err != nil { + return fmt.Errorf("failed to reconcile ignition nodeport network policy: %w", err) + } + policy = networkpolicy.NodePortIgnitionProxyNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileNodePortIgnitionProxyNetworkPolicy(policy, hcluster) + }); err != nil { + return fmt.Errorf("failed to reconcile ignition proxy nodeport network policy: %w", err) + } + return nil +} + +func (r *HostedClusterReconciler) reconcileKonnectivityNetworkPolicies(ctx context.Context, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster, svc hyperv1.ServicePublishingStrategyMapping, controlPlaneNamespaceName string) error { + if svc.ServicePublishingStrategy.Type != hyperv1.NodePort { + return nil + } + policy := networkpolicy.NodePortKonnectivityNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileNodePortKonnectivityNetworkPolicy(policy, hcluster) + }); err != nil { + return fmt.Errorf("failed to reconcile konnectivity nodeport network policy: %w", err) + } + policy = networkpolicy.NodePortKonnectivityKASNetworkPolicy(controlPlaneNamespaceName) + if _, err := createOrUpdate(ctx, r.Client, policy, func() error { + return reconcileNodePortKonnectivityKASNetworkPolicy(policy, hcluster) + }); err != nil { + return fmt.Errorf("failed to reconcile konnectivity nodeport network policy: %w", err) + } return nil } @@ -684,16 +649,6 @@ func reconcileVirtLauncherNetworkPolicy(log logr.Logger, policy *networkingv1.Ne return buildVirtLauncherNetworkPolicyBase(log, policy, hcluster, blockedIPv4Networks, blockedIPv6Networks, controlPlanePeers) } -// reconcileVirtLauncherNetworkPolicyExternalInfra builds the virt-launcher -// NetworkPolicy for deployments where the KubeVirt VMs run on a separate -// infrastructure cluster. Unlike the centralized variant, this omits egress -// rules for control-plane pods (kube-apiserver, oauth, ignition-server-proxy) -// because those pods reside on the management cluster and are reached via -// external IPs already permitted by the broad 0.0.0.0/0 allow rule. -// -// infraClusterNetwork may be nil when the infra kubeconfig lacks cluster- -// scoped read access to networks.config.openshift.io. In that case the -// policy is still created but without CIDR-based egress blocking. func reconcileVirtLauncherNetworkPolicyExternalInfra(log logr.Logger, policy *networkingv1.NetworkPolicy, hcluster *hyperv1.HostedCluster, infraClusterNetwork *configv1.Network) error { blockedIPv4Networks := []string{} blockedIPv6Networks := []string{} @@ -709,9 +664,8 @@ func reconcileVirtLauncherNetworkPolicyExternalInfra(log logr.Logger, policy *ne return buildVirtLauncherNetworkPolicyBase(log, policy, hcluster, blockedIPv4Networks, blockedIPv6Networks, nil) } -// buildVirtLauncherNetworkPolicyBase constructs the common virt-launcher -// NetworkPolicy structure shared by both centralized and external infra -// deployments. extraEgressPeers are appended to the primary egress rule +// buildVirtLauncherNetworkPolicyBase constructs the virt-launcher +// NetworkPolicy structure. extraEgressPeers are appended to the primary egress rule // (e.g. control-plane pod selectors for centralized infra). func buildVirtLauncherNetworkPolicyBase(log logr.Logger, policy *networkingv1.NetworkPolicy, hcluster *hyperv1.HostedCluster, blockedIPv4Networks, blockedIPv6Networks []string, extraEgressPeers []networkingv1.NetworkPolicyPeer) error { protocolTCP := corev1.ProtocolTCP @@ -1065,6 +1019,88 @@ func reconcileMetricsServerNetworkPolicy(policy *networkingv1.NetworkPolicy, hcp return nil } +// reconcileExternalInfraVirtLauncherPolicy discovers the KubeVirt infra client +// and creates/updates the virt-launcher NetworkPolicy on the infrastructure cluster. +// It handles RBAC errors gracefully by setting status conditions. +func (r *HostedClusterReconciler) reconcileExternalInfraVirtLauncherPolicy(ctx context.Context, log logr.Logger, createOrUpdate upsert.CreateOrUpdateFN, hcluster *hyperv1.HostedCluster) error { + kvInfraClient, err := r.KubevirtInfraClients.DiscoverKubevirtClusterClient(ctx, + r.Client, + hcluster.Spec.InfraID, + hcluster.Spec.Platform.Kubevirt.Credentials, + hcluster.Namespace, + hcluster.Namespace) + if err != nil { + return fmt.Errorf("failed to get kubevirt infra client for network policy: %w", err) + } + + infraClient := kvInfraClient.GetInfraClient() + infraNamespace := kvInfraClient.GetInfraNamespace() + + infraClusterNetwork := fetchInfraClusterNetwork(ctx, infraClient, hcluster, log) + + policy := networkpolicy.VirtLauncherNetworkPolicy(infraNamespace) + if _, err := createOrUpdate(ctx, infraClient, policy, func() error { + return reconcileVirtLauncherNetworkPolicyExternalInfra(log, policy, hcluster, infraClusterNetwork) + }); err != nil { + if apierrors.IsForbidden(err) { + rbacMsg := fmt.Sprintf( + "Unable to create/update virt-launcher NetworkPolicy on the infrastructure cluster: "+ + "the external infra kubeconfig lacks networking.k8s.io/networkpolicies permissions. "+ + "Grant create/update/get/list/watch on networkpolicies in the infra namespace. "+ + "Error: %v", err) + log.Info(rbacMsg) + meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ + Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), + Status: metav1.ConditionFalse, + Reason: hyperv1.InfraClusterNetworkPolicyCreateFailedReason, + ObservedGeneration: hcluster.Generation, + Message: rbacMsg, + }) + emitInfraClusterWarningEvent(ctx, infraClient, infraNamespace, hcluster.Spec.InfraID, rbacMsg, log) + } else { + return fmt.Errorf("failed to reconcile virt launcher policy on external infra: %w", err) + } + } + return nil +} + +// fetchInfraClusterNetwork reads networks.config.openshift.io/cluster from the +// infrastructure cluster. When the permission is missing we still allow the +// NetworkPolicy to be created but without CIDR-based egress blocking, and +// surface the RBAC gap as a condition on the HostedCluster. +func fetchInfraClusterNetwork(ctx context.Context, infraClient client.Client, hcluster *hyperv1.HostedCluster, log logr.Logger) *configv1.Network { + networkObj := &configv1.Network{ObjectMeta: metav1.ObjectMeta{Name: "cluster"}} + if err := infraClient.Get(ctx, client.ObjectKeyFromObject(networkObj), networkObj); err != nil { + if apierrors.IsForbidden(err) || apierrors.IsNotFound(err) || meta.IsNoMatchError(err) { + rbacMsg := fmt.Sprintf( + "The external infrastructure kubeconfig lacks permission to read "+ + "networks.config.openshift.io/cluster. The virt-launcher NetworkPolicy "+ + "has been created without CIDR-based egress restrictions, resulting in "+ + "weaker tenant isolation. Grant a ClusterRole with get on "+ + "networks.config.openshift.io to the infra service account for full isolation. "+ + "Error: %v", err) + log.Info(rbacMsg) + meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ + Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), + Status: metav1.ConditionFalse, + Reason: hyperv1.InfraClusterNetworkReadFailedReason, + ObservedGeneration: hcluster.Generation, + Message: rbacMsg, + }) + emitInfraClusterWarningEvent(ctx, infraClient, hcluster.Namespace, hcluster.Spec.InfraID, rbacMsg, log) + } + return nil + } + meta.SetStatusCondition(&hcluster.Status.Conditions, metav1.Condition{ + Type: string(hyperv1.ValidKubeVirtInfraNetworkPolicyRBAC), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: hcluster.Generation, + Message: "Infrastructure cluster network configuration is readable; CIDR-based egress restrictions are active.", + }) + return networkObj +} + // emitInfraClusterWarningEvent creates or updates a warning Event in the // infrastructure cluster namespace so that operators monitoring the infra // cluster can see the RBAC gap without access to the management cluster. diff --git a/hypershift-operator/controllers/hostedcluster/network_policies_test.go b/hypershift-operator/controllers/hostedcluster/network_policies_test.go index 56b283958d3..f093771e660 100644 --- a/hypershift-operator/controllers/hostedcluster/network_policies_test.go +++ b/hypershift-operator/controllers/hostedcluster/network_policies_test.go @@ -9,6 +9,8 @@ import ( hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" "github.com/openshift/hypershift/hypershift-operator/controllers/manifests" "github.com/openshift/hypershift/hypershift-operator/controllers/manifests/networkpolicy" + kvinfra "github.com/openshift/hypershift/kubevirtexternalinfra" + "github.com/openshift/hypershift/support/capabilities" fakecapabilities "github.com/openshift/hypershift/support/capabilities/fake" "github.com/openshift/hypershift/support/upsert" @@ -626,6 +628,703 @@ func TestReconcileLoadBalancerOauthNetworkPolicy(t *testing.T) { } } +func TestGetManagementClusterNetwork(t *testing.T) { + testCases := []struct { + name string + capabilities fakecapabilities.FakeCapabilitiesSupportAllExcept + objects []client.Object + expectNetwork bool + expectError bool + expectedName string + }{ + { + name: "When CapabilityNetworks is not supported, it should return nil", + capabilities: fakecapabilities.FakeCapabilitiesSupportAllExcept{NotSupported: map[capabilities.CapabilityType]struct{}{capabilities.CapabilityNetworks: {}}}, + objects: nil, + expectNetwork: false, + }, + { + name: "When CapabilityNetworks is supported and network exists, it should return the network", + capabilities: fakecapabilities.FakeCapabilitiesSupportAllExcept{NotSupported: map[capabilities.CapabilityType]struct{}{}}, + objects: []client.Object{ + &configv1.Network{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: configv1.NetworkSpec{ + ClusterNetwork: []configv1.ClusterNetworkEntry{{CIDR: "10.128.0.0/14"}}, + }, + }, + }, + expectNetwork: true, + expectedName: "cluster", + }, + { + name: "When CapabilityNetworks is supported but network does not exist, it should return an error", + capabilities: fakecapabilities.FakeCapabilitiesSupportAllExcept{NotSupported: map[capabilities.CapabilityType]struct{}{}}, + objects: nil, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(configv1.AddToScheme(scheme)).To(Succeed()) + + builder := fake.NewClientBuilder().WithScheme(scheme) + if tc.objects != nil { + builder = builder.WithObjects(tc.objects...) + } + fakeClient := builder.Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: &tc.capabilities, + } + + network, err := reconciler.getManagementClusterNetwork(t.Context()) + + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + + if tc.expectNetwork { + g.Expect(network).ToNot(BeNil()) + g.Expect(network.Name).To(Equal(tc.expectedName)) + } else if !tc.expectError { + g.Expect(network).To(BeNil()) + } + }) + } +} + +func TestReconcileManagementKASPolicies(t *testing.T) { + testCases := []struct { + name string + platformType hyperv1.PlatformType + controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel bool + enableCVOManagementClusterMetricsAccess bool + expectManagementKAS bool + expectMetricsServer bool + }{ + { + name: "When label is not applied, it should create no policies", + platformType: hyperv1.AWSPlatform, + controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel: false, + expectManagementKAS: false, + expectMetricsServer: false, + }, + { + name: "When platform is not AWS, it should create no policies", + platformType: hyperv1.AzurePlatform, + controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel: true, + expectManagementKAS: false, + expectMetricsServer: false, + }, + { + name: "When AWS platform with label applied, it should create management-kas policy", + platformType: hyperv1.AWSPlatform, + controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel: true, + expectManagementKAS: true, + expectMetricsServer: false, + }, + { + name: "When AWS platform with label and CVO metrics enabled, it should create both policies", + platformType: hyperv1.AWSPlatform, + controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel: true, + enableCVOManagementClusterMetricsAccess: true, + expectManagementKAS: true, + expectMetricsServer: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: tc.platformType}, + }, + } + hcp := &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Platform: hyperv1.PlatformSpec{Type: tc.platformType}, + }, + } + + managementClusterNetwork := &configv1.Network{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: configv1.NetworkSpec{ + ClusterNetwork: []configv1.ClusterNetworkEntry{{CIDR: "10.128.0.0/14"}}, + }, + } + + //nolint:staticcheck // SA1019: corev1.Endpoints is intentionally used for backward compatibility + kubernetesEndpoint := &corev1.Endpoints{ + ObjectMeta: metav1.ObjectMeta{Name: "kubernetes", Namespace: "default"}, + //nolint:staticcheck // SA1019: corev1.EndpointSubset is intentionally used for backward compatibility + Subsets: []corev1.EndpointSubset{ + {Addresses: []corev1.EndpointAddress{{IP: "10.0.0.1"}}}, + }, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + EnableCVOManagementClusterMetricsAccess: tc.enableCVOManagementClusterMetricsAccess, + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + err := reconciler.reconcileManagementKASPolicies(t.Context(), createOrUpdate, hcluster, hcp, controlPlaneNamespaceName, managementClusterNetwork, kubernetesEndpoint, tc.controlPlaneOperatorAppliesManagementKASNetworkPolicyLabel) + g.Expect(err).ToNot(HaveOccurred()) + + _, hasManagementKAS := createdPolicies["management-kas"] + g.Expect(hasManagementKAS).To(Equal(tc.expectManagementKAS), "management-kas policy presence mismatch") + + _, hasMetricsServer := createdPolicies["metrics-server"] + g.Expect(hasMetricsServer).To(Equal(tc.expectMetricsServer), "metrics-server policy presence mismatch") + }) + } +} + +func TestReconcilePlatformNetworkPolicies(t *testing.T) { + testCases := []struct { + name string + platformType hyperv1.PlatformType + kubevirtCredentials *hyperv1.KubevirtPlatformCredentials + version string + expectPrivateRouter bool + expectVirtLauncher bool + expectIngressOnly bool + }{ + { + name: "When platform is AWS, it should create private-router policy", + platformType: hyperv1.AWSPlatform, + version: "4.15.0", + expectPrivateRouter: true, + }, + { + name: "When platform is Azure, it should create private-router policy", + platformType: hyperv1.AzurePlatform, + version: "4.15.0", + expectPrivateRouter: true, + }, + { + name: "When platform is GCP, it should create private-router policy", + platformType: hyperv1.GCPPlatform, + version: "4.15.0", + expectPrivateRouter: true, + }, + { + name: "When platform is AWS with version < 4.14, it should create ingress-only private-router policy", + platformType: hyperv1.AWSPlatform, + version: "4.13.0", + expectPrivateRouter: true, + expectIngressOnly: true, + }, + { + name: "When platform is KubeVirt without credentials, it should create virt-launcher policy", + platformType: hyperv1.KubevirtPlatform, + version: "4.15.0", + expectVirtLauncher: true, + }, + { + name: "When platform is KubeVirt with credentials, it should create virt-launcher policy on external infra", + platformType: hyperv1.KubevirtPlatform, + kubevirtCredentials: &hyperv1.KubevirtPlatformCredentials{}, + version: "4.15.0", + expectVirtLauncher: true, + }, + { + name: "When platform is IBMCloud, it should not create private-router or virt-launcher policies", + platformType: hyperv1.IBMCloudPlatform, + version: "4.15.0", + expectPrivateRouter: false, + expectVirtLauncher: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: tc.platformType}, + InfraID: "test-infra", + }, + } + if tc.platformType == hyperv1.KubevirtPlatform { + hcluster.Spec.Platform.Kubevirt = &hyperv1.KubevirtPlatformSpec{ + Credentials: tc.kubevirtCredentials, + } + } + + managementClusterNetwork := &configv1.Network{ + ObjectMeta: metav1.ObjectMeta{Name: "cluster"}, + Spec: configv1.NetworkSpec{ + ClusterNetwork: []configv1.ClusterNetworkEntry{{CIDR: "10.128.0.0/14"}}, + ServiceNetwork: []string{"172.30.0.0/16"}, + }, + } + + //nolint:staticcheck // SA1019: corev1.Endpoints is intentionally used for backward compatibility + kubernetesEndpoint := &corev1.Endpoints{ + ObjectMeta: metav1.ObjectMeta{Name: "kubernetes", Namespace: "default"}, + //nolint:staticcheck // SA1019: corev1.EndpointSubset is intentionally used for backward compatibility + Subsets: []corev1.EndpointSubset{ + {Addresses: []corev1.EndpointAddress{{IP: "10.0.0.1"}}}, + }, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + g.Expect(configv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + } + if tc.kubevirtCredentials != nil { + reconciler.KubevirtInfraClients = kvinfra.NewMockKubevirtInfraClientMap(fakeClient, "1.0.0", "1.30.0") + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + log := ctrl.Log.WithName("test") + version := semver.MustParse(tc.version) + + err := reconciler.reconcilePlatformNetworkPolicies(t.Context(), log, createOrUpdate, hcluster, kubernetesEndpoint, managementClusterNetwork, version, controlPlaneNamespaceName) + g.Expect(err).ToNot(HaveOccurred()) + + _, hasPrivateRouter := createdPolicies["private-router"] + g.Expect(hasPrivateRouter).To(Equal(tc.expectPrivateRouter), "private-router policy presence mismatch") + + if tc.expectIngressOnly && hasPrivateRouter { + policy := createdPolicies["private-router"] + g.Expect(policy.Spec.PolicyTypes).To(Equal([]networkingv1.PolicyType{networkingv1.PolicyTypeIngress})) + g.Expect(policy.Spec.Egress).To(BeEmpty()) + } + + _, hasVirtLauncher := createdPolicies["virt-launcher"] + g.Expect(hasVirtLauncher).To(Equal(tc.expectVirtLauncher), "virt-launcher policy presence mismatch") + }) + } +} + +func TestReconcileServiceNetworkPolicies(t *testing.T) { + testCases := []struct { + name string + services []hyperv1.ServicePublishingStrategyMapping + expectedPolicies []string + absentPolicies []string + }{ + { + name: "When OAuth uses NodePort, it should create nodeport-oauth policy", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + }, + expectedPolicies: []string{"nodeport-oauth"}, + absentPolicies: []string{"loadbalancer-oauth"}, + }, + { + name: "When OAuth uses LoadBalancer, it should create loadbalancer-oauth policy", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.LoadBalancer}, + }, + }, + expectedPolicies: []string{"loadbalancer-oauth"}, + absentPolicies: []string{"nodeport-oauth"}, + }, + { + name: "When Ignition uses NodePort, it should create ignition and ignition-proxy policies", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.Ignition, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + }, + expectedPolicies: []string{"nodeport-ignition", "nodeport-ignition-proxy"}, + }, + { + name: "When Ignition uses Route, it should not create ignition policies", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.Ignition, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.Route}, + }, + }, + absentPolicies: []string{"nodeport-ignition", "nodeport-ignition-proxy"}, + }, + { + name: "When Konnectivity uses NodePort, it should create konnectivity and konnectivity-kas policies", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.Konnectivity, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + }, + expectedPolicies: []string{"nodeport-konnectivity", "nodeport-konnectivity-kas"}, + }, + { + name: "When Konnectivity uses Route, it should not create konnectivity policies", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.Konnectivity, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.Route}, + }, + }, + absentPolicies: []string{"nodeport-konnectivity", "nodeport-konnectivity-kas"}, + }, + { + name: "When multiple services use NodePort, it should create all corresponding policies", + services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + { + Service: hyperv1.Ignition, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + { + Service: hyperv1.Konnectivity, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: hyperv1.NodePort, NodePort: &hyperv1.NodePortPublishingStrategy{Address: "10.0.0.1"}}, + }, + }, + expectedPolicies: []string{"nodeport-oauth", "nodeport-ignition", "nodeport-ignition-proxy", "nodeport-konnectivity", "nodeport-konnectivity-kas"}, + }, + { + name: "When no services are specified, it should create no service policies", + services: []hyperv1.ServicePublishingStrategyMapping{}, + absentPolicies: []string{"nodeport-oauth", "loadbalancer-oauth", "nodeport-ignition", "nodeport-ignition-proxy", "nodeport-konnectivity", "nodeport-konnectivity-kas"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: hyperv1.AWSPlatform}, + Services: tc.services, + }, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + err := reconciler.reconcileServiceNetworkPolicies(t.Context(), createOrUpdate, hcluster, controlPlaneNamespaceName) + g.Expect(err).ToNot(HaveOccurred()) + + for _, expected := range tc.expectedPolicies { + _, found := createdPolicies[expected] + g.Expect(found).To(BeTrue(), "expected %s policy to be created", expected) + } + + for _, absent := range tc.absentPolicies { + _, found := createdPolicies[absent] + g.Expect(found).To(BeFalse(), "expected %s policy to NOT be created", absent) + } + }) + } +} + +func TestReconcileOAuthNetworkPolicies(t *testing.T) { + testCases := []struct { + name string + serviceType hyperv1.PublishingStrategyType + expectNodePortOauth bool + expectLoadBalancerOauth bool + }{ + { + name: "When OAuth uses NodePort, it should create nodeport-oauth policy only", + serviceType: hyperv1.NodePort, + expectNodePortOauth: true, + expectLoadBalancerOauth: false, + }, + { + name: "When OAuth uses LoadBalancer, it should create loadbalancer-oauth policy only", + serviceType: hyperv1.LoadBalancer, + expectNodePortOauth: false, + expectLoadBalancerOauth: true, + }, + { + name: "When OAuth uses Route, it should create no OAuth policies", + serviceType: hyperv1.Route, + expectNodePortOauth: false, + expectLoadBalancerOauth: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: hyperv1.AWSPlatform}, + }, + } + svc := hyperv1.ServicePublishingStrategyMapping{ + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: tc.serviceType}, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + err := reconciler.reconcileOAuthNetworkPolicies(t.Context(), createOrUpdate, hcluster, svc, controlPlaneNamespaceName) + g.Expect(err).ToNot(HaveOccurred()) + + _, hasNodePort := createdPolicies["nodeport-oauth"] + g.Expect(hasNodePort).To(Equal(tc.expectNodePortOauth), "nodeport-oauth policy presence mismatch") + + _, hasLoadBalancer := createdPolicies["loadbalancer-oauth"] + g.Expect(hasLoadBalancer).To(Equal(tc.expectLoadBalancerOauth), "loadbalancer-oauth policy presence mismatch") + }) + } +} + +func TestReconcileIgnitionNetworkPolicies(t *testing.T) { + testCases := []struct { + name string + serviceType hyperv1.PublishingStrategyType + expectIgnition bool + expectIgnitionProxy bool + }{ + { + name: "When Ignition uses NodePort, it should create both ignition and ignition-proxy policies", + serviceType: hyperv1.NodePort, + expectIgnition: true, + expectIgnitionProxy: true, + }, + { + name: "When Ignition uses Route, it should create no ignition policies", + serviceType: hyperv1.Route, + expectIgnition: false, + expectIgnitionProxy: false, + }, + { + name: "When Ignition uses LoadBalancer, it should create no ignition policies", + serviceType: hyperv1.LoadBalancer, + expectIgnition: false, + expectIgnitionProxy: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: hyperv1.AWSPlatform}, + }, + } + svc := hyperv1.ServicePublishingStrategyMapping{ + Service: hyperv1.Ignition, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: tc.serviceType}, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + err := reconciler.reconcileIgnitionNetworkPolicies(t.Context(), createOrUpdate, hcluster, svc, controlPlaneNamespaceName) + g.Expect(err).ToNot(HaveOccurred()) + + _, hasIgnition := createdPolicies["nodeport-ignition"] + g.Expect(hasIgnition).To(Equal(tc.expectIgnition), "nodeport-ignition policy presence mismatch") + + _, hasProxy := createdPolicies["nodeport-ignition-proxy"] + g.Expect(hasProxy).To(Equal(tc.expectIgnitionProxy), "nodeport-ignition-proxy policy presence mismatch") + }) + } +} + +func TestReconcileKonnectivityNetworkPolicies(t *testing.T) { + testCases := []struct { + name string + serviceType hyperv1.PublishingStrategyType + expectKonnectivity bool + expectKonnectivityKAS bool + }{ + { + name: "When Konnectivity uses NodePort, it should create both konnectivity and konnectivity-kas policies", + serviceType: hyperv1.NodePort, + expectKonnectivity: true, + expectKonnectivityKAS: true, + }, + { + name: "When Konnectivity uses Route, it should create no konnectivity policies", + serviceType: hyperv1.Route, + expectKonnectivity: false, + expectKonnectivityKAS: false, + }, + { + name: "When Konnectivity uses LoadBalancer, it should create no konnectivity policies", + serviceType: hyperv1.LoadBalancer, + expectKonnectivity: false, + expectKonnectivityKAS: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + controlPlaneNamespaceName := "test-cp-ns" + hcluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{Type: hyperv1.AWSPlatform}, + }, + } + svc := hyperv1.ServicePublishingStrategyMapping{ + Service: hyperv1.Konnectivity, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{Type: tc.serviceType}, + } + + scheme := runtime.NewScheme() + g.Expect(corev1.AddToScheme(scheme)).To(Succeed()) + g.Expect(networkingv1.AddToScheme(scheme)).To(Succeed()) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + reconciler := &HostedClusterReconciler{ + Client: fakeClient, + ManagementClusterCapabilities: fakecapabilities.NewSupportAllExcept(), + } + + createdPolicies := make(map[string]*networkingv1.NetworkPolicy) + createOrUpdate := upsert.CreateOrUpdateFN(func(ctx context.Context, c client.Client, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + netPol, ok := obj.(*networkingv1.NetworkPolicy) + if !ok { + t.Fatalf("unexpected object type: %T", obj) + } + if err := f(); err != nil { + return controllerutil.OperationResultNone, err + } + createdPolicies[netPol.Name] = netPol + return controllerutil.OperationResultCreated, nil + }) + + err := reconciler.reconcileKonnectivityNetworkPolicies(t.Context(), createOrUpdate, hcluster, svc, controlPlaneNamespaceName) + g.Expect(err).ToNot(HaveOccurred()) + + _, hasKonnectivity := createdPolicies["nodeport-konnectivity"] + g.Expect(hasKonnectivity).To(Equal(tc.expectKonnectivity), "nodeport-konnectivity policy presence mismatch") + + _, hasKonnectivityKAS := createdPolicies["nodeport-konnectivity-kas"] + g.Expect(hasKonnectivityKAS).To(Equal(tc.expectKonnectivityKAS), "nodeport-konnectivity-kas policy presence mismatch") + }) + } +} + func TestReconcileNetworkPolicies_LoadBalancerOauth(t *testing.T) { testCases := []struct { name string diff --git a/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller.go b/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller.go index 34294c89c26..64983a9a73c 100644 --- a/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller.go +++ b/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller.go @@ -2,6 +2,7 @@ package hostedclustersizing import ( "context" + "errors" "fmt" "sort" "time" @@ -15,7 +16,7 @@ import ( hyperutil "github.com/openshift/hypershift/support/util" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -25,6 +26,8 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + + "github.com/go-logr/logr" ) const ( @@ -129,7 +132,7 @@ func (r *reconciler) Reconcile(ctx context.Context, request reconcile.Request) ( hostedCluster, err := r.getHostedCluster(ctx, request.NamespacedName) if err != nil { - if errors.IsNotFound(err) { + if apierrors.IsNotFound(err) { return reconcile.Result{}, nil } return reconcile.Result{}, err @@ -157,7 +160,15 @@ func (r *reconciler) Reconcile(ctx context.Context, request reconcile.Request) ( return reconcile.Result{}, nil } -type ignoreError error +// ignoreError is a concrete error type used to signal that the error should +// be logged but not propagated. Using a named struct (rather than a type alias +// of error) ensures that errors.As only matches intentionally wrapped errors. +type ignoreError struct { + err error +} + +func (e ignoreError) Error() string { return e.err.Error() } +func (e ignoreError) Unwrap() error { return e.err } type action struct { requeueAfter time.Duration @@ -168,16 +179,7 @@ func (r *reconciler) reconcile( ctx context.Context, _ reconcile.Request, config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, ) (*action, error) { - var configValid bool - for _, condition := range config.Status.Conditions { - if condition.Type == schedulingv1alpha1.ClusterSizingConfigurationValidType && condition.Status == metav1.ConditionTrue { - configValid = true - break - } - } - if !configValid { - // we can't put clusters into t-shirt sizes unless we have a valid configuration; we'll re-trigger when - // the configuration object changes and can process clusters then + if !isConfigValid(config) { return nil, nil } @@ -189,7 +191,7 @@ func (r *reconciler) reconcile( isPaused, duration, err := hyperutil.ProcessPausedUntilField(hostedCluster.Spec.PausedUntil, r.now()) if err != nil { logger.Error(err, "error processing hosted cluster paused field") - return nil, nil // user needs to reformat the field, returning error is useless + return nil, nil } if isPaused { logger.Info("Reconciliation paused", "pausedUntil", *hostedCluster.Spec.PausedUntil) @@ -199,93 +201,144 @@ func (r *reconciler) reconcile( lastTransitionTime, lastSizeClass := previousTransitionFor(hostedCluster) currentSizeClass, sizeClassLabelPresent := hostedCluster.ObjectMeta.Labels[hypershiftv1beta1.HostedClusterSizeLabel] if lastTransitionTime != nil && !sizeClassLabelPresent || currentSizeClass != lastSizeClass { - // we can't update both the status and the labels in one call, so when we have updated status but - // have not yet updated the labels, we just need to do that first return &action{ applyCfg: hypershiftv1beta1applyconfigurations.HostedCluster(hostedCluster.Name, hostedCluster.Namespace). WithLabels(map[string]string{hypershiftv1beta1.HostedClusterSizeLabel: lastSizeClass}), }, nil } - var sizeClass *schedulingv1alpha1.SizeConfiguration + sizeClass, err := r.determineSizeClass(ctx, logger, config, hostedCluster, sizeClassLabelPresent) + if err != nil { + return nil, err + } + if sizeClass == nil { + return nil, nil + } + + if sizeClassLabelPresent && sizeClass.Name == currentSizeClass { + return r.clearTransientConditions(hostedCluster, lastTransitionTime) + } + + return r.transitionSizeClass(ctx, config, hostedCluster, sizeClass, currentSizeClass, sizeClassLabelPresent, lastTransitionTime) +} + +func isConfigValid(config *schedulingv1alpha1.ClusterSizingConfiguration) bool { + for _, condition := range config.Status.Conditions { + if condition.Type == schedulingv1alpha1.ClusterSizingConfigurationValidType && condition.Status == metav1.ConditionTrue { + return true + } + } + return false +} + +func (r *reconciler) determineSizeClass( + ctx context.Context, logger logr.Logger, + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, + sizeClassLabelPresent bool, +) (*schedulingv1alpha1.SizeConfiguration, error) { if overrideSize := hostedCluster.Annotations[hypershiftv1beta1.ClusterSizeOverrideAnnotation]; overrideSize != "" { - // given the override size, get the size configuration for i, class := range config.Spec.Sizes { if class.Name == overrideSize { - sizeClass = &config.Spec.Sizes[i] - } - } - } else if autoScaling := hostedCluster.Annotations[hypershiftv1beta1.ResourceBasedControlPlaneAutoscalingAnnotation]; autoScaling == "true" { - if len(config.Spec.Sizes) == 0 { - logger.Error(fmt.Errorf("could not find a size class for hosted cluster"), "no size can be set on hosted cluster") - return nil, nil - } - recommendedSize := hostedCluster.Annotations[hypershiftv1beta1.RecommendedClusterSizeAnnotation] - - // First, try to find the recommended size in the configuration - if recommendedSize != "" { - for i, class := range config.Spec.Sizes { - if class.Name == recommendedSize { - sizeClass = &config.Spec.Sizes[i] - logger.V(1).Info("Using recommended cluster size", "size", recommendedSize) - break - } + return &config.Spec.Sizes[i], nil } } + logger.Error(fmt.Errorf("could not find a size class for hosted cluster"), "no size can be set on hosted cluster") + return nil, nil + } - // If no recommended size is set, or the recommended size wasn't found, use the first size class as fallback - if sizeClass == nil { - sizeClass = &config.Spec.Sizes[0] - if recommendedSize == "" { - logger.Info("Resource-based autoscaling enabled but no recommended size set, using first size class", "defaultSize", sizeClass.Name) - } else { - logger.Info("Recommended size not found in configuration, falling back to first size class", "requestedSize", recommendedSize, "fallbackSize", sizeClass.Name) + if autoScaling := hostedCluster.Annotations[hypershiftv1beta1.ResourceBasedControlPlaneAutoscalingAnnotation]; autoScaling == "true" { + return r.determineSizeClassFromAutoscaling(logger, config, hostedCluster) + } + + return r.determineSizeClassFromNodeCount(ctx, logger, config, hostedCluster, sizeClassLabelPresent) +} + +func (r *reconciler) determineSizeClassFromAutoscaling( + logger logr.Logger, + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, +) (*schedulingv1alpha1.SizeConfiguration, error) { + if len(config.Spec.Sizes) == 0 { + logger.Error(fmt.Errorf("could not find a size class for hosted cluster"), "no size can be set on hosted cluster") + return nil, nil + } + recommendedSize := hostedCluster.Annotations[hypershiftv1beta1.RecommendedClusterSizeAnnotation] + + if recommendedSize != "" { + for i, class := range config.Spec.Sizes { + if class.Name == recommendedSize { + logger.V(1).Info("Using recommended cluster size", "size", recommendedSize) + return &config.Spec.Sizes[i], nil } } + } + + sizeClass := &config.Spec.Sizes[0] + if recommendedSize == "" { + logger.Info("Resource-based autoscaling enabled but no recommended size set, using first size class", "defaultSize", sizeClass.Name) } else { - nodeCount, err := r.determineNodeCount(ctx, hostedCluster, sizeClassLabelPresent) - if err != nil { - if _, ignore := err.(ignoreError); ignore { - logger.Info("Ignoring error", "error", err.Error()) - return nil, nil - } - return nil, err - } + logger.Info("Recommended size not found in configuration, falling back to first size class", "requestedSize", recommendedSize, "fallbackSize", sizeClass.Name) + } + return sizeClass, nil +} - // given the node count we need to figure out if we need to transition to another t-shirt size - for i, class := range config.Spec.Sizes { - if class.Criteria.From <= nodeCount && (class.Criteria.To == nil || *class.Criteria.To >= nodeCount) { - sizeClass = &config.Spec.Sizes[i] - } +func (r *reconciler) determineSizeClassFromNodeCount( + ctx context.Context, logger logr.Logger, + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, + sizeClassLabelPresent bool, +) (*schedulingv1alpha1.SizeConfiguration, error) { + nodeCount, err := r.determineNodeCount(ctx, hostedCluster, sizeClassLabelPresent) + if err != nil { + var ignore ignoreError + if errors.As(err, &ignore) { + logger.Info("Ignoring error", "error", err.Error()) + return nil, nil } + return nil, err } + var sizeClass *schedulingv1alpha1.SizeConfiguration + for i, class := range config.Spec.Sizes { + if class.Criteria.From <= nodeCount && (class.Criteria.To == nil || *class.Criteria.To >= nodeCount) { + sizeClass = &config.Spec.Sizes[i] + } + } if sizeClass == nil { logger.Error(fmt.Errorf("could not find a size class for hosted cluster"), "no size can be set on hosted cluster") - return nil, nil } - if sizeClassLabelPresent && sizeClass.Name == currentSizeClass { - // no transition necessary, clear transient conditions - cfg := applyCfgFor(hostedCluster, - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionPending). - WithStatus(metav1.ConditionFalse). - WithReason("ClusterSizeTransitioned"). - WithMessage("The HostedCluster has transitioned to a new t-shirt size."). - WithLastTransitionTime(metav1.NewTime(*lastTransitionTime)), - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). - WithStatus(metav1.ConditionFalse). - WithReason(hypershiftv1beta1.AsExpectedReason). - WithMessage("The HostedCluster has transitioned to a new t-shirt size."). - WithLastTransitionTime(metav1.NewTime(*lastTransitionTime)), - ) - if cfg != nil { - return &action{applyCfg: cfg}, nil - } - return nil, nil + return sizeClass, nil +} + +func (r *reconciler) clearTransientConditions(hostedCluster *hypershiftv1beta1.HostedCluster, lastTransitionTime *time.Time) (*action, error) { + transitionTime := r.now() + if lastTransitionTime != nil { + transitionTime = *lastTransitionTime + } + cfg := applyCfgFor(hostedCluster, + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeTransitionPending). + WithStatus(metav1.ConditionFalse). + WithReason("ClusterSizeTransitioned"). + WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithLastTransitionTime(metav1.NewTime(transitionTime)), + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). + WithStatus(metav1.ConditionFalse). + WithReason(hypershiftv1beta1.AsExpectedReason). + WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithLastTransitionTime(metav1.NewTime(transitionTime)), + ) + if cfg != nil { + return &action{applyCfg: cfg}, nil } + return nil, nil +} +func (r *reconciler) transitionSizeClass( + ctx context.Context, + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, + sizeClass *schedulingv1alpha1.SizeConfiguration, currentSizeClass string, sizeClassLabelPresent bool, + lastTransitionTime *time.Time, +) (*action, error) { previousMinimumSize := uint32(0) if sizeClassLabelPresent { for _, class := range config.Spec.Sizes { @@ -296,17 +349,57 @@ func (r *reconciler) reconcile( } increasingSize := previousMinimumSize < sizeClass.Criteria.From - // third, we need to know if we're ready to transition the cluster: - // - the hosted cluster has limits to how quickly it can transition up and down, and - // - the management plane has limits to how many clusters can be transitioning at any time + if result := r.checkTransitionDelay(config, hostedCluster, sizeClass, increasingSize, lastTransitionTime); result != nil { + if result.applyCfg == nil { + return &action{requeueAfter: result.requeueAfter}, nil + } + return result, nil + } + + if result, err := r.checkConcurrencyLimit(ctx, config, hostedCluster, sizeClass); err != nil || result != nil { + if result != nil && result.applyCfg == nil { + return &action{requeueAfter: result.requeueAfter}, nil + } + return result, err + } + + cfg := applyCfgFor(hostedCluster, + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeComputed). + WithStatus(metav1.ConditionTrue). + WithReason(sizeClass.Name). + WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithLastTransitionTime(metav1.NewTime(r.now())), + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeTransitionPending). + WithStatus(metav1.ConditionFalse). + WithReason("ClusterSizeTransitioned"). + WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithLastTransitionTime(metav1.NewTime(r.now())), + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). + WithStatus(metav1.ConditionFalse). + WithReason(hypershiftv1beta1.AsExpectedReason). + WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithLastTransitionTime(metav1.NewTime(r.now())), + ) + if cfg != nil { + return &action{applyCfg: cfg}, nil + } + return nil, nil +} + +func (r *reconciler) checkTransitionDelay( + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, + sizeClass *schedulingv1alpha1.SizeConfiguration, increasingSize bool, + lastTransitionTime *time.Time, +) *action { delayStart := time.Time{} if lastTransitionTime != nil { - // if we transitioned in the past, we need to enforce the delay from there delayStart = *lastTransitionTime } lastComputedTime, lastComputedSizeClass := previousComputedSizeFor(hostedCluster) if lastComputedTime != nil && lastComputedSizeClass == sizeClass.Name { - // we computed that the cluster should transition already; enforce the delay from that point delayStart = *lastComputedTime } var delay time.Duration @@ -318,84 +411,69 @@ func (r *reconciler) reconcile( transition = "decrease" delay = config.Spec.TransitionDelay.Decrease.Duration } - if r.now().Sub(delayStart) < delay { - cfg := applyCfgFor(hostedCluster, - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionPending). - WithStatus(metav1.ConditionTrue). - WithReason("TransitionDelayNotElapsed"). - WithMessage(fmt.Sprintf("HostedClusters must wait at least %s to %s in size after the cluster size changes.", delay.String(), transition)). - WithLastTransitionTime(metav1.NewTime(r.now())), - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). - WithStatus(metav1.ConditionTrue). - WithReason(sizeClass.Name). - WithMessage("The HostedCluster will transition to a new t-shirt size."). - WithLastTransitionTime(metav1.NewTime(r.now())), - ) - if cfg != nil { - return &action{applyCfg: cfg, requeueAfter: delayStart.Add(delay).Sub(r.now())}, nil - } else { - return nil, nil - } - } - - // For new clusters being added to the fleet, we have an SLA on creation time and can't afford to delay - // the first transition, as it is required for the control plane to schedule. For other clusters, though, - // we want to limit the amount of churn happening in order to promote the stability of the management plane. - if scheduled := hostedCluster.Annotations[hypershiftv1beta1.HostedClusterScheduledAnnotation]; scheduled == "true" { - hostedClusters, err := r.listHostedClusters(ctx) - if err != nil { - return nil, fmt.Errorf("failed to list hosted clusters when calculating concurrency: %w", err) - } - - if changes, durationUntilChanges := transitionsWithinSlidingWindow(hostedClusters, config.Spec.Concurrency.SlidingWindow.Duration, r.now()); int32(changes) >= config.Spec.Concurrency.Limit { - cfg := applyCfgFor(hostedCluster, - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionPending). - WithStatus(metav1.ConditionTrue). - WithReason("ConcurrencyLimitReached"). - WithMessage(fmt.Sprintf("%d HostedClusters have already transitioned sizes in the last %s, more time must elapse before the next transition.", changes, config.Spec.Concurrency.SlidingWindow.Duration.String())). - WithLastTransitionTime(metav1.NewTime(r.now())), - metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). - WithStatus(metav1.ConditionTrue). - WithReason(sizeClass.Name). - WithMessage("The HostedCluster will transition to a new t-shirt size."). - WithLastTransitionTime(metav1.NewTime(r.now())), - ) - if cfg != nil { - return &action{applyCfg: cfg, requeueAfter: durationUntilChanges}, nil - } else { - return nil, nil - } - } + if r.now().Sub(delayStart) >= delay { + return nil } cfg := applyCfgFor(hostedCluster, metav1applyconfigurations.Condition(). - WithType(hypershiftv1beta1.ClusterSizeComputed). + WithType(hypershiftv1beta1.ClusterSizeTransitionPending). + WithStatus(metav1.ConditionTrue). + WithReason("TransitionDelayNotElapsed"). + WithMessage(fmt.Sprintf("HostedClusters must wait at least %s to %s in size after the cluster size changes.", delay.String(), transition)). + WithLastTransitionTime(metav1.NewTime(r.now())), + metav1applyconfigurations.Condition(). + WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). WithStatus(metav1.ConditionTrue). WithReason(sizeClass.Name). - WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithMessage("The HostedCluster will transition to a new t-shirt size."). WithLastTransitionTime(metav1.NewTime(r.now())), + ) + if cfg != nil { + return &action{applyCfg: cfg, requeueAfter: delayStart.Add(delay).Sub(r.now())} + } + // Conditions already match status; no update needed but delay has not elapsed, so requeue and stop processing. + return &action{requeueAfter: delayStart.Add(delay).Sub(r.now())} +} + +func (r *reconciler) checkConcurrencyLimit( + ctx context.Context, + config *schedulingv1alpha1.ClusterSizingConfiguration, hostedCluster *hypershiftv1beta1.HostedCluster, + sizeClass *schedulingv1alpha1.SizeConfiguration, +) (*action, error) { + if scheduled := hostedCluster.Annotations[hypershiftv1beta1.HostedClusterScheduledAnnotation]; scheduled != "true" { + return nil, nil + } + + hostedClusters, err := r.listHostedClusters(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list hosted clusters when calculating concurrency: %w", err) + } + + changes, durationUntilChanges := transitionsWithinSlidingWindow(hostedClusters, config.Spec.Concurrency.SlidingWindow.Duration, r.now()) + if int32(changes) < config.Spec.Concurrency.Limit { + return nil, nil + } + + cfg := applyCfgFor(hostedCluster, metav1applyconfigurations.Condition(). WithType(hypershiftv1beta1.ClusterSizeTransitionPending). - WithStatus(metav1.ConditionFalse). - WithReason("ClusterSizeTransitioned"). - WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithStatus(metav1.ConditionTrue). + WithReason("ConcurrencyLimitReached"). + WithMessage(fmt.Sprintf("%d HostedClusters have already transitioned sizes in the last %s, more time must elapse before the next transition.", changes, config.Spec.Concurrency.SlidingWindow.Duration.String())). WithLastTransitionTime(metav1.NewTime(r.now())), metav1applyconfigurations.Condition(). WithType(hypershiftv1beta1.ClusterSizeTransitionRequired). - WithStatus(metav1.ConditionFalse). - WithReason(hypershiftv1beta1.AsExpectedReason). - WithMessage("The HostedCluster has transitioned to a new t-shirt size."). + WithStatus(metav1.ConditionTrue). + WithReason(sizeClass.Name). + WithMessage("The HostedCluster will transition to a new t-shirt size."). WithLastTransitionTime(metav1.NewTime(r.now())), ) if cfg != nil { - return &action{applyCfg: cfg}, nil + return &action{applyCfg: cfg, requeueAfter: durationUntilChanges}, nil } - return nil, nil + // Conditions already match status; no update needed but concurrency limit reached, so requeue and stop processing. + return &action{requeueAfter: durationUntilChanges}, nil } func (r *reconciler) determineNodeCount(ctx context.Context, hostedCluster *hypershiftv1beta1.HostedCluster, sizeClassLabelPresent bool) (uint32, error) { @@ -424,7 +502,7 @@ func (r *reconciler) determineNodeCount(ctx context.Context, hostedCluster *hype if hccoReportsNodeCount { hostedControlPlane, err := r.hostedControlPlaneForHostedCluster(ctx, hostedCluster) if err != nil { - return 0, ignoreError(fmt.Errorf("failed to get hosted control plane: %w", err)) + return 0, ignoreError{err: fmt.Errorf("failed to get hosted control plane: %w", err)} } if hostedControlPlane.Status.NodeCount != nil && *hostedControlPlane.Status.NodeCount > 0 { @@ -442,7 +520,7 @@ func (r *reconciler) determineNodeCount(ctx context.Context, hostedCluster *hype if nodePool.Spec.AutoScaling != nil { // If the Kube API Server is not available, and we already have a size label, skip processing if !kasAvailable && sizeClassLabelPresent { - return 0, ignoreError(fmt.Errorf("KAS is not available, and no size class label is set yet")) + return 0, ignoreError{err: fmt.Errorf("KAS is not available, and no size class label is set yet")} } replicas = uint32(nodePool.Status.Replicas) } else if nodePool.Spec.Replicas != nil { diff --git a/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller_test.go b/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller_test.go index b76cb5eb5eb..cdacd1affe9 100644 --- a/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller_test.go +++ b/hypershift-operator/controllers/hostedclustersizing/hostedclustersizing_controller_test.go @@ -2,9 +2,12 @@ package hostedclustersizing import ( "context" + "fmt" "testing" "time" + . "github.com/onsi/gomega" + hypershiftv1beta1 "github.com/openshift/hypershift/api/hypershift/v1beta1" schedulingv1alpha1 "github.com/openshift/hypershift/api/scheduling/v1alpha1" hypershiftv1beta1applyconfigurations "github.com/openshift/hypershift/client/applyconfiguration/hypershift/v1beta1" @@ -951,7 +954,7 @@ func TestSizingController_Reconcile(t *testing.T) { }}, }, { - name: "no-op, delay already exposed in status", + name: "no-op, delay already exposed in status, preserves requeue", config: validCommonConfig, hostedCluster: &hypershiftv1beta1.HostedCluster{ ObjectMeta: metav1.ObjectMeta{ @@ -992,6 +995,7 @@ func TestSizingController_Reconcile(t *testing.T) { Status: hypershiftv1beta1.HostedControlPlaneStatus{NodeCount: ptr.To(3)}, }, nil }, + expected: &action{requeueAfter: 8 * time.Minute}, }, { name: "delay for concurrency", @@ -1114,7 +1118,7 @@ func TestSizingController_Reconcile(t *testing.T) { }, requeueAfter: 5 * time.Minute}, }, { - name: "delay for concurrency, no-op since condition already present", + name: "delay for concurrency, no-op since condition already present, preserves requeue", config: validCommonConfig, hostedCluster: &hypershiftv1beta1.HostedCluster{ ObjectMeta: metav1.ObjectMeta{ @@ -1162,6 +1166,7 @@ func TestSizingController_Reconcile(t *testing.T) { Status: hypershiftv1beta1.HostedControlPlaneStatus{NodeCount: ptr.To(3)}, }, nil }, + expected: &action{requeueAfter: 5 * time.Minute}, }, { name: "delay for concurrency, undo conditions since cluster returned to original size during delay", @@ -1609,6 +1614,610 @@ func TestSizingController_Reconcile(t *testing.T) { } } +func TestIsConfigValid(t *testing.T) { + for _, tc := range []struct { + name string + config *schedulingv1alpha1.ClusterSizingConfiguration + expected bool + }{ + { + name: "When the Valid condition is True, it should return true", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Status: schedulingv1alpha1.ClusterSizingConfigurationStatus{ + Conditions: []metav1.Condition{ + {Type: schedulingv1alpha1.ClusterSizingConfigurationValidType, Status: metav1.ConditionTrue}, + }, + }, + }, + expected: true, + }, + { + name: "When the Valid condition is False, it should return false", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Status: schedulingv1alpha1.ClusterSizingConfigurationStatus{ + Conditions: []metav1.Condition{ + {Type: schedulingv1alpha1.ClusterSizingConfigurationValidType, Status: metav1.ConditionFalse}, + }, + }, + }, + expected: false, + }, + { + name: "When no conditions exist, it should return false", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Status: schedulingv1alpha1.ClusterSizingConfigurationStatus{}, + }, + expected: false, + }, + { + name: "When the Valid condition is Unknown, it should return false", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Status: schedulingv1alpha1.ClusterSizingConfigurationStatus{ + Conditions: []metav1.Condition{ + {Type: schedulingv1alpha1.ClusterSizingConfigurationValidType, Status: metav1.ConditionUnknown}, + }, + }, + }, + expected: false, + }, + { + name: "When only unrelated conditions exist, it should return false", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Status: schedulingv1alpha1.ClusterSizingConfigurationStatus{ + Conditions: []metav1.Condition{ + {Type: "SomeOtherType", Status: metav1.ConditionTrue}, + }, + }, + }, + expected: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + g.Expect(isConfigValid(tc.config)).To(Equal(tc.expected)) + }) + } +} + +func TestDetermineSizeClass(t *testing.T) { + ctrl.SetLogger(zap.New(zap.UseDevMode(true), zap.JSONEncoder(func(o *zapcore.EncoderConfig) { + o.EncodeTime = zapcore.RFC3339TimeEncoder + }))) + logger := ctrl.Log + + configWithSizes := &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Sizes: []schedulingv1alpha1.SizeConfiguration{ + {Name: "small", Criteria: schedulingv1alpha1.NodeCountCriteria{From: 0, To: ptr.To(uint32(10))}}, + {Name: "medium", Criteria: schedulingv1alpha1.NodeCountCriteria{From: 11, To: ptr.To(uint32(100))}}, + {Name: "large", Criteria: schedulingv1alpha1.NodeCountCriteria{From: 101}}, + }, + }, + } + + for _, tc := range []struct { + name string + config *schedulingv1alpha1.ClusterSizingConfiguration + hostedCluster *hypershiftv1beta1.HostedCluster + sizeClassLabelPresent bool + expectedName string + expectNil bool + }{ + { + name: "When override annotation references a valid size, it should return that size", + config: configWithSizes, + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{hypershiftv1beta1.ClusterSizeOverrideAnnotation: "medium"}, + }, + }, + expectedName: "medium", + }, + { + name: "When override annotation references an invalid size, it should return nil", + config: configWithSizes, + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{hypershiftv1beta1.ClusterSizeOverrideAnnotation: "nonexistent"}, + }, + }, + expectNil: true, + }, + { + name: "When autoscaling annotation is true with valid recommended size, it should return that size", + config: configWithSizes, + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + hypershiftv1beta1.ResourceBasedControlPlaneAutoscalingAnnotation: "true", + hypershiftv1beta1.RecommendedClusterSizeAnnotation: "large", + }, + }, + }, + expectedName: "large", + }, + { + name: "When autoscaling annotation is not 'true', it should fall through to node count path", + config: configWithSizes, + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + hypershiftv1beta1.ResourceBasedControlPlaneAutoscalingAnnotation: "false", + hypershiftv1beta1.RecommendedClusterSizeAnnotation: "large", + }, + }, + }, + // Falls through to determineSizeClassFromNodeCount; with nodeCount=0 -> "small" + expectedName: "small", + }, + { + name: "When override takes priority over autoscaling, it should use the override", + config: configWithSizes, + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + hypershiftv1beta1.ClusterSizeOverrideAnnotation: "large", + hypershiftv1beta1.ResourceBasedControlPlaneAutoscalingAnnotation: "true", + hypershiftv1beta1.RecommendedClusterSizeAnnotation: "small", + }, + }, + }, + expectedName: "large", + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + r := &reconciler{ + now: time.Now, + hccoReportsNodeCount: func(_ context.Context, _ *hypershiftv1beta1.HostedCluster) (bool, error) { + return false, nil + }, + nodePoolsForHostedCluster: func(_ context.Context, _ *hypershiftv1beta1.HostedCluster) (*hypershiftv1beta1.NodePoolList, error) { + return &hypershiftv1beta1.NodePoolList{}, nil + }, + } + + result, err := r.determineSizeClass(t.Context(), logger, tc.config, tc.hostedCluster, tc.sizeClassLabelPresent) + g.Expect(err).ToNot(HaveOccurred()) + if tc.expectNil { + g.Expect(result).To(BeNil()) + } else { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.Name).To(Equal(tc.expectedName)) + } + }) + } +} + +func TestDetermineSizeClassFromAutoscaling(t *testing.T) { + ctrl.SetLogger(zap.New(zap.UseDevMode(true), zap.JSONEncoder(func(o *zapcore.EncoderConfig) { + o.EncodeTime = zapcore.RFC3339TimeEncoder + }))) + logger := ctrl.Log + + for _, tc := range []struct { + name string + config *schedulingv1alpha1.ClusterSizingConfiguration + hc *hypershiftv1beta1.HostedCluster + expectedName string + expectNil bool + }{ + { + name: "When sizes list is empty, it should return nil", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Sizes: []schedulingv1alpha1.SizeConfiguration{}, + }, + }, + hc: &hypershiftv1beta1.HostedCluster{}, + expectNil: true, + }, + { + name: "When recommended size annotation matches a configured size, it should return that size", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Sizes: []schedulingv1alpha1.SizeConfiguration{ + {Name: "small"}, + {Name: "medium"}, + {Name: "large"}, + }, + }, + }, + hc: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + hypershiftv1beta1.RecommendedClusterSizeAnnotation: "medium", + }, + }, + }, + expectedName: "medium", + }, + { + name: "When recommended size annotation does not match any configured size, it should fall back to first size", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Sizes: []schedulingv1alpha1.SizeConfiguration{ + {Name: "small"}, + {Name: "large"}, + }, + }, + }, + hc: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + hypershiftv1beta1.RecommendedClusterSizeAnnotation: "nonexistent", + }, + }, + }, + expectedName: "small", + }, + { + name: "When no recommended size annotation is set, it should fall back to first size", + config: &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Sizes: []schedulingv1alpha1.SizeConfiguration{ + {Name: "tiny"}, + {Name: "huge"}, + }, + }, + }, + hc: &hypershiftv1beta1.HostedCluster{}, + expectedName: "tiny", + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + r := &reconciler{} + + result, err := r.determineSizeClassFromAutoscaling(logger, tc.config, tc.hc) + g.Expect(err).ToNot(HaveOccurred()) + if tc.expectNil { + g.Expect(result).To(BeNil()) + } else { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.Name).To(Equal(tc.expectedName)) + } + }) + } +} + +func TestCheckTransitionDelay(t *testing.T) { + theTime, err := time.Parse(time.RFC3339Nano, "2006-01-02T15:04:05.000000000Z") + if err != nil { + t.Fatalf("could not parse time: %v", err) + } + fakeClock := testingclock.NewFakeClock(theTime) + + config := &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + TransitionDelay: schedulingv1alpha1.TransitionDelayConfiguration{ + Increase: metav1.Duration{Duration: 30 * time.Second}, + Decrease: metav1.Duration{Duration: 10 * time.Minute}, + }, + }, + } + + for _, tc := range []struct { + name string + hostedCluster *hypershiftv1beta1.HostedCluster + sizeClass *schedulingv1alpha1.SizeConfiguration + increasingSize bool + lastTransitionTime *time.Time + expectNil bool + expectNoOp bool + expectRequeue time.Duration + }{ + { + name: "When increase delay has elapsed since last transition, it should return nil to allow transition", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + }, + sizeClass: &schedulingv1alpha1.SizeConfiguration{Name: "large"}, + increasingSize: true, + lastTransitionTime: ptr.To(fakeClock.Now().Add(-1 * time.Minute)), + expectNil: true, + }, + { + name: "When decrease delay has not elapsed since last transition, it should return action with requeue", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + }, + sizeClass: &schedulingv1alpha1.SizeConfiguration{Name: "small"}, + increasingSize: false, + lastTransitionTime: ptr.To(fakeClock.Now().Add(-1 * time.Minute)), + expectRequeue: 9 * time.Minute, + }, + { + name: "When no previous transition exists and increase delay has not elapsed from zero time, it should return nil", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + }, + sizeClass: &schedulingv1alpha1.SizeConfiguration{Name: "large"}, + increasingSize: true, + expectNil: true, + }, + { + name: "When delay has not elapsed and conditions already match status, it should return action with requeue but no applyCfg", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + Status: hypershiftv1beta1.HostedClusterStatus{Conditions: []metav1.Condition{ + { + Type: hypershiftv1beta1.ClusterSizeTransitionPending, + Status: metav1.ConditionTrue, + Reason: "TransitionDelayNotElapsed", + Message: "HostedClusters must wait at least 10m0s to decrease in size after the cluster size changes.", + LastTransitionTime: metav1.NewTime(fakeClock.Now().Add(-30 * time.Second)), + }, + { + Type: hypershiftv1beta1.ClusterSizeTransitionRequired, + Status: metav1.ConditionTrue, + Reason: "small", + Message: "The HostedCluster will transition to a new t-shirt size.", + LastTransitionTime: metav1.NewTime(fakeClock.Now().Add(-1 * time.Minute)), + }, + }}, + }, + sizeClass: &schedulingv1alpha1.SizeConfiguration{Name: "small"}, + increasingSize: false, + lastTransitionTime: ptr.To(fakeClock.Now().Add(-5 * time.Minute)), + expectNoOp: true, + expectRequeue: 9 * time.Minute, + }, + { + name: "When computed size matches the target and delay start is from computed time, it should use computed time as delay start", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + Status: hypershiftv1beta1.HostedClusterStatus{Conditions: []metav1.Condition{ + { + Type: hypershiftv1beta1.ClusterSizeTransitionRequired, + Status: metav1.ConditionTrue, + Reason: "large", + Message: "The HostedCluster will transition to a new t-shirt size.", + LastTransitionTime: metav1.NewTime(fakeClock.Now().Add(-20 * time.Second)), + }, + }}, + }, + sizeClass: &schedulingv1alpha1.SizeConfiguration{Name: "large"}, + increasingSize: true, + lastTransitionTime: ptr.To(fakeClock.Now().Add(-5 * time.Minute)), + expectRequeue: 10 * time.Second, + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + r := &reconciler{now: fakeClock.Now} + + result := r.checkTransitionDelay(config, tc.hostedCluster, tc.sizeClass, tc.increasingSize, tc.lastTransitionTime) + if tc.expectNil { + g.Expect(result).To(BeNil()) + } else if tc.expectNoOp { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.applyCfg).To(BeNil(), "no-op action should have nil applyCfg") + g.Expect(result.requeueAfter).To(Equal(tc.expectRequeue), "no-op action should preserve requeueAfter") + } else { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.applyCfg).ToNot(BeNil()) + g.Expect(result.requeueAfter).To(Equal(tc.expectRequeue)) + } + }) + } +} + +func TestCheckConcurrencyLimit(t *testing.T) { + theTime, err := time.Parse(time.RFC3339Nano, "2006-01-02T15:04:05.000000000Z") + if err != nil { + t.Fatalf("could not parse time: %v", err) + } + fakeClock := testingclock.NewFakeClock(theTime) + + config := &schedulingv1alpha1.ClusterSizingConfiguration{ + Spec: schedulingv1alpha1.ClusterSizingConfigurationSpec{ + Concurrency: schedulingv1alpha1.ConcurrencyConfiguration{ + SlidingWindow: metav1.Duration{Duration: 10 * time.Minute}, + Limit: 3, + }, + }, + } + + sizeClass := &schedulingv1alpha1.SizeConfiguration{Name: "large"} + + for _, tc := range []struct { + name string + hostedCluster *hypershiftv1beta1.HostedCluster + listHostedClusters func(context.Context) (*hypershiftv1beta1.HostedClusterList, error) + expectNil bool + expectNoOp bool + expectRequeue bool + expectErr bool + }{ + { + name: "When cluster is not scheduled, it should return nil to skip concurrency check", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + }, + expectNil: true, + }, + { + name: "When cluster is scheduled but transitions are under the limit, it should return nil to allow transition", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc", + Annotations: map[string]string{hypershiftv1beta1.HostedClusterScheduledAnnotation: "true"}, + }, + }, + listHostedClusters: func(_ context.Context) (*hypershiftv1beta1.HostedClusterList, error) { + return &hypershiftv1beta1.HostedClusterList{Items: []hypershiftv1beta1.HostedCluster{ + hostedClusterWithTransition("first", fakeClock.Now().Add(-1*time.Minute)), + hostedClusterWithTransition("second", fakeClock.Now().Add(-2*time.Minute)), + }}, nil + }, + expectNil: true, + }, + { + name: "When cluster is scheduled and transitions are at the limit, it should return action with requeue", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc", + Annotations: map[string]string{hypershiftv1beta1.HostedClusterScheduledAnnotation: "true"}, + }, + }, + listHostedClusters: func(_ context.Context) (*hypershiftv1beta1.HostedClusterList, error) { + return &hypershiftv1beta1.HostedClusterList{Items: []hypershiftv1beta1.HostedCluster{ + hostedClusterWithTransition("first", fakeClock.Now().Add(-1*time.Minute)), + hostedClusterWithTransition("second", fakeClock.Now().Add(-2*time.Minute)), + hostedClusterWithTransition("third", fakeClock.Now().Add(-3*time.Minute)), + }}, nil + }, + expectRequeue: true, + }, + { + name: "When cluster is scheduled and conditions already match, it should return action with requeue but no applyCfg", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc", + Annotations: map[string]string{hypershiftv1beta1.HostedClusterScheduledAnnotation: "true"}, + }, + Status: hypershiftv1beta1.HostedClusterStatus{Conditions: []metav1.Condition{ + { + Type: hypershiftv1beta1.ClusterSizeTransitionPending, + Status: metav1.ConditionTrue, + Reason: "ConcurrencyLimitReached", + Message: "3 HostedClusters have already transitioned sizes in the last 10m0s, more time must elapse before the next transition.", + }, + { + Type: hypershiftv1beta1.ClusterSizeTransitionRequired, + Status: metav1.ConditionTrue, + Reason: "large", + Message: "The HostedCluster will transition to a new t-shirt size.", + }, + }}, + }, + listHostedClusters: func(_ context.Context) (*hypershiftv1beta1.HostedClusterList, error) { + return &hypershiftv1beta1.HostedClusterList{Items: []hypershiftv1beta1.HostedCluster{ + hostedClusterWithTransition("first", fakeClock.Now().Add(-1*time.Minute)), + hostedClusterWithTransition("second", fakeClock.Now().Add(-2*time.Minute)), + hostedClusterWithTransition("third", fakeClock.Now().Add(-3*time.Minute)), + }}, nil + }, + expectNoOp: true, + }, + { + name: "When listing hosted clusters fails, it should return an error", + hostedCluster: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc", + Annotations: map[string]string{hypershiftv1beta1.HostedClusterScheduledAnnotation: "true"}, + }, + }, + listHostedClusters: func(_ context.Context) (*hypershiftv1beta1.HostedClusterList, error) { + return nil, fmt.Errorf("list error") + }, + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + r := &reconciler{ + now: fakeClock.Now, + listHostedClusters: tc.listHostedClusters, + } + + result, err := r.checkConcurrencyLimit(t.Context(), config, tc.hostedCluster, sizeClass) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + if tc.expectNil { + g.Expect(result).To(BeNil()) + } else if tc.expectNoOp { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.applyCfg).To(BeNil(), "no-op action should have nil applyCfg") + g.Expect(result.requeueAfter).To(BeNumerically(">", 0), "no-op action should preserve requeueAfter") + } else if tc.expectRequeue { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.applyCfg).ToNot(BeNil()) + g.Expect(result.requeueAfter).To(BeNumerically(">", 0)) + } + }) + } +} + +func TestClearTransientConditions(t *testing.T) { + theTime, err := time.Parse(time.RFC3339Nano, "2006-01-02T15:04:05.000000000Z") + if err != nil { + t.Fatalf("could not parse time: %v", err) + } + fakeClock := testingclock.NewFakeClock(theTime) + transitionTime := fakeClock.Now().Add(-5 * time.Minute) + + for _, tc := range []struct { + name string + hc *hypershiftv1beta1.HostedCluster + expectNil bool + }{ + { + name: "When transient conditions need clearing, it should return an action with updated conditions", + hc: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + Status: hypershiftv1beta1.HostedClusterStatus{ + Conditions: []metav1.Condition{ + { + Type: hypershiftv1beta1.ClusterSizeTransitionPending, + Status: metav1.ConditionTrue, + Reason: "TransitionDelayNotElapsed", + Message: "some message", + }, + }, + }, + }, + }, + { + name: "When transient conditions already match desired state, it should return nil", + hc: &hypershiftv1beta1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{Namespace: "ns", Name: "hc"}, + Status: hypershiftv1beta1.HostedClusterStatus{ + Conditions: []metav1.Condition{ + { + Type: hypershiftv1beta1.ClusterSizeTransitionPending, + Status: metav1.ConditionFalse, + Reason: "ClusterSizeTransitioned", + Message: "The HostedCluster has transitioned to a new t-shirt size.", + }, + { + Type: hypershiftv1beta1.ClusterSizeTransitionRequired, + Status: metav1.ConditionFalse, + Reason: hypershiftv1beta1.AsExpectedReason, + Message: "The HostedCluster has transitioned to a new t-shirt size.", + }, + }, + }, + }, + expectNil: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + r := &reconciler{now: fakeClock.Now} + + result, err := r.clearTransientConditions(tc.hc, &transitionTime) + g.Expect(err).ToNot(HaveOccurred()) + if tc.expectNil { + g.Expect(result).To(BeNil()) + } else { + g.Expect(result).ToNot(BeNil()) + g.Expect(result.applyCfg).ToNot(BeNil()) + g.Expect(result.applyCfg.Status).ToNot(BeNil()) + } + }) + } +} + func hostedClusterWithTransition(name string, transition time.Time) hypershiftv1beta1.HostedCluster { return hypershiftv1beta1.HostedCluster{ ObjectMeta: metav1.ObjectMeta{ diff --git a/hypershift-operator/controllers/nodepool/aws.go b/hypershift-operator/controllers/nodepool/aws.go index 2c86a808f50..f1a847404e4 100644 --- a/hypershift-operator/controllers/nodepool/aws.go +++ b/hypershift-operator/controllers/nodepool/aws.go @@ -61,29 +61,84 @@ func isSpotEnabled(nodePool *hyperv1.NodePool) bool { } func awsMachineTemplateSpec(infraName string, hostedCluster *hyperv1.HostedCluster, nodePool *hyperv1.NodePool, defaultSG bool, releaseImage *releaseinfo.ReleaseImage) (*capiaws.AWSMachineTemplateSpec, error) { + ami, err := resolveAWSAMI(hostedCluster, nodePool, releaseImage) + if err != nil { + return nil, err + } + + subnet := buildAWSSubnet(nodePool) + rootVolume := buildAWSRootVolume(nodePool) + + securityGroups, err := buildAWSSecurityGroups(nodePool, hostedCluster, defaultSG) + if err != nil { + return nil, err + } + + instanceProfile := fmt.Sprintf("%s-worker-profile", infraName) + if nodePool.Spec.Platform.AWS.InstanceProfile != "" { + instanceProfile = nodePool.Spec.Platform.AWS.InstanceProfile + } + + instanceMetadataOptions := &capiaws.InstanceMetadataOptions{ + HTTPTokens: capiaws.HTTPTokensStateOptional, + HTTPPutResponseHopLimit: 2, + HTTPEndpoint: capiaws.InstanceMetadataEndpointStateEnabled, + InstanceMetadataTags: capiaws.InstanceMetadataEndpointStateDisabled, + } + if value, found := nodePool.Annotations[ec2InstanceMetadataHTTPTokensAnnotation]; found && value == string(capiaws.HTTPTokensStateRequired) { + instanceMetadataOptions.HTTPTokens = capiaws.HTTPTokensStateRequired + } + + awsMachineTemplateSpec := &capiaws.AWSMachineTemplateSpec{ + Template: capiaws.AWSMachineTemplateResource{ + Spec: capiaws.AWSMachineSpec{ + UncompressedUserData: ptr.To(true), + CloudInit: capiaws.CloudInit{ + InsecureSkipSecretsManager: true, + SecureSecretsBackend: "secrets-manager", + }, + IAMInstanceProfile: instanceProfile, + InstanceType: nodePool.Spec.Platform.AWS.InstanceType, + AMI: capiaws.AMIReference{ID: ptr.To(ami)}, + AdditionalSecurityGroups: securityGroups, + Subnet: subnet, + RootVolume: rootVolume, + AdditionalTags: awsAdditionalTags(nodePool, hostedCluster, infraName), + InstanceMetadataOptions: instanceMetadataOptions, + }, + }, + } + + applyAWSPlacementOptions(nodePool, awsMachineTemplateSpec) - var ami string + if hostedCluster.Annotations[hyperv1.AWSMachinePublicIPs] == "true" { + awsMachineTemplateSpec.Template.Spec.PublicIP = ptr.To(true) + } + + return awsMachineTemplateSpec, nil +} + +func resolveAWSAMI(hostedCluster *hyperv1.HostedCluster, nodePool *hyperv1.NodePool, releaseImage *releaseinfo.ReleaseImage) (string, error) { region := hostedCluster.Spec.Platform.AWS.Region arch := nodePool.Spec.Arch if nodePool.Spec.Platform.AWS.AMI != "" { - ami = nodePool.Spec.Platform.AWS.AMI - } else if nodePool.Spec.Platform.AWS.ImageType == hyperv1.ImageTypeWindows { - // Use Windows AMI mapping when ImageType is set to Windows - var err error - ami, err = getWindowsAMI(region, arch, releaseImage) - if err != nil { - return nil, fmt.Errorf("couldn't discover a Windows AMI for release image: %w", err) - } - } else { - // Default behavior for Linux/RHCOS AMIs - // TODO: Should the region be included in the NodePool platform information? - var err error - ami, err = defaultNodePoolAMI(region, arch, releaseImage) + return nodePool.Spec.Platform.AWS.AMI, nil + } + if nodePool.Spec.Platform.AWS.ImageType == hyperv1.ImageTypeWindows { + ami, err := getWindowsAMI(region, arch, releaseImage) if err != nil { - return nil, fmt.Errorf("couldn't discover an AMI for release image: %w", err) + return "", fmt.Errorf("couldn't discover a Windows AMI for release image: %w", err) } + return ami, nil + } + ami, err := defaultNodePoolAMI(region, arch, releaseImage) + if err != nil { + return "", fmt.Errorf("couldn't discover an AMI for release image: %w", err) } + return ami, nil +} +func buildAWSSubnet(nodePool *hyperv1.NodePool) *capiaws.AWSResourceReference { subnet := &capiaws.AWSResourceReference{} subnet.ID = nodePool.Spec.Platform.AWS.Subnet.ID for k := range nodePool.Spec.Platform.AWS.Subnet.Filters { @@ -93,27 +148,33 @@ func awsMachineTemplateSpec(infraName string, hostedCluster *hyperv1.HostedClust } subnet.Filters = append(subnet.Filters, filter) } + return subnet +} +func buildAWSRootVolume(nodePool *hyperv1.NodePool) *capiaws.Volume { rootVolume := &capiaws.Volume{ Size: EC2VolumeDefaultSize, } - if nodePool.Spec.Platform.AWS.RootVolume != nil { - if nodePool.Spec.Platform.AWS.RootVolume.Type != "" { - rootVolume.Type = capiaws.VolumeType(nodePool.Spec.Platform.AWS.RootVolume.Type) - } else { - rootVolume.Type = capiaws.VolumeType(EC2VolumeDefaultType) - } - if nodePool.Spec.Platform.AWS.RootVolume.Size > 0 { - rootVolume.Size = nodePool.Spec.Platform.AWS.RootVolume.Size - } - if nodePool.Spec.Platform.AWS.RootVolume.IOPS > 0 { - rootVolume.IOPS = nodePool.Spec.Platform.AWS.RootVolume.IOPS - } - - rootVolume.Encrypted = nodePool.Spec.Platform.AWS.RootVolume.Encrypted - rootVolume.EncryptionKey = nodePool.Spec.Platform.AWS.RootVolume.EncryptionKey + if nodePool.Spec.Platform.AWS.RootVolume == nil { + return rootVolume + } + if nodePool.Spec.Platform.AWS.RootVolume.Type != "" { + rootVolume.Type = capiaws.VolumeType(nodePool.Spec.Platform.AWS.RootVolume.Type) + } else { + rootVolume.Type = capiaws.VolumeType(EC2VolumeDefaultType) + } + if nodePool.Spec.Platform.AWS.RootVolume.Size > 0 { + rootVolume.Size = nodePool.Spec.Platform.AWS.RootVolume.Size } + if nodePool.Spec.Platform.AWS.RootVolume.IOPS > 0 { + rootVolume.IOPS = nodePool.Spec.Platform.AWS.RootVolume.IOPS + } + rootVolume.Encrypted = nodePool.Spec.Platform.AWS.RootVolume.Encrypted + rootVolume.EncryptionKey = nodePool.Spec.Platform.AWS.RootVolume.EncryptionKey + return rootVolume +} +func buildAWSSecurityGroups(nodePool *hyperv1.NodePool, hostedCluster *hyperv1.HostedCluster, defaultSG bool) ([]capiaws.AWSResourceReference, error) { securityGroups := []capiaws.AWSResourceReference{} for _, sg := range nodePool.Spec.Platform.AWS.SecurityGroups { var filters []capiaws.Filter @@ -137,92 +198,46 @@ func awsMachineTemplateSpec(infraName string, hostedCluster *hyperv1.HostedClust ID: &sgID, }) } + return securityGroups, nil +} - instanceProfile := fmt.Sprintf("%s-worker-profile", infraName) - if nodePool.Spec.Platform.AWS.InstanceProfile != "" { - instanceProfile = nodePool.Spec.Platform.AWS.InstanceProfile - } - - instanceType := nodePool.Spec.Platform.AWS.InstanceType - - instanceMetadataOptions := &capiaws.InstanceMetadataOptions{ - HTTPTokens: capiaws.HTTPTokensStateOptional, - HTTPPutResponseHopLimit: 2, // set to 2 as per AWS recommendation for container envs https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#imds-considerations - HTTPEndpoint: capiaws.InstanceMetadataEndpointStateEnabled, - InstanceMetadataTags: capiaws.InstanceMetadataEndpointStateDisabled, - } - if value, found := nodePool.Annotations[ec2InstanceMetadataHTTPTokensAnnotation]; found && value == string(capiaws.HTTPTokensStateRequired) { - instanceMetadataOptions.HTTPTokens = capiaws.HTTPTokensStateRequired +func applyAWSPlacementOptions(nodePool *hyperv1.NodePool, spec *capiaws.AWSMachineTemplateSpec) { + placement := nodePool.Spec.Platform.AWS.Placement + if placement == nil { + return } + spec.Template.Spec.Tenancy = placement.Tenancy - awsMachineTemplateSpec := &capiaws.AWSMachineTemplateSpec{ - Template: capiaws.AWSMachineTemplateResource{ - Spec: capiaws.AWSMachineSpec{ - UncompressedUserData: ptr.To(true), - CloudInit: capiaws.CloudInit{ - InsecureSkipSecretsManager: true, - SecureSecretsBackend: "secrets-manager", - }, - IAMInstanceProfile: instanceProfile, - InstanceType: instanceType, - AMI: capiaws.AMIReference{ - ID: ptr.To(ami), - }, - AdditionalSecurityGroups: securityGroups, - Subnet: subnet, - RootVolume: rootVolume, - AdditionalTags: awsAdditionalTags(nodePool, hostedCluster, infraName), - InstanceMetadataOptions: instanceMetadataOptions, - }, - }, - } - - if placement := nodePool.Spec.Platform.AWS.Placement; placement != nil { - awsMachineTemplateSpec.Template.Spec.Tenancy = placement.Tenancy - - // Handle market type - placement.MarketType takes precedence over capacityReservation.MarketType - switch placement.MarketType { - case hyperv1.MarketTypeSpot: - // Spot instances - awsMachineTemplateSpec.Template.Spec.SpotMarketOptions = &capiaws.SpotMarketOptions{} - if placement.Spot.MaxPrice != "" { - awsMachineTemplateSpec.Template.Spec.SpotMarketOptions.MaxPrice = ptr.To(placement.Spot.MaxPrice) - } - - case hyperv1.MarketTypeCapacityBlock: - awsMachineTemplateSpec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock - case hyperv1.MarketTypeOnDemand: - awsMachineTemplateSpec.Template.Spec.MarketType = capiaws.MarketTypeOnDemand - default: - // If placement.MarketType is not set, fall back to capacityReservation.MarketType (deprecated) - if capacityReservation := placement.CapacityReservation; capacityReservation != nil { - //nolint:staticcheck // SA1019: capacityReservation.MarketType is deprecated but supported for backward compatibility - switch capacityReservation.MarketType { - case hyperv1.MarketTypeCapacityBlock: - awsMachineTemplateSpec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock - case hyperv1.MarketTypeOnDemand: - awsMachineTemplateSpec.Template.Spec.MarketType = capiaws.MarketTypeOnDemand - default: - if placement.Tenancy != "host" && capacityReservation.ID != nil { - // if the tenancy is not host and the ID is set, default the market type to CapacityBlock - awsMachineTemplateSpec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock - } - } - } + switch placement.MarketType { + case hyperv1.MarketTypeSpot: + spec.Template.Spec.SpotMarketOptions = &capiaws.SpotMarketOptions{} + if placement.Spot.MaxPrice != "" { + spec.Template.Spec.SpotMarketOptions.MaxPrice = ptr.To(placement.Spot.MaxPrice) } - - // Handle capacity reservation options + case hyperv1.MarketTypeCapacityBlock: + spec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock + case hyperv1.MarketTypeOnDemand: + spec.Template.Spec.MarketType = capiaws.MarketTypeOnDemand + default: if capacityReservation := placement.CapacityReservation; capacityReservation != nil { - awsMachineTemplateSpec.Template.Spec.CapacityReservationID = capacityReservation.ID - awsMachineTemplateSpec.Template.Spec.CapacityReservationPreference = capiaws.CapacityReservationPreference(capacityReservation.Preference) + //nolint:staticcheck // SA1019: capacityReservation.MarketType is deprecated but supported for backward compatibility + switch capacityReservation.MarketType { + case hyperv1.MarketTypeCapacityBlock: + spec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock + case hyperv1.MarketTypeOnDemand: + spec.Template.Spec.MarketType = capiaws.MarketTypeOnDemand + default: + if placement.Tenancy != "host" && capacityReservation.ID != nil { + spec.Template.Spec.MarketType = capiaws.MarketTypeCapacityBlock + } + } } } - if hostedCluster.Annotations[hyperv1.AWSMachinePublicIPs] == "true" { - awsMachineTemplateSpec.Template.Spec.PublicIP = ptr.To(true) + if capacityReservation := placement.CapacityReservation; capacityReservation != nil { + spec.Template.Spec.CapacityReservationID = capacityReservation.ID + spec.Template.Spec.CapacityReservationPreference = capiaws.CapacityReservationPreference(capacityReservation.Preference) } - - return awsMachineTemplateSpec, nil } func awsAdditionalTags(nodePool *hyperv1.NodePool, hostedCluster *hyperv1.HostedCluster, infraName string) capiaws.Tags { diff --git a/hypershift-operator/controllers/nodepool/aws_test.go b/hypershift-operator/controllers/nodepool/aws_test.go index e23ffd3f4ea..00763009a15 100644 --- a/hypershift-operator/controllers/nodepool/aws_test.go +++ b/hypershift-operator/controllers/nodepool/aws_test.go @@ -5,6 +5,8 @@ import ( "strings" "testing" + . "github.com/onsi/gomega" + hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/releaseinfo" @@ -1039,3 +1041,630 @@ func TestIsSpotEnabled(t *testing.T) { }) } } + +func TestResolveAWSAMI(t *testing.T) { + releaseImageWithMetadata := &releaseinfo.ReleaseImage{ + ImageStream: &v1.ImageStream{ + ObjectMeta: metav1.ObjectMeta{Name: "4.17.0"}, + }, + StreamMetadata: &releaseinfo.CoreOSStreamMetadata{ + Architectures: map[string]releaseinfo.CoreOSArchitecture{ + "x86_64": { + RHCOS: releaseinfo.CoreRHCOSImage{ + AWSWinLi: releaseinfo.CoreAWSWinLi{ + Regions: map[string]releaseinfo.CoreAWSWinLiRegion{ + "us-east-1": { + Release: "418.94.202410090804-0", + Image: "ami-windows-us-east-1", + }, + }, + }, + }, + }, + }, + }, + } + + testCases := []struct { + name string + hostedCluster *hyperv1.HostedCluster + nodePool *hyperv1.NodePool + releaseImage *releaseinfo.ReleaseImage + expectedAMI string + expectError bool + }{ + { + name: "When nodePool has explicit AMI, it should return that AMI directly", + hostedCluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{AWS: &hyperv1.AWSPlatformSpec{Region: "us-east-1"}}, + }, + }, + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{AWS: &hyperv1.AWSNodePoolPlatform{AMI: "ami-explicit"}}, + }, + }, + releaseImage: releaseImageWithMetadata, + expectedAMI: "ami-explicit", + }, + { + name: "When nodePool has Windows ImageType, it should resolve Windows AMI from metadata", + hostedCluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{AWS: &hyperv1.AWSPlatformSpec{Region: "us-east-1"}}, + }, + }, + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Arch: hyperv1.ArchitectureAMD64, + Platform: hyperv1.NodePoolPlatform{AWS: &hyperv1.AWSNodePoolPlatform{ImageType: hyperv1.ImageTypeWindows}}, + }, + }, + releaseImage: releaseImageWithMetadata, + expectedAMI: "ami-windows-us-east-1", + }, + { + name: "When nodePool has Windows ImageType with unsupported region, it should return error", + hostedCluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{AWS: &hyperv1.AWSPlatformSpec{Region: "ap-southeast-99"}}, + }, + }, + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Arch: hyperv1.ArchitectureAMD64, + Platform: hyperv1.NodePoolPlatform{AWS: &hyperv1.AWSNodePoolPlatform{ImageType: hyperv1.ImageTypeWindows}}, + }, + }, + releaseImage: releaseImageWithMetadata, + expectError: true, + }, + { + name: "When nodePool has no AMI and default Linux type with nil stream metadata, it should return error", + hostedCluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Platform: hyperv1.PlatformSpec{AWS: &hyperv1.AWSPlatformSpec{Region: "us-east-1"}}, + }, + }, + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Arch: hyperv1.ArchitectureAMD64, + Platform: hyperv1.NodePoolPlatform{AWS: &hyperv1.AWSNodePoolPlatform{}}, + }, + }, + releaseImage: &releaseinfo.ReleaseImage{ + ImageStream: &v1.ImageStream{ObjectMeta: metav1.ObjectMeta{Name: "4.17.0"}}, + StreamMetadata: nil, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + ami, err := resolveAWSAMI(tc.hostedCluster, tc.nodePool, tc.releaseImage) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(ami).To(Equal(tc.expectedAMI)) + } + }) + } +} + +func TestBuildAWSSubnet(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + expectedSubnet *capiaws.AWSResourceReference + }{ + { + name: "When subnet has only ID, it should return subnet with ID set", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Subnet: hyperv1.AWSResourceReference{ + ID: ptr.To("subnet-abc123"), + }, + }, + }, + }, + }, + expectedSubnet: &capiaws.AWSResourceReference{ + ID: ptr.To("subnet-abc123"), + }, + }, + { + name: "When subnet has filters, it should copy filters to CAPI format", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Subnet: hyperv1.AWSResourceReference{ + Filters: []hyperv1.Filter{ + {Name: "tag:Name", Values: []string{"my-subnet"}}, + {Name: "vpc-id", Values: []string{"vpc-123"}}, + }, + }, + }, + }, + }, + }, + expectedSubnet: &capiaws.AWSResourceReference{ + Filters: []capiaws.Filter{ + {Name: "tag:Name", Values: []string{"my-subnet"}}, + {Name: "vpc-id", Values: []string{"vpc-123"}}, + }, + }, + }, + { + name: "When subnet has no ID and no filters, it should return empty subnet reference", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Subnet: hyperv1.AWSResourceReference{}, + }, + }, + }, + }, + expectedSubnet: &capiaws.AWSResourceReference{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + subnet := buildAWSSubnet(tc.nodePool) + g.Expect(subnet).To(Equal(tc.expectedSubnet)) + }) + } +} + +func TestBuildAWSRootVolume(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + expectedVolume *capiaws.Volume + }{ + { + name: "When RootVolume is nil, it should return default volume with default size", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + RootVolume: nil, + }, + }, + }, + }, + expectedVolume: &capiaws.Volume{ + Size: EC2VolumeDefaultSize, + }, + }, + { + name: "When RootVolume has custom type and size, it should use them", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + RootVolume: &hyperv1.Volume{ + Type: "io1", + Size: 100, + IOPS: 5000, + }, + }, + }, + }, + }, + expectedVolume: &capiaws.Volume{ + Type: capiaws.VolumeType("io1"), + Size: 100, + IOPS: 5000, + }, + }, + { + name: "When RootVolume has empty type, it should use default type", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + RootVolume: &hyperv1.Volume{ + Type: "", + Size: 50, + }, + }, + }, + }, + }, + expectedVolume: &capiaws.Volume{ + Type: capiaws.VolumeType(EC2VolumeDefaultType), + Size: 50, + }, + }, + { + name: "When RootVolume has zero size, it should keep default size", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + RootVolume: &hyperv1.Volume{ + Type: "gp3", + Size: 0, + }, + }, + }, + }, + }, + expectedVolume: &capiaws.Volume{ + Type: capiaws.VolumeType("gp3"), + Size: EC2VolumeDefaultSize, + }, + }, + { + name: "When RootVolume has encryption settings, it should propagate them", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + RootVolume: &hyperv1.Volume{ + Type: "gp3", + Size: 64, + Encrypted: ptr.To(true), + EncryptionKey: "arn:aws:kms:us-east-1:123:key/abc", + }, + }, + }, + }, + }, + expectedVolume: &capiaws.Volume{ + Type: capiaws.VolumeType("gp3"), + Size: 64, + Encrypted: ptr.To(true), + EncryptionKey: "arn:aws:kms:us-east-1:123:key/abc", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + volume := buildAWSRootVolume(tc.nodePool) + g.Expect(volume).To(Equal(tc.expectedVolume)) + }) + } +} + +func TestBuildAWSSecurityGroups(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + hostedCluster *hyperv1.HostedCluster + defaultSG bool + expectedSGs []capiaws.AWSResourceReference + expectError bool + expectNotReady bool + }{ + { + name: "When nodePool has security groups and defaultSG is true, it should include both", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + SecurityGroups: []hyperv1.AWSResourceReference{ + {ID: ptr.To("sg-custom")}, + }, + }, + }, + }, + }, + hostedCluster: &hyperv1.HostedCluster{ + Status: hyperv1.HostedClusterStatus{ + Platform: &hyperv1.PlatformStatus{ + AWS: &hyperv1.AWSPlatformStatus{ + DefaultWorkerSecurityGroupID: "sg-default", + }, + }, + }, + }, + defaultSG: true, + expectedSGs: []capiaws.AWSResourceReference{ + {ID: ptr.To("sg-custom")}, + {ID: ptr.To("sg-default")}, + }, + }, + { + name: "When nodePool has no security groups and defaultSG is true, it should use only default", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{}, + }, + }, + }, + hostedCluster: &hyperv1.HostedCluster{ + Status: hyperv1.HostedClusterStatus{ + Platform: &hyperv1.PlatformStatus{ + AWS: &hyperv1.AWSPlatformStatus{ + DefaultWorkerSecurityGroupID: "sg-default", + }, + }, + }, + }, + defaultSG: true, + expectedSGs: []capiaws.AWSResourceReference{ + {ID: ptr.To("sg-default")}, + }, + }, + { + name: "When defaultSG is true but no default SG available, it should return NotReadyError", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{}, + }, + }, + }, + hostedCluster: &hyperv1.HostedCluster{ + Status: hyperv1.HostedClusterStatus{ + Platform: &hyperv1.PlatformStatus{ + AWS: &hyperv1.AWSPlatformStatus{ + DefaultWorkerSecurityGroupID: "", + }, + }, + }, + }, + defaultSG: true, + expectError: true, + expectNotReady: true, + }, + { + name: "When defaultSG is false, it should only return nodePool security groups", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + SecurityGroups: []hyperv1.AWSResourceReference{ + {ID: ptr.To("sg-1")}, + {ID: ptr.To("sg-2")}, + }, + }, + }, + }, + }, + hostedCluster: &hyperv1.HostedCluster{}, + defaultSG: false, + expectedSGs: []capiaws.AWSResourceReference{ + {ID: ptr.To("sg-1")}, + {ID: ptr.To("sg-2")}, + }, + }, + { + name: "When security group has filters, it should copy filters to CAPI format", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + SecurityGroups: []hyperv1.AWSResourceReference{ + { + Filters: []hyperv1.Filter{ + {Name: "tag:Role", Values: []string{"worker"}}, + }, + }, + }, + }, + }, + }, + }, + hostedCluster: &hyperv1.HostedCluster{}, + defaultSG: false, + expectedSGs: []capiaws.AWSResourceReference{ + { + Filters: []capiaws.Filter{ + {Name: "tag:Role", Values: []string{"worker"}}, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + sgs, err := buildAWSSecurityGroups(tc.nodePool, tc.hostedCluster, tc.defaultSG) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + if tc.expectNotReady { + _, isNotReady := err.(*NotReadyError) + g.Expect(isNotReady).To(BeTrue()) + } + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(sgs).To(Equal(tc.expectedSGs)) + } + }) + } +} + +func TestApplyAWSPlacementOptions(t *testing.T) { + capacityReservationID := "cr-0123456789abcdef0" + + testCases := []struct { + name string + nodePool *hyperv1.NodePool + expectedSpotMarketOptions *capiaws.SpotMarketOptions + expectedMarketType capiaws.MarketType + expectedTenancy string + expectedCapacityReservationID *string + expectedCapReservationPreference capiaws.CapacityReservationPreference + }{ + { + name: "When placement is nil, it should not modify spec", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: nil, + }, + }, + }, + }, + }, + { + name: "When marketType is Spot with no MaxPrice, it should set empty SpotMarketOptions", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + MarketType: hyperv1.MarketTypeSpot, + Spot: hyperv1.SpotOptions{}, + }, + }, + }, + }, + }, + expectedSpotMarketOptions: &capiaws.SpotMarketOptions{}, + }, + { + name: "When marketType is Spot with MaxPrice, it should set SpotMarketOptions with MaxPrice", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + MarketType: hyperv1.MarketTypeSpot, + Spot: hyperv1.SpotOptions{ + MaxPrice: "1.50", + }, + }, + }, + }, + }, + }, + expectedSpotMarketOptions: &capiaws.SpotMarketOptions{ + MaxPrice: ptr.To("1.50"), + }, + }, + { + name: "When marketType is CapacityBlock, it should set MarketType to CapacityBlock", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + MarketType: hyperv1.MarketTypeCapacityBlock, + }, + }, + }, + }, + }, + expectedMarketType: capiaws.MarketTypeCapacityBlock, + }, + { + name: "When marketType is OnDemand, it should set MarketType to OnDemand", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + MarketType: hyperv1.MarketTypeOnDemand, + }, + }, + }, + }, + }, + expectedMarketType: capiaws.MarketTypeOnDemand, + }, + { + name: "When tenancy is dedicated, it should set tenancy on spec", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + Tenancy: "dedicated", + }, + }, + }, + }, + }, + expectedTenancy: "dedicated", + }, + { + name: "When capacityReservation has ID and preference, it should set both on spec", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + CapacityReservation: &hyperv1.CapacityReservationOptions{ + ID: &capacityReservationID, + Preference: hyperv1.CapacityReservationPreferenceOnly, + }, + }, + }, + }, + }, + }, + expectedCapacityReservationID: &capacityReservationID, + expectedCapReservationPreference: capiaws.CapacityReservationPreference(hyperv1.CapacityReservationPreferenceOnly), + expectedMarketType: capiaws.MarketTypeCapacityBlock, + }, + { + name: "When deprecated capacityReservation.MarketType is CapacityBlock and no top-level marketType, it should use deprecated value", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + CapacityReservation: &hyperv1.CapacityReservationOptions{ + MarketType: hyperv1.MarketTypeCapacityBlock, + }, + }, + }, + }, + }, + }, + expectedMarketType: capiaws.MarketTypeCapacityBlock, + }, + { + name: "When tenancy is host with capacityReservation ID but no marketType, it should not default to CapacityBlock", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + AWS: &hyperv1.AWSNodePoolPlatform{ + Placement: &hyperv1.PlacementOptions{ + Tenancy: "host", + CapacityReservation: &hyperv1.CapacityReservationOptions{ + ID: &capacityReservationID, + }, + }, + }, + }, + }, + }, + expectedTenancy: "host", + expectedMarketType: "", + expectedCapacityReservationID: &capacityReservationID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + spec := &capiaws.AWSMachineTemplateSpec{} + applyAWSPlacementOptions(tc.nodePool, spec) + + g.Expect(spec.Template.Spec.SpotMarketOptions).To(Equal(tc.expectedSpotMarketOptions)) + g.Expect(spec.Template.Spec.MarketType).To(Equal(tc.expectedMarketType)) + g.Expect(spec.Template.Spec.Tenancy).To(Equal(tc.expectedTenancy)) + g.Expect(spec.Template.Spec.CapacityReservationID).To(Equal(tc.expectedCapacityReservationID)) + g.Expect(spec.Template.Spec.CapacityReservationPreference).To(Equal(tc.expectedCapReservationPreference)) + }) + } +} diff --git a/hypershift-operator/controllers/nodepool/capi.go b/hypershift-operator/controllers/nodepool/capi.go index 74a00c6b1eb..b66ebcd1d68 100644 --- a/hypershift-operator/controllers/nodepool/capi.go +++ b/hypershift-operator/controllers/nodepool/capi.go @@ -407,23 +407,10 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, nodePool := c.nodePool capiClusterName := c.capiClusterName - // Set annotations and labels - if machineDeployment.GetAnnotations() == nil { - machineDeployment.Annotations = map[string]string{} - } - machineDeployment.Annotations[nodePoolAnnotation] = client.ObjectKeyFromObject(nodePool).String() - // Delete any paused annotation - delete(machineDeployment.Annotations, capiv1.PausedAnnotation) - if machineDeployment.GetLabels() == nil { - machineDeployment.Labels = map[string]string{} - } - machineDeployment.Labels[capiv1.ClusterNameLabel] = capiClusterName - // Set defaults. These are normally set by the CAPI machinedeployment webhook. - // However, since we don't run the webhook, CAPI updates the machinedeployment - // after it has been created with defaults. - machineDeployment.Spec.MinReadySeconds = ptr.To[int32](0) + c.setMachineDeploymentMetadata(machineDeployment, capiClusterName) + machineDeployment.Spec.MinReadySeconds = ptr.To[int32](0) machineDeployment.Spec.ClusterName = capiClusterName if machineDeployment.Spec.Selector.MatchLabels == nil { machineDeployment.Spec.Selector.MatchLabels = map[string]string{} @@ -442,8 +429,6 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, resourcesName: resourcesName, capiv1.ClusterNameLabel: capiClusterName, }, - // Annotations here propagate down to Machines - // https://cluster-api.sigs.k8s.io/developer/architecture/controllers/metadata-propagation.html#machinedeployment. Annotations: map[string]string{ nodePoolAnnotation: client.ObjectKeyFromObject(nodePool).String(), hyperv1.NodePoolReleaseVersionAnnotation: c.Version(), @@ -452,47 +437,77 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, Spec: capiv1.MachineSpec{ ClusterName: capiClusterName, Bootstrap: capiv1.Bootstrap{ - // Keep current user data for later check. DataSecretName: machineDeployment.Spec.Template.Spec.Bootstrap.DataSecretName, }, InfrastructureRef: corev1.ObjectReference{ Kind: gvk.Kind, APIVersion: gvk.GroupVersion().String(), Namespace: machineTemplateCR.GetNamespace(), - // keep current template name for later check. - Name: machineDeployment.Spec.Template.Spec.InfrastructureRef.Name, + Name: machineDeployment.Spec.Template.Spec.InfrastructureRef.Name, }, - // Keep current version for later check. Version: machineDeployment.Spec.Template.Spec.Version, NodeDrainTimeout: nodePool.Spec.NodeDrainTimeout, NodeVolumeDetachTimeout: nodePool.Spec.NodeVolumeDetachTimeout, }, } - // Add interruptible-instance label for spot instances - // This label must be on the MachineDeployment template so the spot MHC can select machines if isSpotEnabled(nodePool) { machineDeployment.Spec.Template.Labels[interruptibleInstanceLabel] = "" } - // The CAPI provider for OpenStack uses the FailureDomain field to set the availability zone. - if c.nodePool.Spec.Platform.Type == hyperv1.OpenStackPlatform && c.nodePool.Spec.Platform.OpenStack != nil { - if c.nodePool.Spec.Platform.OpenStack.AvailabilityZone != "" { - machineDeployment.Spec.Template.Spec.FailureDomain = ptr.To(c.nodePool.Spec.Platform.OpenStack.AvailabilityZone) + setMachineDeploymentFailureDomain(c.nodePool, machineDeployment) + + if err := c.propagateLabelsAndTaintsToMachines(ctx, log, machineDeployment); err != nil { + return err + } + + machineDeployment.Spec.Strategy = &capiv1.MachineDeploymentStrategy{} + machineDeployment.Spec.Strategy.Type = capiv1.MachineDeploymentStrategyType(nodePool.Spec.Management.Replace.Strategy) + if nodePool.Spec.Management.Replace.RollingUpdate != nil { + machineDeployment.Spec.Strategy.RollingUpdate = &capiv1.MachineRollingUpdateDeployment{ + MaxUnavailable: nodePool.Spec.Management.Replace.RollingUpdate.MaxUnavailable, + MaxSurge: nodePool.Spec.Management.Replace.RollingUpdate.MaxSurge, } } - // The CAPI provider for GCP uses the FailureDomain field to set the zone. - if c.nodePool.Spec.Platform.Type == hyperv1.GCPPlatform && c.nodePool.Spec.Platform.GCP != nil { - if c.nodePool.Spec.Platform.GCP.Zone != "" { - machineDeployment.Spec.Template.Spec.FailureDomain = ptr.To(c.nodePool.Spec.Platform.GCP.Zone) + setMachineDeploymentReplicas(nodePool, machineDeployment) + + if updated := c.propagateVersionAndTemplate(log, machineDeployment, machineTemplateCR); updated { + return nil + } + + c.reconcileMachineDeploymentStatus(log, machineDeployment, machineTemplateCR) + + return nil +} + +func (c *CAPI) setMachineDeploymentMetadata(machineDeployment *capiv1.MachineDeployment, capiClusterName string) { + if machineDeployment.GetAnnotations() == nil { + machineDeployment.Annotations = map[string]string{} + } + machineDeployment.Annotations[nodePoolAnnotation] = client.ObjectKeyFromObject(c.nodePool).String() + delete(machineDeployment.Annotations, capiv1.PausedAnnotation) + if machineDeployment.GetLabels() == nil { + machineDeployment.Labels = map[string]string{} + } + machineDeployment.Labels[capiv1.ClusterNameLabel] = capiClusterName +} + +func setMachineDeploymentFailureDomain(nodePool *hyperv1.NodePool, machineDeployment *capiv1.MachineDeployment) { + if nodePool.Spec.Platform.Type == hyperv1.OpenStackPlatform && nodePool.Spec.Platform.OpenStack != nil { + if nodePool.Spec.Platform.OpenStack.AvailabilityZone != "" { + machineDeployment.Spec.Template.Spec.FailureDomain = ptr.To(nodePool.Spec.Platform.OpenStack.AvailabilityZone) } } + if nodePool.Spec.Platform.Type == hyperv1.GCPPlatform && nodePool.Spec.Platform.GCP != nil { + if nodePool.Spec.Platform.GCP.Zone != "" { + machineDeployment.Spec.Template.Spec.FailureDomain = ptr.To(nodePool.Spec.Platform.GCP.Zone) + } + } +} - // After a MachineDeployment is created we propagate label/taints directly into Machines. - // This is to avoid a NodePool label/taints to trigger a rolling upgrade. - // TODO(Alberto): drop this an rely on core in-place propagation once CAPI 1.4.0 https://github.com/kubernetes-sigs/cluster-api/releases comes through the payload. - // https://issues.redhat.com/browse/HOSTEDCP-971 +func (c *CAPI) propagateLabelsAndTaintsToMachines(ctx context.Context, log logr.Logger, machineDeployment *capiv1.MachineDeployment) error { + nodePool := c.nodePool machineList := &capiv1.MachineList{} if err := c.List(ctx, machineList, client.InNamespace(machineDeployment.Namespace)); err != nil { return err @@ -510,28 +525,20 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, machine.Annotations = make(map[string]string) } - // Propagate labels. for k, v := range nodePool.Spec.NodeLabels { - // Propagated managed labels down to Machines with a known hardcoded prefix - // so the CPO HCCO Node controller can recognize them and apply them to Nodes. labelKey := fmt.Sprintf("%s.%s", labelManagedPrefix, k) machine.Labels[labelKey] = v } - // Propagate globalPS managed label to Machines so the HCCO Node controller - // applies it to Nodes. This enables the GlobalPullSecret DaemonSet to - // schedule on Replace nodes. Only AWS and Azure platforms support this. if nodePool.Spec.Platform.Type == hyperv1.AWSPlatform || nodePool.Spec.Platform.Type == hyperv1.AzurePlatform { globalPSLabelKey := fmt.Sprintf("%s.%s", labelManagedPrefix, globalPSNodeLabel) machine.Labels[globalPSLabelKey] = "true" } - // Propagate taints. taintsInJSON, err := taintsToJSON(nodePool.Spec.Taints) if err != nil { return err } - machine.Annotations[nodePoolAnnotationTaints] = taintsInJSON return nil }); err != nil { @@ -541,25 +548,16 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, log.Info("Reconciled Machine", "result", result) } } + return nil +} - // Set strategy - machineDeployment.Spec.Strategy = &capiv1.MachineDeploymentStrategy{} - machineDeployment.Spec.Strategy.Type = capiv1.MachineDeploymentStrategyType(nodePool.Spec.Management.Replace.Strategy) - if nodePool.Spec.Management.Replace.RollingUpdate != nil { - machineDeployment.Spec.Strategy.RollingUpdate = &capiv1.MachineRollingUpdateDeployment{ - MaxUnavailable: nodePool.Spec.Management.Replace.RollingUpdate.MaxUnavailable, - MaxSurge: nodePool.Spec.Management.Replace.RollingUpdate.MaxSurge, - } - } - - setMachineDeploymentReplicas(nodePool, machineDeployment) - - isUpdating := false - // Propagate version and userData Secret to the machineDeployment. +func (c *CAPI) propagateVersionAndTemplate(log logr.Logger, machineDeployment *capiv1.MachineDeployment, machineTemplateCR client.Object) bool { + nodePool := c.nodePool userDataSecret := c.UserDataSecret() targetVersion := c.Version() targetConfigHash := c.HashWithoutVersion() - targetConfigVersionHash := c.Hash() + isUpdating := false + if userDataSecret.Name != ptr.Deref(machineDeployment.Spec.Template.Spec.Bootstrap.DataSecretName, "") { log.Info("New user data Secret has been generated", "current", machineDeployment.Spec.Template.Spec.Bootstrap.DataSecretName, @@ -579,25 +577,23 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, isUpdating = true } - // template spec has changed, signal a rolling upgrade. if machineTemplateCR.GetName() != machineDeployment.Spec.Template.Spec.InfrastructureRef.Name { log.Info("New machine template has been generated", "current", machineDeployment.Spec.Template.Spec.InfrastructureRef.Name, "target", machineTemplateCR.GetName()) - machineDeployment.Spec.Template.Spec.InfrastructureRef.Name = machineTemplateCR.GetName() isUpdating = true } - if isUpdating { - // We return early here during a version/config/MachineTemplate update to persist the resource with new user data Secret / MachineTemplate, - // so in the next reconciling loop we get a new MachineDeployment.Generation - // and we can do a legit MachineDeploymentComplete/MachineDeployment.Status.ObservedGeneration check. - return nil - } + return isUpdating +} + +func (c *CAPI) reconcileMachineDeploymentStatus(log logr.Logger, machineDeployment *capiv1.MachineDeployment, machineTemplateCR client.Object) { + nodePool := c.nodePool + targetVersion := c.Version() + targetConfigHash := c.HashWithoutVersion() + targetConfigVersionHash := c.Hash() - // If the MachineDeployment is now processing we know - // is at the expected version (spec.version) and config (userData Secret) so we reconcile status and annotation. if MachineDeploymentComplete(machineDeployment) { if nodePool.Status.Version != targetVersion { log.Info("Version update complete", @@ -622,31 +618,23 @@ func (c *CAPI) reconcileMachineDeployment(ctx context.Context, log logr.Logger, } } - // Bubble up AvailableReplicas and Ready condition from MachineDeployment. nodePool.Status.Replicas = machineDeployment.Status.AvailableReplicas - for _, c := range machineDeployment.Status.Conditions { - // This condition should aggregate and summarize readiness from underlying MachineSets and Machines - // https://github.com/kubernetes-sigs/cluster-api/issues/3486. - if c.Type == capiv1.ReadyCondition { - // this is so api server does not complain - // invalid value: \"\": status.conditions.reason in body should be at least 1 chars long" + for _, cond := range machineDeployment.Status.Conditions { + if cond.Type == capiv1.ReadyCondition { reason := hyperv1.AsExpectedReason - if c.Reason != "" { - reason = c.Reason + if cond.Reason != "" { + reason = cond.Reason } - SetStatusCondition(&nodePool.Status.Conditions, hyperv1.NodePoolCondition{ Type: hyperv1.NodePoolReadyConditionType, - Status: c.Status, + Status: cond.Status, ObservedGeneration: nodePool.Generation, - Message: c.Message, + Message: cond.Message, Reason: reason, }) break } } - - return nil } func taintsToJSON(taints []hyperv1.Taint) (string, error) { diff --git a/hypershift-operator/controllers/nodepool/capi_test.go b/hypershift-operator/controllers/nodepool/capi_test.go index b67d2c9f66d..c57e4d2eec6 100644 --- a/hypershift-operator/controllers/nodepool/capi_test.go +++ b/hypershift-operator/controllers/nodepool/capi_test.go @@ -2331,6 +2331,618 @@ func TestPause(t *testing.T) { g.Expect(ms.Annotations).To(HaveKeyWithValue(capiv1.PausedAnnotation, "true")) } +func TestSetMachineDeploymentMetadata(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + machineDeployment *capiv1.MachineDeployment + capiClusterName string + expectAnnotationKey string + expectAnnotationVal string + expectLabelKey string + expectLabelVal string + expectPausedRemoved bool + }{ + { + name: "When MachineDeployment has nil annotations and labels, it should initialize them and set nodePool annotation and cluster label", + nodePool: &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-nodepool", + Namespace: "my-ns", + }, + }, + machineDeployment: &capiv1.MachineDeployment{}, + capiClusterName: "my-cluster", + }, + { + name: "When MachineDeployment has existing annotations including paused, it should remove paused annotation and set nodePool annotation", + nodePool: &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-nodepool", + Namespace: "my-ns", + }, + }, + machineDeployment: &capiv1.MachineDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + capiv1.PausedAnnotation: "true", + "existing-key": "existing-value", + }, + Labels: map[string]string{ + "existing-label": "val", + }, + }, + }, + capiClusterName: "my-cluster", + expectPausedRemoved: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + capi := &CAPI{ + Token: &Token{ + ConfigGenerator: &ConfigGenerator{ + nodePool: tc.nodePool, + }, + }, + capiClusterName: tc.capiClusterName, + } + + capi.setMachineDeploymentMetadata(tc.machineDeployment, tc.capiClusterName) + + g.Expect(tc.machineDeployment.Annotations).To(HaveKeyWithValue( + nodePoolAnnotation, client.ObjectKeyFromObject(tc.nodePool).String())) + g.Expect(tc.machineDeployment.Annotations).ToNot(HaveKey(capiv1.PausedAnnotation)) + g.Expect(tc.machineDeployment.Labels).To(HaveKeyWithValue( + capiv1.ClusterNameLabel, tc.capiClusterName)) + }) + } +} + +func TestSetMachineDeploymentFailureDomain(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + expectedFailureDomain *string + }{ + { + name: "When platform is OpenStack with AvailabilityZone set, it should set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.OpenStackPlatform, + OpenStack: &hyperv1.OpenStackNodePoolPlatform{ + AvailabilityZone: "az-1", + }, + }, + }, + }, + expectedFailureDomain: ptr.To("az-1"), + }, + { + name: "When platform is OpenStack with empty AvailabilityZone, it should not set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.OpenStackPlatform, + OpenStack: &hyperv1.OpenStackNodePoolPlatform{ + AvailabilityZone: "", + }, + }, + }, + }, + expectedFailureDomain: nil, + }, + { + name: "When platform is GCP with Zone set, it should set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.GCPPlatform, + GCP: &hyperv1.GCPNodePoolPlatform{ + Zone: "us-central1-a", + }, + }, + }, + }, + expectedFailureDomain: ptr.To("us-central1-a"), + }, + { + name: "When platform is GCP with empty Zone, it should not set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.GCPPlatform, + GCP: &hyperv1.GCPNodePoolPlatform{ + Zone: "", + }, + }, + }, + }, + expectedFailureDomain: nil, + }, + { + name: "When platform is AWS, it should not set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSNodePoolPlatform{}, + }, + }, + }, + expectedFailureDomain: nil, + }, + { + name: "When platform is OpenStack but spec is nil, it should not set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.OpenStackPlatform, + }, + }, + }, + expectedFailureDomain: nil, + }, + { + name: "When platform is GCP but spec is nil, it should not set failure domain", + nodePool: &hyperv1.NodePool{ + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.GCPPlatform, + }, + }, + }, + expectedFailureDomain: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + md := &capiv1.MachineDeployment{} + setMachineDeploymentFailureDomain(tc.nodePool, md) + g.Expect(md.Spec.Template.Spec.FailureDomain).To(Equal(tc.expectedFailureDomain)) + }) + } +} + +func TestPropagateVersionAndTemplate(t *testing.T) { + testCases := []struct { + name string + currentBootstrapName string + currentVersion string + templateName string + currentInfraRefName string + useDifferentUserData bool + expectedUpdating bool + expectedInfraRefName string + }{ + { + name: "When user data secret name differs from current bootstrap, it should propagate version and return true", + currentBootstrapName: "old-userdata", + currentVersion: "4.16.0", + templateName: "template-1", + currentInfraRefName: "template-1", + useDifferentUserData: true, + expectedUpdating: true, + expectedInfraRefName: "template-1", + }, + { + name: "When machine template name differs from infra ref, it should propagate template and return true", + currentBootstrapName: "", // will be set to match computed name + currentVersion: "4.17.0", + templateName: "new-template", + currentInfraRefName: "old-template", + expectedUpdating: true, + expectedInfraRefName: "new-template", + }, + { + name: "When both user data and template differ, it should propagate both and return true", + currentBootstrapName: "old-userdata", + currentVersion: "4.16.0", + templateName: "new-template", + currentInfraRefName: "old-template", + useDifferentUserData: true, + expectedUpdating: true, + expectedInfraRefName: "new-template", + }, + { + name: "When nothing differs, it should not update and return false", + currentBootstrapName: "", // will be set to match computed name + currentVersion: "4.17.0", + templateName: "same-template", + currentInfraRefName: "same-template", + expectedUpdating: false, + expectedInfraRefName: "same-template", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + nodePool := &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-np", + Namespace: "test-ns", + }, + } + + capi := &CAPI{ + Token: &Token{ + ConfigGenerator: &ConfigGenerator{ + nodePool: nodePool, + controlplaneNamespace: "cp-ns", + rolloutConfig: &rolloutConfig{ + releaseImage: &releaseinfo.ReleaseImage{ + ImageStream: &imageapi.ImageStream{ + ObjectMeta: metav1.ObjectMeta{ + Name: "4.17.0", + }, + }, + }, + }, + }, + }, + } + + // Compute the actual UserDataSecret name from the CAPI struct. + computedUserDataName := capi.UserDataSecret().Name + + // If the test wants the current bootstrap to match, use the computed name. + bootstrapName := tc.currentBootstrapName + if !tc.useDifferentUserData && bootstrapName == "" { + bootstrapName = computedUserDataName + } + + md := &capiv1.MachineDeployment{ + Spec: capiv1.MachineDeploymentSpec{ + Template: capiv1.MachineTemplateSpec{ + Spec: capiv1.MachineSpec{ + Bootstrap: capiv1.Bootstrap{ + DataSecretName: ptr.To(bootstrapName), + }, + InfrastructureRef: corev1.ObjectReference{ + Name: tc.currentInfraRefName, + }, + Version: ptr.To(tc.currentVersion), + }, + }, + }, + } + + templateCR := &capiaws.AWSMachineTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: tc.templateName, + }, + } + + result := capi.propagateVersionAndTemplate(logr.Discard(), md, templateCR) + g.Expect(result).To(Equal(tc.expectedUpdating)) + g.Expect(md.Spec.Template.Spec.InfrastructureRef.Name).To(Equal(tc.expectedInfraRefName)) + + if tc.expectedUpdating && tc.useDifferentUserData { + // When updating, bootstrap should be set to the computed user data name. + g.Expect(*md.Spec.Template.Spec.Bootstrap.DataSecretName).To(Equal(computedUserDataName)) + g.Expect(*md.Spec.Template.Spec.Version).To(Equal("4.17.0")) + } + }) + } +} + +func TestReconcileMachineDeploymentStatus(t *testing.T) { + testCases := []struct { + name string + machineDeployment *capiv1.MachineDeployment + nodePoolVersion string + nodePoolAnnotations map[string]string + targetVersion string + expectedVersion string + expectedReplicas int32 + expectedConfigAnnotation bool + expectedTemplateAnnotation bool + expectedReadyConditionStatus corev1.ConditionStatus + expectedReadyConditionSet bool + }{ + { + name: "When MachineDeployment is complete, it should update nodePool version and annotations", + machineDeployment: &capiv1.MachineDeployment{ + ObjectMeta: metav1.ObjectMeta{Generation: 1}, + Spec: capiv1.MachineDeploymentSpec{ + Replicas: ptr.To[int32](3), + }, + Status: capiv1.MachineDeploymentStatus{ + Replicas: 3, + UpdatedReplicas: 3, + ReadyReplicas: 3, + AvailableReplicas: 3, + ObservedGeneration: 1, + }, + }, + nodePoolVersion: "", + nodePoolAnnotations: map[string]string{}, + targetVersion: "4.17.0", + expectedVersion: "4.17.0", + expectedReplicas: 3, + expectedConfigAnnotation: true, + expectedTemplateAnnotation: true, + }, + { + name: "When MachineDeployment is not complete, it should only update replicas", + machineDeployment: &capiv1.MachineDeployment{ + ObjectMeta: metav1.ObjectMeta{Generation: 2}, + Spec: capiv1.MachineDeploymentSpec{ + Replicas: ptr.To[int32](3), + }, + Status: capiv1.MachineDeploymentStatus{ + Replicas: 3, + UpdatedReplicas: 1, + ReadyReplicas: 1, + AvailableReplicas: 2, + ObservedGeneration: 1, + }, + }, + nodePoolVersion: "4.16.0", + nodePoolAnnotations: map[string]string{}, + targetVersion: "4.17.0", + expectedVersion: "4.16.0", + expectedReplicas: 2, + expectedConfigAnnotation: false, + expectedTemplateAnnotation: false, + }, + { + name: "When MachineDeployment has Ready condition, it should propagate it to nodePool", + machineDeployment: &capiv1.MachineDeployment{ + ObjectMeta: metav1.ObjectMeta{Generation: 2}, + Spec: capiv1.MachineDeploymentSpec{ + Replicas: ptr.To[int32](3), + }, + Status: capiv1.MachineDeploymentStatus{ + AvailableReplicas: 2, + Conditions: capiv1.Conditions{ + { + Type: capiv1.ReadyCondition, + Status: corev1.ConditionTrue, + Reason: "SomeReason", + Message: "all good", + }, + }, + }, + }, + nodePoolVersion: "4.16.0", + nodePoolAnnotations: map[string]string{}, + targetVersion: "4.17.0", + expectedVersion: "4.16.0", + expectedReplicas: 2, + expectedReadyConditionSet: true, + expectedReadyConditionStatus: corev1.ConditionTrue, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + nodePool := &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-np", + Namespace: "test-ns", + Annotations: tc.nodePoolAnnotations, + }, + Status: hyperv1.NodePoolStatus{ + Version: tc.nodePoolVersion, + }, + } + + capi := &CAPI{ + Token: &Token{ + ConfigGenerator: &ConfigGenerator{ + nodePool: nodePool, + controlplaneNamespace: "cp-ns", + rolloutConfig: &rolloutConfig{ + releaseImage: &releaseinfo.ReleaseImage{ + ImageStream: &imageapi.ImageStream{ + ObjectMeta: metav1.ObjectMeta{ + Name: tc.targetVersion, + }, + }, + }, + }, + }, + }, + } + + templateCR := &capiaws.AWSMachineTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: "template-name", + }, + } + + capi.reconcileMachineDeploymentStatus(logr.Discard(), tc.machineDeployment, templateCR) + + g.Expect(nodePool.Status.Replicas).To(Equal(tc.expectedReplicas)) + g.Expect(nodePool.Status.Version).To(Equal(tc.expectedVersion)) + + if tc.expectedConfigAnnotation { + g.Expect(nodePool.Annotations).To(HaveKey(nodePoolAnnotationCurrentConfig)) + } + if tc.expectedTemplateAnnotation { + g.Expect(nodePool.Annotations).To(HaveKeyWithValue( + nodePoolAnnotationPlatformMachineTemplate, "template-name")) + } + + if tc.expectedReadyConditionSet { + readyCond := FindStatusCondition(nodePool.Status.Conditions, hyperv1.NodePoolReadyConditionType) + g.Expect(readyCond).ToNot(BeNil()) + g.Expect(readyCond.Status).To(Equal(tc.expectedReadyConditionStatus)) + } + }) + } +} + +func TestPropagateLabelsAndTaintsToMachines(t *testing.T) { + testCases := []struct { + name string + nodePool *hyperv1.NodePool + machines []capiv1.Machine + expectLabels map[string]string + expectTaintsJSON string + }{ + { + name: "When NodePool has labels and taints on AWS platform, it should propagate them to owned machines with globalPS label", + nodePool: &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-np", + Namespace: "test-ns", + }, + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.AWSPlatform, + }, + NodeLabels: map[string]string{ + "custom-label": "custom-value", + }, + Taints: []hyperv1.Taint{ + {Key: "key1", Value: "val1", Effect: "NoSchedule"}, + }, + }, + }, + machines: []capiv1.Machine{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "machine-1", + Namespace: "cp-ns", + Annotations: map[string]string{ + nodePoolAnnotation: "test-ns/test-np", + }, + }, + }, + }, + expectLabels: map[string]string{ + "managed.hypershift.openshift.io.custom-label": "custom-value", + "managed.hypershift.openshift.io.hypershift.openshift.io/nodepool-globalps-enabled": "true", + }, + expectTaintsJSON: `[{"key":"key1","value":"val1","effect":"NoSchedule"}]`, + }, + { + name: "When NodePool is on KubeVirt platform, it should propagate labels but not globalPS label", + nodePool: &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-np", + Namespace: "test-ns", + }, + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.KubevirtPlatform, + }, + NodeLabels: map[string]string{ + "my-label": "my-value", + }, + }, + }, + machines: []capiv1.Machine{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "machine-1", + Namespace: "cp-ns", + Annotations: map[string]string{ + nodePoolAnnotation: "test-ns/test-np", + }, + }, + }, + }, + expectLabels: map[string]string{ + "managed.hypershift.openshift.io.my-label": "my-value", + }, + }, + { + name: "When machine does not belong to the NodePool, it should not be modified", + nodePool: &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-np", + Namespace: "test-ns", + }, + Spec: hyperv1.NodePoolSpec{ + Platform: hyperv1.NodePoolPlatform{ + Type: hyperv1.AWSPlatform, + }, + NodeLabels: map[string]string{ + "custom-label": "custom-value", + }, + }, + }, + machines: []capiv1.Machine{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "other-machine", + Namespace: "cp-ns", + Annotations: map[string]string{ + nodePoolAnnotation: "other-ns/other-np", + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var objects []client.Object + for i := range tc.machines { + objects = append(objects, &tc.machines[i]) + } + + c := fake.NewClientBuilder().WithScheme(api.Scheme).WithObjects(objects...).Build() + + capi := &CAPI{ + Token: &Token{ + ConfigGenerator: &ConfigGenerator{ + Client: c, + nodePool: tc.nodePool, + }, + }, + } + + md := &capiv1.MachineDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "cp-ns", + }, + } + + err := capi.propagateLabelsAndTaintsToMachines(t.Context(), logr.Discard(), md) + g.Expect(err).ToNot(HaveOccurred()) + + // Check machines that belong to the NodePool. + machineList := &capiv1.MachineList{} + g.Expect(c.List(t.Context(), machineList, client.InNamespace("cp-ns"))).To(Succeed()) + + npKey := client.ObjectKeyFromObject(tc.nodePool).String() + for _, m := range machineList.Items { + if m.Annotations[nodePoolAnnotation] != npKey { + // Machine doesn't belong to this NodePool - should have no managed labels. + for k := range m.Labels { + g.Expect(k).ToNot(HavePrefix(labelManagedPrefix)) + } + continue + } + + for k, v := range tc.expectLabels { + g.Expect(m.Labels).To(HaveKeyWithValue(k, v)) + } + + if tc.expectTaintsJSON != "" { + g.Expect(m.Annotations).To(HaveKeyWithValue(nodePoolAnnotationTaints, tc.expectTaintsJSON)) + } + } + }) + } +} + func TestNewCAPI(t *testing.T) { t.Parallel() testCases := []struct { diff --git a/hypershift-operator/controllers/nodepool/metrics/metrics.go b/hypershift-operator/controllers/nodepool/metrics/metrics.go index ed4819561fa..d52b458db75 100644 --- a/hypershift-operator/controllers/nodepool/metrics/metrics.go +++ b/hypershift-operator/controllers/nodepool/metrics/metrics.go @@ -337,81 +337,17 @@ func (c *nodePoolsMetricsCollector) Collect(ch chan<- prometheus.Metric) { defer c.mu.Unlock() ctx := context.Background() currentCollectTime := c.clock.Now() - log := ctrllog.Log - // Data retrieved from objects other than node pools in below loops - hclusterPathToData := make(map[string]*hclusterData) - machineSetPathToReplicasCount := make(map[string]int32) - machineDeploymentPathToReplicasCount := make(map[string]int32) - - // Hosted clusters loop - { - hclusters := &hyperv1.HostedClusterList{} - - if err := c.List(ctx, hclusters); err != nil { - log.Error(err, "failed to list hosted clusters while collecting metrics") - } - - for k := range hclusters.Items { - hcluster := &hclusters.Items[k] - - data := &hclusterData{ - id: hcluster.Spec.ClusterID, - namespace: hcluster.Namespace, - name: hcluster.Name, - platform: hcluster.Spec.Platform.Type, - } - // Seed with Karpenter-managed vCPUs from AutoNode status. - // Native NodePool vCPUs accumulate on top in the NodePool loop below. - if hcluster.Status.AutoNode.VCPUs != nil { - data.vCpusCount = *hcluster.Status.AutoNode.VCPUs - } - hclusterPathToData[hcluster.Namespace+"/"+hcluster.Name] = data - } - } - - // Machine sets loop - { - machineSets := &capiv1.MachineSetList{} - - if err := c.List(ctx, machineSets); err != nil { - log.Error(err, "failed to list machine sets while collecting metrics") - } - - for k := range machineSets.Items { - machineSet := &machineSets.Items[k] - msPath := machineSet.Namespace + "/" + machineSet.Name - - machineSetPathToReplicasCount[msPath] = *machineSet.Spec.Replicas - } - } - - // Machine deployments loop - { - machineDeployments := &capiv1.MachineDeploymentList{} - - if err := c.List(ctx, machineDeployments); err != nil { - log.Error(err, "failed to list machine deployments while collecting metrics") - } - - for k := range machineDeployments.Items { - machineDeployment := &machineDeployments.Items[k] - mdPath := machineDeployment.Namespace + "/" + machineDeployment.Name - - machineDeploymentPathToReplicasCount[mdPath] = *machineDeployment.Spec.Replicas - } - } + hclusterPathToData := c.collectHostedClusterData(ctx) + machineSetPathToReplicasCount := c.collectMachineSetReplicas(ctx) + machineDeploymentPathToReplicasCount := c.collectMachineDeploymentReplicas(ctx) - // countByPlatformMetric - init platformToNodePoolsCount := make(map[hyperv1.PlatformType]int) - for k := range knownPlatforms { platformToNodePoolsCount[knownPlatforms[k]] = 0 } - // countByPlatformAndFailureConditionMetric - init platformToFailureConditionToNodePoolsCount := make(map[hyperv1.PlatformType]*map[string]int) - for k := range knownPlatforms { platformToFailureConditionToNodePoolsCount[knownPlatforms[k]] = createFailureConditionToNodePoolsCountMap(conditions.ExpectedNodePoolConditions(&hyperv1.NodePool{ Spec: hyperv1.NodePoolSpec{ @@ -422,151 +358,179 @@ func (c *nodePoolsMetricsCollector) Collect(ch chan<- prometheus.Metric) { })) } - // MAIN LOOP - node pools loop - { - npList := &hyperv1.NodePoolList{} - - if err := c.List(ctx, npList); err != nil { - log.Error(err, "failed to list node pools while collecting metrics") - } + npList := &hyperv1.NodePoolList{} + if err := c.List(ctx, npList); err != nil { + ctrllog.Log.Error(err, "failed to list node pools while collecting metrics") + } - for k := range npList.Items { - nodePool := &npList.Items[k] - hclusterId := "" + for k := range npList.Items { + nodePool := &npList.Items[k] + hclusterId := "" - // countByPlatformMetric - aggregation - platform := nodePool.Spec.Platform.Type - platformToNodePoolsCount[platform] += 1 + platform := nodePool.Spec.Platform.Type + platformToNodePoolsCount[platform] += 1 - // countByPlatformAndFailureConditionMetric - aggregation - { - knownConditionToExpectedStatus := conditions.ExpectedNodePoolConditions(nodePool) - _, isKnownPlatform := platformToFailureConditionToNodePoolsCount[platform] + c.aggregateFailureConditions(nodePool, platform, platformToFailureConditionToNodePoolsCount) - if !isKnownPlatform { - platformToFailureConditionToNodePoolsCount[platform] = createFailureConditionToNodePoolsCountMap(knownConditionToExpectedStatus) - } + if hcData := hclusterPathToData[nodePool.Namespace+"/"+nodePool.Spec.ClusterName]; hcData != nil { + hclusterId = hcData.id + hcData.nodePoolsCount += 1 + c.aggregateVCpus(ctx, nodePool, hcData) + } - failureConditionToNodePoolsCount := platformToFailureConditionToNodePoolsCount[platform] + c.observeTransitionDurations(nodePool, currentCollectTime) - for _, condition := range nodePool.Status.Conditions { - expectedStatus, isKnownCondition := knownConditionToExpectedStatus[condition.Type] + nodePoolLabelValues := []string{nodePool.Namespace, nodePool.Name, hclusterId, nodePool.Spec.ClusterName, string(nodePool.Spec.Platform.Type)} + c.collectPerNodePoolMetrics(ch, nodePool, nodePoolLabelValues, machineSetPathToReplicasCount, machineDeploymentPathToReplicasCount) + } - if isKnownCondition && condition.Status != expectedStatus { - failureCondPrefix := "" + c.emitAggregatedMetrics(ch, platformToNodePoolsCount, platformToFailureConditionToNodePoolsCount, hclusterPathToData) - if expectedStatus == corev1.ConditionTrue { - failureCondPrefix = "not_" - } + c.transitionDurationMetric.Collect(ch) + c.lastCollectTime = currentCollectTime +} - failureCondition := failureCondPrefix + condition.Type +func (c *nodePoolsMetricsCollector) collectHostedClusterData(ctx context.Context) map[string]*hclusterData { + hclusterPathToData := make(map[string]*hclusterData) + hclusters := &hyperv1.HostedClusterList{} + if err := c.List(ctx, hclusters); err != nil { + ctrllog.Log.Error(err, "failed to list hosted clusters while collecting metrics") + } + for k := range hclusters.Items { + hcluster := &hclusters.Items[k] + data := &hclusterData{ + id: hcluster.Spec.ClusterID, + namespace: hcluster.Namespace, + name: hcluster.Name, + platform: hcluster.Spec.Platform.Type, + } + // Seed with Karpenter-managed vCPUs from AutoNode status. + // Native NodePool vCPUs accumulate on top in the NodePool loop below. + if hcluster.Status.AutoNode.VCPUs != nil { + data.vCpusCount = *hcluster.Status.AutoNode.VCPUs + } + hclusterPathToData[hcluster.Namespace+"/"+hcluster.Name] = data + } + return hclusterPathToData +} - (*failureConditionToNodePoolsCount)[failureCondition] += 1 - } - } - } +func (c *nodePoolsMetricsCollector) collectMachineSetReplicas(ctx context.Context) map[string]int32 { + result := make(map[string]int32) + machineSets := &capiv1.MachineSetList{} + if err := c.List(ctx, machineSets); err != nil { + ctrllog.Log.Error(err, "failed to list machine sets while collecting metrics") + } + for k := range machineSets.Items { + machineSet := &machineSets.Items[k] + result[machineSet.Namespace+"/"+machineSet.Name] = *machineSet.Spec.Replicas + } + return result +} - if hclusterData := hclusterPathToData[nodePool.Namespace+"/"+nodePool.Spec.ClusterName]; hclusterData != nil { - hclusterId = hclusterData.id - - // countByHClusterMetric - aggregation - hclusterData.nodePoolsCount += 1 - - // vCpusCountByHClusterMetric - aggregation - if hclusterData.vCpusCount >= 0 && nodePool.Status.Replicas > 0 { - nodeVCpus, err := c.retrieveVCpusDetailsPerNode(ctx, nodePool) - if err != nil { - hclusterData.vCpusCount = -1 - hclusterData.vCpusCountErr = err - } else { - hclusterData.vCpusCount += nodeVCpus * nodePool.Status.Replicas - } - } - } +func (c *nodePoolsMetricsCollector) collectMachineDeploymentReplicas(ctx context.Context) map[string]int32 { + result := make(map[string]int32) + machineDeployments := &capiv1.MachineDeploymentList{} + if err := c.List(ctx, machineDeployments); err != nil { + ctrllog.Log.Error(err, "failed to list machine deployments while collecting metrics") + } + for k := range machineDeployments.Items { + md := &machineDeployments.Items[k] + result[md.Namespace+"/"+md.Name] = *md.Spec.Replicas + } + return result +} - // transitionDurationMetric - aggregation - for i := range nodePool.Status.Conditions { - condition := &nodePool.Status.Conditions[i] - if _, isRetained := transitionDurationMetricConditions[condition.Type]; isRetained { - if condition.Status == corev1.ConditionTrue { - t := condition.LastTransitionTime.Time - - if c.lastCollectTime.Before(t) && (t.Before(currentCollectTime) || t.Equal(currentCollectTime)) { - c.transitionDurationMetric.With(map[string]string{"condition": condition.Type}).Observe(t.Sub(nodePool.CreationTimestamp.Time).Seconds()) - } - } - } +func (c *nodePoolsMetricsCollector) aggregateFailureConditions(nodePool *hyperv1.NodePool, platform hyperv1.PlatformType, platformMap map[hyperv1.PlatformType]*map[string]int) { + knownConditionToExpectedStatus := conditions.ExpectedNodePoolConditions(nodePool) + if _, isKnownPlatform := platformMap[platform]; !isKnownPlatform { + platformMap[platform] = createFailureConditionToNodePoolsCountMap(knownConditionToExpectedStatus) + } + failureConditionToNodePoolsCount := platformMap[platform] + for _, condition := range nodePool.Status.Conditions { + expectedStatus, isKnownCondition := knownConditionToExpectedStatus[condition.Type] + if isKnownCondition && condition.Status != expectedStatus { + failureCondPrefix := "" + if expectedStatus == corev1.ConditionTrue { + failureCondPrefix = "not_" } + (*failureConditionToNodePoolsCount)[failureCondPrefix+condition.Type] += 1 + } + } +} - nodePoolLabelValues := []string{nodePool.Namespace, nodePool.Name, hclusterId, nodePool.Spec.ClusterName, string(nodePool.Spec.Platform.Type)} - - // initialRollingOutDurationMetric - if nodePool.Status.Version == "" { - initializingDuration := c.clock.Since(nodePool.CreationTimestamp.Time).Seconds() - - ch <- prometheus.MustNewConstMetric( - initialRollingOutDurationMetricDesc, - prometheus.GaugeValue, - initializingDuration, - nodePoolLabelValues..., - ) - } +func (c *nodePoolsMetricsCollector) aggregateVCpus(ctx context.Context, nodePool *hyperv1.NodePool, hcData *hclusterData) { + if hcData.vCpusCount >= 0 && nodePool.Status.Replicas > 0 { + nodeVCpus, err := c.retrieveVCpusDetailsPerNode(ctx, nodePool) + if err != nil { + hcData.vCpusCount = -1 + hcData.vCpusCountErr = err + } else { + hcData.vCpusCount += nodeVCpus * nodePool.Status.Replicas + } + } +} - // sizeMetric - { - var pathToReplicasCount *map[string]int32 - - switch nodePool.Spec.Management.UpgradeType { - case hyperv1.UpgradeTypeInPlace: - // we use machineSet.Spec.Replicas because .Spec.Replicas will not be set if autoscaling is enabled - pathToReplicasCount = &machineSetPathToReplicasCount - case hyperv1.UpgradeTypeReplace: - // we use machineDeployment.Spec.Replicas because .Spec.Replicas will not be set if autoscaling is enabled - pathToReplicasCount = &machineDeploymentPathToReplicasCount - } - - if pathToReplicasCount != nil { - hcpNs := manifests.HostedControlPlaneNamespace(nodePool.Namespace, nodePool.Spec.ClusterName) - wishedReplicas := float64((*pathToReplicasCount)[hcpNs+"/"+nodePool.Name]) - - ch <- prometheus.MustNewConstMetric( - sizeMetricDesc, - prometheus.GaugeValue, - wishedReplicas, - nodePoolLabelValues..., - ) - } - } +func (c *nodePoolsMetricsCollector) observeTransitionDurations(nodePool *hyperv1.NodePool, currentCollectTime time.Time) { + for i := range nodePool.Status.Conditions { + condition := &nodePool.Status.Conditions[i] + if _, isRetained := transitionDurationMetricConditions[condition.Type]; !isRetained { + continue + } + if condition.Status != corev1.ConditionTrue { + continue + } + t := condition.LastTransitionTime.Time + if c.lastCollectTime.Before(t) && (t.Before(currentCollectTime) || t.Equal(currentCollectTime)) { + c.transitionDurationMetric.With(map[string]string{"condition": condition.Type}).Observe(t.Sub(nodePool.CreationTimestamp.Time).Seconds()) + } + } +} - // availableReplicasMetric - { - availableReplicas := float64(nodePool.Status.Replicas) +func (c *nodePoolsMetricsCollector) collectPerNodePoolMetrics(ch chan<- prometheus.Metric, nodePool *hyperv1.NodePool, labelValues []string, msReplicas, mdReplicas map[string]int32) { + if nodePool.Status.Version == "" { + ch <- prometheus.MustNewConstMetric( + initialRollingOutDurationMetricDesc, + prometheus.GaugeValue, + c.clock.Since(nodePool.CreationTimestamp.Time).Seconds(), + labelValues..., + ) + } - ch <- prometheus.MustNewConstMetric( - availableReplicasMetricDesc, - prometheus.GaugeValue, - availableReplicas, - nodePoolLabelValues..., - ) - } + var pathToReplicasCount *map[string]int32 + switch nodePool.Spec.Management.UpgradeType { + case hyperv1.UpgradeTypeInPlace: + pathToReplicasCount = &msReplicas + case hyperv1.UpgradeTypeReplace: + pathToReplicasCount = &mdReplicas + } + if pathToReplicasCount != nil { + hcpNs := manifests.HostedControlPlaneNamespace(nodePool.Namespace, nodePool.Spec.ClusterName) + ch <- prometheus.MustNewConstMetric( + sizeMetricDesc, + prometheus.GaugeValue, + float64((*pathToReplicasCount)[hcpNs+"/"+nodePool.Name]), + labelValues..., + ) + } - // deletingDurationMetric - if !nodePool.DeletionTimestamp.IsZero() { - deletingDuration := c.clock.Since(nodePool.DeletionTimestamp.Time).Seconds() + ch <- prometheus.MustNewConstMetric( + availableReplicasMetricDesc, + prometheus.GaugeValue, + float64(nodePool.Status.Replicas), + labelValues..., + ) - ch <- prometheus.MustNewConstMetric( - deletingDurationMetricDesc, - prometheus.GaugeValue, - deletingDuration, - nodePoolLabelValues..., - ) - } - } + if !nodePool.DeletionTimestamp.IsZero() { + ch <- prometheus.MustNewConstMetric( + deletingDurationMetricDesc, + prometheus.GaugeValue, + c.clock.Since(nodePool.DeletionTimestamp.Time).Seconds(), + labelValues..., + ) } +} - // AGGREGATED METRICS - - // countByPlatformMetric +func (c *nodePoolsMetricsCollector) emitAggregatedMetrics(ch chan<- prometheus.Metric, platformToNodePoolsCount map[hyperv1.PlatformType]int, platformToFailureConditionToNodePoolsCount map[hyperv1.PlatformType]*map[string]int, hclusterPathToData map[string]*hclusterData) { for platform, nodePoolsCount := range platformToNodePoolsCount { ch <- prometheus.MustNewConstMetric( countByPlatformMetricDesc, @@ -576,7 +540,6 @@ func (c *nodePoolsMetricsCollector) Collect(ch chan<- prometheus.Metric) { ) } - // countByPlatformAndFailureConditionMetric for platform, failureConditionToNodePoolsCount := range platformToFailureConditionToNodePoolsCount { for failureCondition, nodePoolsCount := range *failureConditionToNodePoolsCount { ch <- prometheus.MustNewConstMetric( @@ -589,38 +552,27 @@ func (c *nodePoolsMetricsCollector) Collect(ch chan<- prometheus.Metric) { } } - for _, hclusterData := range hclusterPathToData { - hclusterLabelValues := []string{hclusterData.namespace, hclusterData.name, hclusterData.id, string(hclusterData.platform)} - - // countByHClusterMetric + for _, hcData := range hclusterPathToData { + hclusterLabelValues := []string{hcData.namespace, hcData.name, hcData.id, string(hcData.platform)} ch <- prometheus.MustNewConstMetric( countByHClusterMetricDesc, prometheus.GaugeValue, - float64(hclusterData.nodePoolsCount), + float64(hcData.nodePoolsCount), hclusterLabelValues..., ) - - // vCpusCountByHClusterMetric ch <- prometheus.MustNewConstMetric( vCpusCountByHClusterMetricDesc, prometheus.GaugeValue, - float64(hclusterData.vCpusCount), + float64(hcData.vCpusCount), hclusterLabelValues..., ) - - // vCpusCountByHClusterMetric - if hclusterData.vCpusCountErr != nil { + if hcData.vCpusCountErr != nil { ch <- prometheus.MustNewConstMetric( vCpusComputationErrorByHClusterMetricDesc, prometheus.GaugeValue, 1.0, - append(hclusterLabelValues, hclusterData.vCpusCountErr.Error())..., + append(hclusterLabelValues, hcData.vCpusCountErr.Error())..., ) } } - - // transitionDurationMetric - c.transitionDurationMetric.Collect(ch) - - c.lastCollectTime = currentCollectTime } diff --git a/hypershift-operator/controllers/nodepool/nodepool_controller.go b/hypershift-operator/controllers/nodepool/nodepool_controller.go index 18840aa312c..b6d86acf2cf 100644 --- a/hypershift-operator/controllers/nodepool/nodepool_controller.go +++ b/hypershift-operator/controllers/nodepool/nodepool_controller.go @@ -254,6 +254,7 @@ func (r *NodePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return result, nil } +//nolint:gocyclo func (r *NodePoolReconciler) reconcile(ctx context.Context, hcluster *hyperv1.HostedCluster, nodePool *hyperv1.NodePool) (ctrl.Result, error) { log := ctrl.LoggerFrom(ctx) diff --git a/hypershift-operator/controllers/platform/aws/controller.go b/hypershift-operator/controllers/platform/aws/controller.go index 1c37a3def62..20a2d428fb6 100644 --- a/hypershift-operator/controllers/platform/aws/controller.go +++ b/hypershift-operator/controllers/platform/aws/controller.go @@ -463,13 +463,10 @@ func listKarpenterSubnetIDs(ctx context.Context, c client.Client, namespace stri return subnetIDs, nil } -func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointServiceStatus(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService, hostedCluster *hyperv1.HostedCluster, ec2Client awsapi.EC2API, elbv2Client awsapi.ELBV2API) error { - log := ctrl.LoggerFrom(ctx) - - // If a previous awsendpointservice that points to an ingress controller exists, remove it +func (r *AWSEndpointServiceReconciler) deleteObsoleteEndpointService(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService) (done bool, err error) { endpointServices := &hyperv1.AWSEndpointServiceList{} if err := r.List(ctx, endpointServices, client.InNamespace(awsEndpointService.Namespace)); err != nil { - return fmt.Errorf("failed to list aws endpoint services in namespace: %s: %w", awsEndpointService.Namespace, err) + return false, fmt.Errorf("failed to list aws endpoint services in namespace: %s: %w", awsEndpointService.Namespace, err) } privateRouterEPServiceName := fmt.Sprintf("router-%s", awsEndpointService.Namespace) hasPrivateRouterEPService := false @@ -482,27 +479,29 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointServiceStatus(ctx con hasPrivateIngressControllerEPService = true } } - // Only if both router and private ingress controller AWSEndpointServices exist, delete the obsolete one - if hasPrivateRouterEPService && hasPrivateIngressControllerEPService { - privateIngressControllerEPService := &hyperv1.AWSEndpointService{ - ObjectMeta: metav1.ObjectMeta{ - Name: privateRouterEPServiceName, - Namespace: awsEndpointService.Namespace, - }, - } - if err := r.Delete(ctx, privateIngressControllerEPService); err != nil { - return fmt.Errorf("failed to delete awsendpointservice %s: %w", client.ObjectKeyFromObject(privateIngressControllerEPService).String(), err) - } - // No need to further reconcile if the endpointservice is the one we just deleted. - if awsEndpointService.Name == privateRouterEPServiceName { - return nil - } + if !hasPrivateRouterEPService || !hasPrivateIngressControllerEPService { + return false, nil + } + privateIngressControllerEPService := &hyperv1.AWSEndpointService{ + ObjectMeta: metav1.ObjectMeta{ + Name: privateRouterEPServiceName, + Namespace: awsEndpointService.Namespace, + }, } + if err := r.Delete(ctx, privateIngressControllerEPService); err != nil { + return false, fmt.Errorf("failed to delete awsendpointservice %s: %w", client.ObjectKeyFromObject(privateIngressControllerEPService).String(), err) + } + if awsEndpointService.Name == privateRouterEPServiceName { + return true, nil + } + return false, nil +} - serviceName := awsEndpointService.Status.EndpointServiceName - var serviceID string +func (r *AWSEndpointServiceReconciler) ensureVpcEndpointService(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService, ec2Client awsapi.EC2API, elbv2Client awsapi.ELBV2API) (serviceName string, serviceID string, err error) { + log := ctrl.LoggerFrom(ctx) + + serviceName = awsEndpointService.Status.EndpointServiceName if len(serviceName) != 0 { - // check if Endpoint Service exists in AWS output, err := ec2Client.DescribeVpcEndpointServiceConfigurations(ctx, &ec2.DescribeVpcEndpointServiceConfigurationsInput{ Filters: []ec2types.Filter{ { @@ -514,93 +513,93 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointServiceStatus(ctx con if err != nil { var apiErr smithy.APIError if errors.As(err, &apiErr) { - return errors.New(apiErr.ErrorCode()) + return "", "", errors.New(apiErr.ErrorCode()) } - return err + return "", "", err } if len(output.ServiceConfigurations) == 0 { - // clear the EndpointServiceName so a new Endpoint Service is created on the requeue awsEndpointService.Status.EndpointServiceName = "" - return fmt.Errorf("endpoint service %s not found, resetting status", serviceName) + return "", "", fmt.Errorf("endpoint service %s not found, resetting status", serviceName) } serviceID = aws.ToString(output.ServiceConfigurations[0].ServiceId) log.Info("endpoint service exists", "serviceName", serviceName) - } else { - // determine the LB ARN - lbName := awsEndpointService.Spec.NetworkLoadBalancerName - output, err := elbv2Client.DescribeLoadBalancers(ctx, &elbv2.DescribeLoadBalancersInput{ - Names: []string{lbName}, - }) - if err != nil { - var smithyErr smithy.APIError - if errors.As(err, &smithyErr) { - return errors.New(smithyErr.ErrorCode()) - } - return err - } - if len(output.LoadBalancers) == 0 { - return fmt.Errorf("load balancer %s not found", lbName) - } - lb := output.LoadBalancers[0] - lbARN := lb.LoadBalancerArn - if lbARN == nil { - return fmt.Errorf("load balancer ARN is nil") - } - if lb.State == nil || lb.State.Code != elbv2types.LoadBalancerStateEnumActive { - return fmt.Errorf("load balancer %s is not yet active", *lbARN) + return serviceName, serviceID, nil + } + + lbName := awsEndpointService.Spec.NetworkLoadBalancerName + output, err := elbv2Client.DescribeLoadBalancers(ctx, &elbv2.DescribeLoadBalancersInput{ + Names: []string{lbName}, + }) + if err != nil { + var smithyErr smithy.APIError + if errors.As(err, &smithyErr) { + return "", "", errors.New(smithyErr.ErrorCode()) } + return "", "", err + } + if len(output.LoadBalancers) == 0 { + return "", "", fmt.Errorf("load balancer %s not found", lbName) + } + lb := output.LoadBalancers[0] + lbARN := lb.LoadBalancerArn + if lbARN == nil { + return "", "", fmt.Errorf("load balancer ARN is nil") + } + if lb.State == nil || lb.State.Code != elbv2types.LoadBalancerStateEnumActive { + return "", "", fmt.Errorf("load balancer %s is not yet active", *lbARN) + } - // create the Endpoint Service - tags := apiTagToEC2Tag(awsEndpointService.Spec.ResourceTags) - if r.ManagementClusterCapabilities.Has(capabilities.CapabilityInfrastructure) { - managementClusterInfrastructure := globalconfig.InfrastructureConfig() - if err := r.Get(ctx, client.ObjectKeyFromObject(managementClusterInfrastructure), managementClusterInfrastructure); err != nil { - return fmt.Errorf("failed to get management cluster infrastructure: %w", err) - } - tags = append(tags, ec2types.Tag{ - Key: aws.String("kubernetes.io/cluster/" + managementClusterInfrastructure.Status.InfrastructureName), - Value: aws.String("owned"), - }) - } - createEndpointServiceOutput, err := ec2Client.CreateVpcEndpointServiceConfiguration(ctx, &ec2.CreateVpcEndpointServiceConfigurationInput{ - // TODO: we should probably do some sort of automated acceptance check against the VPC ID in the HostedCluster - AcceptanceRequired: aws.Bool(false), - NetworkLoadBalancerArns: []string{aws.ToString(lbARN)}, - TagSpecifications: []ec2types.TagSpecification{{ - ResourceType: ec2types.ResourceTypeVpcEndpointService, - Tags: tags, - }}, + tags := apiTagToEC2Tag(awsEndpointService.Spec.ResourceTags) + if r.ManagementClusterCapabilities.Has(capabilities.CapabilityInfrastructure) { + managementClusterInfrastructure := globalconfig.InfrastructureConfig() + if err := r.Get(ctx, client.ObjectKeyFromObject(managementClusterInfrastructure), managementClusterInfrastructure); err != nil { + return "", "", fmt.Errorf("failed to get management cluster infrastructure: %w", err) + } + tags = append(tags, ec2types.Tag{ + Key: aws.String("kubernetes.io/cluster/" + managementClusterInfrastructure.Status.InfrastructureName), + Value: aws.String("owned"), }) - if err != nil { - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - if apiErr.ErrorCode() == "InvalidParameter" { - // TODO: optional filter by regex on error msg (could be fragile) - // e.g. "LBs are already associated with another VPC Endpoint Service Configuration" - log.Info("service endpoint might already exist, attempting adoption") - var err error - serviceName, serviceID, err = findExistingVpcEndpointService(ctx, ec2Client, aws.ToString(lbARN)) - if err != nil { - log.Info("existing endpoint service not found, adoption failed", "err", err) - return errors.New(apiErr.ErrorCode()) - } - } else { - return errors.New(apiErr.ErrorCode()) + } + + createEndpointServiceOutput, err := ec2Client.CreateVpcEndpointServiceConfiguration(ctx, &ec2.CreateVpcEndpointServiceConfigurationInput{ + // TODO: we should probably do some sort of automated acceptance check against the VPC ID in the HostedCluster + AcceptanceRequired: aws.Bool(false), + NetworkLoadBalancerArns: []string{aws.ToString(lbARN)}, + TagSpecifications: []ec2types.TagSpecification{{ + ResourceType: ec2types.ResourceTypeVpcEndpointService, + Tags: tags, + }}, + }) + if err != nil { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + if apiErr.ErrorCode() == "InvalidParameter" { + // TODO: optional filter by regex on error msg (could be fragile) + // e.g. "LBs are already associated with another VPC Endpoint Service Configuration" + log.Info("service endpoint might already exist, attempting adoption") + var adoptErr error + serviceName, serviceID, adoptErr = findExistingVpcEndpointService(ctx, ec2Client, aws.ToString(lbARN)) + if adoptErr != nil { + log.Info("existing endpoint service not found, adoption failed", "err", adoptErr) + return "", "", errors.New(apiErr.ErrorCode()) } + } else { + return "", "", errors.New(apiErr.ErrorCode()) } - if len(serviceName) == 0 { - return err - } - log.Info("endpoint service adopted", "serviceName", serviceName) - } else { - serviceName = aws.ToString(createEndpointServiceOutput.ServiceConfiguration.ServiceName) - serviceID = aws.ToString(createEndpointServiceOutput.ServiceConfiguration.ServiceId) - log.Info("endpoint service created", "serviceName", serviceName) } + if len(serviceName) == 0 { + return "", "", err + } + log.Info("endpoint service adopted", "serviceName", serviceName) + } else { + serviceName = aws.ToString(createEndpointServiceOutput.ServiceConfiguration.ServiceName) + serviceID = aws.ToString(createEndpointServiceOutput.ServiceConfiguration.ServiceId) + log.Info("endpoint service created", "serviceName", serviceName) } - awsEndpointService.Status.EndpointServiceName = serviceName + return serviceName, serviceID, nil +} - // reconcile permissions for aws endpoint service +func (r *AWSEndpointServiceReconciler) reconcileEndpointServicePermissions(ctx context.Context, serviceID string, hostedCluster *hyperv1.HostedCluster, ec2Client awsapi.EC2API) error { permResp, err := ec2Client.DescribeVpcEndpointServicePermissions(ctx, &ec2.DescribeVpcEndpointServicePermissionsInput{ ServiceId: aws.String(serviceID), }) @@ -620,23 +619,45 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointServiceStatus(ctx con desiredPerms := sets.NewString(controlPlaneOperatorRoleARN) desiredPerms = desiredPerms.Insert(hostedCluster.Spec.Platform.AWS.AdditionalAllowedPrincipals...) - if !desiredPerms.Equal(oldPerms) { - input := &ec2.ModifyVpcEndpointServicePermissionsInput{ - ServiceId: aws.String(serviceID), - } - if added := desiredPerms.Difference(oldPerms).List(); len(added) > 0 { - input.AddAllowedPrincipals = added - } - if removed := oldPerms.Difference(desiredPerms).List(); len(removed) > 0 { - input.RemoveAllowedPrincipals = removed - } - _, err := ec2Client.ModifyVpcEndpointServicePermissions(ctx, input) - if err != nil { - return fmt.Errorf("failed to update vpc endpoint permissions: %w", err) - } + if desiredPerms.Equal(oldPerms) { + return nil + } + + input := &ec2.ModifyVpcEndpointServicePermissionsInput{ + ServiceId: aws.String(serviceID), + } + if added := desiredPerms.Difference(oldPerms).List(); len(added) > 0 { + input.AddAllowedPrincipals = added + } + if removed := oldPerms.Difference(desiredPerms).List(); len(removed) > 0 { + input.RemoveAllowedPrincipals = removed + } + _, err = ec2Client.ModifyVpcEndpointServicePermissions(ctx, input) + if err != nil { + return fmt.Errorf("failed to update vpc endpoint permissions: %w", err) + } + return nil +} + +func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointServiceStatus(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService, hostedCluster *hyperv1.HostedCluster, ec2Client awsapi.EC2API, elbv2Client awsapi.ELBV2API) error { + done, err := r.deleteObsoleteEndpointService(ctx, awsEndpointService) + if err != nil { + return err + } + if done { + return nil + } + + serviceName, serviceID, err := r.ensureVpcEndpointService(ctx, awsEndpointService, ec2Client, elbv2Client) + if err != nil { + return err } awsEndpointService.Status.EndpointServiceName = serviceName + if err := r.reconcileEndpointServicePermissions(ctx, serviceID, hostedCluster, ec2Client); err != nil { + return err + } + return nil } diff --git a/hypershift-operator/controllers/platform/aws/controller_test.go b/hypershift-operator/controllers/platform/aws/controller_test.go index a3e6e12ad79..d895059bd1c 100644 --- a/hypershift-operator/controllers/platform/aws/controller_test.go +++ b/hypershift-operator/controllers/platform/aws/controller_test.go @@ -1000,6 +1000,257 @@ func TestEnqueueOnKarpenterConfigMapChange(t *testing.T) { } } +func TestListNodePools(t *testing.T) { + testCases := []struct { + name string + clusterName string + namespace string + objects []client.Object + expectedNames []string + expectError bool + }{ + { + name: "When NodePools exist for the given cluster, it should return only those matching the cluster name", + clusterName: "my-cluster", + namespace: "clusters", + objects: []client.Object{ + &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{Name: "np-1", Namespace: "clusters"}, + Spec: hyperv1.NodePoolSpec{ClusterName: "my-cluster"}, + }, + &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{Name: "np-2", Namespace: "clusters"}, + Spec: hyperv1.NodePoolSpec{ClusterName: "other-cluster"}, + }, + &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{Name: "np-3", Namespace: "clusters"}, + Spec: hyperv1.NodePoolSpec{ClusterName: "my-cluster"}, + }, + }, + expectedNames: []string{"np-1", "np-3"}, + }, + { + name: "When no NodePools match the cluster name, it should return an empty list", + clusterName: "my-cluster", + namespace: "clusters", + objects: []client.Object{ + &hyperv1.NodePool{ + ObjectMeta: metav1.ObjectMeta{Name: "np-1", Namespace: "clusters"}, + Spec: hyperv1.NodePoolSpec{ClusterName: "other-cluster"}, + }, + }, + expectedNames: []string{}, + }, + { + name: "When no NodePools exist in the namespace, it should return an empty list", + clusterName: "my-cluster", + namespace: "clusters", + objects: []client.Object{}, + expectedNames: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(tc.objects...). + Build() + + result, err := listNodePools(t.Context(), fakeClient, tc.namespace, tc.clusterName) + + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + names := make([]string, len(result)) + for i, np := range result { + names[i] = np.Name + } + g.Expect(names).To(ConsistOf(tc.expectedNames)) + } + }) + } +} + +func TestApiTagToEC2Tag(t *testing.T) { + testCases := []struct { + name string + input []hyperv1.AWSResourceTag + expected []ec2types.Tag + }{ + { + name: "When input is empty, it should return an empty slice", + input: []hyperv1.AWSResourceTag{}, + expected: []ec2types.Tag{}, + }, + { + name: "When input has multiple tags, it should convert all of them", + input: []hyperv1.AWSResourceTag{ + {Key: "env", Value: "prod"}, + {Key: "team", Value: "platform"}, + }, + expected: []ec2types.Tag{ + {Key: aws.String("env"), Value: aws.String("prod")}, + {Key: aws.String("team"), Value: aws.String("platform")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := apiTagToEC2Tag(tc.input) + g.Expect(result).To(HaveLen(len(tc.expected))) + for i := range result { + g.Expect(aws.ToString(result[i].Key)).To(Equal(aws.ToString(tc.expected[i].Key))) + g.Expect(aws.ToString(result[i].Value)).To(Equal(aws.ToString(tc.expected[i].Value))) + } + }) + } +} + +func TestHostedClusterNamespaceAndName(t *testing.T) { + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + expectedNamespace string + expectedName string + }{ + { + name: "When annotation exists with namespace/name, it should return both parts", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hypershift.openshift.io/cluster": "my-ns/my-cluster", + }, + }, + }, + expectedNamespace: "my-ns", + expectedName: "my-cluster", + }, + { + name: "When annotation is missing, it should return empty strings", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + }, + expectedNamespace: "", + expectedName: "", + }, + { + name: "When annotations map is nil, it should return empty strings", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{}, + }, + expectedNamespace: "", + expectedName: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + ns, name := hostedClusterNamespaceAndName(tc.hcp) + g.Expect(ns).To(Equal(tc.expectedNamespace)) + g.Expect(name).To(Equal(tc.expectedName)) + }) + } +} + +func TestToLowerState(t *testing.T) { + testCases := []struct { + name string + input ec2types.State + expected ec2types.State + }{ + { + name: "When input is PascalCase, it should return lowercase", + input: ec2types.StateAvailable, + expected: "available", + }, + { + name: "When input is already lowercase, it should return unchanged", + input: "pending", + expected: "pending", + }, + { + name: "When input is PendingAcceptance, it should return lowercase", + input: ec2types.StatePendingAcceptance, + expected: "pendingacceptance", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := toLowerState(tc.input) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestEnqueueOnKarpenterConfigMapCreate(t *testing.T) { + testCases := []struct { + name string + cm *corev1.ConfigMap + expectedQueued int + }{ + { + name: "When a karpenter-managed ConfigMap is created, it should enqueue AWSEndpointServices", + cm: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: karpenterutil.KarpenterSubnetsConfigMapName, + Namespace: "clusters-my-cluster", + Labels: map[string]string{ + "hypershift.openshift.io/managed-by": "karpenter", + }, + }, + }, + // awsEndpointServicesByName returns 3 entries + expectedQueued: 3, + }, + { + name: "When a non-karpenter ConfigMap is created, it should not enqueue", + cm: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "some-other-configmap", + Namespace: "clusters-my-cluster", + }, + }, + expectedQueued: 0, + }, + { + name: "When ConfigMap has karpenter name but lacks managed-by label, it should not enqueue", + cm: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: karpenterutil.KarpenterSubnetsConfigMapName, + Namespace: "clusters-my-cluster", + }, + }, + expectedQueued: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + mgr := &fakeManager{} + r := &AWSEndpointServiceReconciler{} + handler := r.enqueueOnKarpenterConfigMapCreate(mgr) + + q := &captureQueue{} + handler(t.Context(), event.CreateEvent{Object: tc.cm}, q) + + g.Expect(q.added).To(HaveLen(tc.expectedQueued)) + }) + } +} + // fakeManager implements just enough of ctrl.Manager for tests that need mgr.GetLogger(). // All unimplemented methods are delegated to the embedded nil Manager, which will // panic if called — intentionally, as tests should never trigger those paths. diff --git a/hypershift-operator/controllers/scheduler/aws/autoscaler.go b/hypershift-operator/controllers/scheduler/aws/autoscaler.go index 0e095405626..e22fa2b4ade 100644 --- a/hypershift-operator/controllers/scheduler/aws/autoscaler.go +++ b/hypershift-operator/controllers/scheduler/aws/autoscaler.go @@ -390,10 +390,6 @@ func machineSetsToScaleUp(pods []corev1.Pod, machineSets []machinev1beta1.Machin requiredNodeCounts := determineRequiredNodes(pendingPods, pods, nodes) var placeHoldersNeeded []nodeRequirement - - // First, the easy ones. If a specific pair label is required, - // find the corresponding machinesets that are not already - // scaled up. for _, r := range requiredNodeCounts { if r.pairLabel != "" { machineSetsToScale := filterMachineSets(machineSets, func(ms *machinev1beta1.MachineSet) bool { @@ -404,8 +400,6 @@ func machineSetsToScaleUp(pods []corev1.Pod, machineSets []machinev1beta1.Machin result = append(result, machineSetsToScale...) continue } - - // Otherwise, we need to find placeholders without a specific pair label placeHoldersNeeded = append(placeHoldersNeeded, r) } @@ -413,10 +407,17 @@ func machineSetsToScaleUp(pods []corev1.Pod, machineSets []machinev1beta1.Machin return result, pendingPods, requiredNodeCounts } - // Determine which pair labels we cannot - // use to schedule additional placeholders - // 1 - pair labels used by a cluster - // 2 - pair labels where a placeholder is already scheduled + takenPairLabels := collectTakenPairLabels(pods, nodes) + + for _, r := range placeHoldersNeeded { + scaled := scaleMachineSetsForRequirement(r, machineSets, machines, nodes, takenPairLabels) + result = append(result, scaled...) + } + + return result, pendingPods, requiredNodeCounts +} + +func collectTakenPairLabels(pods []corev1.Pod, nodes []corev1.Node) sets.Set[string] { takenPairLabels := sets.New[string]() for _, n := range nodes { if n.Labels[hyperv1.HostedClusterLabel] != "" { @@ -428,101 +429,83 @@ func machineSetsToScaleUp(pods []corev1.Pod, machineSets []machinev1beta1.Machin takenPairLabels.Insert(pairLabel) } } + return takenPairLabels +} - for _, r := range placeHoldersNeeded { - needCount := r.count - // First, find any available nodes of the specified size - // These are nodes that are created but may not be ready - // but will allow scheduling of the pods soon. - // Available nodes must: - // 1 - have the request serving label - // 2 - have matching size label - // 3 - have a pair label that is not already taken - availableNodes := filterNodes(nodes, func(n *corev1.Node) bool { - return n.Labels[hyperv1.RequestServingComponentLabel] != "" && - n.Labels[hyperv1.NodeSizeLabel] == r.sizeLabel && - !takenPairLabels.Has(n.Labels[OSDFleetManagerPairedNodesLabel]) - }) - needCount -= len(availableNodes) +func scaleMachineSetsForRequirement(r nodeRequirement, machineSets []machinev1beta1.MachineSet, machines []machinev1beta1.Machine, nodes []corev1.Node, takenPairLabels sets.Set[string]) []machinev1beta1.MachineSet { + var result []machinev1beta1.MachineSet + needCount := r.count - availableNodeMachineSets := sets.New[string]() - for i := range availableNodes { - msName := nodeMachineSet(&availableNodes[i], machines) - if msName == "" { - continue - } + availableNodes := filterNodes(nodes, func(n *corev1.Node) bool { + return n.Labels[hyperv1.RequestServingComponentLabel] != "" && + n.Labels[hyperv1.NodeSizeLabel] == r.sizeLabel && + !takenPairLabels.Has(n.Labels[OSDFleetManagerPairedNodesLabel]) + }) + needCount -= len(availableNodes) + + availableNodeMachineSets := sets.New[string]() + for i := range availableNodes { + msName := nodeMachineSet(&availableNodes[i], machines) + if msName != "" { availableNodeMachineSets.Insert(msName) } + } - // Second, find any machinesets that have already been scaled up - // but do not have any nodes yet. - // Pending machinesets must: - // 1 - have the request serving label - // 2 - have matching size label - // 3 - be scaled up without available replicas - // 4 - not correspond to any available nodes - // 5 - not have a pair label that is assigned to a cluster - pendingMachineSets := filterMachineSets(machineSets, func(ms *machinev1beta1.MachineSet) bool { - return isRequestServingMachineSet(ms) && - machineSetSize(ms) == r.sizeLabel && - ptr.Deref(ms.Spec.Replicas, 0) > 0 && - ms.Status.AvailableReplicas == 0 && - !availableNodeMachineSets.Has(ms.Name) && - !takenPairLabels.Has(machineSetPairLabel(ms)) - }) - needCount -= len(pendingMachineSets) + pendingMachineSets := filterMachineSets(machineSets, func(ms *machinev1beta1.MachineSet) bool { + return isRequestServingMachineSet(ms) && + machineSetSize(ms) == r.sizeLabel && + ptr.Deref(ms.Spec.Replicas, 0) > 0 && + ms.Status.AvailableReplicas == 0 && + !availableNodeMachineSets.Has(ms.Name) && + !takenPairLabels.Has(machineSetPairLabel(ms)) + }) + needCount -= len(pendingMachineSets) - if needCount < 1 { - continue - } + if needCount < 1 { + return nil + } - // Determine if there are pending machinesets that need the machineSet pair to also be scaled up - // and scale those up first - for _, ms := range pendingMachineSets { - if pairedMachineSet := matchingMachineSet(&ms, machineSets); pairedMachineSet != nil { - if ptr.Deref(pairedMachineSet.Spec.Replicas, 0) == 0 { - result = append(result, *pairedMachineSet) - needCount-- - } + for _, ms := range pendingMachineSets { + if pairedMachineSet := matchingMachineSet(&ms, machineSets); pairedMachineSet != nil { + if ptr.Deref(pairedMachineSet.Spec.Replicas, 0) == 0 { + result = append(result, *pairedMachineSet) + needCount-- } } + } + + if needCount < 1 { + return result + } + + result = append(result, pickAvailableMachineSetPairs(r.sizeLabel, needCount, machineSets, takenPairLabels)...) + return result +} - if needCount < 1 { +func pickAvailableMachineSetPairs(sizeLabel string, needCount int, machineSets []machinev1beta1.MachineSet, takenPairLabels sets.Set[string]) []machinev1beta1.MachineSet { + availableMachineSets := filterMachineSets(machineSets, func(ms *machinev1beta1.MachineSet) bool { + return isRequestServingMachineSet(ms) && + machineSetSize(ms) == sizeLabel && + ptr.Deref(ms.Spec.Replicas, 0) == 0 && + !takenPairLabels.Has(machineSetPairLabel(ms)) + }) + var result []machinev1beta1.MachineSet + toSkip := sets.New[string]() + for _, ms := range availableMachineSets { + if toSkip.Has(ms.Name) { continue } - - // Finally, pick random pairs from available machinesets - // Available machinesets must: - // 1 - have the request serving label - // 2 - have the corresponding size label - // 3 - not be scaled up - // 4 - have a pair label that is not already taken - availableMachineSets := filterMachineSets(machineSets, func(ms *machinev1beta1.MachineSet) bool { - return isRequestServingMachineSet(ms) && - machineSetSize(ms) == r.sizeLabel && - ptr.Deref(ms.Spec.Replicas, 0) == 0 && - !takenPairLabels.Has(machineSetPairLabel(ms)) - }) - var machineSetsToScaleUp []machinev1beta1.MachineSet - toSkip := sets.New[string]() - for _, ms := range availableMachineSets { - if toSkip.Has(ms.Name) { - continue - } - pairMachineSet := matchingMachineSet(&ms, availableMachineSets) - if pairMachineSet == nil { - continue - } - toSkip.Insert(pairMachineSet.Name) - machineSetsToScaleUp = append(machineSetsToScaleUp, ms, *pairMachineSet) - if len(machineSetsToScaleUp) >= needCount { - break - } + pairMachineSet := matchingMachineSet(&ms, availableMachineSets) + if pairMachineSet == nil { + continue + } + toSkip.Insert(pairMachineSet.Name) + result = append(result, ms, *pairMachineSet) + if len(result) >= needCount { + break } - result = append(result, machineSetsToScaleUp...) } - - return result, pendingPods, requiredNodeCounts + return result } type nodeRequirement struct { diff --git a/hypershift-operator/controllers/scheduler/aws/autoscaler_test.go b/hypershift-operator/controllers/scheduler/aws/autoscaler_test.go index de5d634ce2c..a56331d7285 100644 --- a/hypershift-operator/controllers/scheduler/aws/autoscaler_test.go +++ b/hypershift-operator/controllers/scheduler/aws/autoscaler_test.go @@ -862,3 +862,366 @@ func TestNonRequestServingMachineSetsToScale(t *testing.T) { }) } } + +func TestCollectTakenPairLabels(t *testing.T) { + tests := []struct { + name string + pods []corev1.Pod + nodes []corev1.Node + expected sets.Set[string] + }{ + { + name: "When there are no nodes and no pods, it should return empty set", + pods: nil, + nodes: nil, + expected: sets.New[string](), + }, + { + name: "When nodes have cluster labels, their pair labels should be collected", + pods: nil, + nodes: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + hyperv1.HostedClusterLabel: "ns-hc1", + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + hyperv1.HostedClusterLabel: "ns-hc2", + OSDFleetManagerPairedNodesLabel: "pair-b", + }, + }, + }, + }, + expected: sets.New[string]("pair-a", "pair-b"), + }, + { + name: "When nodes have no cluster label, their pair labels should not be collected", + pods: nil, + nodes: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + }, + expected: sets.New[string](), + }, + { + name: "When pods are scheduled on nodes with pair labels, those labels should be collected", + pods: []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{Name: "p1"}, + Spec: corev1.PodSpec{NodeName: "n1"}, + }, + }, + nodes: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-c", + }, + }, + }, + }, + expected: sets.New[string]("pair-c"), + }, + { + name: "When pods have pair label in node selector, those labels should be collected", + pods: []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{Name: "p1"}, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-d", + }, + }, + }, + }, + nodes: nil, + expected: sets.New[string]("pair-d"), + }, + { + name: "When both nodes and pods contribute pair labels, it should return the union", + pods: []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{Name: "p1"}, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-from-pod", + }, + }, + }, + }, + nodes: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + hyperv1.HostedClusterLabel: "ns-hc1", + OSDFleetManagerPairedNodesLabel: "pair-from-node", + }, + }, + }, + }, + expected: sets.New[string]("pair-from-pod", "pair-from-node"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + actual := collectTakenPairLabels(test.pods, test.nodes) + g.Expect(actual).To(Equal(test.expected)) + }) + } +} + +func TestScaleMachineSetsForRequirement(t *testing.T) { + reqServingLabel := map[string]string{ + hyperv1.RequestServingComponentLabel: "true", + } + + mkMachineSet := func(name, sizeLabel, pairLabel string, replicas int32, availableReplicas int32) machinev1beta1.MachineSet { + return machinev1beta1.MachineSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "openshift-machine-api", + }, + Spec: machinev1beta1.MachineSetSpec{ + Replicas: &replicas, + Template: machinev1beta1.MachineTemplateSpec{ + Spec: machinev1beta1.MachineSpec{ + ObjectMeta: machinev1beta1.ObjectMeta{ + Labels: map[string]string{ + hyperv1.NodeSizeLabel: sizeLabel, + OSDFleetManagerPairedNodesLabel: pairLabel, + hyperv1.RequestServingComponentLabel: reqServingLabel[hyperv1.RequestServingComponentLabel], + }, + }, + }, + }, + }, + Status: machinev1beta1.MachineSetStatus{ + AvailableReplicas: availableReplicas, + }, + } + } + + mkNode := func(name, sizeLabel, pairLabel string) corev1.Node { + return corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Annotations: map[string]string{ + machineNameNodeAnnotation: "openshift-machine-api/" + name + "-machine", + }, + Labels: map[string]string{ + hyperv1.RequestServingComponentLabel: "true", + hyperv1.NodeSizeLabel: sizeLabel, + OSDFleetManagerPairedNodesLabel: pairLabel, + }, + }, + } + } + + tests := []struct { + name string + requirement nodeRequirement + machineSets []machinev1beta1.MachineSet + machines []machinev1beta1.Machine + nodes []corev1.Node + takenPairLabels sets.Set[string] + expectedNames []string + }{ + { + name: "When available nodes satisfy the requirement, it should return nil", + requirement: nodeRequirement{sizeLabel: "small", count: 2}, + machineSets: []machinev1beta1.MachineSet{mkMachineSet("ms-1", "small", "pair-a", 0, 0)}, + machines: nil, + nodes: []corev1.Node{mkNode("n1", "small", "pair-free"), mkNode("n2", "small", "pair-free")}, + takenPairLabels: sets.New[string](), + expectedNames: nil, + }, + { + name: "When no nodes or pending machinesets exist, it should pick available machineset pairs", + requirement: nodeRequirement{sizeLabel: "small", count: 2}, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-a", 0, 0), + mkMachineSet("ms-1b", "small", "pair-a", 0, 0), + }, + machines: nil, + nodes: nil, + takenPairLabels: sets.New[string](), + expectedNames: []string{"ms-1a", "ms-1b"}, + }, + { + name: "When a pending machineset exists without its pair scaled up, it should scale the pair", + requirement: nodeRequirement{sizeLabel: "small", count: 2}, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-a", 1, 0), + mkMachineSet("ms-1b", "small", "pair-a", 0, 0), + }, + machines: nil, + nodes: nil, + takenPairLabels: sets.New[string](), + expectedNames: []string{"ms-1b"}, + }, + { + name: "When taken pair labels exclude some machinesets, those should be skipped", + requirement: nodeRequirement{sizeLabel: "small", count: 2}, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-taken", 0, 0), + mkMachineSet("ms-1b", "small", "pair-taken", 0, 0), + mkMachineSet("ms-2a", "small", "pair-free", 0, 0), + mkMachineSet("ms-2b", "small", "pair-free", 0, 0), + }, + machines: nil, + nodes: nil, + takenPairLabels: sets.New[string]("pair-taken"), + expectedNames: []string{"ms-2a", "ms-2b"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + actual := scaleMachineSetsForRequirement(test.requirement, test.machineSets, test.machines, test.nodes, test.takenPairLabels) + actualNames := make([]string, 0, len(actual)) + for _, ms := range actual { + actualNames = append(actualNames, ms.Name) + } + if len(test.expectedNames) == 0 { + g.Expect(actual).To(BeEmpty()) + } else { + g.Expect(sets.New(actualNames...)).To(Equal(sets.New(test.expectedNames...))) + } + }) + } +} + +func TestPickAvailableMachineSetPairs(t *testing.T) { + mkMachineSet := func(name, sizeLabel, pairLabel string) machinev1beta1.MachineSet { + return machinev1beta1.MachineSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "openshift-machine-api", + }, + Spec: machinev1beta1.MachineSetSpec{ + Replicas: func() *int32 { r := int32(0); return &r }(), + Template: machinev1beta1.MachineTemplateSpec{ + Spec: machinev1beta1.MachineSpec{ + ObjectMeta: machinev1beta1.ObjectMeta{ + Labels: map[string]string{ + hyperv1.NodeSizeLabel: sizeLabel, + OSDFleetManagerPairedNodesLabel: pairLabel, + hyperv1.RequestServingComponentLabel: "true", + }, + }, + }, + }, + }, + } + } + + tests := []struct { + name string + sizeLabel string + needCount int + machineSets []machinev1beta1.MachineSet + takenPairLabels sets.Set[string] + expectedNames []string + }{ + { + name: "When there are no available machinesets, it should return empty", + sizeLabel: "small", + needCount: 2, + machineSets: nil, + takenPairLabels: sets.New[string](), + expectedNames: nil, + }, + { + name: "When a complete pair is available, it should return both machinesets", + sizeLabel: "small", + needCount: 2, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-1"), + mkMachineSet("ms-1b", "small", "pair-1"), + }, + takenPairLabels: sets.New[string](), + expectedNames: []string{"ms-1a", "ms-1b"}, + }, + { + name: "When pair label is taken, it should skip that pair", + sizeLabel: "small", + needCount: 2, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-taken"), + mkMachineSet("ms-1b", "small", "pair-taken"), + mkMachineSet("ms-2a", "small", "pair-free"), + mkMachineSet("ms-2b", "small", "pair-free"), + }, + takenPairLabels: sets.New[string]("pair-taken"), + expectedNames: []string{"ms-2a", "ms-2b"}, + }, + { + name: "When a machineset has no matching pair, it should be skipped", + sizeLabel: "small", + needCount: 2, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-orphan"), + }, + takenPairLabels: sets.New[string](), + expectedNames: nil, + }, + { + name: "When size does not match, it should be skipped", + sizeLabel: "small", + needCount: 2, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "medium", "pair-1"), + mkMachineSet("ms-1b", "medium", "pair-1"), + }, + takenPairLabels: sets.New[string](), + expectedNames: nil, + }, + { + name: "When needCount is satisfied, it should stop picking pairs", + sizeLabel: "small", + needCount: 2, + machineSets: []machinev1beta1.MachineSet{ + mkMachineSet("ms-1a", "small", "pair-1"), + mkMachineSet("ms-1b", "small", "pair-1"), + mkMachineSet("ms-2a", "small", "pair-2"), + mkMachineSet("ms-2b", "small", "pair-2"), + }, + takenPairLabels: sets.New[string](), + expectedNames: []string{"ms-1a", "ms-1b"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + actual := pickAvailableMachineSetPairs(test.sizeLabel, test.needCount, test.machineSets, test.takenPairLabels) + actualNames := make([]string, 0, len(actual)) + for _, ms := range actual { + actualNames = append(actualNames, ms.Name) + } + if len(test.expectedNames) == 0 { + g.Expect(actual).To(BeEmpty()) + } else { + g.Expect(sets.New(actualNames...)).To(Equal(sets.New(test.expectedNames...))) + } + }) + } +} diff --git a/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes.go b/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes.go index 40bf8b4c583..69dbb720cae 100644 --- a/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes.go +++ b/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes.go @@ -162,7 +162,6 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req return ctrl.Result{}, nil } - // Find existing dedicated serving content Nodes for this HC. dedicatedNodesForHC := &corev1.NodeList{} if err := r.List(ctx, dedicatedNodesForHC, client.HasLabels{hyperv1.RequestServingComponentLabel}, @@ -175,7 +174,6 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req return ctrl.Result{}, fmt.Errorf("found too many dedicated nodes for HC: %v", len(dedicatedNodesForHC.Items)) } - // We check existing dedicated Nodes are 2. If not e.g. some was deleted, continue. if scheduled := hcluster.Annotations[hyperv1.HostedClusterScheduledAnnotation]; scheduled == "true" && len(dedicatedNodesForHC.Items) == 2 { log.Info("hosted cluster is already scheduled, nothing to do") return ctrl.Result{}, nil @@ -186,12 +184,28 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req return ctrl.Result{}, fmt.Errorf("failed to list nodes: %w", err) } + nodesToUse := r.findExistingNodesForCluster(ctx, nodeList, hcluster) + if len(nodesToUse) < 2 { + r.findAvailableNodes(ctx, nodeList, nodesToUse) + } + if len(nodesToUse) < 2 { + return ctrl.Result{}, fmt.Errorf("failed to find enough available nodes for cluster, found %d", len(nodesToUse)) + } + + if err := r.labelAndTaintNodes(ctx, hcluster, nodesToUse); err != nil { + return ctrl.Result{}, err + } + + return ctrl.Result{}, r.updateHostedClusterAnnotations(ctx, hcluster, nodesToUse) +} + +func (r *DedicatedServingComponentScheduler) findExistingNodesForCluster(ctx context.Context, nodeList *corev1.NodeList, hcluster *hyperv1.HostedCluster) map[string]*corev1.Node { + log := ctrl.LoggerFrom(ctx) nodesToUse := map[string]*corev1.Node{} - // first, find any existing nodes already labeled for this hostedcluster + hcValue := fmt.Sprintf("%s-%s", hcluster.Namespace, hcluster.Name) for i := range nodeList.Items { node := &nodeList.Items[i] if !node.DeletionTimestamp.IsZero() { - // Skip nodes that are being deleted continue } zone, hasZoneLabel := node.Labels["topology.kubernetes.io/zone"] @@ -202,76 +216,59 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req if !hasHCLabel { continue } - if hcLabel == fmt.Sprintf("%s-%s", hcluster.Namespace, hcluster.Name) { + if hcLabel == hcValue { nodesToUse[zone] = node log.Info("Found existing node for hosted cluster", "node", node.Name, "zone", zone) } } + return nodesToUse +} - if len(nodesToUse) < 2 { - for i := range nodeList.Items { - node := &nodeList.Items[i] - zone, hasZoneLabel := node.Labels["topology.kubernetes.io/zone"] - if !hasZoneLabel { - // No zone has been set on the node, we cannot use it - continue - } - - _, hasHCLabel := node.Labels[hyperv1.HostedClusterLabel] - if hasHCLabel { - // The node has been allocated to a different hosted cluster, skip it - continue - } - - if nodesToUse[zone] == nil { - - // if the candidate Node is not paired with the existing node to use then skip. - paired := false - if len(nodesToUse) > 0 { - for _, n := range nodesToUse { - if n.Labels[OSDFleetManagerPairedNodesLabel] == node.Labels[OSDFleetManagerPairedNodesLabel] { - paired = true - } - } - if !paired { - continue - } - } - - log.Info("Found node to allocate for hosted cluster", "node", node.Name, "zone", zone) - nodesToUse[zone] = node - } - - if len(nodesToUse) == 2 { - break - } +func (r *DedicatedServingComponentScheduler) findAvailableNodes(ctx context.Context, nodeList *corev1.NodeList, nodesToUse map[string]*corev1.Node) { + log := ctrl.LoggerFrom(ctx) + for i := range nodeList.Items { + node := &nodeList.Items[i] + zone, hasZoneLabel := node.Labels["topology.kubernetes.io/zone"] + if !hasZoneLabel { + continue + } + _, hasHCLabel := node.Labels[hyperv1.HostedClusterLabel] + if hasHCLabel { + continue + } + if nodesToUse[zone] != nil { + continue + } + if !isNodePairedWith(node, nodesToUse) { + continue + } + log.Info("Found node to allocate for hosted cluster", "node", node.Name, "zone", zone) + nodesToUse[zone] = node + if len(nodesToUse) == 2 { + break } } - if len(nodesToUse) < 2 { - return ctrl.Result{}, fmt.Errorf("failed to find enough available nodes for cluster, found %d", len(nodesToUse)) +} + +func isNodePairedWith(candidate *corev1.Node, existing map[string]*corev1.Node) bool { + if len(existing) == 0 { + return true } + for _, n := range existing { + if n.Labels[OSDFleetManagerPairedNodesLabel] == candidate.Labels[OSDFleetManagerPairedNodesLabel] { + return true + } + } + return false +} - nodeGoMemLimit := "" - lbSubnets := "" - pairLabel := "" +func (r *DedicatedServingComponentScheduler) labelAndTaintNodes(ctx context.Context, hcluster *hyperv1.HostedCluster, nodesToUse map[string]*corev1.Node) error { + log := ctrl.LoggerFrom(ctx) + hcNameValue := fmt.Sprintf("%s-%s", hcluster.Namespace, hcluster.Name) for _, node := range nodesToUse { originalNode := node.DeepCopy() - if node.Labels[schedulerutil.GoMemLimitLabel] != "" && nodeGoMemLimit == "" { - nodeGoMemLimit = node.Labels[schedulerutil.GoMemLimitLabel] - } - if node.Labels[schedulerutil.LBSubnetsLabel] != "" && lbSubnets == "" { - lbSubnets = node.Labels[schedulerutil.LBSubnetsLabel] - // If subnets are separated by periods, replace them with commas - lbSubnets = strings.ReplaceAll(lbSubnets, ".", ",") - } - if node.Labels[OSDFleetManagerPairedNodesLabel] != "" && pairLabel == "" { - pairLabel = node.Labels[OSDFleetManagerPairedNodesLabel] - } - - // Add taint and labels for specific hosted cluster hasTaint := false - hcNameValue := fmt.Sprintf("%s-%s", hcluster.Namespace, hcluster.Name) for i := range node.Spec.Taints { if node.Spec.Taints[i].Key == HostedClusterTaint { node.Spec.Taints[i].Value = hcNameValue @@ -292,12 +289,31 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req node.Labels[HostedClusterNamespaceLabel] = hcluster.Namespace if err := r.Patch(ctx, node, client.MergeFromWithOptions(originalNode, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update labels and taints on node %s: %w", node.Name, err) + return fmt.Errorf("failed to update labels and taints on node %s: %w", node.Name, err) } log.Info("Node tainted and labeled for hosted cluster", "node", node.Name) } + return nil +} + +func (r *DedicatedServingComponentScheduler) updateHostedClusterAnnotations(ctx context.Context, hcluster *hyperv1.HostedCluster, nodesToUse map[string]*corev1.Node) error { + log := ctrl.LoggerFrom(ctx) + nodeGoMemLimit := "" + lbSubnets := "" + pairLabel := "" + for _, node := range nodesToUse { + if node.Labels[schedulerutil.GoMemLimitLabel] != "" && nodeGoMemLimit == "" { + nodeGoMemLimit = node.Labels[schedulerutil.GoMemLimitLabel] + } + if node.Labels[schedulerutil.LBSubnetsLabel] != "" && lbSubnets == "" { + lbSubnets = node.Labels[schedulerutil.LBSubnetsLabel] + lbSubnets = strings.ReplaceAll(lbSubnets, ".", ",") + } + if node.Labels[OSDFleetManagerPairedNodesLabel] != "" && pairLabel == "" { + pairLabel = node.Labels[OSDFleetManagerPairedNodesLabel] + } + } - // finally update HostedCluster with new annotation log.Info("Setting scheduled annotation on hosted cluster") originalHcluster := hcluster.DeepCopy() hcluster.Annotations[hyperv1.HostedClusterScheduledAnnotation] = "true" @@ -312,10 +328,9 @@ func (r *DedicatedServingComponentScheduler) Reconcile(ctx context.Context, req fmt.Sprintf("%s=%s", OSDFleetManagerPairedNodesLabel, pairLabel) } if err := r.Patch(ctx, hcluster, client.MergeFrom(originalHcluster)); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update hostedcluster annotation: %w", err) + return fmt.Errorf("failed to update hostedcluster annotation: %w", err) } - - return ctrl.Result{}, nil + return nil } const requestServingSchedulerAndSizerName = "DedicatedServingComponentSchedulerAndSizer" @@ -431,26 +446,7 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte return ctrl.Result{}, fmt.Errorf("failed to get cluster %q: %w", req.NamespacedName, err) } if !hc.DeletionTimestamp.IsZero() { - log.Info("hostedcluster is deleted, cleaning up") - if controllerutil.ContainsFinalizer(hc, schedulerFinalizer) { - if controllerutil.ContainsFinalizer(hc, hostedcluster.HostedClusterFinalizer) { - // Wait until the hosted cluster finalizer is removed - return ctrl.Result{}, nil - } - // Ensure that any placeholder deployment is deleted - if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { - return ctrl.Result{}, err - } - if err := r.deletePairConfigMaps(ctx, hc); err != nil { - return ctrl.Result{}, err - } - controllerutil.RemoveFinalizer(hc, schedulerFinalizer) - if err := r.Update(ctx, hc); err != nil { - return ctrl.Result{}, err - } - } - - return ctrl.Result{}, nil + return r.handleDeletion(ctx, hc) } if hcTopology := hc.Annotations[hyperv1.TopologyAnnotation]; hcTopology != hyperv1.DedicatedRequestServingComponentsTopology { log.Info("hostedcluster does not use isolated request serving components, nothing to do") @@ -467,7 +463,7 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte isPaused, duration, err := util.ProcessPausedUntilField(hc.Spec.PausedUntil, time.Now()) if err != nil { log.Error(err, "error processing hosted cluster paused field") - return ctrl.Result{}, nil // user needs to reformat the field, returning error is useless + return ctrl.Result{}, nil } if isPaused { log.Info("Reconciliation paused", "pausedUntil", *hc.Spec.PausedUntil) @@ -491,16 +487,67 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte return ctrl.Result{}, nil } - // Find existing dedicated serving content Nodes for this HC. - dedicatedNodes := &corev1.NodeList{} - if err := r.List(ctx, dedicatedNodes, - client.HasLabels{hyperv1.RequestServingComponentLabel}, - ); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to list nodes: %w", err) + goalNodes, availableNodes, pairLabel, err := r.classifyDedicatedNodes(ctx, hc, desiredSize) + if err != nil { + return ctrl.Result{}, err + } + if pairLabel == "" { + pairLabel, err = r.pairLabelFromConfigMaps(ctx, hc.Namespace, hc.Name) + if err != nil { + return ctrl.Result{}, fmt.Errorf("failed to get pair label from configmaps: %w", err) + } } - var goalNodes, availableNodes []corev1.Node - var pairLabel string + log = log.WithValues("pairLabel", pairLabel) + + if result, done, err := r.backfillOrClaimNodes(ctx, hc, desiredSize, pairLabel, availableNodes, &config); done { + return result, err + } + + nodesByZone := r.goalNodesByZone(goalNodes) + log = log.WithValues("nodes", nodeNamesByZoneMap(nodesByZone)) + + if len(nodesByZone) > 1 { + log.Info("sufficient nodes exist for placement") + if err := schedulerutil.UpdateHostedCluster(ctx, r.Client, hc, desiredSize, &config, goalNodes); err != nil { + return ctrl.Result{}, err + } + log.Info("removing placeholder") + if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + } + + return r.deployAndLabelPlaceholderNodes(ctx, hc, desiredSize, pairLabel, nodesByZone) +} + +func (r *DedicatedServingComponentSchedulerAndSizer) handleDeletion(ctx context.Context, hc *hyperv1.HostedCluster) (ctrl.Result, error) { + log := ctrl.LoggerFrom(ctx) + log.Info("hostedcluster is deleted, cleaning up") + if controllerutil.ContainsFinalizer(hc, schedulerFinalizer) { + if controllerutil.ContainsFinalizer(hc, hostedcluster.HostedClusterFinalizer) { + return ctrl.Result{}, nil + } + if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { + return ctrl.Result{}, err + } + if err := r.deletePairConfigMaps(ctx, hc); err != nil { + return ctrl.Result{}, err + } + controllerutil.RemoveFinalizer(hc, schedulerFinalizer) + if err := r.Update(ctx, hc); err != nil { + return ctrl.Result{}, err + } + } + return ctrl.Result{}, nil +} + +func (r *DedicatedServingComponentSchedulerAndSizer) classifyDedicatedNodes(ctx context.Context, hc *hyperv1.HostedCluster, desiredSize string) (goalNodes, availableNodes []corev1.Node, pairLabel string, err error) { + dedicatedNodes := &corev1.NodeList{} + if err := r.List(ctx, dedicatedNodes, client.HasLabels{hyperv1.RequestServingComponentLabel}); err != nil { + return nil, nil, "", fmt.Errorf("failed to list nodes: %w", err) + } for _, node := range dedicatedNodes.Items { if !node.DeletionTimestamp.IsZero() { continue @@ -516,24 +563,14 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte availableNodes = append(availableNodes, node) } } - if pairLabel == "" { - // If no nodes were labeled, but only a configmap was created, find the pair label - // to use from the configmaps - pairLabel, err = r.pairLabelFromConfigMaps(ctx, hc.Namespace, hc.Name) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get pair label from configmaps: %w", err) - } - } - - log = log.WithValues("pairLabel", pairLabel) + return goalNodes, availableNodes, pairLabel, nil +} - // Find any nodes that are in the same fleet manager group and have the right size - // but are not labeled with the hosted cluster label. Ensure that these nodes are labeled - // and tainted with the hosted cluster label. This can happen if not all nodes were labeled/tainted - // when they were initially selected. +func (r *DedicatedServingComponentSchedulerAndSizer) backfillOrClaimNodes(ctx context.Context, hc *hyperv1.HostedCluster, desiredSize, pairLabel string, availableNodes []corev1.Node, config *schedulingv1alpha1.ClusterSizingConfiguration) (ctrl.Result, bool, error) { + log := ctrl.LoggerFrom(ctx) if pairLabel != "" { if err := r.ensurePairConfigMap(ctx, pairLabel, hc.Namespace, hc.Name); err != nil { - return ctrl.Result{}, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) + return ctrl.Result{}, true, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) } var needClusterLabel []corev1.Node for _, node := range availableNodes { @@ -545,70 +582,75 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte log.Info("backfilling node labels") for _, node := range needClusterLabel { if err := r.ensureHostedClusterLabelAndTaint(ctx, hc, &node); err != nil { - return ctrl.Result{}, err + return ctrl.Result{}, true, err } } - return ctrl.Result{Requeue: true}, nil - } - } else { - // If there isn't a current pair label, then we can select from available nodes selected by placeholders. - sizeConfig := schedulerutil.SizeConfiguration(&config, desiredSize) - if sizeConfig == nil { - return ctrl.Result{}, fmt.Errorf("could not find size configuration for size %s", desiredSize) + return ctrl.Result{Requeue: true}, true, nil } + return ctrl.Result{}, false, nil + } - // If placeholders are present, use those - if sizeConfig.Management != nil && sizeConfig.Management.Placeholders > 0 { - candidateNodes, err := r.nodesFromPlaceholders(ctx, desiredSize) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get nodes from placeholders: %w", err) - } - if len(candidateNodes) > 0 { - pairLabel = candidateNodes[0].Labels[OSDFleetManagerPairedNodesLabel] - if pairLabel == "" { - return ctrl.Result{}, fmt.Errorf("node %s has no pair label", candidateNodes[0].Name) - } - log.WithValues("pairLabel", candidateNodes[0].Labels[OSDFleetManagerPairedNodesLabel]).Info("claiming candidate nodes") - if err := r.ensurePairConfigMap(ctx, pairLabel, hc.Namespace, hc.Name); err != nil { - return ctrl.Result{}, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) - } - for _, node := range candidateNodes { - if err := r.ensureHostedClusterLabelAndTaint(ctx, hc, &node); err != nil { - return ctrl.Result{}, err - } - } - return ctrl.Result{Requeue: true}, nil - } + result, claimed, err := r.claimPlaceholderNodes(ctx, hc, desiredSize, config) + if claimed { + return result, true, err + } + return ctrl.Result{}, false, nil +} + +func (r *DedicatedServingComponentSchedulerAndSizer) claimPlaceholderNodes(ctx context.Context, hc *hyperv1.HostedCluster, desiredSize string, config *schedulingv1alpha1.ClusterSizingConfiguration) (ctrl.Result, bool, error) { + log := ctrl.LoggerFrom(ctx) + sizeConfig := schedulerutil.SizeConfiguration(config, desiredSize) + if sizeConfig == nil { + return ctrl.Result{}, true, fmt.Errorf("could not find size configuration for size %s", desiredSize) + } + if sizeConfig.Management == nil || sizeConfig.Management.Placeholders <= 0 { + return ctrl.Result{}, false, nil + } + candidateNodes, err := r.nodesFromPlaceholders(ctx, desiredSize) + if err != nil { + return ctrl.Result{}, true, fmt.Errorf("failed to get nodes from placeholders: %w", err) + } + if len(candidateNodes) == 0 { + return ctrl.Result{}, false, nil + } + pairLabel := candidateNodes[0].Labels[OSDFleetManagerPairedNodesLabel] + if pairLabel == "" { + return ctrl.Result{}, true, fmt.Errorf("node %s has no pair label", candidateNodes[0].Name) + } + log.WithValues("pairLabel", pairLabel).Info("claiming candidate nodes") + if err := r.ensurePairConfigMap(ctx, pairLabel, hc.Namespace, hc.Name); err != nil { + return ctrl.Result{}, true, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) + } + for _, node := range candidateNodes { + if err := r.ensureHostedClusterLabelAndTaint(ctx, hc, &node); err != nil { + return ctrl.Result{}, true, err } } + return ctrl.Result{Requeue: true}, true, nil +} - nodeNamesByZone := map[string]string{} +func (r *DedicatedServingComponentSchedulerAndSizer) goalNodesByZone(goalNodes []corev1.Node) map[string]corev1.Node { nodesByZone := map[string]corev1.Node{} for _, node := range goalNodes { if zone := node.Labels[corev1.LabelTopologyZone]; zone != "" { if _, hasNode := nodesByZone[zone]; !hasNode { nodesByZone[zone] = node - nodeNamesByZone[zone] = node.Name } } } - log = log.WithValues("nodes", nodeNamesByZone) + return nodesByZone +} - if len(nodesByZone) > 1 { - log.Info("sufficient nodes exist for placement") - // If we have enough nodes, update the hosted cluster. - if err := schedulerutil.UpdateHostedCluster(ctx, r.Client, hc, desiredSize, &config, goalNodes); err != nil { - return ctrl.Result{}, err - } - // Ensure we don't have a placeholder deployment, since we have nodes - log.Info("removing placeholder") - if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { - return ctrl.Result{}, err - } - return ctrl.Result{}, nil +func nodeNamesByZoneMap(nodesByZone map[string]corev1.Node) map[string]string { + result := map[string]string{} + for zone, node := range nodesByZone { + result[zone] = node.Name } + return result +} - // Create a deployment to ensure nodes of the right size are created +func (r *DedicatedServingComponentSchedulerAndSizer) deployAndLabelPlaceholderNodes(ctx context.Context, hc *hyperv1.HostedCluster, desiredSize, pairLabel string, nodesByZone map[string]corev1.Node) (ctrl.Result, error) { + log := ctrl.LoggerFrom(ctx) nodesNeeded := 2 - len(nodesByZone) if nodesNeeded < 0 { nodesNeeded = 0 @@ -618,36 +660,42 @@ func (r *DedicatedServingComponentSchedulerAndSizer) Reconcile(ctx context.Conte if err != nil { return ctrl.Result{}, err } - if deployment != nil && podspec.IsDeploymentReady(ctx, deployment) { - log.Info("placeholder ready, adding node labels") - nodes, err := r.deploymentNodes(ctx, deployment) - if err != nil { + if deployment == nil || !podspec.IsDeploymentReady(ctx, deployment) { + return ctrl.Result{}, nil + } + log.Info("placeholder ready, adding node labels") + nodes, err := r.deploymentNodes(ctx, deployment) + if err != nil { + return ctrl.Result{}, err + } + pairLabel, err = r.resolvePairLabelFromNodes(nodes) + if err != nil { + return ctrl.Result{}, err + } + if err = r.ensurePairConfigMap(ctx, pairLabel, hc.Namespace, hc.Name); err != nil { + return ctrl.Result{}, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) + } + for _, node := range nodes { + if err := r.ensureHostedClusterLabelAndTaint(ctx, hc, &node); err != nil { return ctrl.Result{}, err } - pairLabel = "" - if len(nodes) > 0 { - pairLabel = nodes[0].Labels[OSDFleetManagerPairedNodesLabel] - if pairLabel == "" { - return ctrl.Result{}, fmt.Errorf("node %s has no fleetmanager pair label", nodes[0].Name) - } - } + } + log.Info("removing placeholder") + if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil +} + +func (r *DedicatedServingComponentSchedulerAndSizer) resolvePairLabelFromNodes(nodes []corev1.Node) (string, error) { + if len(nodes) > 0 { + pairLabel := nodes[0].Labels[OSDFleetManagerPairedNodesLabel] if pairLabel == "" { - return ctrl.Result{}, fmt.Errorf("cannot determine pair label") - } - if err = r.ensurePairConfigMap(ctx, pairLabel, hc.Namespace, hc.Name); err != nil { - return ctrl.Result{}, fmt.Errorf("cannot ensure pair label %s config map: %w", pairLabel, err) - } - for _, node := range nodes { - if err := r.ensureHostedClusterLabelAndTaint(ctx, hc, &node); err != nil { - return ctrl.Result{}, err - } - } - log.Info("removing placeholder") - if err := r.deletePlaceholderDeployment(ctx, hc); err != nil { - return ctrl.Result{}, err + return "", fmt.Errorf("node %s has no fleetmanager pair label", nodes[0].Name) } + return pairLabel, nil } - return ctrl.Result{}, nil + return "", fmt.Errorf("cannot determine pair label") } func (r *DedicatedServingComponentSchedulerAndSizer) ensurePairConfigMap(ctx context.Context, pairLabel, hcNamespace, hcName string) error { diff --git a/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes_test.go b/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes_test.go index 6fd19f7bf2e..c074bc2991a 100644 --- a/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes_test.go +++ b/hypershift-operator/controllers/scheduler/aws/dedicated_request_serving_nodes_test.go @@ -9,6 +9,8 @@ import ( hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" schedulingv1alpha1 "github.com/openshift/hypershift/api/scheduling/v1alpha1" + "github.com/openshift/hypershift/hypershift-operator/controllers/hostedcluster" + schedulerutil "github.com/openshift/hypershift/hypershift-operator/controllers/scheduler/util" hyperapi "github.com/openshift/hypershift/support/api" appsv1 "k8s.io/api/apps/v1" @@ -867,3 +869,988 @@ func TestFilterNodeEvents(t *testing.T) { }) } } + +func TestIsNodePairedWith(t *testing.T) { + tests := []struct { + name string + candidate *corev1.Node + existing map[string]*corev1.Node + expected bool + }{ + { + name: "When there are no existing nodes, it should return true", + candidate: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + existing: map[string]*corev1.Node{}, + expected: true, + }, + { + name: "When the candidate has the same pair label as an existing node, it should return true", + candidate: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-x": { + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + }, + expected: true, + }, + { + name: "When the candidate has a different pair label from all existing nodes, it should return false", + candidate: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-b", + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-x": { + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + }, + expected: false, + }, + { + name: "When multiple existing nodes have mixed pair labels, it should return true if any matches", + candidate: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-b", + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-x": { + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-a", + }, + }, + }, + "zone-y": { + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-b", + }, + }, + }, + }, + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + actual := isNodePairedWith(test.candidate, test.existing) + g.Expect(actual).To(Equal(test.expected)) + }) + } +} + +func TestGoalNodesByZone(t *testing.T) { + tests := []struct { + name string + goalNodes []corev1.Node + expectedKeys []string + }{ + { + name: "When there are no goal nodes, it should return empty map", + goalNodes: nil, + expectedKeys: nil, + }, + { + name: "When nodes are in different zones, it should return one node per zone", + goalNodes: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "n1", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1a"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "n2", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1b"}}}, + }, + expectedKeys: []string{"us-east-1a", "us-east-1b"}, + }, + { + name: "When multiple nodes are in the same zone, it should keep only the first", + goalNodes: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "n1", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1a"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "n2", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1a"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "n3", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1b"}}}, + }, + expectedKeys: []string{"us-east-1a", "us-east-1b"}, + }, + { + name: "When a node has no zone label, it should be skipped", + goalNodes: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "n1", Labels: map[string]string{}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "n2", Labels: map[string]string{corev1.LabelTopologyZone: "us-east-1b"}}}, + }, + expectedKeys: []string{"us-east-1b"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + r := &DedicatedServingComponentSchedulerAndSizer{} + result := r.goalNodesByZone(test.goalNodes) + g.Expect(result).To(HaveLen(len(test.expectedKeys))) + for _, key := range test.expectedKeys { + g.Expect(result).To(HaveKey(key)) + } + }) + } +} + +func TestNodeNamesByZoneMap(t *testing.T) { + tests := []struct { + name string + input map[string]corev1.Node + expected map[string]string + }{ + { + name: "When there are no nodes, it should return empty map", + input: map[string]corev1.Node{}, + expected: map[string]string{}, + }, + { + name: "When there are nodes, it should map zone to node name", + input: map[string]corev1.Node{ + "zone-a": {ObjectMeta: metav1.ObjectMeta{Name: "node-1"}}, + "zone-b": {ObjectMeta: metav1.ObjectMeta{Name: "node-2"}}, + }, + expected: map[string]string{ + "zone-a": "node-1", + "zone-b": "node-2", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + result := nodeNamesByZoneMap(test.input) + g.Expect(result).To(Equal(test.expected)) + }) + } +} + +func TestResolvePairLabelFromNodes(t *testing.T) { + tests := []struct { + name string + nodes []corev1.Node + expected string + expectError bool + }{ + { + name: "When there are no nodes, it should return an error", + nodes: nil, + expectError: true, + }, + { + name: "When the first node has a pair label, it should return that label", + nodes: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "n1", Labels: map[string]string{OSDFleetManagerPairedNodesLabel: "pair-x"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "n2", Labels: map[string]string{OSDFleetManagerPairedNodesLabel: "pair-x"}}}, + }, + expected: "pair-x", + }, + { + name: "When the first node has no pair label, it should return an error", + nodes: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "n1", Labels: map[string]string{}}}, + }, + expectError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + r := &DedicatedServingComponentSchedulerAndSizer{} + result, err := r.resolvePairLabelFromNodes(test.nodes) + if test.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).To(Equal(test.expected)) + } + }) + } +} + +func TestFindExistingNodesForCluster(t *testing.T) { + hc := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc1", + }, + } + hcValue := "ns-hc1" + + tests := []struct { + name string + nodeList *corev1.NodeList + expectedNames []string + }{ + { + name: "When there are no nodes, it should return empty map", + nodeList: &corev1.NodeList{}, + expectedNames: nil, + }, + { + name: "When a node matches the cluster, it should be included", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-a", + hyperv1.HostedClusterLabel: hcValue, + }, + }, + }, + }, + }, + expectedNames: []string{"n1"}, + }, + { + name: "When a node is being deleted, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + DeletionTimestamp: &metav1.Time{Time: metav1.Now().Time}, + Finalizers: []string{"test"}, + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-a", + hyperv1.HostedClusterLabel: hcValue, + }, + }, + }, + }, + }, + expectedNames: nil, + }, + { + name: "When a node has no zone label, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + hyperv1.HostedClusterLabel: hcValue, + }, + }, + }, + }, + }, + expectedNames: nil, + }, + { + name: "When a node belongs to a different cluster, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-a", + hyperv1.HostedClusterLabel: "other-ns-other-hc", + }, + }, + }, + }, + }, + expectedNames: nil, + }, + { + name: "When nodes are in different zones for the same cluster, it should return both", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-a", + hyperv1.HostedClusterLabel: hcValue, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-b", + hyperv1.HostedClusterLabel: hcValue, + }, + }, + }, + }, + }, + expectedNames: []string{"n1", "n2"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + r := &DedicatedServingComponentScheduler{} + result := r.findExistingNodesForCluster(t.Context(), test.nodeList, hc) + g.Expect(result).To(HaveLen(len(test.expectedNames))) + for _, name := range test.expectedNames { + found := false + for _, node := range result { + if node.Name == name { + found = true + break + } + } + g.Expect(found).To(BeTrue(), "expected node %s not found in result", name) + } + }) + } +} + +func TestFindAvailableNodes(t *testing.T) { + tests := []struct { + name string + nodeList *corev1.NodeList + existing map[string]*corev1.Node + expectedLen int + expectedNames []string + }{ + { + name: "When an unassigned node is in a new zone with matching pair label, it should be selected", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-b", + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + expectedLen: 2, + expectedNames: []string{"n2"}, + }, + { + name: "When a node has no zone label, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{}, + }, + }, + }, + }, + existing: map[string]*corev1.Node{}, + expectedLen: 0, + }, + { + name: "When a node is already assigned to a cluster, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-b", + hyperv1.HostedClusterLabel: "other-cluster", + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + }, + existing: map[string]*corev1.Node{}, + expectedLen: 0, + }, + { + name: "When a node has a non-matching pair label, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-b", + OSDFleetManagerPairedNodesLabel: "pair-2", + }, + }, + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + expectedLen: 1, + }, + { + name: "When a zone already has a node in use, it should be skipped", + nodeList: &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "n2", + Labels: map[string]string{ + "topology.kubernetes.io/zone": "zone-a", + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + }, + existing: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + expectedLen: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + r := &DedicatedServingComponentScheduler{} + r.findAvailableNodes(t.Context(), test.nodeList, test.existing) + g.Expect(test.existing).To(HaveLen(test.expectedLen)) + for _, name := range test.expectedNames { + found := false + for _, node := range test.existing { + if node.Name == name { + found = true + break + } + } + g.Expect(found).To(BeTrue(), "expected node %s not found in result", name) + } + }) + } +} + +func TestUpdateHostedClusterAnnotations(t *testing.T) { + tests := []struct { + name string + nodesToUse map[string]*corev1.Node + expectedAnnotations map[string]string + }{ + { + name: "When nodes have GoMemLimit, LBSubnets, and pair label, it should set all annotations", + nodesToUse: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + schedulerutil.GoMemLimitLabel: "4096", + schedulerutil.LBSubnetsLabel: "subnet-1.subnet-2", + OSDFleetManagerPairedNodesLabel: "pair-1", + }, + }, + }, + }, + expectedAnnotations: map[string]string{ + hyperv1.HostedClusterScheduledAnnotation: "true", + hyperv1.KubeAPIServerGOMemoryLimitAnnotation: "4096", + hyperv1.AWSLoadBalancerSubnetsAnnotation: "subnet-1,subnet-2", + hyperv1.AWSLoadBalancerTargetNodesAnnotation: OSDFleetManagerPairedNodesLabel + "=pair-1", + }, + }, + { + name: "When nodes have no optional labels, it should only set the scheduled annotation", + nodesToUse: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{}, + }, + }, + }, + expectedAnnotations: map[string]string{ + hyperv1.HostedClusterScheduledAnnotation: "true", + }, + }, + { + name: "When LBSubnets use periods as separators, it should replace with commas", + nodesToUse: map[string]*corev1.Node{ + "zone-a": { + ObjectMeta: metav1.ObjectMeta{ + Name: "n1", + Labels: map[string]string{ + schedulerutil.LBSubnetsLabel: "subnet-a.subnet-b.subnet-c", + }, + }, + }, + }, + expectedAnnotations: map[string]string{ + hyperv1.HostedClusterScheduledAnnotation: "true", + hyperv1.AWSLoadBalancerSubnetsAnnotation: "subnet-a,subnet-b,subnet-c", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + hc := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-hc", + Annotations: map[string]string{}, + }, + } + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(hc).Build() + r := &DedicatedServingComponentScheduler{Client: c} + err := r.updateHostedClusterAnnotations(t.Context(), hc, test.nodesToUse) + g.Expect(err).ToNot(HaveOccurred()) + + updated := &hyperv1.HostedCluster{} + err = c.Get(t.Context(), client.ObjectKeyFromObject(hc), updated) + g.Expect(err).ToNot(HaveOccurred()) + for key, value := range test.expectedAnnotations { + g.Expect(updated.Annotations).To(HaveKeyWithValue(key, value)) + } + }) + } +} + +func TestHandleDeletion(t *testing.T) { + now := metav1.Now() + + tests := []struct { + name string + hc *hyperv1.HostedCluster + additionalObjects []client.Object + expectFinalizer bool + }{ + { + name: "When HC has no scheduler finalizer, it should return without changes", + hc: &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-hc", + DeletionTimestamp: &now, + Finalizers: []string{"other-finalizer"}, + Annotations: map[string]string{}, + }, + }, + expectFinalizer: false, + }, + { + name: "When HC still has the hostedcluster finalizer, it should wait and keep the scheduler finalizer", + hc: &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-hc", + DeletionTimestamp: &now, + Finalizers: []string{schedulerFinalizer, hostedcluster.HostedClusterFinalizer}, + Annotations: map[string]string{}, + }, + }, + expectFinalizer: true, + }, + { + name: "When HC has scheduler finalizer and no hostedcluster finalizer, it should remove the finalizer", + hc: &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-hc", + DeletionTimestamp: &now, + Finalizers: []string{schedulerFinalizer}, + Annotations: map[string]string{}, + }, + }, + expectFinalizer: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + objs := []client.Object{test.hc} + objs = append(objs, test.additionalObjects...) + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(objs...).Build() + r := &DedicatedServingComponentSchedulerAndSizer{ + Client: c, + createOrUpdate: controllerutil.CreateOrUpdate, + } + _, err := r.handleDeletion(t.Context(), test.hc) + g.Expect(err).ToNot(HaveOccurred()) + + updated := &hyperv1.HostedCluster{} + err = c.Get(t.Context(), client.ObjectKeyFromObject(test.hc), updated) + if test.expectFinalizer { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(controllerutil.ContainsFinalizer(updated, schedulerFinalizer)).To(BeTrue()) + } else { + if err == nil { + g.Expect(controllerutil.ContainsFinalizer(updated, schedulerFinalizer)).To(BeFalse()) + } + } + }) + } +} + +func TestClassifyDedicatedNodes(t *testing.T) { + hc := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc1", + }, + } + hcKey := "ns-hc1" + + mkNode := func(name, hcLabel, pairLabel, sizeLabel string, deleting bool) corev1.Node { + n := corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: map[string]string{ + hyperv1.RequestServingComponentLabel: "true", + hyperv1.NodeSizeLabel: sizeLabel, + OSDFleetManagerPairedNodesLabel: pairLabel, + }, + }, + } + if hcLabel != "" { + n.Labels[hyperv1.HostedClusterLabel] = hcLabel + } + if deleting { + now := metav1.Now() + n.DeletionTimestamp = &now + n.Finalizers = []string{"test"} + } + return n + } + + tests := []struct { + name string + nodes []client.Object + desiredSize string + expectedGoalLen int + expectedAvailLen int + expectedPairLabel string + }{ + { + name: "When there are no nodes, it should return empty slices", + nodes: nil, + desiredSize: "small", + expectedGoalLen: 0, + expectedAvailLen: 0, + }, + { + name: "When nodes are labeled for the cluster with matching size and pair, they should be goal nodes", + nodes: []client.Object{ + func() client.Object { n := mkNode("n1", hcKey, "pair-1", "small", false); return &n }(), + func() client.Object { n := mkNode("n2", hcKey, "pair-1", "small", false); return &n }(), + }, + desiredSize: "small", + expectedGoalLen: 2, + expectedAvailLen: 0, + expectedPairLabel: "pair-1", + }, + { + name: "When nodes have no cluster label, they should be available nodes", + nodes: []client.Object{ + func() client.Object { n := mkNode("n1", "", "pair-1", "small", false); return &n }(), + }, + desiredSize: "small", + expectedGoalLen: 0, + expectedAvailLen: 1, + }, + { + name: "When a node is being deleted, it should be skipped entirely", + nodes: []client.Object{ + func() client.Object { n := mkNode("n1", hcKey, "pair-1", "small", true); return &n }(), + }, + desiredSize: "small", + expectedGoalLen: 0, + expectedAvailLen: 0, + }, + { + name: "When a cluster node has wrong size, it should not be a goal node", + nodes: []client.Object{ + func() client.Object { n := mkNode("n1", hcKey, "pair-1", "medium", false); return &n }(), + }, + desiredSize: "small", + expectedGoalLen: 0, + expectedAvailLen: 0, + expectedPairLabel: "pair-1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(test.nodes...).Build() + r := &DedicatedServingComponentSchedulerAndSizer{Client: c} + goalNodes, availableNodes, pairLabel, err := r.classifyDedicatedNodes(t.Context(), hc, test.desiredSize) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(goalNodes).To(HaveLen(test.expectedGoalLen)) + g.Expect(availableNodes).To(HaveLen(test.expectedAvailLen)) + g.Expect(pairLabel).To(Equal(test.expectedPairLabel)) + }) + } +} + +func TestEnsurePairConfigMap(t *testing.T) { + tests := []struct { + name string + existing []client.Object + pairLabel string + hcNamespace string + hcName string + expectError bool + }{ + { + name: "When no configmap exists, it should create one", + existing: nil, + pairLabel: "pair-1", + hcNamespace: "ns", + hcName: "hc1", + }, + { + name: "When configmap exists for the same cluster, it should succeed", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "ns", + clusterNameKey: "hc1", + }, + }, + }, + pairLabel: "pair-1", + hcNamespace: "ns", + hcName: "hc1", + }, + { + name: "When configmap exists for a different cluster, it should return conflict error", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "other-ns", + clusterNameKey: "other-hc", + }, + }, + }, + pairLabel: "pair-1", + hcNamespace: "ns", + hcName: "hc1", + expectError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(test.existing...).Build() + r := &DedicatedServingComponentSchedulerAndSizer{Client: c} + err := r.ensurePairConfigMap(t.Context(), test.pairLabel, test.hcNamespace, test.hcName) + if test.expectError { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("conflict")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + cm := &corev1.ConfigMap{} + err = c.Get(t.Context(), types.NamespacedName{Namespace: placeholderNamespace, Name: test.pairLabel}, cm) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(cm.Data[clusterNamespaceKey]).To(Equal(test.hcNamespace)) + g.Expect(cm.Data[clusterNameKey]).To(Equal(test.hcName)) + } + }) + } +} + +func TestPairLabelFromConfigMaps(t *testing.T) { + tests := []struct { + name string + existing []client.Object + namespace string + hcName string + expected string + }{ + { + name: "When no configmaps exist, it should return empty string", + existing: nil, + namespace: "ns", + hcName: "hc1", + expected: "", + }, + { + name: "When a matching configmap exists, it should return the pair label", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "ns", + clusterNameKey: "hc1", + }, + }, + }, + namespace: "ns", + hcName: "hc1", + expected: "pair-1", + }, + { + name: "When configmaps exist for other clusters only, it should return empty string", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "other-ns", + clusterNameKey: "other-hc", + }, + }, + }, + namespace: "ns", + hcName: "hc1", + expected: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(test.existing...).Build() + r := &DedicatedServingComponentSchedulerAndSizer{Client: c} + result, err := r.pairLabelFromConfigMaps(t.Context(), test.namespace, test.hcName) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).To(Equal(test.expected)) + }) + } +} + +func TestDeletePairConfigMaps(t *testing.T) { + hc := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "hc1", + }, + } + + tests := []struct { + name string + existing []client.Object + expectedRemaining int + }{ + { + name: "When there are no configmaps, it should succeed", + existing: nil, + expectedRemaining: 0, + }, + { + name: "When configmaps match the cluster, they should be deleted", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "ns", + clusterNameKey: "hc1", + }, + }, + }, + expectedRemaining: 0, + }, + { + name: "When configmaps belong to a different cluster, they should not be deleted", + existing: []client.Object{ + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: placeholderNamespace, + Name: "pair-1", + Labels: map[string]string{pairLabelKey: "pair-1"}, + }, + Data: map[string]string{ + clusterNamespaceKey: "other-ns", + clusterNameKey: "other-hc", + }, + }, + }, + expectedRemaining: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := NewWithT(t) + c := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).WithObjects(test.existing...).Build() + r := &DedicatedServingComponentSchedulerAndSizer{Client: c} + err := r.deletePairConfigMaps(t.Context(), hc) + g.Expect(err).ToNot(HaveOccurred()) + + cmList := &corev1.ConfigMapList{} + err = c.List(t.Context(), cmList, client.InNamespace(placeholderNamespace)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(cmList.Items).To(HaveLen(test.expectedRemaining)) + }) + } +} diff --git a/hypershift-operator/main.go b/hypershift-operator/main.go index 2e3a13d8817..7f54779aeba 100644 --- a/hypershift-operator/main.go +++ b/hypershift-operator/main.go @@ -67,17 +67,20 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v5" admissionregistrationv1 "k8s.io/api/admissionregistration/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/discovery" _ "k8s.io/client-go/plugin/pkg/client/auth/gcp" + "k8s.io/client-go/rest" "k8s.io/utils/ptr" "k8s.io/utils/set" @@ -225,11 +228,119 @@ func NewStartCommand() *cobra.Command { func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { log.Info("Starting hypershift-operator-manager", "version", supportedversion.String()) + if err := validateStartOptions(opts, log); err != nil { + return err + } + + restConfig := ctrl.GetConfigOrDie() + restConfig.UserAgent = "hypershift-operator-manager" + kubeDiscoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig) + if err != nil { + return fmt.Errorf("unable to create discovery client: %w", err) + } + + mgmtClusterCaps, err := capabilities.DetectManagementClusterCapabilities(kubeDiscoveryClient) + if err != nil { + return fmt.Errorf("unable to detect cluster capabilities: %w", err) + } + + webhookOptions, err := configureWebhookOptions(ctx, restConfig, mgmtClusterCaps, opts) + if err != nil { + return err + } + + mgr, err := createManager(restConfig, webhookOptions, opts) + if err != nil { + return err + } + + operatorImage, err := resolveOperatorImage(ctx, mgr, opts, log) + if err != nil { + return err + } + + createOrUpdate := upsert.New(opts.EnableCIDebugOutput) + + metricsSet, sreConfigHash, err := setupMetricsSet(mgr, opts, log) + if err != nil { + return err + } + + apiReadingClient, err := crclient.New(mgr.GetConfig(), crclient.Options{Scheme: hyperapi.Scheme}) + if err != nil { + return fmt.Errorf("failed to construct api reading client: %w", err) + } + + if err := reconcileDeprecationValidatingAdmissionPolicy(ctx, apiReadingClient, mgmtClusterCaps, log); err != nil { + return fmt.Errorf("failed to reconcile deprecation ValidatingAdmissionPolicy: %w", err) + } + + registryProvider, err := globalconfig.NewCommonRegistryProvider(ctx, mgmtClusterCaps, apiReadingClient, opts.RegistryOverrides) + if err != nil { + return fmt.Errorf("failed to create registry provider: %w", err) + } + + if err := setupHostedClusterController(ctx, mgr, opts, mgmtClusterCaps, operatorImage, createOrUpdate, metricsSet, sreConfigHash, registryProvider, log); err != nil { + return err + } + + if err := cleanupLegacyWebhook(ctx, mgr, opts); err != nil { + return err + } + + ec2Client := setupEC2Client(ctx, opts) + npmetrics.CreateAndRegisterNodePoolsMetricsCollector(mgr.GetClient(), ec2Client) + + if err := setupNodePoolController(ctx, mgr, opts, operatorImage, createOrUpdate, registryProvider, ec2Client, log); err != nil { + return err + } + + if mgmtClusterCaps.Has(capabilities.CapabilityProxy) { + if err := proxy.Setup(mgr, opts.Namespace, opts.DeploymentName); err != nil { + return fmt.Errorf("failed to set up the proxy controller: %w", err) + } + } + + enableSizeTagging := os.Getenv("ENABLE_SIZE_TAGGING") == "1" + if enableSizeTagging { + if err := hostedclustersizing.SetupWithManager(ctx, mgr, operatorImage, registryProvider.ReleaseProvider, registryProvider.MetadataProvider); err != nil { + return fmt.Errorf("failed to set up hosted cluster sizing operator: %w", err) + } + } + + if err := setupPlatformControllers(mgr, opts, mgmtClusterCaps, createOrUpdate, log); err != nil { + return err + } + + if err := setupSupportControllers(mgr, opts, mgmtClusterCaps, operatorImage, createOrUpdate, registryProvider, log); err != nil { + return err + } + + if err := setupSchedulerControllers(ctx, mgr, opts, createOrUpdate, enableSizeTagging, log); err != nil { + return err + } + + if err := reconcileDefaultIngressController(ctx, apiReadingClient, log); err != nil { + return err + } + + if err := setupOperatorInfoMetric(mgr); err != nil { + return fmt.Errorf("failed to setup metrics: %w", err) + } + + if err := setupAuditLogPersistence(mgr, opts, log); err != nil { + return err + } + + log.Info("starting manager") + return mgr.Start(ctx) +} + +func validateStartOptions(opts *StartOptions, log logr.Logger) error { if opts.EtcdBackupMaxCount < 1 { return fmt.Errorf("--etcd-backup-max-count must be at least 1, got %d", opts.EtcdBackupMaxCount) } - // Validate scale-from-zero configuration early supportedProviders := set.New("aws") if opts.ScaleFromZeroCreds != "" { if opts.ScaleFromZeroProvider == "" { @@ -248,44 +359,40 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { } else if opts.ScaleFromZeroProvider != "" { log.Info("WARNING: --scale-from-zero-provider is set but --scale-from-zero-creds is empty; scale-from-zero will be disabled", "provider", opts.ScaleFromZeroProvider) } + return nil +} - restConfig := ctrl.GetConfigOrDie() - restConfig.UserAgent = "hypershift-operator-manager" - kubeDiscoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig) - if err != nil { - return fmt.Errorf("unable to create discovery client: %w", err) +func configureWebhookOptions(ctx context.Context, restConfig *rest.Config, mgmtClusterCaps *capabilities.ManagementClusterCapabilities, opts *StartOptions) (webhook.Options, error) { + webhookOptions := webhook.Options{Port: 9443, CertDir: opts.CertDir} + if !mgmtClusterCaps.Has(capabilities.CapabilityAPIServer) { + return webhookOptions, nil } - mgmtClusterCaps, err := capabilities.DetectManagementClusterCapabilities(kubeDiscoveryClient) + configClient, err := configv1.NewForConfig(restConfig) if err != nil { - return fmt.Errorf("unable to detect cluster capabilities: %w", err) + return webhookOptions, fmt.Errorf("unable to create config client: %w", err) } - webhookOptions := webhook.Options{Port: 9443, CertDir: opts.CertDir} - if mgmtClusterCaps.Has(capabilities.CapabilityAPIServer) { - configClient, err := configv1.NewForConfig(restConfig) - if err != nil { - return fmt.Errorf("unable to create config client: %w", err) - } - - apiServerConfig, err := configClient.APIServers().Get(ctx, "cluster", metav1.GetOptions{}) - if err != nil { - return fmt.Errorf("unable to get the api server config: %w", err) - } - - minTLSVersionSetter, err := config.SetMinTLSVersionUsingAPIServer(apiServerConfig) - if err != nil { - return fmt.Errorf("unable to configure webhook server tls version: %w", err) - } + apiServerConfig, err := configClient.APIServers().Get(ctx, "cluster", metav1.GetOptions{}) + if err != nil { + return webhookOptions, fmt.Errorf("unable to get the api server config: %w", err) + } - cipherSuitesSetter, err := config.SetCipherSuitesUsingAPIServer(apiServerConfig) - if err != nil { - return fmt.Errorf("unable to configure webhook server cipher suites: %w", err) - } + minTLSVersionSetter, err := config.SetMinTLSVersionUsingAPIServer(apiServerConfig) + if err != nil { + return webhookOptions, fmt.Errorf("unable to configure webhook server tls version: %w", err) + } - webhookOptions.TLSOpts = []func(*tls.Config){minTLSVersionSetter, cipherSuitesSetter} + cipherSuitesSetter, err := config.SetCipherSuitesUsingAPIServer(apiServerConfig) + if err != nil { + return webhookOptions, fmt.Errorf("unable to configure webhook server cipher suites: %w", err) } + webhookOptions.TLSOpts = []func(*tls.Config){minTLSVersionSetter, cipherSuitesSetter} + return webhookOptions, nil +} + +func createManager(restConfig *rest.Config, webhookOptions webhook.Options, opts *StartOptions) (ctrl.Manager, error) { leaseDuration := time.Second * 60 renewDeadline := time.Second * 40 retryPeriod := time.Second * 15 @@ -310,9 +417,12 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { RetryPeriod: &retryPeriod, }) if err != nil { - return fmt.Errorf("unable to start manager: %w", err) + return nil, fmt.Errorf("unable to start manager: %w", err) } + return mgr, nil +} +func resolveOperatorImage(ctx context.Context, mgr ctrl.Manager, opts *StartOptions, log logr.Logger) (string, error) { lookupOperatorImage := func(userSpecifiedImage string) (string, error) { if len(userSpecifiedImage) > 0 { log.Info("using image from arguments", "image", userSpecifiedImage) @@ -322,10 +432,7 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { if err := mgr.GetAPIReader().Get(ctx, crclient.ObjectKeyFromObject(me), me); err != nil { return "", fmt.Errorf("failed to get operator pod %s: %w", crclient.ObjectKeyFromObject(me), err) } - // Use the container status to make sure we get the sha256 reference rather than a potentially - // floating tag. for _, container := range me.Status.ContainerStatuses { - // TODO: could use downward API for this too, overkill? if container.Name == "operator" { return strings.TrimPrefix(container.ImageID, "docker-pullable://"), nil } @@ -333,28 +440,29 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { return "", fmt.Errorf("couldn't locate operator container on deployment") } var operatorImage string + var err error if err := wait.PollUntilContextTimeout(ctx, 5*time.Second, 30*time.Second, true, func(ctx context.Context) (bool, error) { operatorImage, err = lookupOperatorImage(opts.ControlPlaneOperatorImage) if err != nil { return false, err } - // Apparently this is occasionally set to an empty string if operatorImage == "" { log.Info("operator image is empty, retrying") return false, nil } return true, nil }); err != nil { - return fmt.Errorf("failed to find operator image: %w", err) + return "", fmt.Errorf("failed to find operator image: %w", err) } log.Info("using hosted control plane operator image", "operator-image", operatorImage) + return operatorImage, nil +} - createOrUpdate := upsert.New(opts.EnableCIDebugOutput) - +func setupMetricsSet(mgr ctrl.Manager, opts *StartOptions, log logr.Logger) (metrics.MetricsSet, string, error) { metricsSet, err := metrics.MetricsSetFromEnv() if err != nil { - return err + return "", "", err } log.Info("Using metrics set", "set", metricsSet.String()) @@ -367,34 +475,21 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { if apierrors.IsNotFound(err) { log.Info("WARNING: no configuration found for the SRE metrics set") } else { - return fmt.Errorf("unable to read SRE metrics set configmap: %w", err) + return "", "", fmt.Errorf("unable to read SRE metrics set configmap: %w", err) } } else { if err := metrics.LoadSREMetricsSetConfigurationFromConfigMap(cm); err != nil { - return fmt.Errorf("unable to load SRE metrics configuration: %w", err) + return "", "", fmt.Errorf("unable to load SRE metrics configuration: %w", err) } sreConfigHash = metrics.SREMetricsSetConfigHash(cm) } } + return metricsSet, sreConfigHash, nil +} - // The mgr and therefore the cache is not started yet, thus we have to construct a client that - // directly reads from the api. - apiReadingClient, err := crclient.New(mgr.GetConfig(), crclient.Options{Scheme: hyperapi.Scheme}) - if err != nil { - return fmt.Errorf("failed to construct api reading client: %w", err) - } - - // Reconcile deprecation ValidatingAdmissionPolicy if supported - if err := reconcileDeprecationValidatingAdmissionPolicy(ctx, apiReadingClient, mgmtClusterCaps, log); err != nil { - return fmt.Errorf("failed to reconcile deprecation ValidatingAdmissionPolicy: %w", err) - } - - // Create the registry provider for the release and image metadata providers - registryProvider, err := globalconfig.NewCommonRegistryProvider(ctx, mgmtClusterCaps, apiReadingClient, opts.RegistryOverrides) - +func setupHostedClusterController(ctx context.Context, mgr ctrl.Manager, opts *StartOptions, mgmtClusterCaps *capabilities.ManagementClusterCapabilities, operatorImage string, createOrUpdate upsert.CreateOrUpdateProvider, metricsSet metrics.MetricsSet, sreConfigHash string, registryProvider globalconfig.CommonRegistryProvider, log logr.Logger) error { monitoringDashboards := (os.Getenv("MONITORING_DASHBOARDS") == "1") enableCVOManagementClusterMetricsAccess := (os.Getenv(config.EnableCVOManagementClusterMetricsAccessEnvVar) == "1") - enableEtcdRecovery := os.Getenv(config.EnableEtcdRecoveryEnvVar) == "1" certRotationScale, err := pkiconfig.GetCertRotationScale() @@ -439,42 +534,44 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { } } hcmetrics.CreateAndRegisterHostedClustersMetricsCollector(mgr.GetClient()) + return nil +} - // Since we dropped the validation webhook server we need to ensure this resource doesn't exist - // otherwise it will intercept kas requests and fail. - // TODO (alberto): dropped in 4.14. - if !opts.EnableValidatingWebhook { - validatingWebhookConfiguration := &admissionregistrationv1.ValidatingWebhookConfiguration{ - TypeMeta: metav1.TypeMeta{ - Kind: "ValidatingWebhookConfiguration", - APIVersion: admissionregistrationv1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Namespace: opts.Namespace, - Name: hyperv1.GroupVersion.Group, - }, - } - if err := mgr.GetClient().Delete(ctx, validatingWebhookConfiguration); err != nil { - if !apierrors.IsNotFound(err) { - return err - } +func cleanupLegacyWebhook(ctx context.Context, mgr ctrl.Manager, opts *StartOptions) error { + if opts.EnableValidatingWebhook { + return nil + } + validatingWebhookConfiguration := &admissionregistrationv1.ValidatingWebhookConfiguration{ + TypeMeta: metav1.TypeMeta{ + Kind: "ValidatingWebhookConfiguration", + APIVersion: admissionregistrationv1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Namespace: opts.Namespace, + Name: hyperv1.GroupVersion.Group, + }, + } + if err := mgr.GetClient().Delete(ctx, validatingWebhookConfiguration); err != nil { + if !apierrors.IsNotFound(err) { + return err } } + return nil +} - var ec2Client awsapi.EC2API - - if hyperv1.PlatformType(opts.PrivatePlatform) == hyperv1.AWSPlatform { - awsSession := awsutil.NewSession(ctx, "hypershift-operator", "", "", "", "") - awsConfig := awsutil.NewConfig() - ec2Client = ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { - o.Retryer = awsConfig() - }) +func setupEC2Client(ctx context.Context, opts *StartOptions) awsapi.EC2API { + if hyperv1.PlatformType(opts.PrivatePlatform) != hyperv1.AWSPlatform { + return nil } + awsSession := awsutil.NewSession(ctx, "hypershift-operator", "", "", "", "") + awsConfig := awsutil.NewConfig() + return ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { + o.Retryer = awsConfig() + }) +} - npmetrics.CreateAndRegisterNodePoolsMetricsCollector(mgr.GetClient(), ec2Client) - +func setupNodePoolController(ctx context.Context, mgr ctrl.Manager, opts *StartOptions, operatorImage string, createOrUpdate upsert.CreateOrUpdateProvider, registryProvider globalconfig.CommonRegistryProvider, ec2Client awsapi.EC2API, log logr.Logger) error { var instanceTypeProvider instancetype.Provider - if opts.ScaleFromZeroCreds != "" && opts.ScaleFromZeroProvider != "" { switch strings.ToLower(opts.ScaleFromZeroProvider) { case "aws": @@ -486,7 +583,6 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { instanceTypeProvider = awsinstancetype.NewProvider(scaleFromZeroEC2Client) log.Info("Instance type provider initialized", "provider", opts.ScaleFromZeroProvider) default: - // Should not happen due to validation, but handle gracefully log.Info("WARNING: Unsupported scale-from-zero provider", "provider", opts.ScaleFromZeroProvider) } } @@ -503,21 +599,10 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { }).SetupWithManager(mgr); err != nil { return fmt.Errorf("unable to create controller: %w", err) } + return nil +} - if mgmtClusterCaps.Has(capabilities.CapabilityProxy) { - if err := proxy.Setup(mgr, opts.Namespace, opts.DeploymentName); err != nil { - return fmt.Errorf("failed to set up the proxy controller: %w", err) - } - } - - enableSizeTagging := os.Getenv("ENABLE_SIZE_TAGGING") == "1" - if enableSizeTagging { - if err := hostedclustersizing.SetupWithManager(ctx, mgr, operatorImage, registryProvider.ReleaseProvider, registryProvider.MetadataProvider); err != nil { - return fmt.Errorf("failed to set up hosted cluster sizing operator: %w", err) - } - } - - // Start platform-specific controllers +func setupPlatformControllers(mgr ctrl.Manager, opts *StartOptions, mgmtClusterCaps *capabilities.ManagementClusterCapabilities, createOrUpdate upsert.CreateOrUpdateProvider, log logr.Logger) error { switch hyperv1.PlatformType(opts.PrivatePlatform) { case hyperv1.AWSPlatform: if err := (&aws.AWSEndpointServiceReconciler{ @@ -536,107 +621,117 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { return fmt.Errorf("unable to create GCPPrivateServiceConnect controller: %w", err) } case hyperv1.AzurePlatform: - // ARO HCP uses Swift networking, not Private Link Services - if !azureutil.IsAroHCP() { - azureCloudName := os.Getenv("AZURE_CLOUD_NAME") - if azureCloudName == "" { - azureCloudName = config.DefaultAzureCloud - } - cloudConfig, err := azureutil.GetAzureCloudConfiguration(azureCloudName) - if err != nil { - return fmt.Errorf("failed to get Azure cloud configuration: %w", err) - } + if err := setupAzurePlatformController(mgr, log); err != nil { + return err + } + } + return nil +} - // Environment variables are used because the Azure SDK's DefaultAzureCredential - // and WorkloadIdentityCredential read credentials from env vars by design - // (see azure-sdk-for-go/sdk/azidentity). Two credential modes are supported: - // 1. Workload Identity: The Azure AD Workload Identity webhook injects - // AZURE_CLIENT_ID, AZURE_TENANT_ID, and AZURE_FEDERATED_TOKEN_FILE - // into the pod. DefaultAzureCredential picks these up automatically. - // 2. Credential file: A JSON file with clientId, clientSecret, tenantId, - // subscriptionId is parsed and used with NewClientSecretCredential directly. - var azureCreds azcore.TokenCredential - if plsClientID := os.Getenv("AZURE_PLS_CLIENT_ID"); plsClientID != "" { - log.Info("Using Azure Workload Identity for PLS operations", "clientID", plsClientID) - azureCreds, err = azidentity.NewDefaultAzureCredential( - &azidentity.DefaultAzureCredentialOptions{ - ClientOptions: azcore.ClientOptions{Cloud: cloudConfig}, - }, - ) - if err != nil { - return fmt.Errorf("failed to create Azure workload identity credentials: %w", err) - } - } else if credFile := os.Getenv("AZURE_CREDENTIALS_FILE"); credFile != "" { - raw, err := os.ReadFile(credFile) - if err != nil { - return fmt.Errorf("failed to read Azure credentials file %q: %w", credFile, err) - } - var creds struct { - SubscriptionID string `json:"subscriptionId"` - ClientID string `json:"clientId"` - ClientSecret string `json:"clientSecret"` - TenantID string `json:"tenantId"` - } - if err := yaml.Unmarshal(raw, &creds); err != nil { - return fmt.Errorf("failed to parse Azure credentials file %q: %w", credFile, err) - } - azureCreds, err = azidentity.NewClientSecretCredential( - creds.TenantID, creds.ClientID, creds.ClientSecret, - &azidentity.ClientSecretCredentialOptions{ - ClientOptions: azcore.ClientOptions{Cloud: cloudConfig}, - }, - ) - if err != nil { - return fmt.Errorf("failed to create Azure client secret credentials: %w", err) - } - if os.Getenv("AZURE_SUBSCRIPTION_ID") == "" { - _ = os.Setenv("AZURE_SUBSCRIPTION_ID", creds.SubscriptionID) - } - } else { - return fmt.Errorf("either AZURE_PLS_CLIENT_ID or AZURE_CREDENTIALS_FILE must be set for Azure Private Link Service operations") - } +func setupAzurePlatformController(mgr ctrl.Manager, log logr.Logger) error { + if azureutil.IsAroHCP() { + return nil + } + azureCloudName := os.Getenv("AZURE_CLOUD_NAME") + if azureCloudName == "" { + azureCloudName = config.DefaultAzureCloud + } + cloudConfig, err := azureutil.GetAzureCloudConfiguration(azureCloudName) + if err != nil { + return fmt.Errorf("failed to get Azure cloud configuration: %w", err) + } - azureSubscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") - if azureSubscriptionID == "" { - return fmt.Errorf("AZURE_SUBSCRIPTION_ID environment variable is required for Azure platform") - } - armClientOpts := azureutil.NewARMClientOptions(cloudConfig) - plsClient, err := armnetwork.NewPrivateLinkServicesClient(azureSubscriptionID, azureCreds, armClientOpts) - if err != nil { - return fmt.Errorf("failed to create Azure Private Link Services client: %w", err) - } - lbClient, err := armnetwork.NewLoadBalancersClient(azureSubscriptionID, azureCreds, armClientOpts) - if err != nil { - return fmt.Errorf("failed to create Azure Load Balancers client: %w", err) - } - subnetsClient, err := armnetwork.NewSubnetsClient(azureSubscriptionID, azureCreds, armClientOpts) - if err != nil { - return fmt.Errorf("failed to create Azure Subnets client: %w", err) - } - azureResourceGroup := os.Getenv("AZURE_RESOURCE_GROUP") - if azureResourceGroup == "" { - return fmt.Errorf("AZURE_RESOURCE_GROUP environment variable is required for Azure platform") - } - if err := (&azureplatform.AzurePrivateLinkServiceController{ - Client: mgr.GetClient(), - PrivateLinkServices: plsClient, - LoadBalancers: lbClient, - Subnets: subnetsClient, - ManagementResourceGroup: azureResourceGroup, - }).SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to create AzurePrivateLinkService controller: %w", err) - } + azureCreds, err := resolveAzureCredentials(cloudConfig, log) + if err != nil { + return err + } + + azureSubscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if azureSubscriptionID == "" { + return fmt.Errorf("AZURE_SUBSCRIPTION_ID environment variable is required for Azure platform") + } + armClientOpts := azureutil.NewARMClientOptions(cloudConfig) + plsClient, err := armnetwork.NewPrivateLinkServicesClient(azureSubscriptionID, azureCreds, armClientOpts) + if err != nil { + return fmt.Errorf("failed to create Azure Private Link Services client: %w", err) + } + lbClient, err := armnetwork.NewLoadBalancersClient(azureSubscriptionID, azureCreds, armClientOpts) + if err != nil { + return fmt.Errorf("failed to create Azure Load Balancers client: %w", err) + } + subnetsClient, err := armnetwork.NewSubnetsClient(azureSubscriptionID, azureCreds, armClientOpts) + if err != nil { + return fmt.Errorf("failed to create Azure Subnets client: %w", err) + } + azureResourceGroup := os.Getenv("AZURE_RESOURCE_GROUP") + if azureResourceGroup == "" { + return fmt.Errorf("AZURE_RESOURCE_GROUP environment variable is required for Azure platform") + } + if err := (&azureplatform.AzurePrivateLinkServiceController{ + Client: mgr.GetClient(), + PrivateLinkServices: plsClient, + LoadBalancers: lbClient, + Subnets: subnetsClient, + ManagementResourceGroup: azureResourceGroup, + }).SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to create AzurePrivateLinkService controller: %w", err) + } + return nil +} + +func resolveAzureCredentials(cloudConfig cloud.Configuration, log logr.Logger) (azcore.TokenCredential, error) { + if plsClientID := os.Getenv("AZURE_PLS_CLIENT_ID"); plsClientID != "" { + log.Info("Using Azure Workload Identity for PLS operations", "clientID", plsClientID) + creds, err := azidentity.NewDefaultAzureCredential( + &azidentity.DefaultAzureCredentialOptions{ + ClientOptions: azcore.ClientOptions{Cloud: cloudConfig}, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create Azure workload identity credentials: %w", err) } + return creds, nil + } + + credFile := os.Getenv("AZURE_CREDENTIALS_FILE") + if credFile == "" { + return nil, fmt.Errorf("either AZURE_PLS_CLIENT_ID or AZURE_CREDENTIALS_FILE must be set for Azure Private Link Service operations") } - // Start controller to manage supported versions configmap + raw, err := os.ReadFile(credFile) + if err != nil { + return nil, fmt.Errorf("failed to read Azure credentials file %q: %w", credFile, err) + } + var parsedCreds struct { + SubscriptionID string `json:"subscriptionId"` + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + TenantID string `json:"tenantId"` + } + if err := yaml.Unmarshal(raw, &parsedCreds); err != nil { + return nil, fmt.Errorf("failed to parse Azure credentials file %q: %w", credFile, err) + } + creds, err := azidentity.NewClientSecretCredential( + parsedCreds.TenantID, parsedCreds.ClientID, parsedCreds.ClientSecret, + &azidentity.ClientSecretCredentialOptions{ + ClientOptions: azcore.ClientOptions{Cloud: cloudConfig}, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create Azure client secret credentials: %w", err) + } + if os.Getenv("AZURE_SUBSCRIPTION_ID") == "" { + _ = os.Setenv("AZURE_SUBSCRIPTION_ID", parsedCreds.SubscriptionID) + } + return creds, nil +} + +func setupSupportControllers(mgr ctrl.Manager, opts *StartOptions, mgmtClusterCaps *capabilities.ManagementClusterCapabilities, operatorImage string, createOrUpdate upsert.CreateOrUpdateProvider, registryProvider globalconfig.CommonRegistryProvider, log logr.Logger) error { if err := hosupportedversion.New(mgr.GetClient(), createOrUpdate, opts.Namespace). SetupWithManager(mgr); err != nil { return fmt.Errorf("unable to create supported version controller: %w", err) } - // If enabled, start controller to ensure UWM stack is enabled and configured - // to remotely write telemetry metrics. if opts.EnableUWMTelemetryRemoteWrite { if err := (&uwmtelemetry.Reconciler{ Namespace: opts.Namespace, @@ -683,46 +778,18 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { return fmt.Errorf("unable to create webhook cert controller: %w", err) } } + return nil +} - // Start controllers to manage dedicated request serving isolation +func setupSchedulerControllers(ctx context.Context, mgr ctrl.Manager, opts *StartOptions, createOrUpdate upsert.CreateOrUpdateProvider, enableSizeTagging bool, log logr.Logger) error { if opts.EnableDedicatedRequestServingIsolation && !azureutil.IsAroHCP() { - // Use the new scheduler if we support size tagging on hosted clusters if enableSizeTagging { - hcScheduler := awsscheduler.DedicatedServingComponentSchedulerAndSizer{} - if err := hcScheduler.SetupWithManager(ctx, mgr, createOrUpdate); err != nil { - return fmt.Errorf("unable to create dedicated serving component scheduler/resizer controller: %w", err) - } - placeholderScheduler := awsscheduler.PlaceholderScheduler{} - if err := placeholderScheduler.SetupWithManager(ctx, mgr); err != nil { - return fmt.Errorf("unable to create placeholder scheduler controller: %w", err) - } - autoScaler := awsscheduler.RequestServingNodeAutoscaler{} - if err := autoScaler.SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to create autoscaler controller: %w", err) - } - deScaler := awsscheduler.MachineSetDescaler{} - if err := deScaler.SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to create machine set descaler controller: %w", err) - } - nonRequestServingNodeAutoscaler := awsscheduler.NonRequestServingNodeAutoscaler{} - if err := nonRequestServingNodeAutoscaler.SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to create non request serving node autoscaler controller: %w", err) - } - if err := resourcebasedcpautoscaler.SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to setup control plane autoscaler controller: %w", err) + if err := setupSizeTaggingSchedulers(ctx, mgr, createOrUpdate); err != nil { + return err } } else { - nodeReaper := awsscheduler.DedicatedServingComponentNodeReaper{ - Client: mgr.GetClient(), - } - if err := nodeReaper.SetupWithManager(mgr); err != nil { - return fmt.Errorf("unable to create dedicated serving component node reaper controller: %w", err) - } - hcScheduler := awsscheduler.DedicatedServingComponentScheduler{ - Client: mgr.GetClient(), - } - if err := hcScheduler.SetupWithManager(mgr, createOrUpdate); err != nil { - return fmt.Errorf("unable to create dedicated serving component scheduler controller: %w", err) + if err := setupLegacySchedulers(mgr, createOrUpdate); err != nil { + return err } } } else { @@ -735,8 +802,53 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { return fmt.Errorf("unable to create aro scheduler controller: %w", err) } } + return nil +} + +func setupSizeTaggingSchedulers(ctx context.Context, mgr ctrl.Manager, createOrUpdate upsert.CreateOrUpdateProvider) error { + hcScheduler := awsscheduler.DedicatedServingComponentSchedulerAndSizer{} + if err := hcScheduler.SetupWithManager(ctx, mgr, createOrUpdate); err != nil { + return fmt.Errorf("unable to create dedicated serving component scheduler/resizer controller: %w", err) + } + placeholderScheduler := awsscheduler.PlaceholderScheduler{} + if err := placeholderScheduler.SetupWithManager(ctx, mgr); err != nil { + return fmt.Errorf("unable to create placeholder scheduler controller: %w", err) + } + autoScaler := awsscheduler.RequestServingNodeAutoscaler{} + if err := autoScaler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to create autoscaler controller: %w", err) + } + deScaler := awsscheduler.MachineSetDescaler{} + if err := deScaler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to create machine set descaler controller: %w", err) + } + nonRequestServingNodeAutoscaler := awsscheduler.NonRequestServingNodeAutoscaler{} + if err := nonRequestServingNodeAutoscaler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to create non request serving node autoscaler controller: %w", err) + } + if err := resourcebasedcpautoscaler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to setup control plane autoscaler controller: %w", err) + } + return nil +} + +func setupLegacySchedulers(mgr ctrl.Manager, createOrUpdate upsert.CreateOrUpdateProvider) error { + nodeReaper := awsscheduler.DedicatedServingComponentNodeReaper{ + Client: mgr.GetClient(), + } + if err := nodeReaper.SetupWithManager(mgr); err != nil { + return fmt.Errorf("unable to create dedicated serving component node reaper controller: %w", err) + } + hcScheduler := awsscheduler.DedicatedServingComponentScheduler{ + Client: mgr.GetClient(), + } + if err := hcScheduler.SetupWithManager(mgr, createOrUpdate); err != nil { + return fmt.Errorf("unable to create dedicated serving component scheduler controller: %w", err) + } + return nil +} - // If it exists, block default ingress controller from admitting HCP private routes +func reconcileDefaultIngressController(ctx context.Context, apiReadingClient crclient.Client, log logr.Logger) error { ic := &operatorv1.IngressController{ ObjectMeta: metav1.ObjectMeta{ Name: "default", @@ -767,19 +879,15 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { return fmt.Errorf("failed to reconcile default ingress controller: %w", err) } log.Info("reconciled default ingress controller") - } - if err != nil && apierrors.IsNotFound(err) { + } else if !apierrors.IsNotFound(err) && !meta.IsNoMatchError(err) { return fmt.Errorf("failed to get ingress controller: %w", err) } + return nil +} - if err := setupOperatorInfoMetric(mgr); err != nil { - return fmt.Errorf("failed to setup metrics: %w", err) - } - - // Setup audit log persistence webhooks and controller if enabled +func setupAuditLogPersistence(mgr ctrl.Manager, opts *StartOptions, log logr.Logger) error { enableAuditLogPersistence := os.Getenv("ENABLE_AUDIT_LOG_PERSISTENCE") == "true" if enableAuditLogPersistence && opts.CertDir != "" { - // Register pod mutating webhook hookServer := mgr.GetWebhookServer() hookServer.Register("/mutate-kas-audit-logs", &webhook.Admission{ Handler: auditlogpersistence.NewPodWebhookHandler( @@ -789,7 +897,6 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { ), }) - // Register ConfigMap mutating webhook hookServer.Register("/mutate-kas-audit-log-config", &webhook.Admission{ Handler: auditlogpersistence.NewConfigMapWebhookHandler( mgr.GetLogger(), @@ -802,7 +909,6 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { } if enableAuditLogPersistence { - // Setup snapshot controller if err := auditlogpersistence.SetupSnapshotController(mgr); err != nil { return fmt.Errorf("failed to set up snapshot controller: %w", err) } @@ -810,10 +916,7 @@ func run(ctx context.Context, opts *StartOptions, log logr.Logger) error { } else { log.Info("Audit log persistence feature disabled") } - - // Start the controllers - log.Info("starting manager") - return mgr.Start(ctx) + return nil } // reconcileDeprecationValidatingAdmissionPolicy reconciles the deprecation ValidatingAdmissionPolicy From 0a7a1b9cc64c4f3a58bf93aeb848ed544e6742c3 Mon Sep 17 00:00:00 2001 From: Bryan Cox Date: Thu, 23 Apr 2026 10:21:37 -0400 Subject: [PATCH 3/7] refactor(control-plane-operator): reduce cyclomatic complexity in CPO controllers - Extract helper functions from AWS private link controller to reduce reconcileAWSEndpointService complexity; fix nil dereference on empty VpcEndpoints and add namespace scoping for service list - Fix BasicAuth TLS cert/key mount collision in OAuth IDP conversion (both v1 and v2); split TLS assertions for independent validation - Refactor HCP controller: extract deleteAWSDefaultSecurityGroup, reconcileKubeadminPassword, and other helpers; fix error code handling with switch statement - Fix error masking in HCCO resources controller: don't bypass aggregate error on recovery requeue; continue reconciling when registry config GET fails - Fix error text in OAuth API server helper from "openshift apiserver" to "openshift oauth apiserver" - Simplify PATCH assertion condition in resources test - Add cloud cleanup positive assertions - Extract helpers from ignition server local provider - Extract helpers from metrics proxy scrape config and Azure private link controller - Add behavior-driven unit tests for all extracted functions Signed-off-by: Bryan Cox Commit-Message-Assisted-by: Claude (via Claude Code) --- .../awsprivatelink_controller.go | 361 +++-- .../awsprivatelink_controller_test.go | 626 ++++++++ .../azureprivatelinkservice/controller.go | 295 ++-- .../controller_test.go | 329 +++- .../hostedcontrolplane_controller.go | 1191 +++++++------- .../hostedcontrolplane_controller_test.go | 1412 +++++++++++++++++ .../hostedcontrolplane/oauth/idp_convert.go | 498 +++--- .../oauth/idp_convert_test.go | 715 +++++++++ .../v2/metrics_proxy/deployment_test.go | 211 +-- .../v2/metrics_proxy/scrape_config.go | 276 ++-- .../v2/metrics_proxy/scrape_config_test.go | 552 +++++++ .../v2/oauth/idp_convert.go | 549 ++++--- .../v2/oauth/idp_convert_test.go | 715 +++++++++ .../controllers/resources/resources.go | 409 ++--- .../controllers/resources/resources_test.go | 854 ++++++++-- control-plane-operator/main.go | 554 ++++--- .../controllers/local_ignitionprovider.go | 862 +++++----- .../local_ignitionprovider_test.go | 993 ++++++++++++ 18 files changed, 8591 insertions(+), 2811 deletions(-) diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go index 1dba7b85b32..c93e7adc6f3 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go @@ -603,173 +603,201 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C return err } + endpointID, endpointDNSEntries, err := r.ensureVPCEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) + if err != nil { + return err + } + + if len(endpointDNSEntries) == 0 { + log.Info("endpoint has no DNS entries, skipping DNS record creation", "endpointID", endpointID) + return nil + } + + fqdns, zoneID, err := r.reconcileEndpointDNSRecords(ctx, route53Client, awsEndpointService, hcp, endpointDNSEntries, log) + if err != nil { + return err + } + + awsEndpointService.Status.DNSNames = fqdns + awsEndpointService.Status.DNSZoneID = zoneID + + return r.reconcileExternalNameServices(ctx, hcp, endpointDNSEntries, log) +} + +func (r *AWSEndpointServiceReconciler) ensureVPCEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { endpointID := awsEndpointService.Status.EndpointID - var endpointDNSEntries []ec2types.DnsEntry if endpointID != "" { - // check if Endpoint exists in AWS - output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ - VpcEndpointIds: []string{endpointID}, - }) - if err != nil { - log.Error(err, "failed to describe vpc endpoint", "endpointID", endpointID) - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - if apiErr.ErrorCode() == "InvalidVpcEndpointId.NotFound" { - // clear the EndpointID so a new Endpoint is created on the requeue - awsEndpointService.Status.EndpointID = "" - return fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) - } else { - return errors.New(apiErr.ErrorCode()) - } + return r.reconcileExistingEndpoint(ctx, ec2Client, awsEndpointService, endpointID, log) + } + return r.reconcileNewEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) +} + +func (r *AWSEndpointServiceReconciler) reconcileExistingEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, endpointID string, log logr.Logger) (string, []ec2types.DnsEntry, error) { + output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ + VpcEndpointIds: []string{endpointID}, + }) + if err != nil { + log.Error(err, "failed to describe vpc endpoint", "endpointID", endpointID) + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + if apiErr.ErrorCode() == "InvalidVpcEndpointId.NotFound" { + awsEndpointService.Status.EndpointID = "" + return "", nil, fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) + } else { + return "", nil, errors.New(apiErr.ErrorCode()) } - return err } + return "", nil, err + } - if aws.ToString(output.VpcEndpoints[0].ServiceName) != awsEndpointService.Status.EndpointServiceName { - log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(output.VpcEndpoints[0].ServiceName), "WantedVPCEndpointService", awsEndpointService.Status.EndpointServiceName) - if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ - VpcEndpointIds: []string{aws.ToString(output.VpcEndpoints[0].VpcEndpointId)}, - }); err != nil { - log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(output.VpcEndpoints[0].VpcEndpointId)) - return fmt.Errorf("error deleting AWSEndpoint: %w", err) - } + if len(output.VpcEndpoints) == 0 { + awsEndpointService.Status.EndpointID = "" + return "", nil, fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) + } - // Once the VPC Endpoint is deleted, we need to send an error in order to reexecute the reconcilliation - return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(output.VpcEndpoints[0].ServiceName)) - } + if err := deleteEndpointIfWrongService(ctx, ec2Client, output.VpcEndpoints[0], awsEndpointService.Status.EndpointServiceName, log); err != nil { + return "", nil, err + } - if len(output.VpcEndpoints) == 0 { - // This should not happen but just in case - // clear the EndpointID so a new Endpoint is created on the requeue - awsEndpointService.Status.EndpointID = "" - return fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) - } - log.Info("endpoint exists", "endpointID", endpointID) - endpointDNSEntries = output.VpcEndpoints[0].DnsEntries + log.Info("endpoint exists", "endpointID", endpointID) - // Ensure endpoint has the right subnets. - addedSubnet, removedSubnet := diffIDs(awsEndpointService.Spec.SubnetIDs, output.VpcEndpoints[0].SubnetIds) + if err := modifyEndpointIfNeeded(ctx, ec2Client, awsEndpointService, output.VpcEndpoints[0], endpointID, log); err != nil { + return "", nil, err + } - // Ensure endpoint has the right SG. - existingSG := make([]string, 0) - for _, group := range output.VpcEndpoints[0].Groups { - existingSG = append(existingSG, aws.ToString(group.GroupId)) - } - addedSG, _ := diffIDs([]string{awsEndpointService.Status.SecurityGroupID}, existingSG) - - if addedSubnet != nil || removedSubnet != nil || addedSG != nil { - log.Info("endpoint subnets or security groups have changed") - _, err := ec2Client.ModifyVpcEndpoint(ctx, &ec2.ModifyVpcEndpointInput{ - VpcEndpointId: aws.String(endpointID), - AddSubnetIds: addedSubnet, - RemoveSubnetIds: removedSubnet, - AddSecurityGroupIds: addedSG, - }) - if err != nil { - log.Error(err, "failed to modify vpc endpoint", "id", endpointID, "addSubnets", addedSubnet, "removeSubnets", removedSubnet, "addSG", addedSG) - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to modify vpc endpoint") - return fmt.Errorf("failed to modify vpc endpoint: %s", msg) - } - log.Info("endpoint subnets updated") - } else { - log.Info("endpoint subnets are unchanged") + return endpointID, output.VpcEndpoints[0].DnsEntries, nil +} + +func deleteEndpointIfWrongService(ctx context.Context, ec2Client awsapi.EC2API, endpoint ec2types.VpcEndpoint, expectedServiceName string, log logr.Logger) error { + if aws.ToString(endpoint.ServiceName) == expectedServiceName { + return nil + } + log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(endpoint.ServiceName), "WantedVPCEndpointService", expectedServiceName) + if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ + VpcEndpointIds: []string{aws.ToString(endpoint.VpcEndpointId)}, + }); err != nil { + log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(endpoint.VpcEndpointId)) + return fmt.Errorf("error deleting AWSEndpoint: %w", err) + } + return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(endpoint.ServiceName)) +} + +func modifyEndpointIfNeeded(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, endpoint ec2types.VpcEndpoint, endpointID string, log logr.Logger) error { + addedSubnet, removedSubnet := diffIDs(awsEndpointService.Spec.SubnetIDs, endpoint.SubnetIds) + + existingSG := make([]string, 0) + for _, group := range endpoint.Groups { + existingSG = append(existingSG, aws.ToString(group.GroupId)) + } + addedSG, _ := diffIDs([]string{awsEndpointService.Status.SecurityGroupID}, existingSG) + + if addedSubnet == nil && removedSubnet == nil && addedSG == nil { + log.Info("endpoint subnets are unchanged") + return nil + } + + log.Info("endpoint subnets or security groups have changed") + _, err := ec2Client.ModifyVpcEndpoint(ctx, &ec2.ModifyVpcEndpointInput{ + VpcEndpointId: aws.String(endpointID), + AddSubnetIds: addedSubnet, + RemoveSubnetIds: removedSubnet, + AddSecurityGroupIds: addedSG, + }) + if err != nil { + log.Error(err, "failed to modify vpc endpoint", "id", endpointID, "addSubnets", addedSubnet, "removeSubnets", removedSubnet, "addSG", addedSG) + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } - } else { - if !hasAWSConfig(&hcp.Spec.Platform) { - return fmt.Errorf("AWS platform information not provided in HostedControlPlane") + log.Error(err, "failed to modify vpc endpoint") + return fmt.Errorf("failed to modify vpc endpoint: %s", msg) + } + log.Info("endpoint subnets updated") + return nil +} + +func (r *AWSEndpointServiceReconciler) reconcileNewEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { + if !hasAWSConfig(&hcp.Spec.Platform) { + return "", nil, fmt.Errorf("AWS platform information not provided in HostedControlPlane") + } + + output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ + Filters: apiTagToEC2Filter(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), + }) + if err != nil { + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } + log.Error(err, "failed to describe vpc endpoints") + return "", nil, fmt.Errorf("failed to describe vpc endpoints: %s", msg) + } - // Verify there is not already an Endpoint that we can adopt - // This can happen if we have a stale status on AWSEndpointService or encountered - // an error updating the AWSEndpointService on the previous reconcile - output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ - Filters: apiTagToEC2Filter(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), - }) - if err != nil { - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to describe vpc endpoints") - return fmt.Errorf("failed to describe vpc endpoints: %s", msg) + if len(output.VpcEndpoints) != 0 { + if err := deleteEndpointIfWrongService(ctx, ec2Client, output.VpcEndpoints[0], awsEndpointService.Status.EndpointServiceName, log); err != nil { + return "", nil, err } - if len(output.VpcEndpoints) != 0 { - if aws.ToString(output.VpcEndpoints[0].ServiceName) != awsEndpointService.Status.EndpointServiceName { - log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(output.VpcEndpoints[0].ServiceName), "WantedVPCEndpointService", awsEndpointService.Status.EndpointServiceName) - if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ - VpcEndpointIds: []string{aws.ToString(output.VpcEndpoints[0].VpcEndpointId)}, - }); err != nil { - log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(output.VpcEndpoints[0].VpcEndpointId)) - return fmt.Errorf("error deleting AWSEndpoint: %w", err) - } + endpointID := aws.ToString(output.VpcEndpoints[0].VpcEndpointId) + log.Info("endpoint already exists, adopting", "endpointID", endpointID) + awsEndpointService.Status.EndpointID = endpointID + return endpointID, output.VpcEndpoints[0].DnsEntries, nil + } - // Once the VPC Endpoint is deleted, we need to send an error in order to reexecute the reconcilliation - return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(output.VpcEndpoints[0].ServiceName)) - } - endpointID = aws.ToString(output.VpcEndpoints[0].VpcEndpointId) - log.Info("endpoint already exists, adopting", "endpointID", endpointID) - awsEndpointService.Status.EndpointID = endpointID - endpointDNSEntries = output.VpcEndpoints[0].DnsEntries - } else { - log.Info("endpoint does not already exist") - - if awsEndpointService.Status.SecurityGroupID == "" { - return fmt.Errorf("security group ID doesn't exist yet for the endpoint to use") - } - output, err := ec2Client.CreateVpcEndpoint(ctx, &ec2.CreateVpcEndpointInput{ - SecurityGroupIds: []string{awsEndpointService.Status.SecurityGroupID}, - ServiceName: aws.String(awsEndpointService.Status.EndpointServiceName), - VpcId: aws.String(hcp.Spec.Platform.AWS.CloudProviderConfig.VPC), - VpcEndpointType: ec2types.VpcEndpointTypeInterface, - SubnetIds: awsEndpointService.Spec.SubnetIDs, - TagSpecifications: []ec2types.TagSpecification{{ - ResourceType: ec2types.ResourceTypeVpcEndpoint, - Tags: apiTagToEC2Tag(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), - }}, - }) - if err != nil { - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to create vpc endpoint") - return fmt.Errorf("failed to create vpc endpoint: %s", msg) - } - if output == nil || output.VpcEndpoint == nil { - return fmt.Errorf("CreateVpcEndpoint output is nil") - } + return r.createVPCEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) +} - endpointID = aws.ToString(output.VpcEndpoint.VpcEndpointId) - log.Info("endpoint created", "endpointID", endpointID) - awsEndpointService.Status.EndpointID = endpointID - endpointDNSEntries = output.VpcEndpoint.DnsEntries +func (r *AWSEndpointServiceReconciler) createVPCEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { + log.Info("endpoint does not already exist") + + if awsEndpointService.Status.SecurityGroupID == "" { + return "", nil, fmt.Errorf("security group ID doesn't exist yet for the endpoint to use") + } + output, err := ec2Client.CreateVpcEndpoint(ctx, &ec2.CreateVpcEndpointInput{ + SecurityGroupIds: []string{awsEndpointService.Status.SecurityGroupID}, + ServiceName: aws.String(awsEndpointService.Status.EndpointServiceName), + VpcId: aws.String(hcp.Spec.Platform.AWS.CloudProviderConfig.VPC), + VpcEndpointType: ec2types.VpcEndpointTypeInterface, + SubnetIds: awsEndpointService.Spec.SubnetIDs, + TagSpecifications: []ec2types.TagSpecification{{ + ResourceType: ec2types.ResourceTypeVpcEndpoint, + Tags: apiTagToEC2Tag(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), + }}, + }) + if err != nil { + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } + log.Error(err, "failed to create vpc endpoint") + return "", nil, fmt.Errorf("failed to create vpc endpoint: %s", msg) } - - if len(endpointDNSEntries) == 0 { - log.Info("endpoint has no DNS entries, skipping DNS record creation", "endpointID", endpointID) - return nil + if output == nil || output.VpcEndpoint == nil { + return "", nil, fmt.Errorf("CreateVpcEndpoint output is nil") } + endpointID := aws.ToString(output.VpcEndpoint.VpcEndpointId) + log.Info("endpoint created", "endpointID", endpointID) + awsEndpointService.Status.EndpointID = endpointID + return endpointID, output.VpcEndpoint.DnsEntries, nil +} + +func (r *AWSEndpointServiceReconciler) reconcileEndpointDNSRecords(ctx context.Context, route53Client awsapi.ROUTE53API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, endpointDNSEntries []ec2types.DnsEntry, log logr.Logger) ([]string, string, error) { recordNames := recordsForService(awsEndpointService, hcp) if len(recordNames) == 0 { log.Info("WARNING: no mapping from AWSEndpointService to DNS") - return nil + return nil, "", nil } - zoneName := zoneName(hcp.Name) + zn := zoneName(hcp.Name) var zoneID string if r.awsClientBuilder.getLocalHostedZoneID() == "" { - zoneID, err = lookupZoneID(ctx, route53Client, zoneName) + var err error + zoneID, err = lookupZoneID(ctx, route53Client, zn) if err != nil { - return err + return nil, "", err } r.awsClientBuilder.setLocalHostedZoneID(zoneID) } else { @@ -778,21 +806,23 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C var fqdns []string for _, recordName := range recordNames { - fqdn := fmt.Sprintf("%s.%s", recordName, zoneName) + fqdn := fmt.Sprintf("%s.%s", recordName, zn) fqdns = append(fqdns, fqdn) - err = CreateRecord(ctx, route53Client, zoneID, fqdn, aws.ToString(endpointDNSEntries[0].DnsName), route53types.RRTypeCname) + err := CreateRecord(ctx, route53Client, zoneID, fqdn, aws.ToString(endpointDNSEntries[0].DnsName), route53types.RRTypeCname) if err != nil { - return err + return nil, "", err } log.Info("DNS record created", "fqdn", fqdn) } - awsEndpointService.Status.DNSNames = fqdns - awsEndpointService.Status.DNSZoneID = zoneID + return fqdns, zoneID, nil +} + +func (r *AWSEndpointServiceReconciler) reconcileExternalNameServices(ctx context.Context, hcp *hyperv1.HostedControlPlane, endpointDNSEntries []ec2types.DnsEntry, log logr.Logger) error { + isPublic := netutil.IsPublicHCP(hcp) + externalNames := hcpExternalNames(hcp) - if isPublic, externalNames := netutil.IsPublicHCP(hcp), hcpExternalNames(hcp); !isPublic && len(externalNames) > 0 { - // only if not public and external names are configured, create services of type ExternalName so external-dns - // can create records for them + if !isPublic && len(externalNames) > 0 { var errs []error for svcType, externalName := range externalNames { var svc *corev1.Service @@ -812,27 +842,26 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C if len(errs) > 0 { return fmt.Errorf("failed to create external services for private endpoints: %w", utilerrors.NewAggregate(errs)) } - } else { - // if the cluster is public, ensure that any ExternalName services are removed - privateExternalServices := &corev1.ServiceList{} - if err := r.List(ctx, privateExternalServices, client.HasLabels{externalPrivateServiceLabel}); err != nil { - return fmt.Errorf("cannot list private external services: %w", err) - } - if len(privateExternalServices.Items) > 0 { - log.Info("Removing private external services", "count", len(privateExternalServices.Items)) - var errs []error - for i := range privateExternalServices.Items { - svc := &privateExternalServices.Items[i] - if err := r.Delete(ctx, svc); err != nil { - errs = append(errs, fmt.Errorf("failed to delete private external service %s: %w", svc.Name, err)) - } - } - if len(errs) > 0 { - return utilerrors.NewAggregate(errs) + return nil + } + + privateExternalServices := &corev1.ServiceList{} + if err := r.List(ctx, privateExternalServices, client.InNamespace(hcp.Namespace), client.HasLabels{externalPrivateServiceLabel}); err != nil { + return fmt.Errorf("cannot list private external services: %w", err) + } + if len(privateExternalServices.Items) > 0 { + log.Info("Removing private external services", "count", len(privateExternalServices.Items)) + var errs []error + for i := range privateExternalServices.Items { + svc := &privateExternalServices.Items[i] + if err := r.Delete(ctx, svc); err != nil { + errs = append(errs, fmt.Errorf("failed to delete private external service %s: %w", svc.Name, err)) } } + if len(errs) > 0 { + return utilerrors.NewAggregate(errs) + } } - return nil } diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go index 539a96f3248..efddc4d4286 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go @@ -20,6 +20,7 @@ import ( route53types "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/aws/smithy-go" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -769,6 +770,631 @@ func TestDeleteSecurityGroup(t *testing.T) { // // A proper fix requires persisting the SharedVPC role ARNs in the AWSEndpointService // status so the deletion path can authenticate independently of the HCP. +func TestHasAWSConfig(t *testing.T) { + tests := []struct { + name string + platform hyperv1.PlatformSpec + expected bool + }{ + { + name: "When all AWS config fields are present, it should return true", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: &hyperv1.AWSResourceReference{ + ID: aws.String("subnet-123"), + }, + }, + }, + }, + expected: true, + }, + { + name: "When platform type is not AWS, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AzurePlatform, + }, + expected: false, + }, + { + name: "When AWS spec is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: nil, + }, + expected: false, + }, + { + name: "When CloudProviderConfig is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: nil, + }, + }, + expected: false, + }, + { + name: "When Subnet is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: nil, + }, + }, + }, + expected: false, + }, + { + name: "When Subnet.ID is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: &hyperv1.AWSResourceReference{ + ID: nil, + }, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(hasAWSConfig(&tt.platform)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointPort(t *testing.T) { + tests := []struct { + name string + svcName string + expected int32 + }{ + { + name: "When service is kube-apiserver-private, it should return 6443", + svcName: "kube-apiserver-private", + expected: 6443, + }, + { + name: "When service is private-router, it should return 443", + svcName: "private-router", + expected: 443, + }, + { + name: "When service is an unknown name, it should return 443", + svcName: "some-other-service", + expected: 443, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + aes := &hyperv1.AWSEndpointService{ + ObjectMeta: metav1.ObjectMeta{Name: tt.svcName}, + } + g.Expect(vpcEndpointPort(aes)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointSecurityGroupName(t *testing.T) { + tests := []struct { + name string + infraID string + endpointName string + expected string + }{ + { + name: "When given infraID and endpoint name, it should format correctly", + infraID: "my-infra", + endpointName: "kube-apiserver-private", + expected: "my-infra-vpce-kube-apiserver-private", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(vpcEndpointSecurityGroupName(tt.infraID, tt.endpointName)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointSecurityGroupFilter(t *testing.T) { + tests := []struct { + name string + infraID string + endpointName string + }{ + { + name: "When given infraID and endpoint name, it should return cluster tag and name tag filters", + infraID: "test-infra", + endpointName: "kube-apiserver-private", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + filters := vpcEndpointSecurityGroupFilter(tt.infraID, tt.endpointName) + g.Expect(filters).To(HaveLen(2)) + g.Expect(aws.ToString(filters[0].Name)).To(Equal("tag:kubernetes.io/cluster/test-infra")) + g.Expect(filters[0].Values).To(Equal([]string{"owned"})) + g.Expect(aws.ToString(filters[1].Name)).To(Equal("tag:Name")) + g.Expect(filters[1].Values).To(Equal([]string{"test-infra-vpce-kube-apiserver-private"})) + }) + } +} + +func TestApiTagToEC2Tag(t *testing.T) { + tests := []struct { + name string + svcName string + tags []hyperv1.AWSResourceTag + expected []ec2types.Tag + }{ + { + name: "When no resource tags are provided, it should return only the AWSEndpointService tag", + svcName: "my-svc", + tags: nil, + expected: []ec2types.Tag{ + {Key: aws.String("AWSEndpointService"), Value: aws.String("my-svc")}, + }, + }, + { + name: "When resource tags are provided, it should include them plus the AWSEndpointService tag", + svcName: "my-svc", + tags: []hyperv1.AWSResourceTag{ + {Key: "env", Value: "prod"}, + {Key: "team", Value: "platform"}, + }, + expected: []ec2types.Tag{ + {Key: aws.String("env"), Value: aws.String("prod")}, + {Key: aws.String("team"), Value: aws.String("platform")}, + {Key: aws.String("AWSEndpointService"), Value: aws.String("my-svc")}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := apiTagToEC2Tag(tt.svcName, tt.tags) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestApiTagToEC2Filter(t *testing.T) { + tests := []struct { + name string + svcName string + tags []hyperv1.AWSResourceTag + expected []ec2types.Filter + }{ + { + name: "When no resource tags are provided, it should return only the AWSEndpointService filter", + svcName: "my-svc", + tags: nil, + expected: []ec2types.Filter{ + {Name: aws.String("tag:AWSEndpointService"), Values: []string{"my-svc"}}, + }, + }, + { + name: "When resource tags are provided, it should include them as tag filters plus the AWSEndpointService filter", + svcName: "my-svc", + tags: []hyperv1.AWSResourceTag{ + {Key: "env", Value: "prod"}, + }, + expected: []ec2types.Filter{ + {Name: aws.String("tag:env"), Values: []string{"prod"}}, + {Name: aws.String("tag:AWSEndpointService"), Values: []string{"my-svc"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := apiTagToEC2Filter(tt.svcName, tt.tags) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestZoneName(t *testing.T) { + tests := []struct { + name string + hcpName string + expected string + }{ + { + name: "When given an HCP name, it should append the hypershift.local suffix", + hcpName: "my-cluster", + expected: "my-cluster.hypershift.local", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(zoneName(tt.hcpName)).To(Equal(tt.expected)) + }) + } +} + +func TestRouterZoneName(t *testing.T) { + tests := []struct { + name string + hcpName string + expected string + }{ + { + name: "When given an HCP name, it should prepend apps and append the hypershift.local suffix", + hcpName: "my-cluster", + expected: "apps.my-cluster.hypershift.local", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(RouterZoneName(tt.hcpName)).To(Equal(tt.expected)) + }) + } +} + +func TestControllerName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "When given a name, it should append the observer suffix", + input: "kube-apiserver", + expected: "kube-apiserver-observer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(ControllerName(tt.input)).To(Equal(tt.expected)) + }) + } +} + +func TestNameMapper(t *testing.T) { + tests := []struct { + name string + watchedNames []string + incomingName string + incomingNS string + expectRequests int + }{ + { + name: "When incoming object name matches a watched name, it should return a reconcile request", + watchedNames: []string{"kube-apiserver-private", "private-router"}, + incomingName: "kube-apiserver-private", + incomingNS: "test-ns", + expectRequests: 1, + }, + { + name: "When incoming object name does not match any watched name, it should return nil", + watchedNames: []string{"kube-apiserver-private"}, + incomingName: "other-service", + incomingNS: "test-ns", + expectRequests: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mapFn := nameMapper(tt.watchedNames) + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: tt.incomingName, + Namespace: tt.incomingNS, + }, + } + requests := mapFn(t.Context(), svc) + g.Expect(requests).To(HaveLen(tt.expectRequests)) + if tt.expectRequests > 0 { + g.Expect(requests[0].Name).To(Equal(tt.incomingName)) + g.Expect(requests[0].Namespace).To(Equal(tt.incomingNS)) + } + }) + } +} + +func TestHCPExternalNames(t *testing.T) { + tests := []struct { + name string + hcp *hyperv1.HostedControlPlane + expected map[string]string + }{ + { + name: "When API and OAuth both have Route strategies with hostnames, it should return both", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: "api.example.com"}, + }, + }, + { + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: "oauth.example.com"}, + }, + }, + }, + }, + }, + expected: map[string]string{ + "api": "api.example.com", + "oauth": "oauth.example.com", + }, + }, + { + name: "When no Route strategies are configured, it should return empty map", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.LoadBalancer, + }, + }, + }, + }, + }, + expected: map[string]string{}, + }, + { + name: "When Route strategy has no hostname, it should not include that entry", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: ""}, + }, + }, + }, + }, + }, + expected: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := hcpExternalNames(tt.hcp) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestReconcileExternalService(t *testing.T) { + tests := []struct { + name string + hostName string + targetCName string + expectedLabels map[string]string + }{ + { + name: "When reconciling an external service, it should set type, labels, annotations, and external name", + hostName: "api.example.com", + targetCName: "vpce-abc.vpce-svc.us-east-1.vpce.amazonaws.com", + expectedLabels: map[string]string{ + externalPrivateServiceLabel: "true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + hcp := &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: "test-ns", + UID: "test-uid", + }, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver-private-external", + Namespace: "test-ns", + }, + } + + err := reconcileExternalService(svc, hcp, tt.hostName, tt.targetCName) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(svc.Spec.Type).To(Equal(corev1.ServiceTypeExternalName)) + g.Expect(svc.Spec.ExternalName).To(Equal(tt.targetCName)) + g.Expect(svc.Labels[externalPrivateServiceLabel]).To(Equal("true")) + g.Expect(svc.Annotations[hyperv1.ExternalDNSHostnameAnnotation]).To(Equal(tt.hostName)) + g.Expect(svc.OwnerReferences).To(HaveLen(1)) + g.Expect(svc.OwnerReferences[0].Name).To(Equal("test-hcp")) + }) + } +} + +func TestDeleteEndpointIfWrongService(t *testing.T) { + tests := []struct { + name string + endpointServiceName string + expectedServiceName string + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + expectError bool + }{ + { + name: "When endpoint points to the correct service, it should return nil", + endpointServiceName: "com.amazonaws.vpce-svc-abc", + expectedServiceName: "com.amazonaws.vpce-svc-abc", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + return awsapi.NewMockEC2API(mockCtrl) + }, + expectError: false, + }, + { + name: "When endpoint points to wrong service, it should delete and return error", + endpointServiceName: "com.amazonaws.vpce-svc-wrong", + expectedServiceName: "com.amazonaws.vpce-svc-correct", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DeleteVpcEndpoints(gomock.Any(), gomock.Any()).Return(&ec2v2.DeleteVpcEndpointsOutput{}, nil) + return m + }, + expectError: true, + }, + { + name: "When endpoint points to wrong service and delete fails, it should return the delete error", + endpointServiceName: "com.amazonaws.vpce-svc-wrong", + expectedServiceName: "com.amazonaws.vpce-svc-correct", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DeleteVpcEndpoints(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("access denied")) + return m + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + endpoint := ec2types.VpcEndpoint{ + VpcEndpointId: aws.String("vpce-123"), + ServiceName: aws.String(tt.endpointServiceName), + } + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + err := deleteEndpointIfWrongService(ctx, mockEC2, endpoint, tt.expectedServiceName, ctrl.Log.WithName("test")) + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + }) + } +} + +func TestModifyEndpointIfNeeded(t *testing.T) { + tests := []struct { + name string + specSubnetIDs []string + endpointSubnetIDs []string + specSecurityGroupID string + endpointGroups []ec2types.SecurityGroupIdentifier + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + expectError bool + }{ + { + name: "When subnets and security groups are unchanged, it should not modify", + specSubnetIDs: []string{"subnet-1", "subnet-2"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + return awsapi.NewMockEC2API(mockCtrl) + }, + expectError: false, + }, + { + name: "When subnets have changed, it should call ModifyVpcEndpoint", + specSubnetIDs: []string{"subnet-1", "subnet-3"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(&ec2v2.ModifyVpcEndpointOutput{}, nil) + return m + }, + expectError: false, + }, + { + name: "When security group needs adding, it should call ModifyVpcEndpoint", + specSubnetIDs: []string{"subnet-1"}, + endpointSubnetIDs: []string{"subnet-1"}, + specSecurityGroupID: "sg-new", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-old")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(&ec2v2.ModifyVpcEndpointOutput{}, nil) + return m + }, + expectError: false, + }, + { + name: "When ModifyVpcEndpoint fails, it should return error", + specSubnetIDs: []string{"subnet-1", "subnet-3"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("throttling")) + return m + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + aes := &hyperv1.AWSEndpointService{ + Spec: hyperv1.AWSEndpointServiceSpec{ + SubnetIDs: tt.specSubnetIDs, + }, + Status: hyperv1.AWSEndpointServiceStatus{ + SecurityGroupID: tt.specSecurityGroupID, + }, + } + + endpoint := ec2types.VpcEndpoint{ + SubnetIds: tt.endpointSubnetIDs, + Groups: tt.endpointGroups, + } + + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + err := modifyEndpointIfNeeded(ctx, mockEC2, aes, endpoint, "vpce-123", ctrl.Log.WithName("test")) + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + }) + } +} + func TestReconcileDeletionSharedVPC(t *testing.T) { now := metav1.NewTime(time.Now()) diff --git a/control-plane-operator/controllers/azureprivatelinkservice/controller.go b/control-plane-operator/controllers/azureprivatelinkservice/controller.go index a43b09b57be..13cd2255fdc 100644 --- a/control-plane-operator/controllers/azureprivatelinkservice/controller.go +++ b/control-plane-operator/controllers/azureprivatelinkservice/controller.go @@ -860,191 +860,182 @@ func (r *AzurePrivateLinkServiceReconciler) handleAzureError(ctx context.Context // CPO finalizer has completed PE cleanup and removed its finalizer from the CR. func (r *AzurePrivateLinkServiceReconciler) reconcileDelete(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, log logr.Logger) error { resourceGroup := azPLS.Spec.ResourceGroupName - - // 1. Delete DNS resources using the zone name persisted in status. - // This avoids a dependency on the HostedControlPlane during deletion, which may - // already be torn down or unavailable when the finalizer runs. dnsZoneName := azPLS.Status.DNSZoneName + if dnsZoneName != "" { - // Delete both A records (KAS apex and wildcard apps) - for _, recordName := range []string{kasARecordName, appsARecordName} { - log.Info("Deleting A record", "record", recordName, "zone", dnsZoneName) - deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) - if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, dnsZoneName, armprivatedns.RecordTypeA, recordName, nil); err != nil { - cancel() - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete A record %q: %w", recordName, err) - } - } - cancel() + if err := r.deleteDNSResources(ctx, resourceGroup, dnsZoneName, azPLS.Name, log); err != nil { + return err } + } else { + log.V(1).Info("DNSZoneName not set in status, skipping DNS cleanup") + } - // 2. Delete VNet link (must be deleted before zone) - linkName := vnetLinkName(azPLS.Name) - log.Info("Deleting VNet Link", "name", linkName) - deleteCtx2, cancel2 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel2() - linkPoller, err := r.VirtualNetworkLinks.BeginDelete(deleteCtx2, resourceGroup, dnsZoneName, linkName, nil) - if err != nil { + if err := r.deleteBaseDomainResources(ctx, azPLS, resourceGroup, dnsZoneName, log); err != nil { + return err + } + + return r.deletePrivateEndpoint(ctx, resourceGroup, azPLS.Name, azPLS.Status.PrivateEndpointID, log) +} + +func (r *AzurePrivateLinkServiceReconciler) deleteDNSResources(ctx context.Context, resourceGroup, dnsZoneName, crName string, log logr.Logger) error { + for _, recordName := range []string{kasARecordName, appsARecordName} { + log.Info("Deleting A record", "record", recordName, "zone", dnsZoneName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, dnsZoneName, armprivatedns.RecordTypeA, recordName, nil); err != nil { + cancel() if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting VNet Link: %w", err) - } - } else if linkPoller != nil { - linkPollCtx, linkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer linkPollCancel() - - if _, err := linkPoller.PollUntilDone(linkPollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete VNet Link: %w", err) - } + return fmt.Errorf("failed to delete A record %q: %w", recordName, err) } } + cancel() + } - // 3. Delete Private DNS Zone - log.Info("Deleting Private DNS Zone", "zone", dnsZoneName) - deleteCtx3, cancel3 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel3() - zonePoller, err := r.PrivateDNSZones.BeginDelete(deleteCtx3, resourceGroup, dnsZoneName, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting Private DNS Zone: %w", err) - } - } else if zonePoller != nil { - zonePollCtx, zonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer zonePollCancel() - - if _, err := zonePoller.PollUntilDone(zonePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete Private DNS Zone: %w", err) - } - } + if err := r.deleteVNetLink(ctx, resourceGroup, dnsZoneName, vnetLinkName(crName), log); err != nil { + return err + } + + return r.deleteDNSZone(ctx, resourceGroup, dnsZoneName, log) +} + +func (r *AzurePrivateLinkServiceReconciler) deleteVNetLink(ctx context.Context, resourceGroup, zoneName, linkName string, log logr.Logger) error { + log.Info("Deleting VNet Link", "name", linkName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + linkPoller, err := r.VirtualNetworkLinks.BeginDelete(deleteCtx, resourceGroup, zoneName, linkName, nil) + if err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to begin deleting VNet Link: %w", err) } - } else { - log.V(1).Info("DNSZoneName not set in status, skipping DNS cleanup") + return nil } + if linkPoller == nil { + return nil + } + linkPollCtx, linkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer linkPollCancel() - // 4. Delete base domain DNS resources (A records, VNet link, zone) - baseDomain := azPLS.Spec.BaseDomain - if baseDomain != "" { - // Extract cluster name from the hypershift.local zone name (format: ".hypershift.local") - clusterName := "" - if dnsZoneName != "" { - // Strip the ".hypershift.local" suffix to get the cluster name - if name, ok := strings.CutSuffix(dnsZoneName, "."+privateDNSZoneSuffix); ok { - clusterName = name - } + if _, err := linkPoller.PollUntilDone(linkPollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete VNet Link: %w", err) } + } + return nil +} - if clusterName != "" { - // Determine which base domain A records this CR owns based on its name. - // This mirrors the creation logic in reconcileBaseDomainDNS. - var baseDomainRecords []string - if azPLS.Name == privateRouterCRName { - baseDomainRecords = append(baseDomainRecords, kasBaseDomainRecordPrefix+clusterName) - // Only delete the oauth record if there is no sibling OAuth CR that - // owns it. This prevents the private-router deletion from removing - // an oauth record that now belongs to the dedicated OAuth CR. - hasSiblings, err := r.hasSiblingCR(ctx, azPLS) - if err != nil { - return fmt.Errorf("failed to check for sibling CRs during base domain cleanup: %w", err) - } - if !hasSiblings { - baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) - } - } else { - baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) - } +func (r *AzurePrivateLinkServiceReconciler) deleteDNSZone(ctx context.Context, resourceGroup, zoneName string, log logr.Logger) error { + log.Info("Deleting Private DNS Zone", "zone", zoneName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + zonePoller, err := r.PrivateDNSZones.BeginDelete(deleteCtx, resourceGroup, zoneName, nil) + if err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to begin deleting Private DNS Zone: %w", err) + } + return nil + } + if zonePoller == nil { + return nil + } + zonePollCtx, zonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer zonePollCancel() - for _, recordName := range baseDomainRecords { - log.Info("Deleting base domain A record", "record", recordName, "zone", baseDomain) - deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) - if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, baseDomain, armprivatedns.RecordTypeA, recordName, nil); err != nil { - cancel() - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain A record %q: %w", recordName, err) - } - } - cancel() - } + if _, err := zonePoller.PollUntilDone(zonePollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete Private DNS Zone: %w", err) } + } + return nil +} - // Delete base domain VNet link - bdLinkName := baseDomainVNetLinkName(azPLS.Name) - log.Info("Deleting base domain VNet Link", "name", bdLinkName) - bdLinkCtx, bdLinkCancel := context.WithTimeout(ctx, azureAPITimeout) - defer bdLinkCancel() - bdLinkPoller, err := r.VirtualNetworkLinks.BeginDelete(bdLinkCtx, resourceGroup, baseDomain, bdLinkName, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting base domain VNet Link: %w", err) - } - } else if bdLinkPoller != nil { - bdLinkPollCtx, bdLinkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer bdLinkPollCancel() - - if _, err := bdLinkPoller.PollUntilDone(bdLinkPollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain VNet Link: %w", err) - } - } +func (r *AzurePrivateLinkServiceReconciler) deleteBaseDomainResources(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, resourceGroup, dnsZoneName string, log logr.Logger) error { + baseDomain := azPLS.Spec.BaseDomain + if baseDomain == "" { + return nil + } + + if err := r.deleteBaseDomainARecords(ctx, azPLS, resourceGroup, baseDomain, dnsZoneName, log); err != nil { + return err + } + + if err := r.deleteVNetLink(ctx, resourceGroup, baseDomain, baseDomainVNetLinkName(azPLS.Name), log); err != nil { + return err + } + + hasSiblings, err := r.hasSiblingCR(ctx, azPLS) + if err != nil { + return fmt.Errorf("failed to check for sibling CRs during base domain zone cleanup: %w", err) + } + + if !hasSiblings { + log.Info("Deleting base domain Private DNS Zone (last CR using this zone)", "zone", baseDomain) + return r.deleteDNSZone(ctx, resourceGroup, baseDomain, log) + } + log.Info("Skipping base domain zone deletion, other CRs still use it", "zone", baseDomain) + return nil +} + +func (r *AzurePrivateLinkServiceReconciler) deleteBaseDomainARecords(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, resourceGroup, baseDomain, dnsZoneName string, log logr.Logger) error { + clusterName := "" + if dnsZoneName != "" { + if name, ok := strings.CutSuffix(dnsZoneName, "."+privateDNSZoneSuffix); ok { + clusterName = name } + } + if clusterName == "" { + return nil + } - // Only delete the base domain DNS zone if no other CRs share it. - // When multiple CRs (e.g., private-router and oauth-openshift) use the same - // base domain zone, the zone must not be deleted until the last CR is removed. + var baseDomainRecords []string + if azPLS.Name == privateRouterCRName { + baseDomainRecords = append(baseDomainRecords, kasBaseDomainRecordPrefix+clusterName) hasSiblings, err := r.hasSiblingCR(ctx, azPLS) if err != nil { - return fmt.Errorf("failed to check for sibling CRs during base domain zone cleanup: %w", err) + return fmt.Errorf("failed to check for sibling CRs during base domain cleanup: %w", err) } - if !hasSiblings { - log.Info("Deleting base domain Private DNS Zone (last CR using this zone)", "zone", baseDomain) - bdZoneCtx, bdZoneCancel := context.WithTimeout(ctx, azureAPITimeout) - defer bdZoneCancel() - bdZonePoller, err := r.PrivateDNSZones.BeginDelete(bdZoneCtx, resourceGroup, baseDomain, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting base domain Private DNS Zone: %w", err) - } - } else if bdZonePoller != nil { - bdZonePollCtx, bdZonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer bdZonePollCancel() - - if _, err := bdZonePoller.PollUntilDone(bdZonePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain Private DNS Zone: %w", err) - } - } + baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) + } + } else { + baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) + } + + for _, recordName := range baseDomainRecords { + log.Info("Deleting base domain A record", "record", recordName, "zone", baseDomain) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, baseDomain, armprivatedns.RecordTypeA, recordName, nil); err != nil { + cancel() + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete base domain A record %q: %w", recordName, err) } - } else { - log.Info("Skipping base domain zone deletion, other CRs still use it", "zone", baseDomain) } + cancel() } + return nil +} - // 5. Delete Private Endpoint - // Always attempt deletion by deterministic name, even when PrivateEndpointID is empty. - // If the status was never populated (e.g., status update failed after PE creation), - // relying solely on PrivateEndpointID would orphan the PE in the customer's subscription. - peName := privateEndpointName(azPLS.Name) - log.Info("Deleting Private Endpoint", "name", peName, "hasStatusID", azPLS.Status.PrivateEndpointID != "") - deleteCtx4, cancel4 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel4() - pePoller, err := r.PrivateEndpoints.BeginDelete(deleteCtx4, resourceGroup, peName, nil) +func (r *AzurePrivateLinkServiceReconciler) deletePrivateEndpoint(ctx context.Context, resourceGroup, crName, statusPEID string, log logr.Logger) error { + peName := privateEndpointName(crName) + log.Info("Deleting Private Endpoint", "name", peName, "hasStatusID", statusPEID != "") + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + pePoller, err := r.PrivateEndpoints.BeginDelete(deleteCtx, resourceGroup, peName, nil) if err != nil { if !azureutil.IsAzureNotFoundError(err) { return fmt.Errorf("failed to begin deleting Private Endpoint: %w", err) } - } else if pePoller != nil { - pePollCtx, pePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer pePollCancel() + return nil + } + if pePoller == nil { + return nil + } + pePollCtx, pePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer pePollCancel() - if _, err := pePoller.PollUntilDone(pePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete Private Endpoint: %w", err) - } + if _, err := pePoller.PollUntilDone(pePollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete Private Endpoint: %w", err) } } - return nil } diff --git a/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go b/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go index bf377fa6c9d..3e63b043a3e 100644 --- a/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go +++ b/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go @@ -1715,22 +1715,88 @@ func TestReconcileDelete_WhenNoSiblingCRs_ItShouldDeleteBaseDomainZone(t *testin "should delete both VNet links when the last CR goes away") } -func TestHasSiblingCR(t *testing.T) { +func TestBaseDomainVNetLinkName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + crName string + expected string + }{ + { + name: "When CR name is private-router, it should append basedomain VNet link suffix", + crName: "private-router", + expected: "private-router-basedomain-vnet-link", + }, + { + name: "When CR name is oauth-openshift, it should append basedomain VNet link suffix", + crName: "oauth-openshift", + expected: "oauth-openshift-basedomain-vnet-link", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + result := baseDomainVNetLinkName(tt.crName) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestDNSZoneConfigErrMsgQualifier(t *testing.T) { t.Parallel() tests := []struct { name string - azPLS *hyperv1.AzurePrivateLinkService - siblings []client.Object - expectHas bool - expectErr bool + logPrefix string + expected string }{ { - name: "When sibling with same base domain exists it should return true", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), + name: "When logPrefix is empty, it should return empty string", + logPrefix: "", + expected: "", + }, + { + name: "When logPrefix is set, it should return prefix followed by a space", + logPrefix: "base domain", + expected: "base domain ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + cfg := dnsZoneConfig{logPrefix: tt.logPrefix} + g.Expect(cfg.errMsgQualifier()).To(Equal(tt.expected)) + }) + } +} + +func TestRecordNamesForCR(t *testing.T) { + tests := []struct { + name string + crName string + clusterName string + siblings []client.Object + expected []string + }{ + { + name: "When CR is not private-router, it should return only the oauth record", + crName: "oauth-openshift", + clusterName: "my-cluster", + expected: []string{"oauth-my-cluster"}, + }, + { + name: "When CR is private-router with no sibling, it should return api and oauth records", + crName: "private-router", + clusterName: "my-cluster", + expected: []string{"api-my-cluster", "oauth-my-cluster"}, + }, + { + name: "When CR is private-router with sibling OAuth CR, it should return only api record", + crName: "private-router", + clusterName: "my-cluster", siblings: []client.Object{ func() *hyperv1.AzurePrivateLinkService { cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") @@ -1738,122 +1804,253 @@ func TestHasSiblingCR(t *testing.T) { return cr }(), }, - expectHas: true, + expected: []string{"api-my-cluster"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + scheme := newTestScheme(t, g) + + azPLS := newTestAzurePLS(t, tt.crName, "test-ns") + azPLS.Spec.BaseDomain = "example.com" + + objs := append([]client.Object{azPLS}, tt.siblings...) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(objs...). + Build() + + r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} + result, err := r.recordNamesForCR(t.Context(), azPLS, tt.clusterName, testr.New(t)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestDeleteBaseDomainARecords(t *testing.T) { + tests := []struct { + name string + crName string + dnsZoneName string + baseDomain string + siblings []client.Object + expectedDeletedRecords []string + expectDeleteCalled bool + }{ + { + name: "When dnsZoneName is empty, it should skip deletion", + crName: "private-router", + dnsZoneName: "", + baseDomain: "example.com", + expectDeleteCalled: false, }, { - name: "When no siblings exist it should return false", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), - siblings: []client.Object{}, - expectHas: false, + name: "When dnsZoneName has wrong suffix, it should skip deletion", + crName: "private-router", + dnsZoneName: "cluster.wrong.suffix", + baseDomain: "example.com", + expectDeleteCalled: false, }, { - name: "When sibling has different base domain it should return false", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), + name: "When private-router with no siblings, it should delete api and oauth records", + crName: "private-router", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", + expectedDeletedRecords: []string{"api-my-cluster", "oauth-my-cluster"}, + expectDeleteCalled: true, + }, + { + name: "When private-router with sibling OAuth CR, it should delete only api record", + crName: "private-router", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", siblings: []client.Object{ func() *hyperv1.AzurePrivateLinkService { cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") - cr.Spec.BaseDomain = "other.com" + cr.Spec.BaseDomain = "example.com" return cr }(), }, - expectHas: false, + expectedDeletedRecords: []string{"api-my-cluster"}, + expectDeleteCalled: true, + }, + { + name: "When non-private-router CR, it should delete only oauth record", + crName: "oauth-openshift", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", + expectedDeletedRecords: []string{"oauth-my-cluster"}, + expectDeleteCalled: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() g := NewGomegaWithT(t) scheme := newTestScheme(t, g) - objs := append([]client.Object{tt.azPLS}, tt.siblings...) + azPLS := newTestAzurePLS(t, tt.crName, "test-ns") + azPLS.Spec.BaseDomain = tt.baseDomain + azPLS.Status.DNSZoneName = tt.dnsZoneName + + objs := append([]client.Object{azPLS}, tt.siblings...) fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(objs...). Build() - r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} - has, err := r.hasSiblingCR(t.Context(), tt.azPLS) + mockRecords := &mockRecordSets{} + r := &AzurePrivateLinkServiceReconciler{ + Client: fakeClient, + RecordSets: mockRecords, + } - if tt.expectErr { - g.Expect(err).To(HaveOccurred()) - } else { - g.Expect(err).ToNot(HaveOccurred()) - g.Expect(has).To(Equal(tt.expectHas)) + err := r.deleteBaseDomainARecords(t.Context(), azPLS, "test-rg", tt.baseDomain, tt.dnsZoneName, testr.New(t)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(mockRecords.deleteCalled).To(Equal(tt.expectDeleteCalled)) + if tt.expectDeleteCalled { + g.Expect(mockRecords.deletedRecordNames).To(Equal(tt.expectedDeletedRecords)) } }) } } func TestMapHCPToAzurePLS(t *testing.T) { - t.Parallel() tests := []struct { - name string - hcp *hyperv1.HostedControlPlane - existingPLS []client.Object - expectedLen int + name string + hcp *hyperv1.HostedControlPlane + plsCRs []client.Object + expectRequests int }{ { - name: "When HCP has PLS finalizer and AzurePLS CRs exist it should return reconcile requests", + name: "When HCP has the Azure PLS finalizer and PLS CRs exist, it should return requests for all PLS CRs", hcp: func() *hyperv1.HostedControlPlane { hcp := newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com") hcp.Finalizers = []string{hcpAzurePLSFinalizerName} return hcp }(), - existingPLS: []client.Object{ - newTestAzurePLS(t, "pls-1", "test-ns"), - newTestAzurePLS(t, "pls-2", "test-ns"), + plsCRs: []client.Object{ + newTestAzurePLS(t, "private-router", "test-ns"), + newTestAzurePLS(t, "oauth-openshift", "test-ns"), }, - expectedLen: 2, + expectRequests: 2, }, { - name: "When HCP does not have PLS finalizer it should return nil", - hcp: newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com"), - expectedLen: 0, + name: "When HCP does not have the Azure PLS finalizer, it should return no requests", + hcp: newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com"), + plsCRs: []client.Object{ + newTestAzurePLS(t, "private-router", "test-ns"), + }, + expectRequests: 0, }, { - name: "When HCP has PLS finalizer but no AzurePLS CRs exist it should return empty list", + name: "When HCP has the finalizer but no PLS CRs exist, it should return no requests", hcp: func() *hyperv1.HostedControlPlane { hcp := newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com") hcp.Finalizers = []string{hcpAzurePLSFinalizerName} return hcp }(), - expectedLen: 0, + plsCRs: []client.Object{}, + expectRequests: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() g := NewGomegaWithT(t) scheme := newTestScheme(t, g) - objs := []client.Object{tt.hcp} - objs = append(objs, tt.existingPLS...) - + objs := append([]client.Object{tt.hcp}, tt.plsCRs...) fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(objs...). Build() r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} - mapFunc := r.mapHCPToAzurePLS() - + mapFn := r.mapHCPToAzurePLS() ctx := log.IntoContext(t.Context(), testr.New(t)) - requests := mapFunc(ctx, tt.hcp) + requests := mapFn(ctx, tt.hcp) + g.Expect(requests).To(HaveLen(tt.expectRequests)) + }) + } +} - if tt.expectedLen == 0 { - g.Expect(requests).To(BeEmpty()) +func TestHasSiblingCR(t *testing.T) { + t.Parallel() + tests := []struct { + name string + azPLS *hyperv1.AzurePrivateLinkService + siblings []client.Object + expectHas bool + expectErr bool + }{ + { + name: "When sibling with same base domain exists it should return true", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{ + func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + }, + expectHas: true, + }, + { + name: "When no siblings exist it should return false", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{}, + expectHas: false, + }, + { + name: "When sibling has different base domain it should return false", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{ + func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") + cr.Spec.BaseDomain = "other.com" + return cr + }(), + }, + expectHas: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + scheme := newTestScheme(t, g) + + objs := append([]client.Object{tt.azPLS}, tt.siblings...) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(objs...). + Build() + + r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} + has, err := r.hasSiblingCR(t.Context(), tt.azPLS) + + if tt.expectErr { + g.Expect(err).To(HaveOccurred()) } else { - g.Expect(requests).To(HaveLen(tt.expectedLen)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(has).To(Equal(tt.expectHas)) } }) } @@ -2331,7 +2528,7 @@ func TestReconcileDelete_WhenBaseDomainVNetLinkDeleteFails_ItShouldReturnError(t err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting base domain VNet Link"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting VNet Link"))) } func TestReconcileDelete_WhenBaseDomainZoneDeleteFails_ItShouldReturnError(t *testing.T) { @@ -2358,7 +2555,7 @@ func TestReconcileDelete_WhenBaseDomainZoneDeleteFails_ItShouldReturnError(t *te err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting base domain Private DNS Zone"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting Private DNS Zone"))) } // Tests CR deletion with DNSZoneName set in status, which causes reconcileDelete @@ -2921,7 +3118,7 @@ func TestReconcileDelete_WhenBaseDomainVNetLinkPollerFails_ItShouldReturnError(t err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to delete base domain VNet Link"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to delete VNet Link"))) } func TestReconcileDelete_WhenBaseDomainZonePollerFails_ItShouldReturnError(t *testing.T) { @@ -2948,7 +3145,7 @@ func TestReconcileDelete_WhenBaseDomainZonePollerFails_ItShouldReturnError(t *te err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to delete base domain Private DNS Zone"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to delete Private DNS Zone"))) } // Tests CR deletion without DNSZoneName, which skips DNS cleanup and only verifies diff --git a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go index 5f923e8d5d5..017ac10a8d3 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go +++ b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go @@ -367,153 +367,101 @@ func (r *HostedControlPlaneReconciler) eventHandlers(scheme *runtime.Scheme, res return handlers } -func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - r.Log = ctrl.LoggerFrom(ctx) - r.Log.Info("Reconciling") - - // Fetch the hostedControlPlane instance - hostedControlPlane := &hyperv1.HostedControlPlane{} - err := r.Client.Get(ctx, req.NamespacedName, hostedControlPlane) - if err != nil { - if apierrors.IsNotFound(err) { - return ctrl.Result{}, nil - } - return ctrl.Result{}, err - } - - originalHostedControlPlane := hostedControlPlane.DeepCopy() - - // Return early if deleted - if !hostedControlPlane.DeletionTimestamp.IsZero() { - condition := &metav1.Condition{ - Type: string(hyperv1.AWSDefaultSecurityGroupDeleted), - } - if shouldCleanupCloudResources(r.Log, hostedControlPlane) { - if code, destroyErr := r.destroyAWSDefaultSecurityGroup(ctx, hostedControlPlane); destroyErr != nil { - condition.Message = "failed to delete AWS default security group" - if code == "DependencyViolation" { - condition.Message = destroyErr.Error() - } - condition.Reason = hyperv1.AWSErrorReason - condition.Status = metav1.ConditionFalse - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) - - if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition error message: %v", err, condition.Message) - } - - if code == "UnauthorizedOperation" { - r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of unauthorized operation.") - } - if code == "DependencyViolation" { - r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of dependency violation.") - } else { - return ctrl.Result{}, fmt.Errorf("failed to delete AWS default security group: %w", destroyErr) - } - } else { - condition.Message = hyperv1.AllIsWellMessage - condition.Reason = hyperv1.AsExpectedReason - condition.Status = metav1.ConditionTrue - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) +func (r *HostedControlPlaneReconciler) reconcileDeletion(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane, originalHostedControlPlane *hyperv1.HostedControlPlane) (ctrl.Result, error) { + condition := &metav1.Condition{ + Type: string(hyperv1.AWSDefaultSecurityGroupDeleted), + } + if shouldCleanupCloudResources(r.Log, hostedControlPlane) { + if code, destroyErr := r.destroyAWSDefaultSecurityGroup(ctx, hostedControlPlane); destroyErr != nil { + condition.Message = "failed to delete AWS default security group" + if code == "DependencyViolation" { + condition.Message = destroyErr.Error() + } + condition.Reason = hyperv1.AWSErrorReason + condition.Status = metav1.ConditionFalse + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) - if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition message: %v", err, condition.Message) - } + if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition error message: %v", err, condition.Message) } - done, err := r.removeCloudResources(ctx, hostedControlPlane) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to ensure cloud resources are removed: %w", err) + switch code { + case "UnauthorizedOperation": + r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of unauthorized operation.") + case "DependencyViolation": + r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of dependency violation.") + default: + return ctrl.Result{}, fmt.Errorf("failed to delete AWS default security group: %w", destroyErr) } - if !done { - return ctrl.Result{RequeueAfter: time.Minute}, nil + } else { + condition.Message = hyperv1.AllIsWellMessage + condition.Reason = hyperv1.AsExpectedReason + condition.Status = metav1.ConditionTrue + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) + + if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition message: %v", err, condition.Message) } } - if controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { - originalHCP := hostedControlPlane.DeepCopy() - controllerutil.RemoveFinalizer(hostedControlPlane, finalizer) - if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to remove finalizer from cluster: %w", err) - } + done, err := r.removeCloudResources(ctx, hostedControlPlane) + if err != nil { + return ctrl.Result{}, fmt.Errorf("failed to ensure cloud resources are removed: %w", err) + } + if !done { + return ctrl.Result{RequeueAfter: time.Minute}, nil } - return ctrl.Result{}, nil } - // Ensure the hostedControlPlane has a finalizer for cleanup - if !controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { + if controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { originalHCP := hostedControlPlane.DeepCopy() - controllerutil.AddFinalizer(hostedControlPlane, finalizer) + controllerutil.RemoveFinalizer(hostedControlPlane, finalizer) if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to add finalizer to hostedControlPlane: %w", err) + return ctrl.Result{}, fmt.Errorf("failed to remove finalizer from cluster: %w", err) } } + return ctrl.Result{}, nil +} - if r.OperateOnReleaseImage != "" && r.OperateOnReleaseImage != util.HCPControlPlaneReleaseImage(hostedControlPlane) { - r.Log.Info("releaseImage is " + util.HCPControlPlaneReleaseImage(hostedControlPlane) + ", but this operator is configured for " + r.OperateOnReleaseImage + ", skipping reconciliation") - return ctrl.Result{}, nil - } - - // Reconcile global configuration validation status - { - condition := metav1.Condition{ - Type: string(hyperv1.ValidHostedControlPlaneConfiguration), - ObservedGeneration: hostedControlPlane.Generation, - } - if err := r.validateConfigAndClusterCapabilities(ctx, hostedControlPlane); err != nil { - condition.Status = metav1.ConditionFalse - condition.Message = err.Error() - condition.Reason = hyperv1.InsufficientClusterCapabilitiesReason - } else { - condition.Status = metav1.ConditionTrue - condition.Message = "Configuration passes validation" - condition.Reason = hyperv1.AsExpectedReason - } - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) - } - - // Reconcile etcd cluster status - { - newCondition := metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, - } - switch hostedControlPlane.Spec.Etcd.ManagementType { - case hyperv1.Managed: - r.Log.Info("Reconciling etcd cluster status for managed strategy") - sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err != nil { - if apierrors.IsNotFound(err) { - newCondition = metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.EtcdStatefulSetNotFoundReason, - } - } else { - return ctrl.Result{}, fmt.Errorf("failed to fetch etcd statefulset %s/%s: %w", sts.Namespace, sts.Name, err) +func (r *HostedControlPlaneReconciler) reconcileEtcdStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + newCondition := metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + switch hostedControlPlane.Spec.Etcd.ManagementType { + case hyperv1.Managed: + r.Log.Info("Reconciling etcd cluster status for managed strategy") + sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err != nil { + if apierrors.IsNotFound(err) { + newCondition = metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdStatefulSetNotFoundReason, } } else { - conditionPtr, err := r.etcdStatefulSetCondition(ctx, sts) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get etcd statefulset status: %w", err) - } - newCondition = *conditionPtr + return fmt.Errorf("failed to fetch etcd statefulset %s/%s: %w", sts.Namespace, sts.Name, err) } - case hyperv1.Unmanaged: - r.Log.Info("Assuming Etcd cluster is running in unmanaged etcd strategy") - newCondition = metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionTrue, - Reason: "EtcdRunning", - Message: "Etcd cluster is assumed to be running in unmanaged state", + } else { + conditionPtr, err := r.etcdStatefulSetCondition(ctx, sts) + if err != nil { + return fmt.Errorf("failed to get etcd statefulset status: %w", err) } + newCondition = *conditionPtr + } + case hyperv1.Unmanaged: + r.Log.Info("Assuming Etcd cluster is running in unmanaged etcd strategy") + newCondition = metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: "EtcdRunning", + Message: "Etcd cluster is assumed to be running in unmanaged state", } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) - // Reconcile etcd restore status if hostedControlPlane.Spec.Etcd.ManagementType == hyperv1.Managed && hostedControlPlane.Spec.Etcd.Managed != nil && len(hostedControlPlane.Spec.Etcd.Managed.Storage.RestoreSnapshotURL) > 0 { restoreCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.EtcdSnapshotRestored)) @@ -521,218 +469,176 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R r.Log.Info("Reconciling etcd cluster restore status") sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err == nil { - newCondition := metav1.Condition{} + rc := metav1.Condition{} conditionPtr := r.etcdRestoredCondition(ctx, sts) if conditionPtr != nil { - newCondition = *conditionPtr - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + rc = *conditionPtr + rc.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, rc) } } } } - // Validate KMS config - switch hostedControlPlane.Spec.Platform.Type { - case hyperv1.AWSPlatform: - r.validateAWSKMSConfig(ctx, hostedControlPlane) - case hyperv1.AzurePlatform: - r.validateAzureKMSConfig(ctx, hostedControlPlane) - } + return nil +} - // Reconcile Kube APIServer status - { - newCondition := metav1.Condition{ +func (r *HostedControlPlaneReconciler) reconcileKASStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + newCondition := metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + deployment := manifests.KASDeployment(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(deployment), deployment); err != nil { + if apierrors.IsNotFound(err) { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.NotFoundReason, + Message: "Kube APIServer deployment not found", + } + } else { + return fmt.Errorf("failed to fetch Kube APIServer deployment %s/%s: %w", deployment.Namespace, deployment.Name, err) + } + } else { + newCondition = metav1.Condition{ Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionUnknown, + Status: metav1.ConditionFalse, Reason: hyperv1.StatusUnknownReason, } - deployment := manifests.KASDeployment(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(deployment), deployment); err != nil { - if apierrors.IsNotFound(err) { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.NotFoundReason, - Message: "Kube APIServer deployment not found", - } - } else { - return ctrl.Result{}, fmt.Errorf("failed to fetch Kube APIServer deployment %s/%s: %w", deployment.Namespace, deployment.Name, err) - } - } else { - // Assume the deployment is unavailable until proven otherwise. - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.StatusUnknownReason, - } - for _, cond := range deployment.Status.Conditions { - if cond.Type == appsv1.DeploymentAvailable { - if cond.Status == corev1.ConditionTrue { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionTrue, - Reason: hyperv1.AsExpectedReason, - Message: "Kube APIServer deployment is available", - } - } else { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.WaitingForAvailableReason, - Message: "Waiting for Kube APIServer deployment to become available", - } + for _, cond := range deployment.Status.Conditions { + if cond.Type == appsv1.DeploymentAvailable { + if cond.Status == corev1.ConditionTrue { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + Message: "Kube APIServer deployment is available", + } + } else { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingForAvailableReason, + Message: "Waiting for Kube APIServer deployment to become available", } - break } + break } } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + return nil +} - // Reconcile Degraded status - { - condition := metav1.Condition{ - Type: string(hyperv1.HostedControlPlaneDegraded), - Status: metav1.ConditionFalse, - Reason: hyperv1.AsExpectedReason, - ObservedGeneration: hostedControlPlane.Generation, +func (r *HostedControlPlaneReconciler) reconcileDegradedStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + condition := metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: hostedControlPlane.Generation, + } + cpoManagedDeploymentList := &appsv1.DeploymentList{} + if err := r.List(ctx, cpoManagedDeploymentList, client.MatchingLabels{ + component.ManagedByLabel: "control-plane-operator", + }, client.InNamespace(hostedControlPlane.Namespace)); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to list managed deployments in namespace %s: %w", hostedControlPlane.Namespace, err) } - cpoManagedDeploymentList := &appsv1.DeploymentList{} - if err := r.List(ctx, cpoManagedDeploymentList, client.MatchingLabels{ - component.ManagedByLabel: "control-plane-operator", - }, client.InNamespace(hostedControlPlane.Namespace)); err != nil { - if !apierrors.IsNotFound(err) { - return ctrl.Result{}, fmt.Errorf("failed to list managed deployments in namespace %s: %w", hostedControlPlane.Namespace, err) - } + } + var errs []error + sort.SliceStable(cpoManagedDeploymentList.Items, func(i, j int) bool { + return cpoManagedDeploymentList.Items[i].Name < cpoManagedDeploymentList.Items[j].Name + }) + for _, deployment := range cpoManagedDeploymentList.Items { + if deployment.Status.UnavailableReplicas > 0 { + errs = append(errs, fmt.Errorf("%s deployment has %d unavailable replicas", deployment.Name, deployment.Status.UnavailableReplicas)) } - var errs []error - sort.SliceStable(cpoManagedDeploymentList.Items, func(i, j int) bool { - return cpoManagedDeploymentList.Items[i].Name < cpoManagedDeploymentList.Items[j].Name - }) - for _, deployment := range cpoManagedDeploymentList.Items { - if deployment.Status.UnavailableReplicas > 0 { - errs = append(errs, fmt.Errorf("%s deployment has %d unavailable replicas", deployment.Name, deployment.Status.UnavailableReplicas)) - } + } + err := utilerrors.NewAggregate(errs) + if err != nil { + condition.Status = metav1.ConditionTrue + condition.Reason = "UnavailableReplicas" + condition.Message = err.Error() + } + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) + return nil +} + +func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + r.Log = ctrl.LoggerFrom(ctx) + r.Log.Info("Reconciling") + + hostedControlPlane := &hyperv1.HostedControlPlane{} + err := r.Client.Get(ctx, req.NamespacedName, hostedControlPlane) + if err != nil { + if apierrors.IsNotFound(err) { + return ctrl.Result{}, nil } - err := utilerrors.NewAggregate(errs) - if err != nil { - condition.Status = metav1.ConditionTrue - condition.Reason = "UnavailableReplicas" - condition.Message = err.Error() + return ctrl.Result{}, err + } + + originalHostedControlPlane := hostedControlPlane.DeepCopy() + + if !hostedControlPlane.DeletionTimestamp.IsZero() { + return r.reconcileDeletion(ctx, hostedControlPlane, originalHostedControlPlane) + } + + if !controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { + originalHCP := hostedControlPlane.DeepCopy() + controllerutil.AddFinalizer(hostedControlPlane, finalizer) + if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to add finalizer to hostedControlPlane: %w", err) } - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) } - // Reconcile infrastructure status + if r.OperateOnReleaseImage != "" && r.OperateOnReleaseImage != util.HCPControlPlaneReleaseImage(hostedControlPlane) { + r.Log.Info("releaseImage is " + util.HCPControlPlaneReleaseImage(hostedControlPlane) + ", but this operator is configured for " + r.OperateOnReleaseImage + ", skipping reconciliation") + return ctrl.Result{}, nil + } + { - r.Log.Info("Reconciling infrastructure status") - newCondition := metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, + condition := metav1.Condition{ + Type: string(hyperv1.ValidHostedControlPlaneConfiguration), + ObservedGeneration: hostedControlPlane.Generation, } - infraStatus, err := r.reconcileInfrastructureStatus(ctx, hostedControlPlane) - if err != nil { - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionUnknown, - Reason: hyperv1.InfraStatusFailureReason, - Message: err.Error(), - } - r.Log.Error(err, "failed to determine infrastructure status") + if err := r.validateConfigAndClusterCapabilities(ctx, hostedControlPlane); err != nil { + condition.Status = metav1.ConditionFalse + condition.Message = err.Error() + condition.Reason = hyperv1.InsufficientClusterCapabilitiesReason } else { - if infraStatus.IsReady() { - hostedControlPlane.Status.ControlPlaneEndpoint = hyperv1.APIEndpoint{ - Host: infraStatus.APIHost, - Port: infraStatus.APIPort, - } - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionTrue, - Message: hyperv1.AllIsWellMessage, - Reason: hyperv1.AsExpectedReason, - } - if util.HCPOAuthEnabled(hostedControlPlane) { - hostedControlPlane.Status.OAuthCallbackURLTemplate = fmt.Sprintf("https://%s:%d/oauth2callback/[identity-provider-name]", infraStatus.OAuthHost, infraStatus.OAuthPort) - } - } else { - message := "Cluster infrastructure is still provisioning" - if len(infraStatus.Message) > 0 { - message = infraStatus.Message - } - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionFalse, - Reason: hyperv1.WaitingOnInfrastructureReadyReason, - Message: message, - } - r.Log.Info("Infrastructure is not yet ready") - } + condition.Status = metav1.ConditionTrue + condition.Message = "Configuration passes validation" + condition.Reason = hyperv1.AsExpectedReason } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) } - // Reconcile external DNS status - { - newCondition := metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, - } + if err := r.reconcileEtcdStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err + } - kasExternalHostname := netutil.ServiceExternalDNSHostname(hostedControlPlane, hyperv1.APIServer) - if kasExternalHostname != "" { - if err := netutil.ResolveDNSHostname(ctx, kasExternalHostname); err != nil { - newCondition = metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionFalse, - Reason: hyperv1.ExternalDNSHostNotReachableReason, - Message: err.Error(), - } - } else { - newCondition = metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionTrue, - Message: hyperv1.AllIsWellMessage, - Reason: hyperv1.AsExpectedReason, - } - } - } else { - newCondition.Message = "External DNS is not configured" - } + switch hostedControlPlane.Spec.Platform.Type { + case hyperv1.AWSPlatform: + r.validateAWSKMSConfig(ctx, hostedControlPlane) + case hyperv1.AzurePlatform: + r.validateAzureKMSConfig(ctx, hostedControlPlane) + } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + if err := r.reconcileKASStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err } - // Reconcile hostedcontrolplane availability and Ready flag - { - healthCheckErr := r.healthCheckKASLoadBalancers(ctx, hostedControlPlane) - - // Check control plane components availability only if the HCP is not already marked as available - var componentsNotAvailableMsg string - var componentsErr error - availableCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) - alreadyAvailable := availableCondition != nil && availableCondition.Status == metav1.ConditionTrue - if !alreadyAvailable { - componentsNotAvailableMsg, componentsErr = r.controlPlaneComponentsAvailable(ctx, hostedControlPlane) - } - - ready, condition := reconcileAvailabilityStatus( - hostedControlPlane.Status.Conditions, - hostedControlPlane.Status.KubeConfig != nil, - healthCheckErr, - componentsNotAvailableMsg, - componentsErr, - hostedControlPlane.Generation, - ) - hostedControlPlane.Status.Ready = ready - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) + if err := r.reconcileDegradedStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err } + r.reconcileInfrastructureStatusCondition(ctx, hostedControlPlane) + r.reconcileExternalDNSStatusCondition(ctx, hostedControlPlane) + r.reconcileAvailabilityAndReadyStatus(ctx, hostedControlPlane) + // Admin Kubeconfig kubeconfig := manifests.KASAdminKubeconfigSecret(hostedControlPlane.Namespace, hostedControlPlane.Spec.KubeConfig) if err := r.Get(ctx, client.ObjectKeyFromObject(kubeconfig), kubeconfig); err != nil { @@ -754,20 +660,8 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R return reconcile.Result{}, err } - explicitOauthConfig := hostedControlPlane.Spec.Configuration != nil && hostedControlPlane.Spec.Configuration.OAuth != nil - if explicitOauthConfig { - hostedControlPlane.Status.KubeadminPassword = nil - } else { - kubeadminPassword := common.KubeadminPasswordSecret(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(kubeadminPassword), kubeadminPassword); err != nil { - if !apierrors.IsNotFound(err) { - return reconcile.Result{}, fmt.Errorf("failed to get kubeadmin password: %w", err) - } - } else { - hostedControlPlane.Status.KubeadminPassword = &corev1.LocalObjectReference{ - Name: kubeadminPassword.Name, - } - } + if err := r.reconcileKubeadminPasswordStatus(ctx, hostedControlPlane); err != nil { + return reconcile.Result{}, err } // Reconcile valid release info status @@ -776,38 +670,8 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R return reconcile.Result{}, fmt.Errorf("failed to look up release image metadata: %w", err) } - // Reconcile controlPlaneVersion status. - // This runs after LookupReleaseImage so we can use the version and resolved - // digest from the release image metadata. - { - clk := r.clock - if clk == nil { - clk = clock.RealClock{} - } - // Resolve the release image to its digest so controlPlaneVersion records - // the immutable image reference, consistent with how CVO records images. - pullSecret := common.PullSecret(hostedControlPlane.Namespace) - if err := r.Client.Get(ctx, client.ObjectKeyFromObject(pullSecret), pullSecret); err != nil { - return reconcile.Result{}, fmt.Errorf("failed to get pull secret for version reconciliation: %w", err) - } - _, resolvedRef, err := r.ImageMetadataProvider.GetDigest(ctx, util.HCPControlPlaneReleaseImage(hostedControlPlane), pullSecret.Data[corev1.DockerConfigJsonKey]) - if err != nil { - return reconcile.Result{}, fmt.Errorf("failed to resolve control plane release image digest: %w", err) - } - resolvedImage := resolvedRef.String() - - componentsList := &hyperv1.ControlPlaneComponentList{} - if listErr := r.Client.List(ctx, componentsList, client.InNamespace(hostedControlPlane.Namespace)); listErr != nil { - // On list failure, ensure a Partial entry exists so consumers - // know an upgrade was attempted. Preserve observedGeneration. - hostedControlPlane.Status.ControlPlaneVersion = ensureControlPlaneVersionPartial(hostedControlPlane, clk, releaseImage.Version(), resolvedImage) - // Persist the Partial entry before returning the error. - if patchErr := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); patchErr != nil { - return reconcile.Result{}, fmt.Errorf("failed to patch status after component list failure: %w (list error: %v)", patchErr, listErr) - } - return reconcile.Result{}, fmt.Errorf("failed to list control plane components for version reconciliation: %w", listErr) - } - hostedControlPlane.Status.ControlPlaneVersion = reconcileControlPlaneVersion(hostedControlPlane, componentsList.Items, clk, releaseImage.Version(), resolvedImage) + if err := r.reconcileControlPlaneVersionStatus(ctx, hostedControlPlane, originalHostedControlPlane, releaseImage); err != nil { + return reconcile.Result{}, err } hostedControlPlane.Status.Initialized = true @@ -848,8 +712,157 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R if !hostedControlPlane.Status.Ready { return ctrl.Result{RequeueAfter: hcpNotReadyRequeueInterval}, nil } - - return ctrl.Result{RequeueAfter: hcpReadyRequeueInterval}, nil + + return ctrl.Result{RequeueAfter: hcpReadyRequeueInterval}, nil +} + +func (r *HostedControlPlaneReconciler) reconcileInfrastructureStatusCondition(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + r.Log.Info("Reconciling infrastructure status") + newCondition := metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + infraStatus, err := r.reconcileInfrastructureStatus(ctx, hostedControlPlane) + if err != nil { + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionUnknown, + Reason: hyperv1.InfraStatusFailureReason, + Message: err.Error(), + } + r.Log.Error(err, "failed to determine infrastructure status") + } else if infraStatus.IsReady() { + hostedControlPlane.Status.ControlPlaneEndpoint = hyperv1.APIEndpoint{ + Host: infraStatus.APIHost, + Port: infraStatus.APIPort, + } + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionTrue, + Message: hyperv1.AllIsWellMessage, + Reason: hyperv1.AsExpectedReason, + } + if util.HCPOAuthEnabled(hostedControlPlane) { + hostedControlPlane.Status.OAuthCallbackURLTemplate = fmt.Sprintf("https://%s:%d/oauth2callback/[identity-provider-name]", infraStatus.OAuthHost, infraStatus.OAuthPort) + } + } else { + message := "Cluster infrastructure is still provisioning" + if len(infraStatus.Message) > 0 { + message = infraStatus.Message + } + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingOnInfrastructureReadyReason, + Message: message, + } + r.Log.Info("Infrastructure is not yet ready") + } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) +} + +func (r *HostedControlPlaneReconciler) reconcileExternalDNSStatusCondition(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + newCondition := metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + + kasExternalHostname := netutil.ServiceExternalDNSHostname(hostedControlPlane, hyperv1.APIServer) + if kasExternalHostname != "" { + if err := netutil.ResolveDNSHostname(ctx, kasExternalHostname); err != nil { + newCondition = metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionFalse, + Reason: hyperv1.ExternalDNSHostNotReachableReason, + Message: err.Error(), + } + } else { + newCondition = metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionTrue, + Message: hyperv1.AllIsWellMessage, + Reason: hyperv1.AsExpectedReason, + } + } + } else { + newCondition.Message = "External DNS is not configured" + } + + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) +} + +func (r *HostedControlPlaneReconciler) reconcileAvailabilityAndReadyStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + healthCheckErr := r.healthCheckKASLoadBalancers(ctx, hostedControlPlane) + + var componentsNotAvailableMsg string + var componentsErr error + availableCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) + alreadyAvailable := availableCondition != nil && availableCondition.Status == metav1.ConditionTrue + if !alreadyAvailable { + componentsNotAvailableMsg, componentsErr = r.controlPlaneComponentsAvailable(ctx, hostedControlPlane) + } + + ready, condition := reconcileAvailabilityStatus( + hostedControlPlane.Status.Conditions, + hostedControlPlane.Status.KubeConfig != nil, + healthCheckErr, + componentsNotAvailableMsg, + componentsErr, + hostedControlPlane.Generation, + ) + hostedControlPlane.Status.Ready = ready + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) +} + +func (r *HostedControlPlaneReconciler) reconcileKubeadminPasswordStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + explicitOauthConfig := hostedControlPlane.Spec.Configuration != nil && hostedControlPlane.Spec.Configuration.OAuth != nil + if explicitOauthConfig { + hostedControlPlane.Status.KubeadminPassword = nil + return nil + } + kubeadminPassword := common.KubeadminPasswordSecret(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(kubeadminPassword), kubeadminPassword); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to get kubeadmin password: %w", err) + } + hostedControlPlane.Status.KubeadminPassword = nil + } else { + hostedControlPlane.Status.KubeadminPassword = &corev1.LocalObjectReference{ + Name: kubeadminPassword.Name, + } + } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileControlPlaneVersionStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane, originalHostedControlPlane *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { + clk := r.clock + if clk == nil { + clk = clock.RealClock{} + } + pullSecret := common.PullSecret(hostedControlPlane.Namespace) + if err := r.Client.Get(ctx, client.ObjectKeyFromObject(pullSecret), pullSecret); err != nil { + return fmt.Errorf("failed to get pull secret for version reconciliation: %w", err) + } + _, resolvedRef, err := r.ImageMetadataProvider.GetDigest(ctx, util.HCPControlPlaneReleaseImage(hostedControlPlane), pullSecret.Data[corev1.DockerConfigJsonKey]) + if err != nil { + return fmt.Errorf("failed to resolve control plane release image digest: %w", err) + } + resolvedImage := resolvedRef.String() + + componentsList := &hyperv1.ControlPlaneComponentList{} + if listErr := r.Client.List(ctx, componentsList, client.InNamespace(hostedControlPlane.Namespace)); listErr != nil { + hostedControlPlane.Status.ControlPlaneVersion = ensureControlPlaneVersionPartial(hostedControlPlane, clk, releaseImage.Version(), resolvedImage) + if patchErr := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); patchErr != nil { + return fmt.Errorf("failed to patch status after component list failure: %w (list error: %v)", patchErr, listErr) + } + return fmt.Errorf("failed to list control plane components for version reconciliation: %w", listErr) + } + hostedControlPlane.Status.ControlPlaneVersion = reconcileControlPlaneVersion(hostedControlPlane, componentsList.Items, clk, releaseImage.Version(), resolvedImage) + return nil } // reconcileAvailabilityStatus determines the HostedControlPlane availability condition @@ -1293,35 +1306,7 @@ func (r *HostedControlPlaneReconciler) reconcileKubeadminPassword(ctx context.Co return nil } -func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hyperv1.HostedControlPlane, infraStatus infra.InfrastructureStatus, createOrUpdate upsert.CreateOrUpdateFN) error { - p := pki.NewPKIParams(hcp, infraStatus.APIHost, infraStatus.OAuthHost, infraStatus.KonnectivityHost) - - // Root CA - rootCASecret := manifests.RootCASecret(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, rootCASecret, func() error { - return pki.ReconcileRootCA(rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile root CA: %w", err) - } - - var observedDefaultIngressCert *corev1.ConfigMap - if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { - observedDefaultIngressCert = manifests.IngressObservedDefaultIngressCertCA(hcp.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(observedDefaultIngressCert), observedDefaultIngressCert); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to get observed default ingress cert: %w", err) - } - observedDefaultIngressCert = nil - } - } - rootCAConfigMap := manifests.RootCAConfigMap(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, rootCAConfigMap, func() error { - return pki.ReconcileRootCAConfigMap(rootCAConfigMap, p.OwnerRef, rootCASecret, observedDefaultIngressCert) - }); err != nil { - return fmt.Errorf("failed to reconcile root CA configmap: %w", err) - } - - // Etcd signer for all the etcd-related certs +func (r *HostedControlPlaneReconciler) reconcileEtcdCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN) error { etcdSignerSecret := manifests.EtcdSignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdSignerSecret, func() error { return pki.ReconcileEtcdSignerSecret(etcdSignerSecret, p.OwnerRef) @@ -1336,7 +1321,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd signer CA configmap: %w", err) } - // Etcd client secret etcdClientSecret := manifests.EtcdClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdClientSecret, func() error { return pki.ReconcileEtcdClientSecret(etcdClientSecret, etcdSignerSecret, p.OwnerRef) @@ -1344,7 +1328,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd client secret: %w", err) } - // Etcd server secret etcdServerSecret := manifests.EtcdServerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdServerSecret, func() error { return pki.ReconcileEtcdServerSecret(etcdServerSecret, etcdSignerSecret, p.OwnerRef) @@ -1352,7 +1335,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd server secret: %w", err) } - // Etcd peer secret etcdPeerSecret := manifests.EtcdPeerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdPeerSecret, func() error { return pki.ReconcileEtcdPeerSecret(etcdPeerSecret, etcdSignerSecret, p.OwnerRef) @@ -1360,8 +1342,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd peer secret: %w", err) } - // Etcd metrics signer - // Etcd signer for all the etcd-related certs etcdMetricsSignerSecret := manifests.EtcdMetricsSignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdMetricsSignerSecret, func() error { return pki.ReconcileEtcdMetricsSignerSecret(etcdMetricsSignerSecret, p.OwnerRef) @@ -1376,7 +1356,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd signer CA configmap: %w", err) } - // Etcd client secret etcdMetricsClientSecret := manifests.EtcdMetricsClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdMetricsClientSecret, func() error { return pki.ReconcileEtcdMetricsClientSecret(etcdMetricsClientSecret, etcdMetricsSignerSecret, p.OwnerRef) @@ -1384,7 +1363,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd client secret: %w", err) } - // KAS server secret + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileKASCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { kasServerSecret := manifests.KASServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, kasServerSecret, func() error { return pki.ReconcileKASServerCertSecret(kasServerSecret, rootCASecret, p.OwnerRef, p.ExternalAPIAddress, p.InternalAPIAddress, p.ServiceCIDR, p.NodeInternalAPIServerIP) @@ -1392,7 +1374,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile kas server secret: %w", err) } - // KAS server private secret kasServerPrivateSecret := manifests.KASServerPrivateCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, kasServerPrivateSecret, func() error { return pki.ReconcileKASServerPrivateCertSecret(kasServerPrivateSecret, rootCASecret, p.OwnerRef) @@ -1409,7 +1390,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return err } - // Service account signing key secret serviceAccountSigningKeySecret := manifests.ServiceAccountSigningKeySecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, serviceAccountSigningKeySecret, func() error { return pki.ReconcileServiceAccountSigningKeySecret(serviceAccountSigningKeySecret, p.OwnerRef) @@ -1417,7 +1397,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile api server service account key secret: %w", err) } - // OpenShift APIServer + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileOpenshiftCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { openshiftAPIServerCertSecret := manifests.OpenShiftAPIServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftAPIServerCertSecret, func() error { return pki.ReconcileOpenShiftAPIServerCertSecret(openshiftAPIServerCertSecret, rootCASecret, p.OwnerRef) @@ -1426,7 +1409,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } if util.HCPOAuthEnabled(hcp) { - // OpenShift OAuth APIServer openshiftOAuthAPIServerCertSecret := manifests.OpenShiftOAuthAPIServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftOAuthAPIServerCertSecret, func() error { return pki.ReconcileOpenShiftOAuthAPIServerCertSecret(openshiftOAuthAPIServerCertSecret, rootCASecret, p.OwnerRef) @@ -1435,7 +1417,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } } - // OpenShift ControllerManager Cert openshiftControllerManagerCertSecret := manifests.OpenShiftControllerManagerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftControllerManagerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(openshiftControllerManagerCertSecret, rootCASecret, p.OwnerRef) @@ -1443,7 +1424,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile openshift controller manager cert: %w", err) } - // OpenShift Route ControllerManager Cert openshiftRouteControllerManagerCertSecret := manifests.OpenShiftRouteControllerManagerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftRouteControllerManagerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(openshiftRouteControllerManagerCertSecret, rootCASecret, p.OwnerRef) @@ -1451,7 +1431,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile openshift route controller manager cert: %w", err) } - // Cluster Policy Controller Cert clusterPolicyControllerCertSecret := manifests.ClusterPolicyControllerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, clusterPolicyControllerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(clusterPolicyControllerCertSecret, rootCASecret, p.OwnerRef) @@ -1459,6 +1438,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile cluster policy controller cert: %w", err) } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileKonnectivityCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN) error { konnectivitySigner := manifests.KonnectivitySignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivitySigner, func() error { return pki.ReconcileKonnectivitySignerSecret(konnectivitySigner, p.OwnerRef) @@ -1473,7 +1456,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity CA config map: %v", err) } - // Konnectivity Server Cert konnectivityServerSecret := manifests.KonnectivityServerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityServerSecret, func() error { return pki.ReconcileKonnectivityServerSecret(konnectivityServerSecret, konnectivitySigner, p.OwnerRef) @@ -1481,7 +1463,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity server cert: %w", err) } - // Konnectivity Cluster Cert konnectivityClusterSecret := manifests.KonnectivityClusterSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityClusterSecret, func() error { return pki.ReconcileKonnectivityClusterSecret(konnectivityClusterSecret, konnectivitySigner, p.OwnerRef, p.ExternalKconnectivityAddress) @@ -1489,7 +1470,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity cluster cert: %w", err) } - // Konnectivity Client Cert konnectivityClientSecret := manifests.KonnectivityClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityClientSecret, func() error { return pki.ReconcileKonnectivityClientSecret(konnectivityClientSecret, konnectivitySigner, p.OwnerRef) @@ -1497,7 +1477,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity client cert: %w", err) } - // Konnectivity Agent Cert konnectivityAgentSecret := manifests.KonnectivityAgentSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityAgentSecret, func() error { return pki.ReconcileKonnectivityAgentSecret(konnectivityAgentSecret, konnectivitySigner, p.OwnerRef) @@ -1505,84 +1484,51 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity agent cert: %w", err) } - // Reconcile ingress serving certificate only if Ingress capability is enabled. - if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { - // Ingress Cert - ingressCert := manifests.IngressCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, ingressCert, func() error { - return pki.ReconcileIngressCert(ingressCert, rootCASecret, p.OwnerRef, p.IngressSubdomain) - }); err != nil { - return fmt.Errorf("failed to reconcile ingress cert secret: %w", err) - } - } - - var userCABundles []client.ObjectKey - if hcp.Spec.AdditionalTrustBundle != nil { - userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.AdditionalTrustBundle.Name}) - } - if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.Proxy != nil && hcp.Spec.Configuration.Proxy.TrustedCA.Name != "" { - userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.Configuration.Proxy.TrustedCA.Name}) - } + return nil +} - trustedCABundle := manifests.TrustedCABundleConfigMap(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, trustedCABundle, func() error { - return r.reconcileManagedTrustedCABundle(ctx, trustedCABundle, userCABundles) +func (r *HostedControlPlaneReconciler) reconcileOAuthCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret, trustedCABundle *corev1.ConfigMap) error { + oauthServerCert := manifests.OpenShiftOAuthServerCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, oauthServerCert, func() error { + return pki.ReconcileOAuthServerCert(oauthServerCert, rootCASecret, p.OwnerRef, p.ExternalOauthAddress) }); err != nil { - return fmt.Errorf("failed to reconcile managed UserCA configMap: %w", err) + return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) } - if util.HCPOAuthEnabled(hcp) { - // OAuth server Cert - oauthServerCert := manifests.OpenShiftOAuthServerCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, oauthServerCert, func() error { - return pki.ReconcileOAuthServerCert(oauthServerCert, rootCASecret, p.OwnerRef, p.ExternalOauthAddress) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) - } - - // OAuth master CA Bundle - bundleSources := []*corev1.Secret{oauthServerCert} - if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.APIServer != nil { - for _, namedCert := range hcp.Spec.Configuration.APIServer.ServingCerts.NamedCertificates { - if namedCert.ServingCertificate.Name == "" { - continue - } - certSecret := &corev1.Secret{ObjectMeta: metav1.ObjectMeta{Name: namedCert.ServingCertificate.Name, Namespace: hcp.Namespace}} - if err := r.Get(ctx, client.ObjectKeyFromObject(certSecret), certSecret); err != nil { - return fmt.Errorf("failed to get named certificate %s: %w", certSecret.Name, err) - } - bundleSources = append(bundleSources, certSecret) + bundleSources := []*corev1.Secret{oauthServerCert} + if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.APIServer != nil { + for _, namedCert := range hcp.Spec.Configuration.APIServer.ServingCerts.NamedCertificates { + if namedCert.ServingCertificate.Name == "" { + continue } + certSecret := &corev1.Secret{ObjectMeta: metav1.ObjectMeta{Name: namedCert.ServingCertificate.Name, Namespace: hcp.Namespace}} + if err := r.Get(ctx, client.ObjectKeyFromObject(certSecret), certSecret); err != nil { + return fmt.Errorf("failed to get named certificate %s: %w", certSecret.Name, err) + } + bundleSources = append(bundleSources, certSecret) } + } - if trustedCABundle.Data[certs.UserCABundleMapKey] != "" { - bundleSources = append(bundleSources, &corev1.Secret{ - Data: map[string][]byte{ - certs.CASignerCertMapKey: []byte(trustedCABundle.Data[certs.UserCABundleMapKey]), - }, - }) - } - - oauthMasterCABundle := manifests.OpenShiftOAuthMasterCABundle(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, oauthMasterCABundle, func() error { - return pki.ReconcileOAuthMasterCABundle(oauthMasterCABundle, p.OwnerRef, bundleSources) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) - } + if trustedCABundle.Data[certs.UserCABundleMapKey] != "" { + bundleSources = append(bundleSources, &corev1.Secret{ + Data: map[string][]byte{ + certs.CASignerCertMapKey: []byte(trustedCABundle.Data[certs.UserCABundleMapKey]), + }, + }) } - // MCS Cert - if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { - machineConfigServerCert := manifests.MachineConfigServerCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, machineConfigServerCert, func() error { - return pki.ReconcileMachineConfigServerCert(machineConfigServerCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile machine config server cert secret: %w", err) - } + oauthMasterCABundle := manifests.OpenShiftOAuthMasterCABundle(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, oauthMasterCABundle, func() error { + return pki.ReconcileOAuthMasterCABundle(oauthMasterCABundle, p.OwnerRef, bundleSources) + }); err != nil { + return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) } - var err error + + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileOLMAndMiscCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { if capabilities.IsNodeTuningCapabilityEnabled(hcp.Spec.Capabilities) { - // Cluster Node Tuning Operator metrics Serving Cert NodeTuningOperatorServingCert := manifests.ClusterNodeTuningOperatorServingCertSecret(hcp.Namespace) NodeTuningOperatorService := manifests.ClusterNodeTuningOperatorMetricsService(hcp.Namespace) err := removeServiceCAAnnotationAndSecret(ctx, r.Client, NodeTuningOperatorService, NodeTuningOperatorServingCert) @@ -1595,7 +1541,7 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile node tuning operator serving cert: %w", err) } } - // OLM PackageServer Cert + packageServerCertSecret := manifests.OLMPackageServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, packageServerCertSecret, func() error { return pki.ReconcileOLMPackageServerCertSecret(packageServerCertSecret, rootCASecret, p.OwnerRef) @@ -1603,7 +1549,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile packageserver cert: %w", err) } - // OLM Catalog Operator Serving Cert catalogOperatorServingCert := manifests.OLMCatalogOperatorServingCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, catalogOperatorServingCert, func() error { return pki.ReconcileOLMCatalogOperatorServingCertSecret(catalogOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1611,7 +1556,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile olm catalog operator serving cert: %w", err) } - // OLM Operator Serving Cert olmOperatorServingCert := manifests.OLMOperatorServingCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, olmOperatorServingCert, func() error { return pki.ReconcileOLMOperatorServingCertSecret(olmOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1620,7 +1564,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } if capabilities.IsImageRegistryCapabilityEnabled(hcp.Spec.Capabilities) { - // Image Registry Operator Serving Cert imageRegistryOperatorServingCert := manifests.ImageRegistryOperatorServingCert(hcp.Namespace) if _, err := createOrUpdate(ctx, r, imageRegistryOperatorServingCert, func() error { return pki.ReconcileRegistryOperatorServingCert(imageRegistryOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1643,31 +1586,26 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile cvo serving cert: %w", err) } - // For the Multus Admission Controller, Network Node Identity, and OVN Control Plane Metrics Serving Certs: - // We want to remove the secret if there was an existing one created by service-ca; otherwise, it will cause - // issues in cases where you are upgrading an older CPO prior to us adding the feature to reconcile the serving - // cert secret ourselves. + return nil +} - // Multus Admission Controller Serving Cert - only if Multus is not disabled +func (r *HostedControlPlaneReconciler) reconcileNetworkServingCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { if !netutil.IsDisableMultiNetwork(hcp) { multusAdmissionControllerService := manifests.MultusAdmissionControllerService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(multusAdmissionControllerService), multusAdmissionControllerService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(multusAdmissionControllerService), multusAdmissionControllerService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve multus-admission-controller service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(multusAdmissionControllerService); !hasServiceCAAnnotation { multusAdmissionControllerServingCertSecret := manifests.MultusAdmissionControllerServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, multusAdmissionControllerServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, multusAdmissionControllerServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, multusAdmissionControllerServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, multusAdmissionControllerServingCertSecret, func() error { return pki.ReconcileMultusAdmissionControllerServingCertSecret(multusAdmissionControllerServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile multus admission controller serving cert: %w", err) @@ -1675,179 +1613,272 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } } - // Network Node Identity Serving Cert networkNodeIdentityService := manifests.NetworkNodeIdentityService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(networkNodeIdentityService), networkNodeIdentityService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(networkNodeIdentityService), networkNodeIdentityService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve network-node-identity service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(networkNodeIdentityService); !hasServiceCAAnnotation { networkNodeIdentityServingCertSecret := manifests.NetworkNodeIdentityControllerServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, networkNodeIdentityServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, networkNodeIdentityServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, networkNodeIdentityServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, networkNodeIdentityServingCertSecret, func() error { return pki.ReconcileNetworkNodeIdentityServingCertSecret(networkNodeIdentityServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile network node identity serving cert: %w", err) } } - // OVN Control Plane Metrics Serving Cert ovnControlPlaneService := manifests.OVNKubernetesControlPlaneService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(ovnControlPlaneService), ovnControlPlaneService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(ovnControlPlaneService), ovnControlPlaneService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve ovn-kubernetes-control-plane service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(ovnControlPlaneService); !hasServiceCAAnnotation { ovnControlPlaneMetricsServingCertSecret := manifests.OVNControlPlaneMetricsServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, ovnControlPlaneMetricsServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, ovnControlPlaneMetricsServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, ovnControlPlaneMetricsServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, ovnControlPlaneMetricsServingCertSecret, func() error { return pki.ReconcileOVNControlPlaneMetricsServingCertSecret(ovnControlPlaneMetricsServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile OVN control plane serving cert: %w", err) } } - if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { - if hcp.Spec.Platform.Type != hyperv1.IBMCloudPlatform { - ignitionServerCert := manifests.IgnitionServerCertSecret(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, ignitionServerCert, func() error { - return pki.ReconcileIgnitionServerCertSecret(ignitionServerCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile ignition server cert: %w", err) - } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileAWSPlatformCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { + awsPodIdentityWebhookServingCert := manifests.AWSPodIdentityWebhookServingCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, awsPodIdentityWebhookServingCert, func() error { + return pki.ReconcileAWSPodIdentityWebhookServingCert(awsPodIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile %s secret: %w", awsPodIdentityWebhookServingCert.Name, err) + } + + awsEBSCsiDriverControllerMetricsService := manifests.AWSEBSCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(awsEBSCsiDriverControllerMetricsService), awsEBSCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve aws-ebs-csi-driver-controller-metrics service: %w", err) } } - // Platform specific certs - switch hcp.Spec.Platform.Type { - case hyperv1.AWSPlatform: - awsPodIdentityWebhookServingCert := manifests.AWSPodIdentityWebhookServingCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, awsPodIdentityWebhookServingCert, func() error { - return pki.ReconcileAWSPodIdentityWebhookServingCert(awsPodIdentityWebhookServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile %s secret: %w", awsPodIdentityWebhookServingCert.Name, err) + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(awsEBSCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + awsEBSCsiDriverControllerMetricsServingCert := manifests.AWSEBSCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, awsEBSCsiDriverControllerMetricsServingCert); err != nil { + return err } - awsEBSCsiDriverControllerMetricsService := manifests.AWSEBSCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(awsEBSCsiDriverControllerMetricsService), awsEBSCsiDriverControllerMetricsService); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve aws-ebs-csi-driver-controller-metrics service: %w", err) - } + if _, err := createOrUpdate(ctx, r, awsEBSCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAWSEBSCsiDriverControllerMetricsServingCertSecret(awsEBSCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile aws ebs csi driver controller metrics serving cert: %w", err) } + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(awsEBSCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - awsEBSCsiDriverControllerMetricsServingCert := manifests.AWSEBSCsiDriverControllerMetricsServingCert(hcp.Namespace) + return nil +} - err = removeServiceCASecret(ctx, r.Client, awsEBSCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } +func (r *HostedControlPlaneReconciler) reconcileAzurePlatformCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { + azureWorkloadIdentityWebhookServingCert := manifests.AzureWorkloadIdentityWebhookServingCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, azureWorkloadIdentityWebhookServingCert, func() error { + return pki.ReconcileAzureWorkloadIdentityWebhookServingCert(azureWorkloadIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile %s secret: %w", azureWorkloadIdentityWebhookServingCert.Name, err) + } - if _, err = createOrUpdate(ctx, r, awsEBSCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAWSEBSCsiDriverControllerMetricsServingCertSecret(awsEBSCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile aws ebs csi driver controller metrics serving cert: %w", err) - } + AzureDiskCsiDriverOperatorServingCert := manifests.AzureDiskCSIDriverOperatorServingCertSecret(hcp.Namespace) + AzureDiskCsiDriverOperatorService := manifests.AzureDiskCSIDriverOperatorMetricsService(hcp.Namespace) + if err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureDiskCsiDriverOperatorService, AzureDiskCsiDriverOperatorServingCert); err != nil { + r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + } + if _, err := createOrUpdate(ctx, r, AzureDiskCsiDriverOperatorServingCert, func() error { + z := pki.ReconcileAzureDiskCsiDriverOperatorMetricsServingCertSecret(AzureDiskCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) + return z + }); err != nil { + return fmt.Errorf("failed to reconcile azure-disk csi driver operator serving cert: %w", err) + } + + azureDiskCsiDriverControllerMetricsService := manifests.AzureDiskCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(azureDiskCsiDriverControllerMetricsService), azureDiskCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve azure-disk-csi-driver-controller-metrics service: %w", err) } - case hyperv1.AzurePlatform: - azureWorkloadIdentityWebhookServingCert := manifests.AzureWorkloadIdentityWebhookServingCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, azureWorkloadIdentityWebhookServingCert, func() error { - return pki.ReconcileAzureWorkloadIdentityWebhookServingCert(azureWorkloadIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + } + + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureDiskCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + azureDiskCsiDriverControllerMetricsServingCert := manifests.AzureDiskCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, azureDiskCsiDriverControllerMetricsServingCert); err != nil { + return err + } + + if _, err := createOrUpdate(ctx, r, azureDiskCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAzureDiskCsiDriverControllerMetricsServingCertSecret(azureDiskCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile %s secret: %w", azureWorkloadIdentityWebhookServingCert.Name, err) + return fmt.Errorf("failed to reconcile azure disk csi driver controller metrics serving cert: %w", err) } + } - // Azure-disk CSI driver Operator metrics Serving Cert - AzureDiskCsiDriverOperatorServingCert := manifests.AzureDiskCSIDriverOperatorServingCertSecret(hcp.Namespace) - AzureDiskCsiDriverOperatorService := manifests.AzureDiskCSIDriverOperatorMetricsService(hcp.Namespace) - err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureDiskCsiDriverOperatorService, AzureDiskCsiDriverOperatorServingCert) - if err != nil { - r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + AzureFileCsiDriverOperatorServingCert := manifests.AzureFileCSIDriverOperatorServingCertSecret(hcp.Namespace) + AzureFileCsiDriverOperatorService := manifests.AzureFileCSIDriverOperatorMetricsService(hcp.Namespace) + if err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureFileCsiDriverOperatorService, AzureFileCsiDriverOperatorServingCert); err != nil { + r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + } + if _, err := createOrUpdate(ctx, r, AzureFileCsiDriverOperatorServingCert, func() error { + z := pki.ReconcileAzureFileCsiDriverOperatorMetricsServingCertSecret(AzureFileCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) + return z + }); err != nil { + return fmt.Errorf("failed to reconcile azure-file csi driver operator serving cert: %w", err) + } + + azureFileCsiDriverControllerMetricsService := manifests.AzureFileCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(azureFileCsiDriverControllerMetricsService), azureFileCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve azure-file-csi-driver-controller-metrics service: %w", err) + } + } + + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureFileCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + azureFileCsiDriverControllerMetricsServingCert := manifests.AzureFileCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, azureFileCsiDriverControllerMetricsServingCert); err != nil { + return err } - if _, err = createOrUpdate(ctx, r, AzureDiskCsiDriverOperatorServingCert, func() error { - z := pki.ReconcileAzureDiskCsiDriverOperatorMetricsServingCertSecret(AzureDiskCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) - return z + + if _, err := createOrUpdate(ctx, r, azureFileCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAzureFileCsiDriverControllerMetricsServingCertSecret(azureFileCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile azure-disk csi driver operator serving cert: %w", err) + return fmt.Errorf("failed to reconcile azure file csi driver controller metrics serving cert: %w", err) } + } + + return nil +} + +func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hyperv1.HostedControlPlane, infraStatus infra.InfrastructureStatus, createOrUpdate upsert.CreateOrUpdateFN) error { + p := pki.NewPKIParams(hcp, infraStatus.APIHost, infraStatus.OAuthHost, infraStatus.KonnectivityHost) + + rootCASecret := manifests.RootCASecret(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, rootCASecret, func() error { + return pki.ReconcileRootCA(rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile root CA: %w", err) + } - azureDiskCsiDriverControllerMetricsService := manifests.AzureDiskCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(azureDiskCsiDriverControllerMetricsService), azureDiskCsiDriverControllerMetricsService); err != nil { + var observedDefaultIngressCert *corev1.ConfigMap + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { + observedDefaultIngressCert = manifests.IngressObservedDefaultIngressCertCA(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(observedDefaultIngressCert), observedDefaultIngressCert); err != nil { if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve azure-disk-csi-driver-controller-metrics service: %w", err) + return fmt.Errorf("failed to get observed default ingress cert: %w", err) } + observedDefaultIngressCert = nil } + } + rootCAConfigMap := manifests.RootCAConfigMap(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, rootCAConfigMap, func() error { + return pki.ReconcileRootCAConfigMap(rootCAConfigMap, p.OwnerRef, rootCASecret, observedDefaultIngressCert) + }); err != nil { + return fmt.Errorf("failed to reconcile root CA configmap: %w", err) + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureDiskCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - azureDiskCsiDriverControllerMetricsServingCert := manifests.AzureDiskCsiDriverControllerMetricsServingCert(hcp.Namespace) + if err := r.reconcileEtcdCerts(ctx, hcp, p, createOrUpdate); err != nil { + return err + } - err = removeServiceCASecret(ctx, r.Client, azureDiskCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } + if err := r.reconcileKASCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - if _, err = createOrUpdate(ctx, r, azureDiskCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAzureDiskCsiDriverControllerMetricsServingCertSecret(azureDiskCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile azure disk csi driver controller metrics serving cert: %w", err) - } - } + if err := r.reconcileOpenshiftCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - // Azure-file CSI driver Operator metrics Serving Cert - AzureFileCsiDriverOperatorServingCert := manifests.AzureFileCSIDriverOperatorServingCertSecret(hcp.Namespace) - AzureFileCsiDriverOperatorService := manifests.AzureFileCSIDriverOperatorMetricsService(hcp.Namespace) - err = removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureFileCsiDriverOperatorService, AzureFileCsiDriverOperatorServingCert) - if err != nil { - r.Log.Error(err, "failed to remove service ca annotation and secret: %w") - } - if _, err = createOrUpdate(ctx, r, AzureFileCsiDriverOperatorServingCert, func() error { - z := pki.ReconcileAzureFileCsiDriverOperatorMetricsServingCertSecret(AzureFileCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) - return z + if err := r.reconcileKonnectivityCerts(ctx, hcp, p, createOrUpdate); err != nil { + return err + } + + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { + ingressCert := manifests.IngressCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, ingressCert, func() error { + return pki.ReconcileIngressCert(ingressCert, rootCASecret, p.OwnerRef, p.IngressSubdomain) }); err != nil { - return fmt.Errorf("failed to reconcile azure-file csi driver operator serving cert: %w", err) + return fmt.Errorf("failed to reconcile ingress cert secret: %w", err) } + } - azureFileCsiDriverControllerMetricsService := manifests.AzureFileCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(azureFileCsiDriverControllerMetricsService), azureFileCsiDriverControllerMetricsService); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve azure-file-csi-driver-controller-metrics service: %w", err) - } + var userCABundles []client.ObjectKey + if hcp.Spec.AdditionalTrustBundle != nil { + userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.AdditionalTrustBundle.Name}) + } + if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.Proxy != nil && hcp.Spec.Configuration.Proxy.TrustedCA.Name != "" { + userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.Configuration.Proxy.TrustedCA.Name}) + } + + trustedCABundle := manifests.TrustedCABundleConfigMap(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, trustedCABundle, func() error { + return r.reconcileManagedTrustedCABundle(ctx, trustedCABundle, userCABundles) + }); err != nil { + return fmt.Errorf("failed to reconcile managed UserCA configMap: %w", err) + } + + if util.HCPOAuthEnabled(hcp) { + if err := r.reconcileOAuthCerts(ctx, hcp, p, createOrUpdate, rootCASecret, trustedCABundle); err != nil { + return err } + } + + if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { + machineConfigServerCert := manifests.MachineConfigServerCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, machineConfigServerCert, func() error { + return pki.ReconcileMachineConfigServerCert(machineConfigServerCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile machine config server cert secret: %w", err) + } + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureFileCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - azureFileCsiDriverControllerMetricsServingCert := manifests.AzureFileCsiDriverControllerMetricsServingCert(hcp.Namespace) + if err := r.reconcileOLMAndMiscCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - err = removeServiceCASecret(ctx, r.Client, azureFileCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } + if err := r.reconcileNetworkServingCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - if _, err := createOrUpdate(ctx, r, azureFileCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAzureFileCsiDriverControllerMetricsServingCertSecret(azureFileCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) + if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { + if hcp.Spec.Platform.Type != hyperv1.IBMCloudPlatform { + ignitionServerCert := manifests.IgnitionServerCertSecret(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, ignitionServerCert, func() error { + return pki.ReconcileIgnitionServerCertSecret(ignitionServerCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile azure file csi driver controller metrics serving cert: %w", err) + return fmt.Errorf("failed to reconcile ignition server cert: %w", err) } } } + switch hcp.Spec.Platform.Type { + case hyperv1.AWSPlatform: + if err := r.reconcileAWSPlatformCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } + case hyperv1.AzurePlatform: + if err := r.reconcileAzurePlatformCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } + } + return nil } diff --git a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go index f4f43ab5151..08eb8ae1e89 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go @@ -38,7 +38,9 @@ import ( "github.com/openshift/hypershift/support/releaseinfo/testutils" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/dockerv1client" + "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/reference" "github.com/openshift/hypershift/support/upsert" + "github.com/openshift/hypershift/support/util" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" configv1 "github.com/openshift/api/config/v1" @@ -54,6 +56,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/util/workqueue" + "k8s.io/utils/clock" + testingclock "k8s.io/utils/clock/testing" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" @@ -67,7 +71,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/yaml" + "github.com/docker/distribution" "github.com/go-logr/zapr" + "github.com/opencontainers/go-digest" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" ) @@ -2873,3 +2879,1409 @@ func TestRemoveCloudResources(t *testing.T) { }) } } +func TestReconcileEtcdStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondType string + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When etcd management type is Unmanaged, it should set EtcdAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Unmanaged, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: "EtcdRunning", + Message: "Etcd cluster is assumed to be running in unmanaged state", + ObservedGeneration: 3, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet is not found, it should set EtcdAvailable to False with NotFound reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdStatefulSetNotFoundReason, + ObservedGeneration: 5, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet exists with quorum, it should set EtcdAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](3), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 3, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + ObservedGeneration: 2, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet has no quorum, it should set EtcdAvailable to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](3), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 0, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdWaitingForQuorumReason, + ObservedGeneration: 4, + }, + }, + { + name: "When etcd management type is Managed and Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + expectError: true, + }, + { + name: "When etcd management type is Managed with RestoreSnapshotURL and StatefulSet has ready pods, it should set EtcdSnapshotRestored condition", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + Managed: &hyperv1.ManagedEtcdSpec{ + Storage: hyperv1.ManagedEtcdStorageSpec{ + RestoreSnapshotURL: []string{"https://example.com/snapshot"}, + }, + }, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](1), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 1, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", + Namespace: testNamespace, + Labels: map[string]string{ + "app": "etcd", + }, + }, + Status: corev1.PodStatus{ + InitContainerStatuses: []corev1.ContainerStatus{ + { + Name: "etcd-init", + Ready: true, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + ObservedGeneration: 6, + }, + }, + { + name: "When etcd management type is empty, it should set condition to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: "", + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 1, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileEtcdStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.EtcdAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(Equal(tc.expectedCondition.Message)) + } + + // For the restore snapshot test case, also verify EtcdSnapshotRestored condition + if tc.name == "When etcd management type is Managed with RestoreSnapshotURL and StatefulSet has ready pods, it should set EtcdSnapshotRestored condition" { + restoreCond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.EtcdSnapshotRestored)) + g.Expect(restoreCond).ToNot(BeNil()) + g.Expect(restoreCond.Status).To(Equal(metav1.ConditionTrue)) + g.Expect(restoreCond.Reason).To(Equal(hyperv1.AsExpectedReason)) + } + }) + } +} + +func TestReconcileKASStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When KAS deployment is not found, it should set KubeAPIServerAvailable to False with NotFound reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.NotFoundReason, + Message: "Kube APIServer deployment not found", + ObservedGeneration: 3, + }, + }, + { + name: "When KAS deployment exists and is Available, it should set KubeAPIServerAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentAvailable, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + Message: "Kube APIServer deployment is available", + ObservedGeneration: 5, + }, + }, + { + name: "When KAS deployment exists but Available condition is False, it should set KubeAPIServerAvailable to False with WaitingForAvailable reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 7, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentAvailable, + Status: corev1.ConditionFalse, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingForAvailableReason, + Message: "Waiting for Kube APIServer deployment to become available", + ObservedGeneration: 7, + }, + }, + { + name: "When KAS deployment exists but has no Available condition, it should set KubeAPIServerAvailable to False with StatusUnknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentProgressing, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 2, + }, + }, + { + name: "When KAS deployment exists with empty conditions, it should set KubeAPIServerAvailable to False with StatusUnknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{}, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 1, + }, + }, + { + name: "When Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileKASStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.KubeAPIServerAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(Equal(tc.expectedCondition.Message)) + } + }) + } +} + +func TestReconcileDegradedStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When no CPO-managed deployments exist, it should set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 2, + }, + }, + { + name: "When all CPO-managed deployments are fully available, it should set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 0, + }, + }, + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-controller-manager", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 0, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 3, + }, + }, + { + name: "When a single CPO-managed deployment has unavailable replicas, it should set Degraded to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 2, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionTrue, + Reason: "UnavailableReplicas", + Message: "kube-apiserver deployment has 2 unavailable replicas", + ObservedGeneration: 4, + }, + }, + { + name: "When multiple CPO-managed deployments have unavailable replicas, it should aggregate all errors in message", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 1, + }, + }, + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-controller-manager", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 3, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionTrue, + Reason: "UnavailableReplicas", + ObservedGeneration: 5, + }, + }, + { + name: "When deployments exist without the CPO managed-by label, it should ignore them and set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "unmanaged-deployment", + Namespace: testNamespace, + Labels: map[string]string{ + "app": "something-else", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 5, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 6, + }, + }, + { + name: "When List returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, client client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + return fmt.Errorf("simulated list error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileDegradedStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.HostedControlPlaneDegraded)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(ContainSubstring(tc.expectedCondition.Message)) + } + + // For the multi-deployment case, verify that both deployment names appear in the message + if tc.name == "When multiple CPO-managed deployments have unavailable replicas, it should aggregate all errors in message" { + g.Expect(cond.Message).To(ContainSubstring("kube-apiserver")) + g.Expect(cond.Message).To(ContainSubstring("kube-controller-manager")) + } + }) + } +} + +func TestReconcileInfrastructureStatusCondition(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + infraStatus infra.InfrastructureStatus + infraErr error + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + expectedEndpoint hyperv1.APIEndpoint + expectOAuthCallback bool + }{ + { + name: "When infrastructure is ready, it should set InfrastructureReady to True and populate endpoint", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "api.example.com", + APIPort: 6443, + KonnectivityHost: "konnectivity.example.com", + KonnectivityPort: 8091, + }, + expectedCondStatus: metav1.ConditionTrue, + expectedCondReason: hyperv1.AsExpectedReason, + expectedEndpoint: hyperv1.APIEndpoint{ + Host: "api.example.com", + Port: 6443, + }, + }, + { + name: "When infrastructure is not ready, it should set InfrastructureReady to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "", + APIPort: 0, + Message: "Load balancer pending", + }, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + name: "When infrastructure status returns error, it should set InfrastructureReady to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + infraErr: fmt.Errorf("failed to get infrastructure status"), + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.InfraStatusFailureReason, + }, + { + name: "When infrastructure is not ready with empty message, it should use default provisioning message", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "", + APIPort: 0, + }, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + r := &HostedControlPlaneReconciler{ + Client: fake.NewClientBuilder().WithScheme(api.Scheme).Build(), + Log: zapr.NewLogger(zaptest.NewLogger(t)), + reconcileInfrastructureStatus: func(ctx context.Context, hcp *hyperv1.HostedControlPlane) (infra.InfrastructureStatus, error) { + return tc.infraStatus, tc.infraErr + }, + } + + r.reconcileInfrastructureStatusCondition(t.Context(), tc.hcp) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.InfrastructureReady)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + + if tc.expectedCondStatus == metav1.ConditionTrue { + g.Expect(tc.hcp.Status.ControlPlaneEndpoint).To(Equal(tc.expectedEndpoint)) + } + + if tc.expectedCondStatus == metav1.ConditionFalse && tc.infraStatus.Message == "" { + g.Expect(cond.Message).To(Equal("Cluster infrastructure is still provisioning")) + } + if tc.expectedCondStatus == metav1.ConditionFalse && tc.infraStatus.Message != "" { + g.Expect(cond.Message).To(Equal(tc.infraStatus.Message)) + } + }) + } +} + +func TestReconcileExternalDNSStatusCondition(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + expectedMessage string + }{ + { + name: "When no external DNS hostname is configured, it should set ExternalDNSReachable to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Networking: hyperv1.ClusterNetworking{ + APIServer: &hyperv1.APIServerNetworking{ + Port: ptr.To[int32](6443), + }, + }, + }, + }, + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.StatusUnknownReason, + expectedMessage: "External DNS is not configured", + }, + { + name: "When HCP is private (no PublicZoneID), it should set ExternalDNSReachable to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Private, + }, + }, + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.LoadBalancer, + LoadBalancer: &hyperv1.LoadBalancerPublishingStrategy{ + Hostname: "api.example.com", + }, + }, + }, + }, + }, + }, + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.StatusUnknownReason, + expectedMessage: "External DNS is not configured", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + r := &HostedControlPlaneReconciler{ + Client: fake.NewClientBuilder().WithScheme(api.Scheme).Build(), + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + r.reconcileExternalDNSStatusCondition(t.Context(), tc.hcp) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.ExternalDNSReachable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + if tc.expectedMessage != "" { + g.Expect(cond.Message).To(Equal(tc.expectedMessage)) + } + }) + } +} + +func TestReconcileAvailabilityAndReadyStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedReady bool + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + }{ + { + name: "When no status conditions exist and no kubeconfig, it should set not ready with Unknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.StatusUnknownReason, + }, + { + name: "When infrastructure condition is False, it should propagate infrastructure failure", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeConfig: &hyperv1.KubeconfigSecretRef{ + Name: "admin-kubeconfig", + Key: "kubeconfig", + }, + Conditions: []metav1.Condition{ + { + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + }, + { + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + }, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + name: "When health check fails but no conditions are set, it should report KAS LB not reachable", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + // No service strategy for APIServer - health check returns error + Services: []hyperv1.ServicePublishingStrategyMapping{}, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeConfig: &hyperv1.KubeconfigSecretRef{ + Name: "admin-kubeconfig", + Key: "kubeconfig", + }, + Conditions: []metav1.Condition{ + { + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + { + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + }, + { + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + }, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.KASLoadBalancerNotReachableReason, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + objs := []client.Object{} + objs = append(objs, tc.existingObjects...) + + c := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(objs...). + Build() + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + r.reconcileAvailabilityAndReadyStatus(t.Context(), tc.hcp) + + g.Expect(tc.hcp.Status.Ready).To(Equal(tc.expectedReady)) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + }) + } +} + +func TestReconcileKubeadminPasswordStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedPasswordRef *corev1.LocalObjectReference + expectError bool + }{ + { + name: "When explicit OAuth config is specified, it should clear kubeadmin password status", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Configuration: &hyperv1.ClusterConfiguration{ + OAuth: &configv1.OAuthSpec{ + IdentityProviders: []configv1.IdentityProvider{ + { + Name: "test-idp", + IdentityProviderConfig: configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeOpenID, + }, + }, + }, + }, + }, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeadminPassword: &corev1.LocalObjectReference{ + Name: "old-kubeadmin-password", + }, + }, + }, + expectedPasswordRef: nil, + }, + { + name: "When no OAuth config and kubeadmin password secret exists, it should set password status", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kubeadmin-password", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + "password": []byte("test-password"), + }, + }, + }, + expectedPasswordRef: &corev1.LocalObjectReference{ + Name: "kubeadmin-password", + }, + }, + { + name: "When no OAuth config and kubeadmin password secret does not exist, it should leave password status nil", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + existingObjects: []client.Object{}, + expectedPasswordRef: nil, + }, + { + name: "When no OAuth config and Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileKubeadminPasswordStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + if tc.expectedPasswordRef == nil { + g.Expect(tc.hcp.Status.KubeadminPassword).To(BeNil()) + } else { + g.Expect(tc.hcp.Status.KubeadminPassword).ToNot(BeNil()) + g.Expect(tc.hcp.Status.KubeadminPassword.Name).To(Equal(tc.expectedPasswordRef.Name)) + } + }) + } +} + +// fakeVersionImageMetadataProvider is a simple test double for ImageMetadataProvider +// that returns deterministic results without contacting a registry. +type fakeVersionImageMetadataProvider struct { + fakeDigest string + fakeRef *reference.DockerImageReference + digestErr error +} + +func (f *fakeVersionImageMetadataProvider) ImageMetadata(_ context.Context, _ string, _ []byte) (*dockerv1client.DockerImageConfig, error) { + return &dockerv1client.DockerImageConfig{}, nil +} + +func (f *fakeVersionImageMetadataProvider) GetManifest(_ context.Context, _ string, _ []byte) (distribution.Manifest, error) { + return nil, nil +} + +func (f *fakeVersionImageMetadataProvider) GetDigest(_ context.Context, _ string, _ []byte) (digest.Digest, *reference.DockerImageReference, error) { + if f.digestErr != nil { + return "", nil, f.digestErr + } + return digest.Digest(f.fakeDigest), f.fakeRef, nil +} + +func (f *fakeVersionImageMetadataProvider) GetMetadata(_ context.Context, _ string, _ []byte) (*dockerv1client.DockerImageConfig, []distribution.Descriptor, distribution.BlobStore, error) { + return &dockerv1client.DockerImageConfig{}, nil, nil, nil +} + +func (f *fakeVersionImageMetadataProvider) GetOverride(_ context.Context, _ string, _ []byte) (*reference.DockerImageReference, error) { + return f.fakeRef, nil +} + +func TestReconcileControlPlaneVersionStatus(t *testing.T) { + testNamespace := "test-namespace" + fakeClock := testingclock.NewFakeClock(time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)) + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + digestErr error + expectError bool + expectedDesired configv1.Release + expectedState configv1.UpdateState + }{ + { + name: "When pull secret exists and components are listed successfully with first population, it should create Partial history entry", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + expectedDesired: configv1.Release{ + Version: "4.20.0", + Image: "quay.io/openshift-release-dev/ocp-release@sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + expectedState: configv1.PartialUpdate, + }, + { + name: "When pull secret is missing, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{}, + expectError: true, + }, + { + name: "When GetDigest fails, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + digestErr: fmt.Errorf("failed to resolve digest"), + expectError: true, + }, + { + name: "When component list fails, it should set partial version and return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + releaseImage := testutils.InitReleaseImageOrDie("4.20.0") + + resolvedRef := &reference.DockerImageReference{ + Registry: "quay.io", + Namespace: "openshift-release-dev", + Name: "ocp-release", + ID: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + } + + imgProvider := &fakeVersionImageMetadataProvider{ + fakeDigest: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + fakeRef: resolvedRef, + digestErr: tc.digestErr, + } + + clientBuilder := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...) + + // For the "component list fails" case, intercept the List call to fail + if tc.name == "When component list fails, it should set partial version and return error" { + clientBuilder = clientBuilder.WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, c client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if _, ok := list.(*hyperv1.ControlPlaneComponentList); ok { + return fmt.Errorf("simulated list error") + } + return c.List(ctx, list, opts...) + }, + }) + // Also need to add status subresource for patching + clientBuilder = clientBuilder.WithStatusSubresource(tc.hcp) + } + + c := clientBuilder.Build() + + // For the component list failure case, we need the HCP to exist in the fake client + if tc.name == "When component list fails, it should set partial version and return error" { + g.Expect(c.Create(t.Context(), tc.hcp)).To(Succeed()) + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + ImageMetadataProvider: imgProvider, + clock: fakeClock, + } + + originalHCP := tc.hcp.DeepCopy() + err := r.reconcileControlPlaneVersionStatus(t.Context(), tc.hcp, originalHCP, releaseImage) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + + // For the component list failure case, verify partial status was still populated + if tc.name == "When component list fails, it should set partial version and return error" { + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Version).To(Equal("4.20.0")) + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Image).ToNot(BeEmpty()) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History).ToNot(BeEmpty()) + } + return + } + g.Expect(err).ToNot(HaveOccurred()) + + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Version).To(Equal(tc.expectedDesired.Version)) + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Image).To(Equal(tc.expectedDesired.Image)) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History).ToNot(BeEmpty()) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History[0].State).To(Equal(tc.expectedState)) + }) + } +} + +// Compile-time assertion that fakeVersionImageMetadataProvider satisfies the interface. +var _ util.ImageMetadataProvider = &fakeVersionImageMetadataProvider{} + +// Compile-time assertion for clock interface used by tests. +var _ clock.Clock = &testingclock.FakeClock{} diff --git a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go index 60dd8db418a..4aeb258be96 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go +++ b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go @@ -137,6 +137,223 @@ func defaultIDPMappingMethods(identityProviders []configv1.IdentityProvider) []c return out } +func convertBasicAuthIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + basicAuthConfig := providerConfig.BasicAuth + if basicAuthConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.BasicAuthPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "BasicAuthPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{URL: basicAuthConfig.URL}, + } + if basicAuthConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if basicAuthConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if basicAuthConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertGitHubIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + githubConfig := providerConfig.GitHub + if githubConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GitHubIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GitHubIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: githubConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + Organizations: githubConfig.Organizations, + Teams: githubConfig.Teams, + Hostname: githubConfig.Hostname, + } + if githubConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} + +func convertGitLabIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + gitlabConfig := providerConfig.GitLab + if gitlabConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GitLabIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GitLabIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + URL: gitlabConfig.URL, + ClientID: gitlabConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + Legacy: new(bool), + } + if gitlabConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertGoogleIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + googleConfig := providerConfig.Google + if googleConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GoogleIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GoogleIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: googleConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + HostedDomain: googleConfig.HostedDomain, + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} + +func convertHTPasswdIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + if providerConfig.HTPasswd == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.HTPasswdPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "HTPasswdPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertKeystoneIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + keystoneConfig := providerConfig.Keystone + if keystoneConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.KeystonePasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "KeystonePasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{URL: keystoneConfig.URL}, + DomainName: keystoneConfig.DomainName, + UseKeystoneIdentity: true, + } + if keystoneConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if keystoneConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if keystoneConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertLDAPIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + ldapConfig := providerConfig.LDAP + if ldapConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.LDAPPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "LDAPPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + URL: ldapConfig.URL, + BindDN: ldapConfig.BindDN, + Insecure: ldapConfig.Insecure, + Attributes: osinv1.LDAPAttributeMapping{ + ID: ldapConfig.Attributes.ID, + PreferredUsername: ldapConfig.Attributes.PreferredUsername, + Name: ldapConfig.Attributes.Name, + Email: ldapConfig.Attributes.Email, + }, + } + if ldapConfig.BindPassword.Name != "" { + provider.BindPassword = configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), + }} + } + if ldapConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertOpenIDIDP(ctx context.Context, providerConfig *configv1.IdentityProviderConfig, configOverride *ConfigOverride, i int, idpVolumeMounts *IDPVolumeMountInfo, kclient crclient.Client, namespace string, skipKonnectivityDialer bool) (*idpData, error) { + openIDConfig := providerConfig.OpenID + if openIDConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + openIDProvider := &osinv1.OpenIDIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "OpenIDIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: openIDConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + ExtraScopes: openIDConfig.ExtraScopes, + ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, + } + if configOverride != nil { + openIDProvider.URLs = configOverride.URLs + openIDProvider.Claims = configOverride.Claims + } else { + urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) + if err != nil { + return nil, err + } + openIDProvider.URLs = *urls + var groups []string + if len(openIDConfig.Claims.Groups) > 0 { + groups = make([]string, len(openIDConfig.Claims.Groups)) + for i, group := range openIDConfig.Claims.Groups { + groups[i] = string(group) + } + } + openIDProvider.Claims = osinv1.OpenIDClaims{ + ID: []string{configv1.UserIDClaim}, + PreferredUsername: openIDConfig.Claims.PreferredUsername, + Name: openIDConfig.Claims.Name, + Email: openIDConfig.Claims.Email, + Groups: groups, + } + } + if len(openIDConfig.CA.Name) > 0 { + openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + data := &idpData{provider: openIDProvider, login: true} + if configOverride != nil && configOverride.Challenge != nil { + data.challenge = *configOverride.Challenge + } else { + challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow(ctx, kclient, openIDProvider.URLs.Token, openIDConfig.ClientID, namespace, openIDConfig.CA, openIDConfig.ClientSecret, skipKonnectivityDialer) + if err != nil { + return nil, fmt.Errorf("error attempting password grant flow: %v", err) + } + data.challenge = challengeFlowsAllowed + } + return data, nil +} + +func convertRequestHeaderIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + requestHeaderConfig := providerConfig.RequestHeader + if requestHeaderConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.RequestHeaderIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "RequestHeaderIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + LoginURL: requestHeaderConfig.LoginURL, + ChallengeURL: requestHeaderConfig.ChallengeURL, + ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), + ClientCommonNames: requestHeaderConfig.ClientCommonNames, + Headers: requestHeaderConfig.Headers, + PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, + NameHeaders: requestHeaderConfig.NameHeaders, + EmailHeaders: requestHeaderConfig.EmailHeaders, + } + return &idpData{ + provider: provider, + challenge: len(requestHeaderConfig.ChallengeURL) > 0, + login: len(requestHeaderConfig.LoginURL) > 0, + }, nil +} + func convertProviderConfigToIDPData( ctx context.Context, providerConfig *configv1.IdentityProviderConfig, @@ -147,289 +364,28 @@ func convertProviderConfigToIDPData( namespace string, skipKonnectivityDialer bool, ) (*idpData, error) { - const missingProviderFmt string = "type %s was specified, but its configuration is missing" - - data := &idpData{login: true} - switch providerConfig.Type { case configv1.IdentityProviderTypeBasicAuth: - basicAuthConfig := providerConfig.BasicAuth - if basicAuthConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.BasicAuthPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "BasicAuthPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: basicAuthConfig.URL, - }, - } - if basicAuthConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if basicAuthConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-key", corev1.TLSCertKey) - } - if basicAuthConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - - data.provider = provider - data.challenge = true - + return convertBasicAuthIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitHub: - githubConfig := providerConfig.GitHub - if githubConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.GitHubIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitHubIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: githubConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Organizations: githubConfig.Organizations, - Teams: githubConfig.Teams, - Hostname: githubConfig.Hostname, - } - if githubConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = false - + return convertGitHubIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitLab: - gitlabConfig := providerConfig.GitLab - if gitlabConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.GitLabIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitLabIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - - URL: gitlabConfig.URL, - ClientID: gitlabConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Legacy: new(bool), // we require OIDC for GitLab now - } - if gitlabConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - - data.provider = provider - data.challenge = true - + return convertGitLabIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGoogle: - googleConfig := providerConfig.Google - if googleConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - data.provider = &osinv1.GoogleIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GoogleIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: googleConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - HostedDomain: googleConfig.HostedDomain, - } - data.challenge = false - + return convertGoogleIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeHTPasswd: - if providerConfig.HTPasswd == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - data.provider = &osinv1.HTPasswdPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "HTPasswdPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), - } - data.challenge = true - + return convertHTPasswdIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeKeystone: - keystoneConfig := providerConfig.Keystone - if keystoneConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.KeystonePasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "KeystonePasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: keystoneConfig.URL, - }, - DomainName: keystoneConfig.DomainName, - UseKeystoneIdentity: true, // force use of keystone ID - } - if keystoneConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if keystoneConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) - } - if keystoneConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - data.provider = provider - data.challenge = true - + return convertKeystoneIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeLDAP: - ldapConfig := providerConfig.LDAP - if ldapConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.LDAPPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "LDAPPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - URL: ldapConfig.URL, - BindDN: ldapConfig.BindDN, - Insecure: ldapConfig.Insecure, - Attributes: osinv1.LDAPAttributeMapping{ - ID: ldapConfig.Attributes.ID, - PreferredUsername: ldapConfig.Attributes.PreferredUsername, - Name: ldapConfig.Attributes.Name, - Email: ldapConfig.Attributes.Email, - }, - } - if ldapConfig.BindPassword.Name != "" { - provider.BindPassword = configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), - }, - } - } - if ldapConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = true - + return convertLDAPIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeOpenID: - openIDConfig := providerConfig.OpenID - if openIDConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - openIDProvider := &osinv1.OpenIDIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "OpenIDIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: openIDConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - ExtraScopes: openIDConfig.ExtraScopes, - ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, - } - //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) - if configOverride != nil { - openIDProvider.URLs = configOverride.URLs - openIDProvider.Claims = configOverride.Claims - } else { - urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) - if err != nil { - return nil, err - } - openIDProvider.URLs = *urls - var groups []string - if len(openIDConfig.Claims.Groups) > 0 { - groups = make([]string, len(openIDConfig.Claims.Groups)) - for i, group := range openIDConfig.Claims.Groups { - groups[i] = string(group) - } - } - openIDProvider.Claims = osinv1.OpenIDClaims{ - // There is no longer a user-facing setting for ID as it is considered unsafe - ID: []string{configv1.UserIDClaim}, - PreferredUsername: openIDConfig.Claims.PreferredUsername, - Name: openIDConfig.Claims.Name, - Email: openIDConfig.Claims.Email, - Groups: groups, - } - } - if len(openIDConfig.CA.Name) > 0 { - openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = openIDProvider - - if configOverride != nil && configOverride.Challenge != nil { - data.challenge = *configOverride.Challenge - } else { - // openshift CR validating in kube-apiserver does not allow - // challenge-redirecting IdPs to be configured with OIDC so it is safe - // to allow challenge-issuing flow if it's available on the OIDC side - challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( - ctx, - kclient, - openIDProvider.URLs.Token, - openIDConfig.ClientID, - namespace, - openIDConfig.CA, - openIDConfig.ClientSecret, - skipKonnectivityDialer, - ) - if err != nil { - return nil, fmt.Errorf("error attempting password grant flow: %v", err) - } - data.challenge = challengeFlowsAllowed - } + return convertOpenIDIDP(ctx, providerConfig, configOverride, i, idpVolumeMounts, kclient, namespace, skipKonnectivityDialer) case configv1.IdentityProviderTypeRequestHeader: - requestHeaderConfig := providerConfig.RequestHeader - if requestHeaderConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - data.provider = &osinv1.RequestHeaderIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "RequestHeaderIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - LoginURL: requestHeaderConfig.LoginURL, - ChallengeURL: requestHeaderConfig.ChallengeURL, - ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), - ClientCommonNames: requestHeaderConfig.ClientCommonNames, - Headers: requestHeaderConfig.Headers, - PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, - NameHeaders: requestHeaderConfig.NameHeaders, - EmailHeaders: requestHeaderConfig.EmailHeaders, - } - data.challenge = len(requestHeaderConfig.ChallengeURL) > 0 - data.login = len(requestHeaderConfig.LoginURL) > 0 - + return convertRequestHeaderIDP(providerConfig, i, idpVolumeMounts) default: return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) - } // switch - - return data, nil + } } const ( diff --git a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go index f92a638fb7e..d96a6ce199e 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go @@ -210,6 +210,721 @@ func TestOpenIDProviderConversion(t *testing.T) { } } +func TestDefaultIDPMappingMethods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []configv1.IdentityProvider + expected []configv1.MappingMethodType + }{ + { + name: "When no identity providers are given, it should return an empty slice", + input: []configv1.IdentityProvider{}, + expected: []configv1.MappingMethodType{}, + }, + { + name: "When mapping method is empty, it should default to claim", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodClaim}, + }, + { + name: "When mapping method is already set, it should preserve it", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: configv1.MappingMethodAdd}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodAdd}, + }, + { + name: "When multiple providers have mixed mapping methods, it should default only empty ones", + input: []configv1.IdentityProvider{ + {Name: "a", MappingMethod: ""}, + {Name: "b", MappingMethod: configv1.MappingMethodLookup}, + {Name: "c", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{ + configv1.MappingMethodClaim, + configv1.MappingMethodLookup, + configv1.MappingMethodClaim, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := defaultIDPMappingMethods(tc.input) + g.Expect(result).To(HaveLen(len(tc.expected))) + for i, idp := range result { + g.Expect(idp.MappingMethod).To(Equal(tc.expected[i])) + } + }) + } +} + +func newTestVolumeMountInfo() *IDPVolumeMountInfo { + return &IDPVolumeMountInfo{ + Container: oauthContainerMain().Name, + VolumeMounts: podspec.VolumeMounts{ + oauthContainerMain().Name: podspec.ContainerMounts{}, + }, + } +} + +func TestConvertBasicAuthIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When BasicAuth config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + }, + expectErr: true, + }, + { + name: "When BasicAuth config has only URL, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When BasicAuth config has CA and TLS certs, it should set volume mount paths", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + CA: configv1.ConfigMapNameReference{Name: "my-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "my-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "my-key"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertBasicAuthIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("configuration is missing")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.BasicAuthPasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://auth.example.com")) + if tc.config.BasicAuth.CA.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + } + if tc.config.BasicAuth.TLSClientCert.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + } + if tc.config.BasicAuth.TLSClientKey.Name != "" { + g.Expect(provider.RemoteConnectionInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + } + }) + } +} + +func TestConvertGitHubIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitHub config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + }, + expectErr: true, + }, + { + name: "When GitHub config is provided with organizations, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + Organizations: []string{"org1", "org2"}, + Teams: []string{"org1/team1"}, + Hostname: "github.example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + { + name: "When GitHub config has a CA, it should set the CA path", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + CA: configv1.ConfigMapNameReference{Name: "gh-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitHubIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitHubIdentityProvider) + g.Expect(provider.ClientID).To(Equal(tc.config.GitHub.ClientID)) + g.Expect(provider.Organizations).To(Equal(tc.config.GitHub.Organizations)) + if tc.config.GitHub.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertGitLabIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitLab config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + }, + expectErr: true, + }, + { + name: "When GitLab config is provided, it should produce a valid provider with legacy set", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + GitLab: &configv1.GitLabIdentityProvider{ + URL: "https://gitlab.example.com", + ClientID: "gl-client", + ClientSecret: configv1.SecretNameReference{Name: "gl-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitLabIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitLabIdentityProvider) + g.Expect(provider.URL).To(Equal("https://gitlab.example.com")) + g.Expect(provider.Legacy).ToNot(BeNil()) + g.Expect(*provider.Legacy).To(BeFalse()) + } + }) + } +} + +func TestConvertGoogleIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Google config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + }, + expectErr: true, + }, + { + name: "When Google config is provided with hosted domain, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + Google: &configv1.GoogleIdentityProvider{ + ClientID: "google-client", + ClientSecret: configv1.SecretNameReference{Name: "google-secret"}, + HostedDomain: "example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGoogleIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GoogleIdentityProvider) + g.Expect(provider.ClientID).To(Equal("google-client")) + g.Expect(provider.HostedDomain).To(Equal("example.com")) + } + }) + } +} + +func TestConvertHTPasswdIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When HTPasswd config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + }, + expectErr: true, + }, + { + name: "When HTPasswd config is provided, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertHTPasswdIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.HTPasswdPasswordIdentityProvider) + g.Expect(provider.File).To(ContainSubstring("idp_secret_0_file-data")) + } + }) + } +} + +func TestConvertKeystoneIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Keystone config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + }, + expectErr: true, + }, + { + name: "When Keystone config is provided with TLS certs, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + Keystone: &configv1.KeystoneIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://keystone.example.com", + CA: configv1.ConfigMapNameReference{Name: "ks-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "ks-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "ks-key"}, + }, + DomainName: "my-domain", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertKeystoneIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.KeystonePasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://keystone.example.com")) + g.Expect(provider.DomainName).To(Equal("my-domain")) + g.Expect(provider.UseKeystoneIdentity).To(BeTrue()) + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + }) + } +} + +func TestConvertLDAPIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When LDAP config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + }, + expectErr: true, + }, + { + name: "When LDAP config is provided with bind password, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + BindDN: "cn=admin,dc=example,dc=com", + BindPassword: configv1.SecretNameReference{Name: "ldap-bind-pw"}, + Insecure: true, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + PreferredUsername: []string{"uid"}, + Name: []string{"cn"}, + Email: []string{"mail"}, + }, + CA: configv1.ConfigMapNameReference{Name: "ldap-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When LDAP config has no bind password, it should not set the bind password field", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + Insecure: false, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertLDAPIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.LDAPPasswordIdentityProvider) + g.Expect(provider.URL).To(Equal(tc.config.LDAP.URL)) + g.Expect(provider.BindDN).To(Equal(tc.config.LDAP.BindDN)) + g.Expect(provider.Insecure).To(Equal(tc.config.LDAP.Insecure)) + g.Expect(provider.Attributes.ID).To(Equal(tc.config.LDAP.Attributes.ID)) + if tc.config.LDAP.BindPassword.Name != "" { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(ContainSubstring("idp_secret_0_bind-password")) + } else { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(BeEmpty()) + } + if tc.config.LDAP.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertRequestHeaderIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When RequestHeader config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + }, + expectErr: true, + }, + { + name: "When RequestHeader has login and challenge URLs, it should set login and challenge to true", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + LoginURL: "https://login.example.com", + ChallengeURL: "https://challenge.example.com", + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + ClientCommonNames: []string{"client1"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When RequestHeader has empty login and challenge URLs, it should set login and challenge to false", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: false, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertRequestHeaderIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.RequestHeaderIdentityProvider) + g.Expect(provider.Headers).To(Equal(tc.config.RequestHeader.Headers)) + g.Expect(provider.ClientCA).To(ContainSubstring("idp_cm_0_ca")) + } + }) + } +} + +func TestConvertProviderConfigToIDPData_UnsupportedType(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + config := &configv1.IdentityProviderConfig{ + Type: "UnsupportedType", + } + _, err := convertProviderConfigToIDPData(t.Context(), config, nil, 0, vmi, nil, "test", true) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("is not supported")) +} + +func TestConvertProviderConfigToIDPData_Routing(t *testing.T) { + t.Parallel() + tests := []struct { + name string + providerType configv1.IdentityProviderType + config *configv1.IdentityProviderConfig + expectedKind string + }{ + { + name: "When type is BasicAuth, it should route to BasicAuth converter", + providerType: configv1.IdentityProviderTypeBasicAuth, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{URL: "https://example.com"}, + }, + }, + expectedKind: "BasicAuthPasswordIdentityProvider", + }, + { + name: "When type is GitHub, it should route to GitHub converter", + providerType: configv1.IdentityProviderTypeGitHub, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "id", + ClientSecret: configv1.SecretNameReference{Name: "s"}, + }, + }, + expectedKind: "GitHubIdentityProvider", + }, + { + name: "When type is HTPasswd, it should route to HTPasswd converter", + providerType: configv1.IdentityProviderTypeHTPasswd, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd"}, + }, + }, + expectedKind: "HTPasswdPasswordIdentityProvider", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertProviderConfigToIDPData(t.Context(), tc.config, nil, 0, vmi, nil, "test", true) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.provider.GetObjectKind().GroupVersionKind().Kind).To(Equal(tc.expectedKind)) + }) + } +} + +func TestIDPVolumeMountInfo_ConfigMapPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.ConfigMapPath(2, "my-configmap", "ca", "ca.crt") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_cm_2_ca/ca.crt")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-cm-2-ca")) + g.Expect(vmi.Volumes[0].ConfigMap.Name).To(Equal("my-configmap")) +} + +func TestIDPVolumeMountInfo_SecretPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.SecretPath(3, "my-secret", "client-secret", "clientSecret") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_secret_3_client-secret/clientSecret")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-secret-3-client-secret")) + g.Expect(vmi.Volumes[0].Secret.SecretName).To(Equal("my-secret")) + g.Expect(vmi.Volumes[0].Secret.DefaultMode).ToNot(BeNil()) + g.Expect(*vmi.Volumes[0].Secret.DefaultMode).To(Equal(int32(0640))) +} + +func TestIsValidURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + rawurl string + optional bool + expected bool + }{ + { + name: "When URL is empty and optional, it should return true", + rawurl: "", + optional: true, + expected: true, + }, + { + name: "When URL is empty and required, it should return false", + rawurl: "", + optional: false, + expected: false, + }, + { + name: "When URL is a valid https URL, it should return true", + rawurl: "https://example.com/auth", + optional: false, + expected: true, + }, + { + name: "When URL uses http scheme, it should return false", + rawurl: "http://example.com/auth", + optional: false, + expected: false, + }, + { + name: "When URL has a fragment, it should return false", + rawurl: "https://example.com/auth#fragment", + optional: false, + expected: false, + }, + { + name: "When URL has no host, it should return false", + rawurl: "https://", + optional: false, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + g.Expect(isValidURL(tc.rawurl, tc.optional)).To(Equal(tc.expected)) + }) + } +} + func TestTransportForCARef(t *testing.T) { namespace := "test" diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go index a13fc371951..cfb173c7403 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go @@ -49,6 +49,30 @@ func newPodMonitorWithTLS(name, namespace string, tlsCfg *prometheusoperatorv1.S } } +func newCertVolumeTestContext(namespace string, scheme *runtime.Scheme, objects ...runtime.Object) component.WorkloadContext { + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(objects...).Build() + return component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, + } +} + +func assertCertVolumeCount(t *testing.T, cpContext component.WorkloadContext, namespace string, expectedVolumes, expectedMounts int) ([]corev1.Volume, []corev1.VolumeMount) { + t.Helper() + volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(volumes) != expectedVolumes { + t.Errorf("expected %d volumes, got %d", expectedVolumes, len(volumes)) + } + if len(mounts) != expectedMounts { + t.Errorf("expected %d mounts, got %d", expectedMounts, len(mounts)) + } + return volumes, mounts +} + func TestCertVolumesFromMonitors(t *testing.T) { t.Parallel() @@ -81,7 +105,6 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }, }) - // sm2 references the same root-ca and metrics-client — should deduplicate. sm2 := newServiceMonitorWithTLS("kube-controller-manager", namespace, &prometheusoperatorv1.TLSConfig{ SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ CA: prometheusoperatorv1.SecretOrConfigMap{ @@ -103,27 +126,9 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 2, 2) - // Expect 2 unique volumes: metrics-client (Secret), root-ca (ConfigMap). - if len(volumes) != 2 { - t.Errorf("expected 2 volumes, got %d", len(volumes)) - } - if len(mounts) != 2 { - t.Errorf("expected 2 mounts, got %d", len(mounts)) - } - - // Verify sorted order: metrics-client before root-ca. if len(volumes) >= 2 { if volumes[0].Name != "metrics-client" { t.Errorf("expected first volume to be metrics-client, got %s", volumes[0].Name) @@ -158,21 +163,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -207,21 +199,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -242,45 +221,15 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, } - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When no monitors exist, it should return empty slices", func(t *testing.T) { t.Parallel() - fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When Secret volumes are generated, they should have optional=true", func(t *testing.T) { @@ -297,21 +246,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + volumes, mounts := assertCertVolumeCount(t, cpContext, namespace, 1, 1) secret := volumes[0].VolumeSource.Secret if secret == nil { @@ -321,9 +257,6 @@ func TestCertVolumesFromMonitors(t *testing.T) { t.Error("expected Secret volume to have optional=true") } - if len(mounts) != 1 { - t.Fatalf("expected 1 mount, got %d", len(mounts)) - } expectedPath := certBasePath + "/test-secret" if mounts[0].MountPath != expectedPath { t.Errorf("expected mount path %q, got %q", expectedPath, mounts[0].MountPath) @@ -344,21 +277,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -381,27 +301,12 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + cpContext := newCertVolumeTestContext(namespace, scheme, pm) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } if volumes[0].Name != "root-ca" { t.Errorf("expected volume name root-ca, got %s", volumes[0].Name) } - if len(mounts) != 1 { - t.Fatalf("expected 1 mount, got %d", len(mounts)) - } }) t.Run("When PodMonitor has no TLS config, it should be skipped", func(t *testing.T) { @@ -409,23 +314,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { pm := newPodMonitorWithTLS("cluster-autoscaler", namespace, nil) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, pm) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When ServiceMonitor and PodMonitor share the same CA ref, it should deduplicate", func(t *testing.T) { @@ -450,21 +340,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm, pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Errorf("expected 1 deduplicated volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm, pm) + assertCertVolumeCount(t, cpContext, namespace, 1, 1) }) } diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go index 325a2d37e33..6068e51c5ef 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go @@ -31,7 +31,6 @@ const ( ) func adaptScrapeConfig(cpContext component.WorkloadContext, cm *corev1.ConfigMap) error { - log := logr.FromContextOrDiscard(cpContext) namespace := cpContext.HCP.Namespace endpointResolverURL := fmt.Sprintf("https://endpoint-resolver.%s.svc", namespace) @@ -43,164 +42,193 @@ func adaptScrapeConfig(cpContext component.WorkloadContext, cm *corev1.ConfigMap }, } - // Process ServiceMonitors. + smComponents, err := componentsFromServiceMonitors(cpContext, namespace) + if err != nil { + return err + } + cfg.Components = append(cfg.Components, smComponents...) + + pmComponents, err := componentsFromPodMonitors(cpContext, namespace) + if err != nil { + return err + } + cfg.Components = append(cfg.Components, pmComponents...) + + sort.Slice(cfg.Components, func(i, j int) bool { + return cfg.Components[i].Name < cfg.Components[j].Name + }) + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal scrape config: %w", err) + } + + if cm.Data == nil { + cm.Data = make(map[string]string) + } + cm.Data["config.yaml"] = string(data) + return nil +} + +func componentsFromServiceMonitors(cpContext component.WorkloadContext, namespace string) ([]metricsproxybin.ComponentFileConfig, error) { + log := logr.FromContextOrDiscard(cpContext) + smList := &prometheusoperatorv1.ServiceMonitorList{} if err := cpContext.Client.List(cpContext, smList, client.InNamespace(namespace)); err != nil { - return fmt.Errorf("failed to list ServiceMonitors: %w", err) + return nil, fmt.Errorf("failed to list ServiceMonitors: %w", err) } + var components []metricsproxybin.ComponentFileConfig for i := range smList.Items { - sm := &smList.Items[i] - if len(sm.Spec.Endpoints) == 0 { - continue - } - ep := sm.Spec.Endpoints[0] - - serviceName, podSelector, err := findServiceForMonitor(cpContext, namespace, sm.Name, sm.Spec.Selector) - if err != nil { - log.V(4).Info("skipping ServiceMonitor: service not found", "serviceMonitor", sm.Name, "error", err) - continue - } - if len(podSelector) == 0 { - log.V(4).Info("skipping ServiceMonitor: service has no pod selector", "serviceMonitor", sm.Name, "service", serviceName) - continue + comp, ok := componentFromServiceMonitor(cpContext, log, namespace, &smList.Items[i]) + if ok { + components = append(components, comp) } + } + return components, nil +} - portRef := ep.Port - if portRef == "" && ep.TargetPort != nil { - portRef = ep.TargetPort.String() - } - if portRef == "" { - log.V(4).Info("skipping ServiceMonitor: no port reference", "serviceMonitor", sm.Name) - continue - } +func componentFromServiceMonitor(cpContext component.WorkloadContext, log logr.Logger, namespace string, sm *prometheusoperatorv1.ServiceMonitor) (metricsproxybin.ComponentFileConfig, bool) { + if len(sm.Spec.Endpoints) == 0 { + return metricsproxybin.ComponentFileConfig{}, false + } + ep := sm.Spec.Endpoints[0] - port, err := resolveServicePort(cpContext, namespace, serviceName, portRef, podSelector) - if err != nil { - log.V(4).Info("skipping ServiceMonitor: port not resolvable", "serviceMonitor", sm.Name, "port", portRef, "error", err) - continue - } + serviceName, podSelector, err := findServiceForMonitor(cpContext, namespace, sm.Name, sm.Spec.Selector) + if err != nil { + log.V(4).Info("skipping ServiceMonitor: service not found", "serviceMonitor", sm.Name, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } + if len(podSelector) == 0 { + log.V(4).Info("skipping ServiceMonitor: service has no pod selector", "serviceMonitor", sm.Name, "service", serviceName) + return metricsproxybin.ComponentFileConfig{}, false + } - scheme := "http" - if ep.Scheme != nil { - scheme = ep.Scheme.String() - } + portRef := ep.Port + if portRef == "" && ep.TargetPort != nil { + portRef = ep.TargetPort.String() + } + if portRef == "" { + log.V(4).Info("skipping ServiceMonitor: no port reference", "serviceMonitor", sm.Name) + return metricsproxybin.ComponentFileConfig{}, false + } - metricsPath := "/metrics" - if ep.Path != "" { - metricsPath = ep.Path - } + port, err := resolveServicePort(cpContext, namespace, serviceName, portRef, podSelector) + if err != nil { + log.V(4).Info("skipping ServiceMonitor: port not resolvable", "serviceMonitor", sm.Name, "port", portRef, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } - var serverName string - if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { - serverName = *ep.TLSConfig.ServerName - } + comp := metricsproxybin.ComponentFileConfig{ + Selector: podSelector, + MetricsPort: port, + MetricsPath: endpointMetricsPath(ep.Path), + MetricsScheme: endpointScheme(ep.Scheme), + TLSServerName: safeTLSServerName(ep.TLSConfig), + } - comp := metricsproxybin.ComponentFileConfig{ - Selector: podSelector, - MetricsPort: port, - MetricsPath: metricsPath, - MetricsScheme: scheme, - TLSServerName: serverName, - } + if ep.TLSConfig != nil { + populateTLSFilePaths(&comp, ep.TLSConfig.CA, ep.TLSConfig.Cert, ep.TLSConfig.KeySecret) + } - if ep.TLSConfig != nil { - comp.CAFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.CA) - comp.CertFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.Cert) - if ep.TLSConfig.KeySecret != nil { - comp.KeyFile = filepath.Join(certBasePath, ep.TLSConfig.KeySecret.Name, ep.TLSConfig.KeySecret.Key) - } - } + comp.Name = sm.Name + populateMetricsLabelsFromAnnotations(&comp, sm.Annotations) + return comp, true +} - comp.Name = sm.Name - populateMetricsLabelsFromAnnotations(&comp, sm.Annotations) - cfg.Components = append(cfg.Components, comp) - } +func componentsFromPodMonitors(cpContext component.WorkloadContext, namespace string) ([]metricsproxybin.ComponentFileConfig, error) { + log := logr.FromContextOrDiscard(cpContext) - // Process PodMonitors. pmList := &prometheusoperatorv1.PodMonitorList{} if err := cpContext.Client.List(cpContext, pmList, client.InNamespace(namespace)); err != nil { - return fmt.Errorf("failed to list PodMonitors: %w", err) + return nil, fmt.Errorf("failed to list PodMonitors: %w", err) } + var components []metricsproxybin.ComponentFileConfig for i := range pmList.Items { - pm := &pmList.Items[i] - if len(pm.Spec.PodMetricsEndpoints) == 0 { - continue + comp, ok := componentFromPodMonitor(cpContext, log, namespace, &pmList.Items[i]) + if ok { + components = append(components, comp) } - ep := pm.Spec.PodMetricsEndpoints[0] + } + return components, nil +} - portName := "" - if ep.Port != nil { - portName = *ep.Port - } - if portName == "" { - log.V(4).Info("skipping PodMonitor: no port name", "podMonitor", pm.Name) - continue - } +func componentFromPodMonitor(cpContext component.WorkloadContext, log logr.Logger, namespace string, pm *prometheusoperatorv1.PodMonitor) (metricsproxybin.ComponentFileConfig, bool) { + if len(pm.Spec.PodMetricsEndpoints) == 0 { + return metricsproxybin.ComponentFileConfig{}, false + } + ep := pm.Spec.PodMetricsEndpoints[0] - // Resolve the port number from a Pod matching the PodMonitor's selector. - podSelector, err := metav1.LabelSelectorAsSelector(&pm.Spec.Selector) - if err != nil { - log.V(4).Info("skipping PodMonitor: invalid selector", "podMonitor", pm.Name, "error", err) - continue - } - port, err := resolvePodPort(cpContext, namespace, podSelector, portName) - if err != nil { - log.V(4).Info("skipping PodMonitor: port not resolvable", "podMonitor", pm.Name, "port", portName, "error", err) - continue - } + portName := "" + if ep.Port != nil { + portName = *ep.Port + } + if portName == "" { + log.V(4).Info("skipping PodMonitor: no port name", "podMonitor", pm.Name) + return metricsproxybin.ComponentFileConfig{}, false + } - scheme := "http" - if ep.Scheme != nil { - scheme = ep.Scheme.String() - } + podSelector, err := metav1.LabelSelectorAsSelector(&pm.Spec.Selector) + if err != nil { + log.V(4).Info("skipping PodMonitor: invalid selector", "podMonitor", pm.Name, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } + port, err := resolvePodPort(cpContext, namespace, podSelector, portName) + if err != nil { + log.V(4).Info("skipping PodMonitor: port not resolvable", "podMonitor", pm.Name, "port", portName, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } - metricsPath := "/metrics" - if ep.Path != "" { - metricsPath = ep.Path - } + var serverName string + if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { + serverName = *ep.TLSConfig.ServerName + } - var serverName string - if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { - serverName = *ep.TLSConfig.ServerName - } + comp := metricsproxybin.ComponentFileConfig{ + Selector: pm.Spec.Selector.MatchLabels, + MetricsPort: port, + MetricsPath: endpointMetricsPath(ep.Path), + MetricsScheme: endpointScheme(ep.Scheme), + TLSServerName: serverName, + } - comp := metricsproxybin.ComponentFileConfig{ - Selector: pm.Spec.Selector.MatchLabels, - MetricsPort: port, - MetricsPath: metricsPath, - MetricsScheme: scheme, - TLSServerName: serverName, - } + if ep.TLSConfig != nil { + populateTLSFilePaths(&comp, ep.TLSConfig.CA, ep.TLSConfig.Cert, ep.TLSConfig.KeySecret) + } - if ep.TLSConfig != nil { - comp.CAFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.CA) - comp.CertFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.Cert) - if ep.TLSConfig.KeySecret != nil { - comp.KeyFile = filepath.Join(certBasePath, ep.TLSConfig.KeySecret.Name, ep.TLSConfig.KeySecret.Key) - } - } + comp.Name = pm.Name + populateMetricsLabelsFromAnnotations(&comp, pm.Annotations) + return comp, true +} - comp.Name = pm.Name - populateMetricsLabelsFromAnnotations(&comp, pm.Annotations) - cfg.Components = append(cfg.Components, comp) +func endpointScheme(scheme *prometheusoperatorv1.Scheme) string { + if scheme != nil { + return scheme.String() } + return "http" +} - sort.Slice(cfg.Components, func(i, j int) bool { - return cfg.Components[i].Name < cfg.Components[j].Name - }) +func endpointMetricsPath(path string) string { + if path != "" { + return path + } + return "/metrics" +} - data, err := yaml.Marshal(cfg) - if err != nil { - return fmt.Errorf("failed to marshal scrape config: %w", err) +func safeTLSServerName(tlsCfg *prometheusoperatorv1.TLSConfig) string { + if tlsCfg != nil && tlsCfg.ServerName != nil { + return *tlsCfg.ServerName } + return "" +} - if cm.Data == nil { - cm.Data = make(map[string]string) +func populateTLSFilePaths(comp *metricsproxybin.ComponentFileConfig, ca, cert prometheusoperatorv1.SecretOrConfigMap, keySecret *corev1.SecretKeySelector) { + comp.CAFile = certFilePathFromSecretOrConfigMap(ca) + comp.CertFile = certFilePathFromSecretOrConfigMap(cert) + if keySecret != nil { + comp.KeyFile = filepath.Join(certBasePath, keySecret.Name, keySecret.Key) } - cm.Data["config.yaml"] = string(data) - return nil } // certFilePathFromSecretOrConfigMap returns the file path for a volume-mounted diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go index 7652cbc8208..c613de30fa0 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/yaml" + "github.com/go-logr/logr" prometheusoperatorv1 "github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring/v1" ) @@ -624,6 +625,557 @@ func TestAdaptScrapeConfigPodMonitorMissingPod(t *testing.T) { }) } +func TestEndpointScheme(t *testing.T) { + t.Parallel() + tests := []struct { + name string + scheme *prometheusoperatorv1.Scheme + expected string + }{ + { + name: "When scheme is nil, it should default to http", + scheme: nil, + expected: "http", + }, + { + name: "When scheme is https, it should return https", + scheme: (*prometheusoperatorv1.Scheme)(ptr.To("https")), + expected: "https", + }, + { + name: "When scheme is http, it should return http", + scheme: (*prometheusoperatorv1.Scheme)(ptr.To("http")), + expected: "http", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(endpointScheme(tc.scheme)).To(Equal(tc.expected)) + }) + } +} + +func TestEndpointMetricsPath(t *testing.T) { + t.Parallel() + tests := []struct { + name string + path string + expected string + }{ + { + name: "When path is empty, it should default to /metrics", + path: "", + expected: "/metrics", + }, + { + name: "When path is provided, it should return that path", + path: "/custom/metrics", + expected: "/custom/metrics", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(endpointMetricsPath(tc.path)).To(Equal(tc.expected)) + }) + } +} + +func TestSafeTLSServerName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + tlsCfg *prometheusoperatorv1.TLSConfig + expected string + }{ + { + name: "When TLS config is nil, it should return empty string", + tlsCfg: nil, + expected: "", + }, + { + name: "When ServerName is nil, it should return empty string", + tlsCfg: &prometheusoperatorv1.TLSConfig{ + SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ + ServerName: nil, + }, + }, + expected: "", + }, + { + name: "When ServerName is set, it should return the server name", + tlsCfg: &prometheusoperatorv1.TLSConfig{ + SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ + ServerName: ptr.To("kube-apiserver"), + }, + }, + expected: "kube-apiserver", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(safeTLSServerName(tc.tlsCfg)).To(Equal(tc.expected)) + }) + } +} + +func TestCertFilePathFromSecretOrConfigMap(t *testing.T) { + t.Parallel() + tests := []struct { + name string + ref prometheusoperatorv1.SecretOrConfigMap + expected string + }{ + { + name: "When both Secret and ConfigMap are nil, it should return empty string", + ref: prometheusoperatorv1.SecretOrConfigMap{}, + expected: "", + }, + { + name: "When ConfigMap is set, it should return configmap path", + ref: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + expected: certBasePath + "/root-ca/ca.crt", + }, + { + name: "When Secret is set, it should return secret path", + ref: prometheusoperatorv1.SecretOrConfigMap{ + Secret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "my-secret"}, + Key: "tls.crt", + }, + }, + expected: certBasePath + "/my-secret/tls.crt", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(certFilePathFromSecretOrConfigMap(tc.ref)).To(Equal(tc.expected)) + }) + } +} + +func TestPopulateTLSFilePaths(t *testing.T) { + t.Parallel() + tests := []struct { + name string + ca prometheusoperatorv1.SecretOrConfigMap + cert prometheusoperatorv1.SecretOrConfigMap + keySecret *corev1.SecretKeySelector + expectedCA string + expectedCert string + expectedKey string + }{ + { + name: "When all TLS refs are empty, it should set empty paths", + ca: prometheusoperatorv1.SecretOrConfigMap{}, + cert: prometheusoperatorv1.SecretOrConfigMap{}, + keySecret: nil, + expectedCA: "", + expectedCert: "", + expectedKey: "", + }, + { + name: "When CA, cert, and key are all provided, it should set all paths", + ca: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + cert: prometheusoperatorv1.SecretOrConfigMap{ + Secret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "metrics-client"}, + Key: "tls.crt", + }, + }, + keySecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "metrics-client"}, + Key: "tls.key", + }, + expectedCA: certBasePath + "/root-ca/ca.crt", + expectedCert: certBasePath + "/metrics-client/tls.crt", + expectedKey: certBasePath + "/metrics-client/tls.key", + }, + { + name: "When only CA is provided, it should set only CA path", + ca: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + cert: prometheusoperatorv1.SecretOrConfigMap{}, + keySecret: nil, + expectedCA: certBasePath + "/root-ca/ca.crt", + expectedCert: "", + expectedKey: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + comp := &metricsproxybin.ComponentFileConfig{} + populateTLSFilePaths(comp, tc.ca, tc.cert, tc.keySecret) + g.Expect(comp.CAFile).To(Equal(tc.expectedCA)) + g.Expect(comp.CertFile).To(Equal(tc.expectedCert)) + g.Expect(comp.KeyFile).To(Equal(tc.expectedKey)) + }) + } +} + +func TestPopulateMetricsLabelsFromAnnotations(t *testing.T) { + t.Parallel() + tests := []struct { + name string + annotations map[string]string + expectedJob string + expectedNamespace string + expectedService string + expectedEndpoint string + }{ + { + name: "When no annotations are present, it should leave fields empty", + annotations: map[string]string{}, + expectedJob: "", + expectedNamespace: "", + expectedService: "", + expectedEndpoint: "", + }, + { + name: "When annotations is nil, it should leave fields empty", + annotations: nil, + expectedJob: "", + expectedNamespace: "", + expectedService: "", + expectedEndpoint: "", + }, + { + name: "When all metrics annotations are present, it should populate all fields", + annotations: map[string]string{ + "hypershift.openshift.io/metrics-job": "kube-apiserver", + "hypershift.openshift.io/metrics-namespace": "openshift-kube-apiserver", + "hypershift.openshift.io/metrics-service": "kube-apiserver", + "hypershift.openshift.io/metrics-endpoint": "https", + }, + expectedJob: "kube-apiserver", + expectedNamespace: "openshift-kube-apiserver", + expectedService: "kube-apiserver", + expectedEndpoint: "https", + }, + { + name: "When only some annotations are present, it should populate only matching fields", + annotations: map[string]string{ + "hypershift.openshift.io/metrics-job": "etcd", + "hypershift.openshift.io/metrics-namespace": "openshift-etcd", + }, + expectedJob: "etcd", + expectedNamespace: "openshift-etcd", + expectedService: "", + expectedEndpoint: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + comp := &metricsproxybin.ComponentFileConfig{} + populateMetricsLabelsFromAnnotations(comp, tc.annotations) + g.Expect(comp.MetricsJob).To(Equal(tc.expectedJob)) + g.Expect(comp.MetricsNamespace).To(Equal(tc.expectedNamespace)) + g.Expect(comp.MetricsService).To(Equal(tc.expectedService)) + g.Expect(comp.MetricsEndpoint).To(Equal(tc.expectedEndpoint)) + }) + } +} + +func TestFindServiceForMonitor(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + tests := []struct { + name string + objects []runtime.Object + smName string + selector metav1.LabelSelector + expectedSvc string + expectedErr bool + expectedSelKeys []string + }{ + { + name: "When service exists with same name as ServiceMonitor, it should find it by direct lookup", + objects: []runtime.Object{ + newService("kube-apiserver", namespace, "kube-apiserver", "client", 6443), + }, + smName: "kube-apiserver", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "kube-apiserver"}}, + expectedSvc: "kube-apiserver", + expectedErr: false, + expectedSelKeys: []string{"app"}, + }, + { + name: "When no service matches name but label selector matches, it should fall back to label selector", + objects: []runtime.Object{ + newService("etcd-client", namespace, "etcd", "metrics", 2381), + }, + smName: "etcd", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "etcd"}}, + expectedSvc: "etcd-client", + expectedErr: false, + expectedSelKeys: []string{"app"}, + }, + { + name: "When no service exists at all, it should return an error", + objects: []runtime.Object{}, + smName: "missing", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "missing"}}, + expectedErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(tc.objects...).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + svcName, podSelector, err := findServiceForMonitor(cpContext, namespace, tc.smName, tc.selector) + if tc.expectedErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(svcName).To(Equal(tc.expectedSvc)) + for _, key := range tc.expectedSelKeys { + g.Expect(podSelector).To(HaveKey(key)) + } + } + }) + } +} + +func TestResolveServicePort(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + tests := []struct { + name string + objects []runtime.Object + serviceName string + portName string + podSelector map[string]string + expectedPort int32 + expectedErr bool + }{ + { + name: "When service has a numeric targetPort, it should resolve directly", + objects: []runtime.Object{ + newService("my-svc", namespace, "app", "https", 6443), + }, + serviceName: "my-svc", + portName: "https", + podSelector: map[string]string{"app": "app"}, + expectedPort: 6443, + expectedErr: false, + }, + { + name: "When service has a named targetPort, it should resolve from a pod", + objects: []runtime.Object{ + newServiceWithTargetPort("my-svc", namespace, "app", "client", 443, intstr.FromString("https")), + newPodWithLabels("my-pod", namespace, "https", 6443, map[string]string{"app": "app"}), + }, + serviceName: "my-svc", + portName: "client", + podSelector: map[string]string{"app": "app"}, + expectedPort: 6443, + expectedErr: false, + }, + { + name: "When service port name does not match, it should return an error", + objects: []runtime.Object{ + newService("my-svc", namespace, "app", "metrics", 8080), + }, + serviceName: "my-svc", + portName: "nonexistent", + podSelector: map[string]string{"app": "app"}, + expectedErr: true, + }, + { + name: "When service does not exist, it should return an error", + objects: []runtime.Object{}, + serviceName: "missing-svc", + portName: "https", + podSelector: map[string]string{"app": "app"}, + expectedErr: true, + }, + { + name: "When service has no targetPort, it should default to the service port value", + objects: []runtime.Object{ + &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "my-svc", Namespace: namespace}, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{"app": "app"}, + Ports: []corev1.ServicePort{{Name: "web", Port: 8443}}, + }, + }, + }, + serviceName: "my-svc", + portName: "web", + podSelector: map[string]string{"app": "app"}, + expectedPort: 8443, + expectedErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(tc.objects...).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + port, err := resolveServicePort(cpContext, namespace, tc.serviceName, tc.portName, tc.podSelector) + if tc.expectedErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(port).To(Equal(tc.expectedPort)) + } + }) + } +} + +func TestComponentFromServiceMonitor_EdgeCases(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + t.Run("When ServiceMonitor has no endpoints, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + sm := &prometheusoperatorv1.ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "empty-sm", Namespace: namespace}, + Spec: prometheusoperatorv1.ServiceMonitorSpec{}, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromServiceMonitor(cpContext, logr.Discard(), namespace, sm) + g.Expect(ok).To(BeFalse()) + }) + + t.Run("When endpoint has no port reference, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + sm := &prometheusoperatorv1.ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "no-port-sm", Namespace: namespace}, + Spec: prometheusoperatorv1.ServiceMonitorSpec{ + Endpoints: []prometheusoperatorv1.Endpoint{{}}, + Selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "test"}}, + }, + } + svc := newService("no-port-sm", namespace, "test", "metrics", 8080) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(svc).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromServiceMonitor(cpContext, logr.Discard(), namespace, sm) + g.Expect(ok).To(BeFalse()) + }) +} + +func TestComponentFromPodMonitor_EdgeCases(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + t.Run("When PodMonitor has no endpoints, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + pm := &prometheusoperatorv1.PodMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "empty-pm", Namespace: namespace}, + Spec: prometheusoperatorv1.PodMonitorSpec{}, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromPodMonitor(cpContext, logr.Discard(), namespace, pm) + g.Expect(ok).To(BeFalse()) + }) + + t.Run("When PodMonitor endpoint has nil port, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + pm := &prometheusoperatorv1.PodMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "nil-port-pm", Namespace: namespace}, + Spec: prometheusoperatorv1.PodMonitorSpec{ + PodMetricsEndpoints: []prometheusoperatorv1.PodMetricsEndpoint{ + {Port: nil}, + }, + Selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "test"}}, + }, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromPodMonitor(cpContext, logr.Discard(), namespace, pm) + g.Expect(ok).To(BeFalse()) + }) +} + func TestResolvePodPort(t *testing.T) { t.Parallel() diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go index f41a705612e..d4e12f63669 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go @@ -147,291 +147,322 @@ func convertProviderConfigToIDPData( namespace string, skipKonnectivityDialer bool, ) (*idpData, error) { - const missingProviderFmt string = "type %s was specified, but its configuration is missing" - - data := &idpData{login: true} - switch providerConfig.Type { case configv1.IdentityProviderTypeBasicAuth: - basicAuthConfig := providerConfig.BasicAuth - if basicAuthConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.BasicAuthPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "BasicAuthPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: basicAuthConfig.URL, - }, - } - if basicAuthConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if basicAuthConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-key", corev1.TLSCertKey) - } - if basicAuthConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - - data.provider = provider - data.challenge = true - + return convertBasicAuthIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitHub: - githubConfig := providerConfig.GitHub - if githubConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.GitHubIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitHubIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: githubConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Organizations: githubConfig.Organizations, - Teams: githubConfig.Teams, - Hostname: githubConfig.Hostname, - } - if githubConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = false - + return convertGitHubIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitLab: - gitlabConfig := providerConfig.GitLab - if gitlabConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + return convertGitLabIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeGoogle: + return convertGoogleIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeHTPasswd: + return convertHTPasswdIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeKeystone: + return convertKeystoneIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeLDAP: + return convertLDAPIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeOpenID: + return convertOpenIDIDP(ctx, providerConfig, configOverride, i, idpVolumeMounts, kclient, namespace, skipKonnectivityDialer) + case configv1.IdentityProviderTypeRequestHeader: + return convertRequestHeaderIDP(providerConfig, i, idpVolumeMounts) + default: + return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) + } +} - provider := &osinv1.GitLabIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitLabIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, +func convertBasicAuthIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + basicAuthConfig := providerConfig.BasicAuth + if basicAuthConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.BasicAuthPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "BasicAuthPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{ + URL: basicAuthConfig.URL, + }, + } + if basicAuthConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if basicAuthConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if basicAuthConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - URL: gitlabConfig.URL, - ClientID: gitlabConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, +func convertGitHubIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + githubConfig := providerConfig.GitHub + if githubConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GitHubIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GitHubIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: githubConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - Legacy: new(bool), // we require OIDC for GitLab now - } - if gitlabConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - - data.provider = provider - data.challenge = true - - case configv1.IdentityProviderTypeGoogle: - googleConfig := providerConfig.Google - if googleConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + Organizations: githubConfig.Organizations, + Teams: githubConfig.Teams, + Hostname: githubConfig.Hostname, + } + if githubConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} - data.provider = &osinv1.GoogleIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GoogleIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: googleConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, +func convertGitLabIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + gitlabConfig := providerConfig.GitLab + if gitlabConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GitLabIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GitLabIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + URL: gitlabConfig.URL, + ClientID: gitlabConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - HostedDomain: googleConfig.HostedDomain, - } - data.challenge = false - - case configv1.IdentityProviderTypeHTPasswd: - if providerConfig.HTPasswd == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + Legacy: new(bool), // we require OIDC for GitLab now + } + if gitlabConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - data.provider = &osinv1.HTPasswdPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "HTPasswdPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), +func convertGoogleIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + googleConfig := providerConfig.Google + if googleConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GoogleIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GoogleIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: googleConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), - } - data.challenge = true - - case configv1.IdentityProviderTypeKeystone: - keystoneConfig := providerConfig.Keystone - if keystoneConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + HostedDomain: googleConfig.HostedDomain, + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} - provider := &osinv1.KeystonePasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "KeystonePasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: keystoneConfig.URL, - }, - DomainName: keystoneConfig.DomainName, - UseKeystoneIdentity: true, // force use of keystone ID - } - if keystoneConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if keystoneConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) - } - if keystoneConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - data.provider = provider - data.challenge = true +func convertHTPasswdIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + if providerConfig.HTPasswd == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.HTPasswdPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "HTPasswdPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - case configv1.IdentityProviderTypeLDAP: - ldapConfig := providerConfig.LDAP - if ldapConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } +func convertKeystoneIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + keystoneConfig := providerConfig.Keystone + if keystoneConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.KeystonePasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "KeystonePasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{ + URL: keystoneConfig.URL, + }, + DomainName: keystoneConfig.DomainName, + UseKeystoneIdentity: true, // force use of keystone ID + } + if keystoneConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if keystoneConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if keystoneConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - provider := &osinv1.LDAPPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "LDAPPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - URL: ldapConfig.URL, - BindDN: ldapConfig.BindDN, - Insecure: ldapConfig.Insecure, - Attributes: osinv1.LDAPAttributeMapping{ - ID: ldapConfig.Attributes.ID, - PreferredUsername: ldapConfig.Attributes.PreferredUsername, - Name: ldapConfig.Attributes.Name, - Email: ldapConfig.Attributes.Email, +func convertLDAPIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + ldapConfig := providerConfig.LDAP + if ldapConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.LDAPPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "LDAPPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + URL: ldapConfig.URL, + BindDN: ldapConfig.BindDN, + Insecure: ldapConfig.Insecure, + Attributes: osinv1.LDAPAttributeMapping{ + ID: ldapConfig.Attributes.ID, + PreferredUsername: ldapConfig.Attributes.PreferredUsername, + Name: ldapConfig.Attributes.Name, + Email: ldapConfig.Attributes.Email, + }, + } + if ldapConfig.BindPassword.Name != "" { + provider.BindPassword = configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), }, } - if ldapConfig.BindPassword.Name != "" { - provider.BindPassword = configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), - }, - } - } - if ldapConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = true + } + if ldapConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - case configv1.IdentityProviderTypeOpenID: - openIDConfig := providerConfig.OpenID - if openIDConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } +func convertOpenIDIDP( + ctx context.Context, + providerConfig *configv1.IdentityProviderConfig, + configOverride *ConfigOverride, + i int, + idpVolumeMounts *IDPVolumeMountInfo, + kclient crclient.Reader, + namespace string, + skipKonnectivityDialer bool, +) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + openIDConfig := providerConfig.OpenID + if openIDConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } - openIDProvider := &osinv1.OpenIDIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "OpenIDIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: openIDConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, + openIDProvider := &osinv1.OpenIDIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "OpenIDIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: openIDConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - ExtraScopes: openIDConfig.ExtraScopes, - ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, - } - //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) - if configOverride != nil { - openIDProvider.URLs = configOverride.URLs - openIDProvider.Claims = configOverride.Claims - } else { - urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) - if err != nil { - return nil, err - } - openIDProvider.URLs = *urls - var groups []string - if len(openIDConfig.Claims.Groups) > 0 { - groups = make([]string, len(openIDConfig.Claims.Groups)) - for i, group := range openIDConfig.Claims.Groups { - groups[i] = string(group) - } - } - openIDProvider.Claims = osinv1.OpenIDClaims{ - // There is no longer a user-facing setting for ID as it is considered unsafe - ID: []string{configv1.UserIDClaim}, - PreferredUsername: openIDConfig.Claims.PreferredUsername, - Name: openIDConfig.Claims.Name, - Email: openIDConfig.Claims.Email, - Groups: groups, - } - } - if len(openIDConfig.CA.Name) > 0 { - openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = openIDProvider - - if configOverride != nil && configOverride.Challenge != nil { - data.challenge = *configOverride.Challenge - } else { - // openshift CR validating in kube-apiserver does not allow - // challenge-redirecting IdPs to be configured with OIDC so it is safe - // to allow challenge-issuing flow if it's available on the OIDC side - challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( - ctx, - kclient, - openIDProvider.URLs.Token, - openIDConfig.ClientID, - namespace, - openIDConfig.CA, - openIDConfig.ClientSecret, - skipKonnectivityDialer, - ) - if err != nil { - return nil, fmt.Errorf("error attempting password grant flow: %v", err) + }, + ExtraScopes: openIDConfig.ExtraScopes, + ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, + } + //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) + if configOverride != nil { + openIDProvider.URLs = configOverride.URLs + openIDProvider.Claims = configOverride.Claims + } else { + urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) + if err != nil { + return nil, err + } + openIDProvider.URLs = *urls + var groups []string + if len(openIDConfig.Claims.Groups) > 0 { + groups = make([]string, len(openIDConfig.Claims.Groups)) + for i, group := range openIDConfig.Claims.Groups { + groups[i] = string(group) } - data.challenge = challengeFlowsAllowed } - case configv1.IdentityProviderTypeRequestHeader: - requestHeaderConfig := providerConfig.RequestHeader - if requestHeaderConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + openIDProvider.Claims = osinv1.OpenIDClaims{ + // There is no longer a user-facing setting for ID as it is considered unsafe + ID: []string{configv1.UserIDClaim}, + PreferredUsername: openIDConfig.Claims.PreferredUsername, + Name: openIDConfig.Claims.Name, + Email: openIDConfig.Claims.Email, + Groups: groups, } - data.provider = &osinv1.RequestHeaderIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "RequestHeaderIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - LoginURL: requestHeaderConfig.LoginURL, - ChallengeURL: requestHeaderConfig.ChallengeURL, - ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), - ClientCommonNames: requestHeaderConfig.ClientCommonNames, - Headers: requestHeaderConfig.Headers, - PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, - NameHeaders: requestHeaderConfig.NameHeaders, - EmailHeaders: requestHeaderConfig.EmailHeaders, - } - data.challenge = len(requestHeaderConfig.ChallengeURL) > 0 - data.login = len(requestHeaderConfig.LoginURL) > 0 - - default: - return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) - } // switch + } + if len(openIDConfig.CA.Name) > 0 { + openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + data := &idpData{provider: openIDProvider, login: true} + if configOverride != nil && configOverride.Challenge != nil { + data.challenge = *configOverride.Challenge + } else { + // openshift CR validating in kube-apiserver does not allow + // challenge-redirecting IdPs to be configured with OIDC so it is safe + // to allow challenge-issuing flow if it's available on the OIDC side + challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( + ctx, + kclient, + openIDProvider.URLs.Token, + openIDConfig.ClientID, + namespace, + openIDConfig.CA, + openIDConfig.ClientSecret, + skipKonnectivityDialer, + ) + if err != nil { + return nil, fmt.Errorf("error attempting password grant flow: %v", err) + } + data.challenge = challengeFlowsAllowed + } return data, nil } +func convertRequestHeaderIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + requestHeaderConfig := providerConfig.RequestHeader + if requestHeaderConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.RequestHeaderIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "RequestHeaderIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + LoginURL: requestHeaderConfig.LoginURL, + ChallengeURL: requestHeaderConfig.ChallengeURL, + ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), + ClientCommonNames: requestHeaderConfig.ClientCommonNames, + Headers: requestHeaderConfig.Headers, + PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, + NameHeaders: requestHeaderConfig.NameHeaders, + EmailHeaders: requestHeaderConfig.EmailHeaders, + } + return &idpData{ + provider: provider, + challenge: len(requestHeaderConfig.ChallengeURL) > 0, + login: len(requestHeaderConfig.LoginURL) > 0, + }, nil +} + const ( konnectivityClientDataCertKey = "tls.crt" konnectivityClientDataKey = "tls.key" diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go index 78b458a6c96..24adb544a93 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go @@ -210,6 +210,721 @@ func TestOpenIDProviderConversion(t *testing.T) { } } +func TestDefaultIDPMappingMethods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []configv1.IdentityProvider + expected []configv1.MappingMethodType + }{ + { + name: "When no identity providers are given, it should return an empty slice", + input: []configv1.IdentityProvider{}, + expected: []configv1.MappingMethodType{}, + }, + { + name: "When mapping method is empty, it should default to claim", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodClaim}, + }, + { + name: "When mapping method is already set, it should preserve it", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: configv1.MappingMethodAdd}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodAdd}, + }, + { + name: "When multiple providers have mixed mapping methods, it should default only empty ones", + input: []configv1.IdentityProvider{ + {Name: "a", MappingMethod: ""}, + {Name: "b", MappingMethod: configv1.MappingMethodLookup}, + {Name: "c", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{ + configv1.MappingMethodClaim, + configv1.MappingMethodLookup, + configv1.MappingMethodClaim, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := defaultIDPMappingMethods(tc.input) + g.Expect(result).To(HaveLen(len(tc.expected))) + for i, idp := range result { + g.Expect(idp.MappingMethod).To(Equal(tc.expected[i])) + } + }) + } +} + +func newTestVolumeMountInfo() *IDPVolumeMountInfo { + return &IDPVolumeMountInfo{ + Container: ComponentName, + VolumeMounts: podspec.VolumeMounts{ + ComponentName: podspec.ContainerMounts{}, + }, + } +} + +func TestConvertBasicAuthIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When BasicAuth config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + }, + expectErr: true, + }, + { + name: "When BasicAuth config has only URL, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When BasicAuth config has CA and TLS certs, it should set volume mount paths", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + CA: configv1.ConfigMapNameReference{Name: "my-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "my-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "my-key"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertBasicAuthIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("configuration is missing")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.BasicAuthPasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://auth.example.com")) + if tc.config.BasicAuth.CA.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + } + if tc.config.BasicAuth.TLSClientCert.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + } + if tc.config.BasicAuth.TLSClientKey.Name != "" { + g.Expect(provider.RemoteConnectionInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + } + }) + } +} + +func TestConvertGitHubIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitHub config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + }, + expectErr: true, + }, + { + name: "When GitHub config is provided with organizations, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + Organizations: []string{"org1", "org2"}, + Teams: []string{"org1/team1"}, + Hostname: "github.example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + { + name: "When GitHub config has a CA, it should set the CA path", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + CA: configv1.ConfigMapNameReference{Name: "gh-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitHubIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitHubIdentityProvider) + g.Expect(provider.ClientID).To(Equal(tc.config.GitHub.ClientID)) + g.Expect(provider.Organizations).To(Equal(tc.config.GitHub.Organizations)) + if tc.config.GitHub.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertGitLabIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitLab config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + }, + expectErr: true, + }, + { + name: "When GitLab config is provided, it should produce a valid provider with legacy set", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + GitLab: &configv1.GitLabIdentityProvider{ + URL: "https://gitlab.example.com", + ClientID: "gl-client", + ClientSecret: configv1.SecretNameReference{Name: "gl-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitLabIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitLabIdentityProvider) + g.Expect(provider.URL).To(Equal("https://gitlab.example.com")) + g.Expect(provider.Legacy).ToNot(BeNil()) + g.Expect(*provider.Legacy).To(BeFalse()) + } + }) + } +} + +func TestConvertGoogleIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Google config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + }, + expectErr: true, + }, + { + name: "When Google config is provided with hosted domain, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + Google: &configv1.GoogleIdentityProvider{ + ClientID: "google-client", + ClientSecret: configv1.SecretNameReference{Name: "google-secret"}, + HostedDomain: "example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGoogleIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GoogleIdentityProvider) + g.Expect(provider.ClientID).To(Equal("google-client")) + g.Expect(provider.HostedDomain).To(Equal("example.com")) + } + }) + } +} + +func TestConvertHTPasswdIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When HTPasswd config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + }, + expectErr: true, + }, + { + name: "When HTPasswd config is provided, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertHTPasswdIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.HTPasswdPasswordIdentityProvider) + g.Expect(provider.File).To(ContainSubstring("idp_secret_0_file-data")) + } + }) + } +} + +func TestConvertKeystoneIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Keystone config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + }, + expectErr: true, + }, + { + name: "When Keystone config is provided with TLS certs, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + Keystone: &configv1.KeystoneIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://keystone.example.com", + CA: configv1.ConfigMapNameReference{Name: "ks-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "ks-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "ks-key"}, + }, + DomainName: "my-domain", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertKeystoneIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.KeystonePasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://keystone.example.com")) + g.Expect(provider.DomainName).To(Equal("my-domain")) + g.Expect(provider.UseKeystoneIdentity).To(BeTrue()) + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + }) + } +} + +func TestConvertLDAPIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When LDAP config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + }, + expectErr: true, + }, + { + name: "When LDAP config is provided with bind password, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + BindDN: "cn=admin,dc=example,dc=com", + BindPassword: configv1.SecretNameReference{Name: "ldap-bind-pw"}, + Insecure: true, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + PreferredUsername: []string{"uid"}, + Name: []string{"cn"}, + Email: []string{"mail"}, + }, + CA: configv1.ConfigMapNameReference{Name: "ldap-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When LDAP config has no bind password, it should not set the bind password field", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + Insecure: false, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertLDAPIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.LDAPPasswordIdentityProvider) + g.Expect(provider.URL).To(Equal(tc.config.LDAP.URL)) + g.Expect(provider.BindDN).To(Equal(tc.config.LDAP.BindDN)) + g.Expect(provider.Insecure).To(Equal(tc.config.LDAP.Insecure)) + g.Expect(provider.Attributes.ID).To(Equal(tc.config.LDAP.Attributes.ID)) + if tc.config.LDAP.BindPassword.Name != "" { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(ContainSubstring("idp_secret_0_bind-password")) + } else { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(BeEmpty()) + } + if tc.config.LDAP.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertRequestHeaderIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When RequestHeader config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + }, + expectErr: true, + }, + { + name: "When RequestHeader has login and challenge URLs, it should set login and challenge to true", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + LoginURL: "https://login.example.com", + ChallengeURL: "https://challenge.example.com", + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + ClientCommonNames: []string{"client1"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When RequestHeader has empty login and challenge URLs, it should set login and challenge to false", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: false, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertRequestHeaderIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.RequestHeaderIdentityProvider) + g.Expect(provider.Headers).To(Equal(tc.config.RequestHeader.Headers)) + g.Expect(provider.ClientCA).To(ContainSubstring("idp_cm_0_ca")) + } + }) + } +} + +func TestConvertProviderConfigToIDPData_UnsupportedType(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + config := &configv1.IdentityProviderConfig{ + Type: "UnsupportedType", + } + _, err := convertProviderConfigToIDPData(t.Context(), config, nil, 0, vmi, nil, "test", true) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("is not supported")) +} + +func TestConvertProviderConfigToIDPData_Routing(t *testing.T) { + t.Parallel() + tests := []struct { + name string + providerType configv1.IdentityProviderType + config *configv1.IdentityProviderConfig + expectedKind string + }{ + { + name: "When type is BasicAuth, it should route to BasicAuth converter", + providerType: configv1.IdentityProviderTypeBasicAuth, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{URL: "https://example.com"}, + }, + }, + expectedKind: "BasicAuthPasswordIdentityProvider", + }, + { + name: "When type is GitHub, it should route to GitHub converter", + providerType: configv1.IdentityProviderTypeGitHub, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "id", + ClientSecret: configv1.SecretNameReference{Name: "s"}, + }, + }, + expectedKind: "GitHubIdentityProvider", + }, + { + name: "When type is HTPasswd, it should route to HTPasswd converter", + providerType: configv1.IdentityProviderTypeHTPasswd, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd"}, + }, + }, + expectedKind: "HTPasswdPasswordIdentityProvider", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertProviderConfigToIDPData(t.Context(), tc.config, nil, 0, vmi, nil, "test", true) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.provider.GetObjectKind().GroupVersionKind().Kind).To(Equal(tc.expectedKind)) + }) + } +} + +func TestIDPVolumeMountInfo_ConfigMapPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.ConfigMapPath(2, "my-configmap", "ca", "ca.crt") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_cm_2_ca/ca.crt")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-cm-2-ca")) + g.Expect(vmi.Volumes[0].ConfigMap.Name).To(Equal("my-configmap")) +} + +func TestIDPVolumeMountInfo_SecretPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.SecretPath(3, "my-secret", "client-secret", "clientSecret") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_secret_3_client-secret/clientSecret")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-secret-3-client-secret")) + g.Expect(vmi.Volumes[0].Secret.SecretName).To(Equal("my-secret")) + g.Expect(vmi.Volumes[0].Secret.DefaultMode).ToNot(BeNil()) + g.Expect(*vmi.Volumes[0].Secret.DefaultMode).To(Equal(int32(0640))) +} + +func TestIsValidURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + rawurl string + optional bool + expected bool + }{ + { + name: "When URL is empty and optional, it should return true", + rawurl: "", + optional: true, + expected: true, + }, + { + name: "When URL is empty and required, it should return false", + rawurl: "", + optional: false, + expected: false, + }, + { + name: "When URL is a valid https URL, it should return true", + rawurl: "https://example.com/auth", + optional: false, + expected: true, + }, + { + name: "When URL uses http scheme, it should return false", + rawurl: "http://example.com/auth", + optional: false, + expected: false, + }, + { + name: "When URL has a fragment, it should return false", + rawurl: "https://example.com/auth#fragment", + optional: false, + expected: false, + }, + { + name: "When URL has no host, it should return false", + rawurl: "https://", + optional: false, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + g.Expect(isValidURL(tc.rawurl, tc.optional)).To(Equal(tc.expected)) + }) + } +} + func TestTransportForCARef(t *testing.T) { namespace := "test" diff --git a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go index 89077b4daa1..3f25380fba6 100644 --- a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go +++ b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go @@ -343,31 +343,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } if !hcp.DeletionTimestamp.IsZero() { - // Delete admission policies during cluster deletion to allow HCCO cleanup operations for ARO HCP - if hcp.Spec.Platform.Type == hyperv1.AzurePlatform { - registryConfigManagementStateAdmissionPolicy := registry.AdmissionPolicy{Name: registry.AdmissionPolicyNameManagementState} - // During cluster deletion, delete the admission policy and its binding to allow CIRO cleanup - log.Info("Cluster is being deleted, deleting registry management state admission policy and binding to allow cleanup") - - // Delete binding first to avoid dangling reference - binding := manifests.ValidatingAdmissionPolicyBinding(fmt.Sprintf("%s-binding", registryConfigManagementStateAdmissionPolicy.Name)) - _, err := k8sutil.DeleteIfNeeded(ctx, r.client, binding) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicyBinding %s: %v", binding.Name, err) - } - - // Delete policy - vap := manifests.ValidatingAdmissionPolicy(registryConfigManagementStateAdmissionPolicy.Name) - if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, vap); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicy %s: %v", vap.Name, err) - } - } - - if shouldCleanupCloudResources(hcp) { - log.Info("Cleaning up hosted cluster cloud resources") - return r.destroyCloudResources(ctx, hcp) - } - return ctrl.Result{}, nil + return r.reconcileDeletion(ctx, log, hcp) } if isPaused, duration := util.IsReconciliationPaused(log, hcp.Spec.PausedUntil); isPaused { @@ -453,21 +429,184 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile rbac: %w", err)) } + errs = append(errs, r.reconcileRegistryAndIngress(ctx, hcp, log)...) + + errs = append(errs, r.reconcileAPIServicesAndOAuth(ctx, hcp, log, releaseImage)...) + + errs = append(errs, r.reconcileNetworkingAndSecrets(ctx, hcp, log, pullSecret)...) + + log.Info("reconciling olm resources") + errs = append(errs, r.reconcileOLM(ctx, hcp, pullSecret)...) + + errs = append(errs, r.reconcileStorageAndMisc(ctx, log, hcp, releaseImage)...) + r.cleanupLegacyResources(ctx, log, hcp, releaseImage, &errs) + + errs = append(errs, r.reconcilePlatformSpecificResources(ctx, log, hcp, releaseImage)...) + + if result, err := r.reconcileClusterRecovery(ctx, log, hcp, errs); err != nil || result.RequeueAfter > 0 { + return result, utilerrors.NewAggregate(append(errs, err)) + } + + return ctrl.Result{}, utilerrors.NewAggregate(errs) +} + +func (r *reconciler) reconcileStorageAndMisc(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + + log.Info("reconciling kubelet configs") + if err := r.reconcileKubeletConfig(ctx); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile kubelet config: %w", err)) + } + + if hostedcontrolplane.IsStorageAndCSIManaged(hcp) { + log.Info("reconciling storage resources") + errs = append(errs, r.reconcileStorage(ctx, hcp)...) + + log.Info("reconciling node level csi configuration") + if err := r.reconcileCSIDriver(ctx, hcp, releaseImage); err != nil { + errs = append(errs, err) + } + } + + recyclerServiceAccount := manifests.RecyclerServiceAccount() + if _, err := r.CreateOrUpdate(ctx, r.client, recyclerServiceAccount, func() error { + return nil + }); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile pv recycler service account: %w", err)) + } + + log.Info("reconciling observed configuration") + errs = append(errs, r.reconcileObservedConfiguration(ctx, hcp)...) + + errs = append(errs, r.ensureGuestAdmissionWebhooksAreValid(ctx)) + return errs +} + +func (r *reconciler) cleanupLegacyResources(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage, errs *[]error) { + if !r.isClusterVersionUpdated(ctx, releaseImage.Version()) { + return + } + deleteDNSOperatorDeploymentOnce.Do(func() { + dnsOperatorDeployment := manifests.DNSOperatorDeployment() + log.Info("removing any existing DNS operator deployment") + if err := r.uncachedClient.Delete(ctx, dnsOperatorDeployment); err != nil && !apierrors.IsNotFound(err) { + *errs = append(*errs, err) + } + }) + deleteCVORemovedResourcesOnce.Do(func() { + resources := cvo.ResourcesToRemove(hcp.Spec.Platform.Type) + for _, resource := range resources { + log.Info("removing existing resources", "resource", resource) + if err := r.uncachedClient.Delete(ctx, resource); err != nil && !apierrors.IsNotFound(err) { + *errs = append(*errs, err) + } + } + }) +} + +func (r *reconciler) reconcileDeletion(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane) (ctrl.Result, error) { + if hcp.Spec.Platform.Type == hyperv1.AzurePlatform { + registryConfigManagementStateAdmissionPolicy := registry.AdmissionPolicy{Name: registry.AdmissionPolicyNameManagementState} + log.Info("Cluster is being deleted, deleting registry management state admission policy and binding to allow cleanup") + + binding := manifests.ValidatingAdmissionPolicyBinding(fmt.Sprintf("%s-binding", registryConfigManagementStateAdmissionPolicy.Name)) + if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, binding); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicyBinding %s: %v", binding.Name, err) + } + + vap := manifests.ValidatingAdmissionPolicy(registryConfigManagementStateAdmissionPolicy.Name) + if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, vap); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicy %s: %v", vap.Name, err) + } + } + + if shouldCleanupCloudResources(hcp) { + log.Info("Cleaning up hosted cluster cloud resources") + return r.destroyCloudResources(ctx, hcp) + } + return ctrl.Result{}, nil +} + +func (r *reconciler) reconcilePlatformSpecificResources(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + switch hcp.Spec.Platform.Type { + case hyperv1.AWSPlatform: + log.Info("reconciling AWS specific resources") + errs = append(errs, r.reconcileAWSIdentityWebhook(ctx)...) + case hyperv1.AzurePlatform: + log.Info("reconciling Azure specific resources") + errs = append(errs, r.reconcileAzureCloudNodeManager(ctx, releaseImage.ComponentImages()["azure-cloud-node-manager"])...) + errs = append(errs, r.reconcileAzureIdentityWebhook(ctx)...) + } + return errs +} + +func (r *reconciler) reconcileClusterRecovery(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, existingErrs []error) (ctrl.Result, error) { + if _, exists := hcp.Annotations[hyperv1.HostedClusterRestoredFromBackupAnnotation]; !exists { + return ctrl.Result{}, nil + } + + condition := &metav1.Condition{ + Type: string(hyperv1.HostedClusterRestoredFromBackup), + Reason: hyperv1.RecoveryFinishedReason, + } + + finished, err := r.reconcileRestoredCluster(ctx, hcp) + if err != nil { + log.Error(err, "failed to reconcile hosted cluster recovery") + return ctrl.Result{}, utilerrors.NewAggregate(append(existingErrs, err)) + } + + if !finished { + log.Info("hosted cluster recovery not finished yet") + condition.Status = metav1.ConditionFalse + condition.Message = "Hosted cluster recovery not finished yet" + } else { + log.Info("hosted cluster recovery finished") + condition.Status = metav1.ConditionTrue + condition.Message = "Hosted cluster recovery finished" + } + + meta.SetStatusCondition(&hcp.Status.Conditions, *condition) + log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) + if err := r.cpClient.Status().Update(ctx, hcp); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) + } + log.Info("successfully updated hcp status with recovery condition") + + if !finished { + return ctrl.Result{RequeueAfter: 120 * time.Second}, nil + } + return ctrl.Result{}, nil +} + +func (r *reconciler) reconcileCSIDriver(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { + switch hcp.Spec.Platform.Type { + case hyperv1.KubevirtPlatform: + // Most csi drivers should be laid down by the Cluster Storage Operator (CSO) instead of + // the hcco operator. Only KubeVirt is unique at the moment. + err := kubevirtcsi.ReconcileTenant(r.client, hcp, ctx, r.CreateOrUpdate, releaseImage.ComponentImages()) + if err != nil { + return err + } + } + + return nil +} + +func (r *reconciler) reconcileRegistryAndIngress(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger) []error { + var errs []error + registryConfig := manifests.Registry() var registryConfigExists bool - // Check if the registry config exists if err := r.client.Get(ctx, client.ObjectKeyFromObject(registryConfig), registryConfig); err != nil { if !apierrors.IsNotFound(err) { - return ctrl.Result{}, fmt.Errorf("failed to get registry config: %w", err) + errs = append(errs, fmt.Errorf("failed to get registry config: %w", err)) } } else { registryConfigExists = true } - // For platforms where cluster-image-registry-operator (CIRO) needs a PVC to be created, bootstrap needs to happen - // in CIRO before the registry config is created. For now, this is the case for the OpenStack platform. - // If the object exist, we reconcile the registry config for other fields as it should be fine since the PVC would - // exist at this point. if capabilities.IsImageRegistryCapabilityEnabled(hcp.Spec.Capabilities) { if imageRegistryPlatformWithPVC(hcp.Spec.Platform.Type) && (!registryConfigExists || registryConfig == nil) { log.Info("skipping registry config to let CIRO bootstrap") @@ -480,23 +619,16 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } log.Info("reconciling registry config") if _, err := r.CreateOrUpdate(ctx, r.client, registryConfig, func() error { - err = registry.ReconcileRegistryConfig(registryConfig, r.platformType, hcp.Spec.InfrastructureAvailabilityPolicy) - if err != nil { - return err - } - return nil + return registry.ReconcileRegistryConfig(registryConfig, r.platformType, hcp.Spec.InfrastructureAvailabilityPolicy) }); err != nil { errs = append(errs, fmt.Errorf("failed to reconcile imageregistry config: %w", err)) } - // TODO: remove this when ROSA HCP stops setting the managementState to Removed to disable the Image Registry if registryConfig.Spec.ManagementState == operatorv1.Removed && r.platformType != hyperv1.IBMCloudPlatform && r.platformType != hyperv1.AzurePlatform { log.Info("imageregistry operator managementstate is removed, disabling openshift-controller-manager controllers and cleaning up resources") ocmConfigMap := cpomanifests.OpenShiftControllerManagerConfig(r.hcpNamespace) if _, err := r.CreateOrUpdate(ctx, r.cpClient, ocmConfigMap, func() error { if ocmConfigMap.Data == nil { - // CPO has not created the configmap yet, wait for create - // This should not happen as we are started by the CPO after the configmap should be created return nil } config := &openshiftcpv1.OpenShiftControllerManagerConfig{} @@ -536,8 +668,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } } } - // Reconcile the IngressController resource only if the ingress capability is enabled. - // Skip this step if the user explicitly disabled ingress. + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { log.Info("reconciling ingress controller") if err := r.reconcileIngressController(ctx, hcp); err != nil { @@ -550,6 +681,12 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile oauth client secrets: %w", err)) } + return errs +} + +func (r *reconciler) reconcileAPIServicesAndOAuth(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + log.Info("reconciling kube control plane signer secret") kubeControlPlaneSignerSecret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -629,29 +766,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } if util.HCPOAuthEnabled(hcp) { - log.Info("reconciling openshift oauth apiserver apiservices") - if err := r.reconcileOpenshiftOAuthAPIServerAPIServices(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift apiserver service: %w", err)) - } - - log.Info("reconciling openshift oauth apiserver service") - openshiftOAuthAPIServerService := manifests.OpenShiftOAuthAPIServerClusterService() - if _, err := r.CreateOrUpdate(ctx, r.client, openshiftOAuthAPIServerService, func() error { - oapi.ReconcileClusterService(openshiftOAuthAPIServerService) - return nil - }); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver service: %w", err)) - } - - log.Info("reconciling openshift oauth apiserver endpoints") - if err := r.reconcileOpenshiftOAuthAPIServerEndpoints(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift apiserver endpoints: %w", err)) - } - - log.Info("reconciling kubeadmin password hash secret") - if err := r.reconcileKubeadminPasswordHashSecret(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile kubeadmin password hash secret: %w", err)) - } + errs = append(errs, r.reconcileOAuthAPIServerResources(ctx, hcp, log)...) } log.Info("reconciling kube apiserver service monitor") @@ -667,6 +782,42 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile metrics forwarder: %w", err)) } + return errs +} + +func (r *reconciler) reconcileOAuthAPIServerResources(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger) []error { + var errs []error + + log.Info("reconciling openshift oauth apiserver apiservices") + if err := r.reconcileOpenshiftOAuthAPIServerAPIServices(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver apiservices: %w", err)) + } + + log.Info("reconciling openshift oauth apiserver service") + openshiftOAuthAPIServerService := manifests.OpenShiftOAuthAPIServerClusterService() + if _, err := r.CreateOrUpdate(ctx, r.client, openshiftOAuthAPIServerService, func() error { + oapi.ReconcileClusterService(openshiftOAuthAPIServerService) + return nil + }); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver service: %w", err)) + } + + log.Info("reconciling openshift oauth apiserver endpoints") + if err := r.reconcileOpenshiftOAuthAPIServerEndpoints(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver endpoints: %w", err)) + } + + log.Info("reconciling kubeadmin password hash secret") + if err := r.reconcileKubeadminPasswordHashSecret(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile kubeadmin password hash secret: %w", err)) + } + + return errs +} + +func (r *reconciler) reconcileNetworkingAndSecrets(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger, pullSecret *corev1.Secret) []error { + var errs []error + log.Info("reconciling network operator") networkOperator := networkoperator.NetworkOperator() var ovnConfig *hyperv1.OVNKubernetesConfig @@ -679,12 +830,10 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result }); err != nil { errs = append(errs, fmt.Errorf("failed to reconcile network operator: %w", err)) } - // Detect suboptimal MTU size on kubevirt hosted cluster with ovn-k and raise a condition in such a case if err := networkoperator.DetectSuboptimalMTU(ctx, r.cpClient, networkOperator, hcp); err != nil { errs = append(errs, err) } - // this allows users to disable data collection in sensitive environments - // solves https://issues.redhat.com/browse/OCPBUGS-12208 + ensureExistsReconciliationStrategy := false if _, exists := hcp.Annotations[hyperv1.EnsureExistsPullSecretReconciliation]; exists { ensureExistsReconciliationStrategy = true @@ -779,131 +928,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile openshift controller manager service ca bundle: %w", err)) } - log.Info("reconciling olm resources") - errs = append(errs, r.reconcileOLM(ctx, hcp, pullSecret)...) - - log.Info("reconciling kubelet configs") - if err := r.reconcileKubeletConfig(ctx); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile kubelet config: %w", err)) - } - - if hostedcontrolplane.IsStorageAndCSIManaged(hcp) { - log.Info("reconciling storage resources") - errs = append(errs, r.reconcileStorage(ctx, hcp)...) - - log.Info("reconciling node level csi configuration") - if err := r.reconcileCSIDriver(ctx, hcp, releaseImage); err != nil { - errs = append(errs, err) - } - } - - recyclerServiceAccount := manifests.RecyclerServiceAccount() - if _, err := r.CreateOrUpdate(ctx, r.client, recyclerServiceAccount, func() error { - return nil - }); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile pv recycler service account: %w", err)) - } - - log.Info("reconciling observed configuration") - errs = append(errs, r.reconcileObservedConfiguration(ctx, hcp)...) - - errs = append(errs, r.ensureGuestAdmissionWebhooksAreValid(ctx)) - - // Delete the DNS operator deployment in the hosted cluster, if it is - // present there. A separate DNS operator deployment runs as part of - // the hosted control-plane, but an upgraded cluster might still have - // an old DNS operator deployment in the hosted cluster. The caching - // client has a label selector that doesn't match the deployment, - // so we must use the uncached client for this delete call. To avoid - // excessive API calls using the uncached client, the delete call is - // guarded using a sync.Once. - if r.isClusterVersionUpdated(ctx, releaseImage.Version()) { - deleteDNSOperatorDeploymentOnce.Do(func() { - dnsOperatorDeployment := manifests.DNSOperatorDeployment() - log.Info("removing any existing DNS operator deployment") - if err := r.uncachedClient.Delete(ctx, dnsOperatorDeployment); err != nil && !apierrors.IsNotFound(err) { - errs = append(errs, err) - } - }) - deleteCVORemovedResourcesOnce.Do(func() { - resources := cvo.ResourcesToRemove(hcp.Spec.Platform.Type) - for _, resource := range resources { - log.Info("removing existing resources", "resource", resource) - if err := r.uncachedClient.Delete(ctx, resource); err != nil && !apierrors.IsNotFound(err) { - errs = append(errs, err) - } - } - }) - } - - // Reconcile platform specific resources - switch hcp.Spec.Platform.Type { - case hyperv1.AWSPlatform: - log.Info("reconciling AWS specific resources") - errs = append(errs, r.reconcileAWSIdentityWebhook(ctx)...) - case hyperv1.AzurePlatform: - log.Info("reconciling Azure specific resources") - errs = append(errs, r.reconcileAzureCloudNodeManager(ctx, releaseImage.ComponentImages()["azure-cloud-node-manager"])...) - errs = append(errs, r.reconcileAzureIdentityWebhook(ctx)...) - } - - // Reconcile hostedCluster recovery if the hosted cluster was restored from backup - if _, exists := hcp.Annotations[hyperv1.HostedClusterRestoredFromBackupAnnotation]; exists { - var ( - finished bool - err error - ) - condition := &metav1.Condition{ - Type: string(hyperv1.HostedClusterRestoredFromBackup), - Reason: hyperv1.RecoveryFinishedReason, - } - - finished, err = r.reconcileRestoredCluster(ctx, hcp) - if err != nil { - log.Error(err, "failed to reconcile hosted cluster recovery") - return ctrl.Result{}, utilerrors.NewAggregate(append(errs, err)) - } - - if !finished { - log.Info("hosted cluster recovery not finished yet") - condition.Status = metav1.ConditionFalse - condition.Message = "Hosted cluster recovery not finished yet" - meta.SetStatusCondition(&hcp.Status.Conditions, *condition) - log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) - if err := r.cpClient.Status().Update(ctx, hcp); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) - } - log.Info("successfully updated hcp status with recovery not finished condition") - - return ctrl.Result{RequeueAfter: 120 * time.Second}, nil - } - - log.Info("hosted cluster recovery finished") - condition.Status = metav1.ConditionTrue - condition.Message = "Hosted cluster recovery finished" - meta.SetStatusCondition(&hcp.Status.Conditions, *condition) - log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) - if err := r.cpClient.Status().Update(ctx, hcp); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) - } - log.Info("successfully updated hcp status with recovery finished condition") - } - - return ctrl.Result{}, utilerrors.NewAggregate(errs) -} - -func (r *reconciler) reconcileCSIDriver(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { - switch hcp.Spec.Platform.Type { - case hyperv1.KubevirtPlatform: - // Most csi drivers should be laid down by the Cluster Storage Operator (CSO) instead of - // the hcco operator. Only KubeVirt is unique at the moment. - err := kubevirtcsi.ReconcileTenant(r.client, hcp, ctx, r.CreateOrUpdate, releaseImage.ComponentImages()) - if err != nil { - return err - } - } - - return nil + return errs } func (r *reconciler) reconcileMetricsForwarder(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { diff --git a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go index def47e1bd6e..611f39abe08 100644 --- a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go +++ b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "math/rand" + "reflect" "strings" + "sync" "testing" "time" @@ -15,15 +17,18 @@ import ( "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/api" "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/kas" "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/manifests" + "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/registry" "github.com/openshift/hypershift/hypershift-operator/controllers/nodepool" "github.com/openshift/hypershift/support/azureutil" "github.com/openshift/hypershift/support/globalconfig" "github.com/openshift/hypershift/support/netutil" + "github.com/openshift/hypershift/support/releaseinfo" fakereleaseprovider "github.com/openshift/hypershift/support/releaseinfo/fake" supportutil "github.com/openshift/hypershift/support/util" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" configv1 "github.com/openshift/api/config/v1" + imageapi "github.com/openshift/api/image/v1" operatorv1 "github.com/openshift/api/operator/v1" appsv1 "k8s.io/api/apps/v1" @@ -2699,6 +2704,140 @@ func Test_reconciler_reconcileControlPlaneConnectionAvailable(t *testing.T) { } } +func verifyKASCheckerLabelsAndSelectors(t *testing.T, dep *appsv1.Deployment) { + t.Helper() + if dep.Spec.Selector == nil || dep.Spec.Selector.MatchLabels["app"] != manifests.KASConnectionCheckerName { + t.Error("Selector labels not set correctly") + } + if dep.Spec.Template.ObjectMeta.Labels["app"] != manifests.KASConnectionCheckerName { + t.Error("Pod template labels not set correctly") + } +} + +func verifyKASCheckerContainerBasics(t *testing.T, dep *appsv1.Deployment, expectedImage string) corev1.Container { + t.Helper() + if len(dep.Spec.Template.Spec.Containers) != 1 { + t.Fatalf("Expected 1 container, got %d", len(dep.Spec.Template.Spec.Containers)) + } + container := dep.Spec.Template.Spec.Containers[0] + if container.Name != "connection-checker" { + t.Errorf("Expected container name 'connection-checker', got %s", container.Name) + } + if container.Image != expectedImage { + t.Errorf("Expected cli image %s, got %s", expectedImage, container.Image) + } + return container +} + +func verifyKASCheckerScript(t *testing.T, container corev1.Container) { + t.Helper() + if len(container.Command) != 3 || container.Command[0] != "/bin/sh" || container.Command[1] != "-c" { + t.Fatalf("Expected command [/bin/sh -c