diff --git a/.golangci.yml b/.golangci.yml index f3657cf77a0..f7322988a68 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,9 +3,12 @@ run: allow-parallel-runners: true linters: enable: + - gocyclo - misspell - unparam settings: + gocyclo: + min-complexity: 30 govet: enable: - nilness diff --git a/cmd/cluster/core/create.go b/cmd/cluster/core/create.go index 9dd96f94c1d..4c3ccc771e0 100644 --- a/cmd/cluster/core/create.go +++ b/cmd/cluster/core/create.go @@ -236,51 +236,25 @@ func (r *resources) asObjects() []crclient.Object { func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, error) { prototype := &resources{} - // allow client side defaulting when release image is empty but release stream is set. - if len(opts.ReleaseImage) == 0 && len(opts.ReleaseStream) != 0 { - client, err := util.GetClient() - if err != nil { - return nil, fmt.Errorf("failed to get client: %w", err) - } - defaultVersion, err := supportedversion.LookupDefaultOCPVersion(ctx, opts.ReleaseStream, client) - if err != nil { - return nil, fmt.Errorf("release image is required when unable to lookup default OCP version: %w", err) - } - opts.ReleaseImage = defaultVersion.PullSpec + if err := resolveReleaseImage(ctx, opts); err != nil { + return nil, err } - annotations := map[string]string{} - for _, s := range opts.Annotations { - pair := strings.SplitN(s, "=", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid annotation: %s", s) - } - k, v := pair[0], pair[1] - annotations[k] = v + annotations, err := parseKeyValuePairs(opts.Annotations, "annotation") + if err != nil { + return nil, err } - - labels := map[string]string{} - for _, s := range opts.Labels { - pair := strings.SplitN(s, "=", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid label: %s", s) - } - k, v := pair[0], pair[1] - labels[k] = v + labels, err := parseKeyValuePairs(opts.Labels, "label") + if err != nil { + return nil, err } - if len(opts.ControlPlaneOperatorImage) > 0 { annotations[hyperv1.ControlPlaneOperatorImageAnnotation] = opts.ControlPlaneOperatorImage } - pullSecret := opts.PullSecret - var err error - // overrides if pullSecretFile is set - if len(opts.PullSecretFile) > 0 { - pullSecret, err = os.ReadFile(opts.PullSecretFile) - if err != nil { - return nil, fmt.Errorf("failed to read pull secret file: %w", err) - } + pullSecret, err := resolvePullSecret(opts) + if err != nil { + return nil, err } prototype.Namespace = &corev1.Namespace{ @@ -350,49 +324,124 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e }, } + applyClusterCapabilities(prototype.Cluster, opts) + if err := applyEtcdConfig(prototype.Cluster, opts); err != nil { + return nil, err + } + if err := applySSHKey(prototype, opts); err != nil { + return nil, err + } + if err := applyPausedUntil(prototype.Cluster, opts); err != nil { + return nil, err + } + applyOLMConfig(prototype.Cluster, opts) + if err := applyNetworkConfig(prototype.Cluster, opts); err != nil { + return nil, err + } + applySchedulingConfig(prototype.Cluster, opts) + if err := applyTrustBundleAndImageSources(prototype, opts); err != nil { + return nil, err + } + applyFeatureSet(prototype.Cluster, opts) + + if len(opts.KubeAPIServerDNSName) > 0 { + if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { + return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + } + prototype.Cluster.Spec.KubeAPIServerDNSName = opts.KubeAPIServerDNSName + } + + return prototype, nil +} + +func resolveReleaseImage(ctx context.Context, opts *CreateOptions) error { + if len(opts.ReleaseImage) != 0 || len(opts.ReleaseStream) == 0 { + return nil + } + client, err := util.GetClient() + if err != nil { + return fmt.Errorf("failed to get client: %w", err) + } + defaultVersion, err := supportedversion.LookupDefaultOCPVersion(ctx, opts.ReleaseStream, client) + if err != nil { + return fmt.Errorf("release image is required when unable to lookup default OCP version: %w", err) + } + opts.ReleaseImage = defaultVersion.PullSpec + return nil +} + +func parseKeyValuePairs(items []string, kind string) (map[string]string, error) { + result := map[string]string{} + for _, s := range items { + pair := strings.SplitN(s, "=", 2) + if len(pair) != 2 { + return nil, fmt.Errorf("invalid %s: %s", kind, s) + } + if pair[0] == "" { + return nil, fmt.Errorf("invalid %s: key must not be empty in %q", kind, s) + } + result[pair[0]] = pair[1] + } + return result, nil +} + +func resolvePullSecret(opts *CreateOptions) ([]byte, error) { + if len(opts.PullSecretFile) > 0 { + data, err := os.ReadFile(opts.PullSecretFile) + if err != nil { + return nil, fmt.Errorf("failed to read pull secret file: %w", err) + } + return data, nil + } + return opts.PullSecret, nil +} + +func applyClusterCapabilities(cluster *hyperv1.HostedCluster, opts *CreateOptions) { if len(opts.EnableClusterCapabilities) > 0 { caps := make([]hyperv1.OptionalCapability, len(opts.EnableClusterCapabilities)) for i, c := range opts.EnableClusterCapabilities { caps[i] = hyperv1.OptionalCapability(c) } - prototype.Cluster.Spec.Capabilities.Enabled = caps + cluster.Spec.Capabilities.Enabled = caps } - if len(opts.DisableClusterCapabilities) > 0 { caps := make([]hyperv1.OptionalCapability, len(opts.DisableClusterCapabilities)) for i, c := range opts.DisableClusterCapabilities { caps[i] = hyperv1.OptionalCapability(c) } - prototype.Cluster.Spec.Capabilities.Disabled = caps + cluster.Spec.Capabilities.Disabled = caps } +} +func applyEtcdConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { if opts.EtcdStorageClass != "" { - prototype.Cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName = ptr.To(opts.EtcdStorageClass) + cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName = ptr.To(opts.EtcdStorageClass) } - if opts.EtcdStorageSize != "" { etcdStorageSize, err := resource.ParseQuantity(opts.EtcdStorageSize) if err != nil { - return nil, fmt.Errorf("failed parse ectd storage size: %w", err) + return fmt.Errorf("failed parse ectd storage size: %w", err) } - prototype.Cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size = &etcdStorageSize + cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size = &etcdStorageSize } + return nil +} +func applySSHKey(prototype *resources, opts *CreateOptions) error { sshKey, sshPrivateKey := opts.PublicKey, opts.PrivateKey - // overrides secret if SSHKeyFile is set + var err error if len(opts.SSHKeyFile) > 0 { if opts.GenerateSSH { - return nil, fmt.Errorf("--generate-ssh and --ssh-key cannot be specified together") + return fmt.Errorf("--generate-ssh and --ssh-key cannot be specified together") } - key, err := os.ReadFile(opts.SSHKeyFile) + sshKey, err = os.ReadFile(opts.SSHKeyFile) if err != nil { - return nil, fmt.Errorf("failed to read ssh key file: %w", err) + return fmt.Errorf("failed to read ssh key file: %w", err) } - sshKey = key } else if opts.GenerateSSH { sshKey, sshPrivateKey, err = util.GenerateSSHKeys() if err != nil { - return nil, fmt.Errorf("failed to generate ssh keys: %w", err) + return fmt.Errorf("failed to generate ssh keys: %w", err) } } if len(sshKey) > 0 { @@ -415,105 +464,96 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e } prototype.Cluster.Spec.SSHKey = corev1.LocalObjectReference{Name: prototype.SSHKey.Name} } + return nil +} - // validate pausedUntil value - // valid values are either "true" or RFC3339 format date - if len(opts.PausedUntil) > 0 && opts.PausedUntil != "true" { - _, err := time.Parse(time.RFC3339, opts.PausedUntil) - if err != nil { - return nil, fmt.Errorf("invalid pausedUntil value, should be \"true\" or a valid RFC3339 date format: %w", err) - } - prototype.Cluster.Spec.PausedUntil = &opts.PausedUntil +func applyPausedUntil(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { + if len(opts.PausedUntil) == 0 || opts.PausedUntil == "true" { + return nil + } + _, err := time.Parse(time.RFC3339, opts.PausedUntil) + if err != nil { + return fmt.Errorf("invalid pausedUntil value, should be \"true\" or a valid RFC3339 date format: %w", err) } + cluster.Spec.PausedUntil = &opts.PausedUntil + return nil +} +func applyOLMConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) { if opts.OLMDisableDefaultSources { - prototype.Cluster.Spec.Configuration.OperatorHub = &configv1.OperatorHubSpec{ + cluster.Spec.Configuration.OperatorHub = &configv1.OperatorHubSpec{ DisableAllDefaultSources: true, } } - if len(opts.OLMCatalogPlacement) > 0 { - prototype.Cluster.Spec.OLMCatalogPlacement = opts.OLMCatalogPlacement + cluster.Spec.OLMCatalogPlacement = opts.OLMCatalogPlacement } +} - var clusterNetworkEntries []hyperv1.ClusterNetworkEntry +func applyNetworkConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) error { for _, cidr := range opts.ClusterCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing ClusterCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing ClusterCIDR (%s): %w", cidr, err) } - clusterNetworkEntries = append(clusterNetworkEntries, hyperv1.ClusterNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.ClusterNetwork = append(cluster.Spec.Networking.ClusterNetwork, hyperv1.ClusterNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.ClusterNetwork = clusterNetworkEntries - - var serviceNetworkEntries []hyperv1.ServiceNetworkEntry for _, cidr := range opts.ServiceCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing ServiceCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing ServiceCIDR (%s): %w", cidr, err) } - serviceNetworkEntries = append(serviceNetworkEntries, hyperv1.ServiceNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.ServiceNetwork = append(cluster.Spec.Networking.ServiceNetwork, hyperv1.ServiceNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.ServiceNetwork = serviceNetworkEntries - - var machineNetworkEntries []hyperv1.MachineNetworkEntry for _, cidr := range opts.MachineCIDR { parsedCIDR, err := ipnet.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("parsing MachineCIDR (%s): %w", cidr, err) + return fmt.Errorf("parsing MachineCIDR (%s): %w", cidr, err) } - machineNetworkEntries = append(machineNetworkEntries, hyperv1.MachineNetworkEntry{CIDR: *parsedCIDR}) + cluster.Spec.Networking.MachineNetwork = append(cluster.Spec.Networking.MachineNetwork, hyperv1.MachineNetworkEntry{CIDR: *parsedCIDR}) } - prototype.Cluster.Spec.Networking.MachineNetwork = machineNetworkEntries if opts.DisableMultiNetwork { - if prototype.Cluster.Spec.OperatorConfiguration == nil { - prototype.Cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} - } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} - } - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.DisableMultiNetwork = &opts.DisableMultiNetwork + ensureClusterNetworkOperatorSpec(cluster) + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.DisableMultiNetwork = &opts.DisableMultiNetwork } - if opts.OVNKubernetesMTU > 0 { - if prototype.Cluster.Spec.OperatorConfiguration == nil { - prototype.Cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} + ensureClusterNetworkOperatorSpec(cluster) + if cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig == nil { + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig = &hyperv1.OVNKubernetesConfig{} } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} - } - if prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig == nil { - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig = &hyperv1.OVNKubernetesConfig{} - } - prototype.Cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig.MTU = opts.OVNKubernetesMTU + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator.OVNKubernetesConfig.MTU = opts.OVNKubernetesMTU } - if opts.AllocateNodeCIDRs { enabled := hyperv1.AllocateNodeCIDRsEnabled - prototype.Cluster.Spec.Networking.AllocateNodeCIDRs = &enabled + cluster.Spec.Networking.AllocateNodeCIDRs = &enabled } + return nil +} - if opts.NodeSelector != nil { - prototype.Cluster.Spec.NodeSelector = opts.NodeSelector +func ensureClusterNetworkOperatorSpec(cluster *hyperv1.HostedCluster) { + if cluster.Spec.OperatorConfiguration == nil { + cluster.Spec.OperatorConfiguration = &hyperv1.OperatorConfiguration{} } - - if opts.PodsLabels != nil { - prototype.Cluster.Spec.Labels = opts.PodsLabels + if cluster.Spec.OperatorConfiguration.ClusterNetworkOperator == nil { + cluster.Spec.OperatorConfiguration.ClusterNetworkOperator = &hyperv1.ClusterNetworkOperatorSpec{} } +} - for _, tStr := range opts.Tolerations { - toleration, err := parseTolerationString(tStr) - if err != nil { - return nil, err - } - prototype.Cluster.Spec.Tolerations = append(prototype.Cluster.Spec.Tolerations, *toleration) +func applySchedulingConfig(cluster *hyperv1.HostedCluster, opts *CreateOptions) { + if opts.NodeSelector != nil { + cluster.Spec.NodeSelector = opts.NodeSelector } + if opts.PodsLabels != nil { + cluster.Spec.Labels = opts.PodsLabels + } +} +func applyTrustBundleAndImageSources(prototype *resources, opts *CreateOptions) error { if len(opts.AdditionalTrustBundle) > 0 { userCABundle, err := os.ReadFile(opts.AdditionalTrustBundle) if err != nil { - return nil, fmt.Errorf("failed to read additional trust bundle file: %w", err) + return fmt.Errorf("failed to read additional trust bundle file: %w", err) } prototype.AdditionalTrustBundle = &corev1.ConfigMap{ TypeMeta: metav1.TypeMeta{ @@ -534,44 +574,42 @@ func prototypeResources(ctx context.Context, opts *CreateOptions) (*resources, e if len(opts.ImageContentSources) > 0 { icspFileBytes, err := os.ReadFile(opts.ImageContentSources) if err != nil { - return nil, fmt.Errorf("failed to read image content sources file: %w", err) + return fmt.Errorf("failed to read image content sources file: %w", err) } - var imageContentSources []hyperv1.ImageContentSource err = yaml.Unmarshal(icspFileBytes, &imageContentSources) if err != nil { - return nil, fmt.Errorf("unable to deserialize image content sources file: %w", err) + return fmt.Errorf("unable to deserialize image content sources file: %w", err) } prototype.Cluster.Spec.ImageContentSources = imageContentSources } - if opts.FeatureSet != string(configv1.Default) { - switch opts.FeatureSet { - case string(configv1.TechPreviewNoUpgrade): - prototype.Cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ - FeatureGateSelection: configv1.FeatureGateSelection{ - FeatureSet: configv1.TechPreviewNoUpgrade, - }, - } - case string(configv1.DevPreviewNoUpgrade): - prototype.Cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ - FeatureGateSelection: configv1.FeatureGateSelection{ - FeatureSet: configv1.DevPreviewNoUpgrade, - }, - } - default: - return nil, fmt.Errorf("invalid feature set: %s", opts.FeatureSet) + for _, tStr := range opts.Tolerations { + toleration, err := parseTolerationString(tStr) + if err != nil { + return err } + prototype.Cluster.Spec.Tolerations = append(prototype.Cluster.Spec.Tolerations, *toleration) } - if len(opts.KubeAPIServerDNSName) > 0 { - if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { - return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + return nil +} + +func applyFeatureSet(cluster *hyperv1.HostedCluster, opts *CreateOptions) { + switch opts.FeatureSet { + case string(configv1.TechPreviewNoUpgrade): + cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ + FeatureGateSelection: configv1.FeatureGateSelection{ + FeatureSet: configv1.TechPreviewNoUpgrade, + }, + } + case string(configv1.DevPreviewNoUpgrade): + cluster.Spec.Configuration.FeatureGate = &configv1.FeatureGateSpec{ + FeatureGateSelection: configv1.FeatureGateSelection{ + FeatureSet: configv1.DevPreviewNoUpgrade, + }, } - prototype.Cluster.Spec.KubeAPIServerDNSName = opts.KubeAPIServerDNSName } - - return prototype, nil } func apply(ctx context.Context, l logr.Logger, infraID string, objects []crclient.Object, waitForRollout bool, mutate func(crclient.Object)) error { @@ -680,81 +718,103 @@ func (opts *RawCreateOptions) Validate(ctx context.Context) (*ValidatedCreateOpt if opts.Name == "" { return nil, errors.New("--name is required") } - if opts.PullSecretFile == "" { return nil, errors.New("--pull-secret is required") } + if err := opts.validateVersionAndWait(ctx); err != nil { + return nil, err + } + + errs := validation.IsDNS1123Label(opts.Name) + if len(errs) > 0 { + return nil, fmt.Errorf("HostedCluster name failed RFC1123 validation: %s", strings.Join(errs[:], " ")) + } + + if err := opts.validateClusterExistence(ctx); err != nil { + return nil, err + } + if err := opts.validateArchAndFeatureSet(); err != nil { + return nil, err + } + if err := opts.validateCapabilities(); err != nil { + return nil, err + } + if err := opts.validateNetworkOptions(); err != nil { + return nil, err + } + + return &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: opts, + }, + }, nil +} +func (opts *RawCreateOptions) validateVersionAndWait(ctx context.Context) error { if opts.VersionCheck { versionCLI := supportedversion.GetRevision() client, err := util.GetClient() if err != nil { - return nil, fmt.Errorf("failed to get client: %w", err) + return fmt.Errorf("failed to get client: %w", err) } if err := validateVersion(ctx, versionCLI, client); err != nil { - return nil, fmt.Errorf("version validation failed: %w", err) + return fmt.Errorf("version validation failed: %w", err) } } - if opts.Wait && opts.NodePoolReplicas < 1 { - return nil, errors.New("--wait requires --node-pool-replicas > 0") + return errors.New("--wait requires --node-pool-replicas > 0") } + return nil +} - // Validate HostedCluster name follows RFC1123 standard - // https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names - errs := validation.IsDNS1123Label(opts.Name) - if len(errs) > 0 { - return nil, fmt.Errorf("HostedCluster name failed RFC1123 validation: %s", strings.Join(errs[:], " ")) +func (opts *RawCreateOptions) validateClusterExistence(ctx context.Context) error { + if opts.Render { + return nil + } + client, err := util.GetClient() + if err != nil { + return err + } + cluster := &hyperv1.HostedCluster{ObjectMeta: metav1.ObjectMeta{Namespace: opts.Namespace, Name: opts.Name}} + if err := client.Get(ctx, crclient.ObjectKeyFromObject(cluster), cluster); err == nil { + return fmt.Errorf("hostedcluster %s already exists", crclient.ObjectKeyFromObject(cluster)) + } else if !apierrors.IsNotFound(err) { + return fmt.Errorf("hostedcluster doesn't exist validation failed with error: %w", err) } - if !opts.Render { - client, err := util.GetClient() - if err != nil { - return nil, err - } - // Validate HostedCluster with this name doesn't exist in the namespace - cluster := &hyperv1.HostedCluster{ObjectMeta: metav1.ObjectMeta{Namespace: opts.Namespace, Name: opts.Name}} - if err := client.Get(ctx, crclient.ObjectKeyFromObject(cluster), cluster); err == nil { - return nil, fmt.Errorf("hostedcluster %s already exists", crclient.ObjectKeyFromObject(cluster)) - } else if !apierrors.IsNotFound(err) { - return nil, fmt.Errorf("hostedcluster doesn't exist validation failed with error: %w", err) - } - - // Validate multi-arch aspects - kc, err := hyperutil.GetKubeClientSet() - if err != nil { - return nil, fmt.Errorf("could not retrieve kube clientset: %w", err) - } - if err := validateMgmtClusterAndNodePoolCPUArchitectures(ctx, opts, kc, &hyperutil.RegistryClientImageMetadataProvider{}); err != nil { - if strings.Contains(err.Error(), "failed to retrieve manifest") { - opts.Log.Info("WARNING: Unable to access the payload, skipping the Architectures check.", "error", err.Error()) - } else { - return nil, err - } + kc, err := hyperutil.GetKubeClientSet() + if err != nil { + return fmt.Errorf("could not retrieve kube clientset: %w", err) + } + if err := validateMgmtClusterAndNodePoolCPUArchitectures(ctx, opts, kc, &hyperutil.RegistryClientImageMetadataProvider{}); err != nil { + if strings.Contains(err.Error(), "failed to retrieve manifest") { + opts.Log.Info("WARNING: Unable to access the payload, skipping the Architectures check.", "error", err.Error()) + } else { + return err } } + return nil +} +func (opts *RawCreateOptions) validateArchAndFeatureSet() error { arch := strings.ToLower(opts.Arch) switch arch { - case hyperv1.ArchitectureAMD64: - case hyperv1.ArchitectureARM64: - case hyperv1.ArchitecturePPC64LE: - case hyperv1.ArchitectureS390X: + case hyperv1.ArchitectureAMD64, hyperv1.ArchitectureARM64, hyperv1.ArchitecturePPC64LE, hyperv1.ArchitectureS390X: default: - return nil, fmt.Errorf("specified arch %q is not supported", opts.Arch) + return fmt.Errorf("specified arch %q is not supported", opts.Arch) } - // Validate feature set is "", TechPreviewNoUpgrade, or DevPreviewNoUpgrade switch opts.FeatureSet { - case string(configv1.Default): - case string(configv1.TechPreviewNoUpgrade): - case string(configv1.DevPreviewNoUpgrade): + case string(configv1.Default), string(configv1.TechPreviewNoUpgrade), string(configv1.DevPreviewNoUpgrade): case string(configv1.CustomNoUpgrade): - return nil, fmt.Errorf("only a predefined feature set is supported by the feature-set flag") + return fmt.Errorf("only a predefined feature set is supported by the feature-set flag") default: - return nil, fmt.Errorf("specified feature set %q is not supported", opts.FeatureSet) + return fmt.Errorf("specified feature set %q is not supported", opts.FeatureSet) } + return nil +} +func (opts *RawCreateOptions) validateCapabilities() error { acceptedValues := sets.NewString( string(hyperv1.ImageRegistryCapability), string(hyperv1.OpenShiftSamplesCapability), @@ -764,54 +824,44 @@ func (opts *RawCreateOptions) Validate(ctx context.Context) (*ValidatedCreateOpt string(hyperv1.NodeTuningCapability), string(hyperv1.IngressCapability), ) - if len(opts.DisableClusterCapabilities) > 0 { - for _, capability := range opts.DisableClusterCapabilities { - if !acceptedValues.Has(capability) { - return nil, fmt.Errorf("unknown disabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) - } + for _, capability := range opts.DisableClusterCapabilities { + if !acceptedValues.Has(capability) { + return fmt.Errorf("unknown disabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) } } - if len(opts.EnableClusterCapabilities) > 0 { - for _, capability := range opts.EnableClusterCapabilities { - if !acceptedValues.Has(capability) { - return nil, fmt.Errorf("unknown enabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) - } + for _, capability := range opts.EnableClusterCapabilities { + if !acceptedValues.Has(capability) { + return fmt.Errorf("unknown enabled capability: %s, accepted values are: %v", capability, acceptedValues.List()) } } - disabledCaps := sets.NewString(opts.DisableClusterCapabilities...) if disabledCaps.Has(string(hyperv1.IngressCapability)) && !disabledCaps.Has(string(hyperv1.ConsoleCapability)) { - return nil, fmt.Errorf("ingress capability can only be disabled if Console capability is also disabled") + return fmt.Errorf("ingress capability can only be disabled if Console capability is also disabled") } - if len(opts.KubeAPIServerDNSName) > 0 { if err := validation.IsDNS1123Subdomain(opts.KubeAPIServerDNSName); len(err) > 0 { - return nil, fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) + return fmt.Errorf("KubeAPIServerDNSName failed DNS validation: %s", strings.Join(err[:], " ")) } } + return nil +} +func (opts *RawCreateOptions) validateNetworkOptions() error { if opts.DisableMultiNetwork && opts.NetworkType != "Other" { - return nil, fmt.Errorf("disableMultiNetwork is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) + return fmt.Errorf("disableMultiNetwork is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) } - if opts.OVNKubernetesMTU != 0 { if opts.NetworkType != string(hyperv1.OVNKubernetes) { - return nil, fmt.Errorf("--ovn-kubernetes-mtu is only valid when --network-type is OVNKubernetes (got '%s')", opts.NetworkType) + return fmt.Errorf("--ovn-kubernetes-mtu is only valid when --network-type is OVNKubernetes (got '%s')", opts.NetworkType) } if opts.OVNKubernetesMTU < 576 || opts.OVNKubernetesMTU > 9216 { - return nil, fmt.Errorf("--ovn-kubernetes-mtu must be between 576 and 9216 (got %d)", opts.OVNKubernetesMTU) + return fmt.Errorf("--ovn-kubernetes-mtu must be between 576 and 9216 (got %d)", opts.OVNKubernetesMTU) } } - if opts.AllocateNodeCIDRs && opts.NetworkType != "Other" { - return nil, fmt.Errorf("allocateNodeCIDRs is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) + return fmt.Errorf("allocateNodeCIDRs is only allowed when networkType is 'Other' (got '%s')", opts.NetworkType) } - - return &ValidatedCreateOptions{ - validatedCreateOptions: &validatedCreateOptions{ - RawCreateOptions: opts, - }, - }, nil + return nil } // completedCreateOptions is a private wrapper that enforces a call of Complete() before CreateCluster() can be invoked. diff --git a/cmd/cluster/core/create_test.go b/cmd/cluster/core/create_test.go index e2c0e66257c..532f77e8ed6 100644 --- a/cmd/cluster/core/create_test.go +++ b/cmd/cluster/core/create_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "testing" + "time" . "github.com/onsi/gomega" @@ -13,6 +14,8 @@ import ( "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/dockerv1client" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" + configv1 "github.com/openshift/api/config/v1" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -887,3 +890,840 @@ func TestGetServicePublishingStrategyMapping(t *testing.T) { }) } } + +func TestParseKeyValuePairs(t *testing.T) { + tests := []struct { + name string + items []string + kind string + expected map[string]string + expectError string + }{ + { + name: "When valid key=value pairs are provided, it should return a map", + items: []string{"key1=value1", "key2=value2"}, + kind: "annotation", + expected: map[string]string{"key1": "value1", "key2": "value2"}, + }, + { + name: "When an empty list is provided, it should return an empty map", + items: []string{}, + kind: "label", + expected: map[string]string{}, + }, + { + name: "When a malformed pair without equals sign is provided, it should return an error", + items: []string{"badentry"}, + kind: "annotation", + expectError: "invalid annotation: badentry", + }, + { + name: "When a value contains an equals sign, it should split only on the first equals", + items: []string{"key=val=extra"}, + kind: "label", + expected: map[string]string{"key": "val=extra"}, + }, + { + name: "When a value is empty after equals sign, it should accept it", + items: []string{"key="}, + kind: "annotation", + expected: map[string]string{"key": ""}, + }, + { + name: "When the key is empty, it should return an error", + items: []string{"=value"}, + kind: "label", + expectError: "key must not be empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result, err := parseKeyValuePairs(tc.items, tc.kind) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(result).To(Equal(tc.expected)) + } + }) + } +} + +func TestApplyEtcdConfig(t *testing.T) { + tests := []struct { + name string + etcdStorageClass string + etcdStorageSize string + expectError string + expectStorageClass *string + expectSizeString string + }{ + { + name: "When no etcd storage options are provided, it should leave defaults unchanged", + etcdStorageClass: "", + etcdStorageSize: "", + }, + { + name: "When etcd storage class is provided, it should set the storage class", + etcdStorageClass: "gp3-csi", + expectStorageClass: ptr.To("gp3-csi"), + }, + { + name: "When a valid etcd storage size is provided, it should set the storage size", + etcdStorageSize: "8Gi", + expectSizeString: "8Gi", + }, + { + name: "When an invalid etcd storage size is provided, it should return an error", + etcdStorageSize: "notasize", + expectError: "failed parse ectd storage size", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Etcd: hyperv1.EtcdSpec{ + Managed: &hyperv1.ManagedEtcdSpec{ + Storage: hyperv1.ManagedEtcdStorageSpec{ + PersistentVolume: &hyperv1.PersistentVolumeEtcdStorageSpec{ + Size: &hyperv1.DefaultPersistentVolumeEtcdStorageSize, + }, + }, + }, + }, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + EtcdStorageClass: tc.etcdStorageClass, + EtcdStorageSize: tc.etcdStorageSize, + }, + }, + }, + }, + } + + err := applyEtcdConfig(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + if tc.expectStorageClass != nil { + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName).To(Equal(tc.expectStorageClass)) + } + if tc.expectSizeString != "" { + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size.String()).To(Equal(tc.expectSizeString)) + } + // When no storage options are provided, verify defaults remain unchanged + if tc.etcdStorageClass == "" && tc.etcdStorageSize == "" { + // StorageClassName should remain nil (default) + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.StorageClassName).To(BeNil(), "StorageClassName should remain nil when not specified") + // Size should remain at the default value + g.Expect(cluster.Spec.Etcd.Managed.Storage.PersistentVolume.Size).To(Equal(&hyperv1.DefaultPersistentVolumeEtcdStorageSize), "Size should remain at default when not specified") + } + } + }) + } +} + +func TestApplyPausedUntil(t *testing.T) { + tests := []struct { + name string + pausedUntil string + expectError string + expectSet bool + }{ + { + name: "When pausedUntil is empty, it should not set PausedUntil on the cluster", + pausedUntil: "", + expectSet: false, + }, + { + name: "When pausedUntil is 'true', it should not set PausedUntil on the cluster", + pausedUntil: "true", + expectSet: false, + }, + { + name: "When pausedUntil is a valid RFC3339 date, it should set PausedUntil on the cluster", + pausedUntil: "2026-12-31T23:59:59Z", + expectSet: true, + }, + { + name: "When pausedUntil is an invalid date format, it should return an error", + pausedUntil: "not-a-date", + expectError: "invalid pausedUntil value", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + PausedUntil: tc.pausedUntil, + }, + }, + }, + }, + } + + err := applyPausedUntil(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + if tc.expectSet { + g.Expect(cluster.Spec.PausedUntil).NotTo(BeNil()) + g.Expect(*cluster.Spec.PausedUntil).To(Equal(tc.pausedUntil)) + } else { + g.Expect(cluster.Spec.PausedUntil).To(BeNil()) + } + } + }) + } +} + +func TestApplyOLMConfig(t *testing.T) { + tests := []struct { + name string + olmDisableDefaultSources bool + olmCatalogPlacement hyperv1.OLMCatalogPlacement + expectOperatorHub bool + expectCatalogPlacement hyperv1.OLMCatalogPlacement + }{ + { + name: "When OLM default sources are disabled, it should set DisableAllDefaultSources", + olmDisableDefaultSources: true, + olmCatalogPlacement: "", + expectOperatorHub: true, + }, + { + name: "When OLM catalog placement is set to Guest, it should set OLMCatalogPlacement", + olmCatalogPlacement: hyperv1.GuestOLMCatalogPlacement, + expectCatalogPlacement: hyperv1.GuestOLMCatalogPlacement, + }, + { + name: "When no OLM options are set, it should not modify the cluster", + expectOperatorHub: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Configuration: &hyperv1.ClusterConfiguration{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + OLMDisableDefaultSources: tc.olmDisableDefaultSources, + OLMCatalogPlacement: tc.olmCatalogPlacement, + }, + }, + }, + }, + } + + applyOLMConfig(cluster, opts) + if tc.expectOperatorHub { + g.Expect(cluster.Spec.Configuration.OperatorHub).NotTo(BeNil()) + g.Expect(cluster.Spec.Configuration.OperatorHub.DisableAllDefaultSources).To(BeTrue()) + } + if tc.expectCatalogPlacement != "" { + g.Expect(cluster.Spec.OLMCatalogPlacement).To(Equal(tc.expectCatalogPlacement)) + } + // When no OLM options are set, verify cluster remains unmodified + if !tc.olmDisableDefaultSources && tc.olmCatalogPlacement == "" { + g.Expect(cluster.Spec.Configuration.OperatorHub).To(BeNil(), "OperatorHub should remain nil when OLM options are not set") + g.Expect(cluster.Spec.OLMCatalogPlacement).To(BeEmpty(), "OLMCatalogPlacement should remain empty when not specified") + } + }) + } +} + +func TestApplyFeatureSet(t *testing.T) { + tests := []struct { + name string + featureSet string + expectFeatureGate bool + expectedFeatureSet configv1.FeatureSet + }{ + { + name: "When feature set is Default, it should not set FeatureGate", + featureSet: string(configv1.Default), + expectFeatureGate: false, + }, + { + name: "When feature set is TechPreviewNoUpgrade, it should set the TechPreviewNoUpgrade feature gate", + featureSet: string(configv1.TechPreviewNoUpgrade), + expectFeatureGate: true, + expectedFeatureSet: configv1.TechPreviewNoUpgrade, + }, + { + name: "When feature set is DevPreviewNoUpgrade, it should set the DevPreviewNoUpgrade feature gate", + featureSet: string(configv1.DevPreviewNoUpgrade), + expectFeatureGate: true, + expectedFeatureSet: configv1.DevPreviewNoUpgrade, + }, + { + name: "When feature set is an unrecognized value, it should not set FeatureGate", + featureSet: "SomethingElse", + expectFeatureGate: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Configuration: &hyperv1.ClusterConfiguration{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + FeatureSet: tc.featureSet, + }, + }, + }, + }, + } + + applyFeatureSet(cluster, opts) + if tc.expectFeatureGate { + g.Expect(cluster.Spec.Configuration.FeatureGate).NotTo(BeNil()) + g.Expect(cluster.Spec.Configuration.FeatureGate.FeatureSet).To(Equal(tc.expectedFeatureSet)) + } else { + g.Expect(cluster.Spec.Configuration.FeatureGate).To(BeNil()) + } + }) + } +} + +func TestApplyNetworkConfig(t *testing.T) { + tests := []struct { + name string + clusterCIDR []string + serviceCIDR []string + machineCIDR []string + expectError string + expectClusters int + expectServices int + expectMachines int + }{ + { + name: "When valid CIDRs are provided, it should parse and append them", + clusterCIDR: []string{"10.128.0.0/14"}, + serviceCIDR: []string{"172.30.0.0/16"}, + machineCIDR: []string{"10.0.0.0/16"}, + expectClusters: 1, + expectServices: 1, + expectMachines: 1, + }, + { + name: "When dual-stack CIDRs are provided, it should parse both", + clusterCIDR: []string{"10.128.0.0/14", "fd01::/48"}, + serviceCIDR: []string{"172.30.0.0/16", "fd02::/112"}, + expectClusters: 2, + expectServices: 2, + expectMachines: 0, + }, + { + name: "When an invalid cluster CIDR is provided, it should return an error", + clusterCIDR: []string{"not-a-cidr"}, + expectError: "parsing ClusterCIDR", + }, + { + name: "When an invalid service CIDR is provided, it should return an error", + serviceCIDR: []string{"not-a-cidr"}, + expectError: "parsing ServiceCIDR", + }, + { + name: "When an invalid machine CIDR is provided, it should return an error", + machineCIDR: []string{"not-a-cidr"}, + expectError: "parsing MachineCIDR", + }, + { + name: "When no CIDRs are provided, it should produce empty network lists", + expectClusters: 0, + expectServices: 0, + expectMachines: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + ClusterCIDR: tc.clusterCIDR, + ServiceCIDR: tc.serviceCIDR, + MachineCIDR: tc.machineCIDR, + }, + }, + }, + }, + } + + err := applyNetworkConfig(cluster, opts) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(cluster.Spec.Networking.ClusterNetwork).To(HaveLen(tc.expectClusters)) + g.Expect(cluster.Spec.Networking.ServiceNetwork).To(HaveLen(tc.expectServices)) + g.Expect(cluster.Spec.Networking.MachineNetwork).To(HaveLen(tc.expectMachines)) + } + }) + } +} + +func TestApplySchedulingConfig(t *testing.T) { + tests := []struct { + name string + nodeSelector map[string]string + podsLabels map[string]string + expectNodeSelector map[string]string + expectLabels map[string]string + }{ + { + name: "When node selector is provided, it should set NodeSelector on the cluster", + nodeSelector: map[string]string{"role": "cp", "disk": "fast"}, + expectNodeSelector: map[string]string{"role": "cp", "disk": "fast"}, + }, + { + name: "When pods labels are provided, it should set Labels on the cluster", + podsLabels: map[string]string{"team": "hypershift"}, + expectLabels: map[string]string{"team": "hypershift"}, + }, + { + name: "When neither is provided, it should not modify the cluster", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + NodeSelector: tc.nodeSelector, + PodsLabels: tc.podsLabels, + }, + }, + }, + }, + } + + applySchedulingConfig(cluster, opts) + if tc.expectNodeSelector != nil { + g.Expect(cluster.Spec.NodeSelector).To(Equal(tc.expectNodeSelector)) + } + if tc.expectLabels != nil { + g.Expect(cluster.Spec.Labels).To(Equal(tc.expectLabels)) + } + // When neither node selector nor labels are provided, verify cluster remains unmodified + if tc.nodeSelector == nil && tc.podsLabels == nil { + g.Expect(cluster.Spec.NodeSelector).To(BeNil(), "NodeSelector should remain nil when not specified") + g.Expect(cluster.Spec.Labels).To(BeNil(), "Labels should remain nil when not specified") + } + }) + } +} + +func TestParseTolerationString(t *testing.T) { + tests := []struct { + name string + input string + expectError string + expected *corev1.Toleration + }{ + { + name: "When a full toleration is specified, it should parse all fields", + input: "key=node-role.kubernetes.io/master,operator=Exists,effect=NoSchedule", + expected: &corev1.Toleration{ + Key: "node-role.kubernetes.io/master", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + { + name: "When operator is Equal and a value is specified, it should parse correctly", + input: "key=mykey,value=myvalue,operator=Equal,effect=NoExecute", + expected: &corev1.Toleration{ + Key: "mykey", + Value: "myvalue", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoExecute, + }, + }, + { + name: "When tolerationSeconds is specified, it should parse the integer value", + input: "key=mykey,operator=Exists,effect=NoExecute,tolerationSeconds=300", + expected: &corev1.Toleration{ + Key: "mykey", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoExecute, + TolerationSeconds: ptr.To(int64(300)), + }, + }, + { + name: "When effect is PreferNoSchedule, it should normalize the case", + input: "key=mykey,effect=preferNoSchedule", + expected: &corev1.Toleration{ + Key: "mykey", + Effect: corev1.TaintEffectPreferNoSchedule, + }, + }, + { + name: "When an unknown operator type is provided, it should return an error", + input: "key=mykey,operator=Unknown", + expectError: "unknown operator type", + }, + { + name: "When an unknown effect type is provided, it should return an error", + input: "key=mykey,effect=Unknown", + expectError: "unknown effect type", + }, + { + name: "When a malformed key-value is provided, it should return an error", + input: "badformat", + expectError: "invalid toleration cli argument", + }, + { + name: "When an unknown field is provided, it should return an error", + input: "unknownfield=value", + expectError: "unknown field", + }, + { + name: "When tolerationSeconds is not an integer, it should return an error", + input: "key=mykey,tolerationSeconds=abc", + expectError: "failed to parse tolerationSeconds", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result, err := parseTolerationString(tc.input) + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(result).To(Equal(tc.expected)) + } + }) + } +} + +func TestPostProcess(t *testing.T) { + t.Run("When secret encryption is nil, it should default to AESCBC", func(t *testing.T) { + g := NewWithT(t) + r := &resources{ + Cluster: &hyperv1.HostedCluster{}, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: "test-cluster", + Namespace: "clusters", + }, + }, + }, + }, + } + + postProcess(r, opts) + + g.Expect(r.Cluster.Spec.SecretEncryption).NotTo(BeNil()) + g.Expect(r.Cluster.Spec.SecretEncryption.Type).To(Equal(hyperv1.AESCBC)) + g.Expect(r.Cluster.Spec.SecretEncryption.AESCBC).NotTo(BeNil()) + g.Expect(r.Cluster.Spec.SecretEncryption.AESCBC.ActiveKey.Name).To(Equal("test-cluster-etcd-encryption-key")) + g.Expect(r.Resources).To(HaveLen(1)) + }) + + t.Run("When secret encryption is already set, it should not override it", func(t *testing.T) { + g := NewWithT(t) + r := &resources{ + Cluster: &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + SecretEncryption: &hyperv1.SecretEncryptionSpec{ + Type: hyperv1.KMS, + }, + }, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: "test-cluster", + Namespace: "clusters", + }, + }, + }, + }, + } + + postProcess(r, opts) + + g.Expect(r.Cluster.Spec.SecretEncryption.Type).To(Equal(hyperv1.KMS)) + g.Expect(r.Resources).To(BeEmpty()) + }) +} + +func TestDefaultNodePool(t *testing.T) { + tests := []struct { + name string + clusterName string + suffix string + expectedName string + }{ + { + name: "When no suffix is provided, it should use the cluster name", + clusterName: "my-cluster", + suffix: "", + expectedName: "my-cluster", + }, + { + name: "When a suffix is provided, it should append it to the cluster name", + clusterName: "my-cluster", + suffix: "us-east-1a", + expectedName: "my-cluster-us-east-1a", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + replicas := int32(3) + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + Name: tc.clusterName, + Namespace: "clusters", + NodePoolReplicas: replicas, + ReleaseImage: "quay.io/openshift/ocp:4.16", + AutoRepair: true, + Arch: "amd64", + NodeDrainTimeout: 5 * time.Minute, + }, + }, + }, + }, + } + + constructor := defaultNodePool(opts) + np := constructor(hyperv1.AWSPlatform, tc.suffix) + + g.Expect(np.Name).To(Equal(tc.expectedName)) + g.Expect(np.Namespace).To(Equal("clusters")) + g.Expect(np.Spec.ClusterName).To(Equal(tc.clusterName)) + g.Expect(np.Spec.Platform.Type).To(Equal(hyperv1.AWSPlatform)) + g.Expect(*np.Spec.Replicas).To(Equal(replicas)) + g.Expect(np.Spec.Management.AutoRepair).To(BeTrue()) + g.Expect(np.Spec.Arch).To(Equal("amd64")) + g.Expect(np.Spec.Release.Image).To(Equal("quay.io/openshift/ocp:4.16")) + }) + } +} + +func TestValidateArchAndFeatureSet(t *testing.T) { + tests := []struct { + name string + arch string + featureSet string + expectError string + }{ + { + name: "When amd64 arch and Default feature set are provided, it should pass", + arch: "amd64", + featureSet: string(configv1.Default), + }, + { + name: "When arm64 arch is provided, it should pass", + arch: "arm64", + featureSet: string(configv1.Default), + }, + { + name: "When ppc64le arch is provided, it should pass", + arch: "ppc64le", + featureSet: string(configv1.Default), + }, + { + name: "When s390x arch is provided, it should pass", + arch: "s390x", + featureSet: string(configv1.Default), + }, + { + name: "When an unsupported arch is provided, it should return an error", + arch: "mips", + featureSet: string(configv1.Default), + expectError: "specified arch", + }, + { + name: "When TechPreviewNoUpgrade feature set is provided, it should pass", + arch: "amd64", + featureSet: string(configv1.TechPreviewNoUpgrade), + }, + { + name: "When DevPreviewNoUpgrade feature set is provided, it should pass", + arch: "amd64", + featureSet: string(configv1.DevPreviewNoUpgrade), + }, + { + name: "When CustomNoUpgrade feature set is provided, it should return an error", + arch: "amd64", + featureSet: string(configv1.CustomNoUpgrade), + expectError: "only a predefined feature set is supported", + }, + { + name: "When an unknown feature set is provided, it should return an error", + arch: "amd64", + featureSet: "SomethingRandom", + expectError: "specified feature set", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + opts := &RawCreateOptions{ + Arch: tc.arch, + FeatureSet: tc.featureSet, + } + + err := opts.validateArchAndFeatureSet() + if tc.expectError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(tc.expectError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + }) + } +} + +func TestApplyClusterCapabilities(t *testing.T) { + tests := []struct { + name string + enableCaps []string + disableCaps []string + expectEnabled []hyperv1.OptionalCapability + expectDisabled []hyperv1.OptionalCapability + }{ + { + name: "When both enable and disable capabilities are provided, it should set both", + enableCaps: []string{"baremetal", "Console"}, + disableCaps: []string{"ImageRegistry"}, + expectEnabled: []hyperv1.OptionalCapability{"baremetal", "Console"}, + expectDisabled: []hyperv1.OptionalCapability{"ImageRegistry"}, + }, + { + name: "When only disable capabilities are provided, it should set disabled only", + disableCaps: []string{"Insights", "NodeTuning"}, + expectDisabled: []hyperv1.OptionalCapability{"Insights", "NodeTuning"}, + }, + { + name: "When neither enable nor disable are provided, it should not set capabilities", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + Capabilities: &hyperv1.Capabilities{}, + }, + } + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + EnableClusterCapabilities: tc.enableCaps, + DisableClusterCapabilities: tc.disableCaps, + }, + }, + }, + }, + } + + applyClusterCapabilities(cluster, opts) + + if tc.expectEnabled != nil { + g.Expect(cluster.Spec.Capabilities.Enabled).To(Equal(tc.expectEnabled)) + } else { + g.Expect(cluster.Spec.Capabilities.Enabled).To(BeNil()) + } + if tc.expectDisabled != nil { + g.Expect(cluster.Spec.Capabilities.Disabled).To(Equal(tc.expectDisabled)) + } else { + g.Expect(cluster.Spec.Capabilities.Disabled).To(BeNil()) + } + }) + } +} + +func TestEnsureClusterNetworkOperatorSpec(t *testing.T) { + t.Run("When OperatorConfiguration is nil, it should initialize both OperatorConfiguration and ClusterNetworkOperator", func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{} + + ensureClusterNetworkOperatorSpec(cluster) + + g.Expect(cluster.Spec.OperatorConfiguration).NotTo(BeNil()) + g.Expect(cluster.Spec.OperatorConfiguration.ClusterNetworkOperator).NotTo(BeNil()) + }) + + t.Run("When OperatorConfiguration exists but ClusterNetworkOperator is nil, it should initialize ClusterNetworkOperator", func(t *testing.T) { + g := NewWithT(t) + cluster := &hyperv1.HostedCluster{ + Spec: hyperv1.HostedClusterSpec{ + OperatorConfiguration: &hyperv1.OperatorConfiguration{}, + }, + } + + ensureClusterNetworkOperatorSpec(cluster) + + g.Expect(cluster.Spec.OperatorConfiguration.ClusterNetworkOperator).NotTo(BeNil()) + }) +} diff --git a/cmd/fix/dr_oidc_iam.go b/cmd/fix/dr_oidc_iam.go index 131f216a750..381a3b8905e 100644 --- a/cmd/fix/dr_oidc_iam.go +++ b/cmd/fix/dr_oidc_iam.go @@ -449,38 +449,13 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { o.Retryer = retryerFn() }) - // Step 1: Check if OIDC documents exist in S3 - fmt.Println("Step 1: Checking OIDC documents in S3") - oidcDocsExist := o.checkOIDCDocumentsExist(ctx, s3Client) - if oidcDocsExist && !o.ForceRecreate { - fmt.Printf("- (%s) OIDC documents found in S3\n", greenCheck()) - } else { - if o.ForceRecreate { - fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC documents\n", yellowForce()) - } else { - fmt.Printf("- (%s) OIDC documents missing in S3 - will create them\n", redX()) - } - } - - // Step 2: Check if OIDC identity provider exists in IAM - fmt.Println("Step 2: Checking OIDC identity provider in IAM") - providerARN, exists, err := o.checkOIDCProvider(ctx, iamClient) + oidcDocsExist, providerARN, exists, err := o.checkOIDCState(ctx, s3Client, iamClient) if err != nil { - return fmt.Errorf("failed to check OIDC provider: %w", err) + return err } - - if exists && !o.ForceRecreate { - fmt.Printf("- (%s) OIDC identity provider exists\n", greenCheck()) - if oidcDocsExist { - fmt.Println("\nAll OIDC components are in place - no action needed!") - return nil - } - } else { - if o.ForceRecreate { - fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC provider\n", yellowForce()) - } else { - fmt.Printf("- (%s) OIDC identity provider missing - needs to be recreated\n", redX()) - } + if oidcDocsExist && exists && !o.ForceRecreate { + fmt.Println("\nAll OIDC components are in place - no action needed!") + return nil } // Step 3: Create/ensure S3 bucket exists @@ -490,20 +465,8 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { } fmt.Printf("- (%s) S3 bucket ready with public access enabled\n", greenCheck()) - // Step 4: Generate and upload OIDC documents using the EXISTING cluster key - if !oidcDocsExist || o.ForceRecreate { - fmt.Println("Step 4: Generating and uploading OIDC documents") - if o.DryRun { - fmt.Printf("- (%s) DRY RUN: Would generate and upload OIDC documents using existing cluster signing key\n", yellowQuestion()) - } else { - if err := o.generateAndUploadOIDCDocuments(ctx, k8sClient, s3Client); err != nil { - return fmt.Errorf("failed to generate OIDC documents: %w", err) - } - fmt.Printf("- (%s) OIDC documents generated and uploaded\n", greenCheck()) - } - } else { - fmt.Println("Step 4: OIDC documents") - fmt.Printf("- (%s) OIDC documents already exist\n", greenCheck()) + if err := o.ensureOIDCDocuments(ctx, k8sClient, s3Client, oidcDocsExist); err != nil { + return err } // Step 5: Get SSL certificate thumbprint @@ -514,26 +477,8 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { } fmt.Printf("- (%s) SSL certificate thumbprint retrieved\n", greenCheck()) - // Step 6: Create/recreate OIDC provider - if !exists || o.ForceRecreate { - fmt.Println("Step 6: Creating/recreating OIDC identity provider") - if o.DryRun { - fmt.Printf("- (%s) DRY RUN: Would create OIDC provider\n", yellowQuestion()) - fmt.Printf(" Issuer: %s\n", o.Issuer) - fmt.Printf(" Thumbprint: %s\n", thumbprint) - fmt.Printf(" Allowed clients: openshift, sts.amazonaws.com\n") - } else { - if err := o.deleteOIDCProviderIfExists(ctx, iamClient, providerARN); err != nil { - return fmt.Errorf("failed to delete existing OIDC provider: %w", err) - } - if _, err := o.createOIDCProvider(ctx, iamClient, thumbprint); err != nil { - return fmt.Errorf("failed to create OIDC provider: %w", err) - } - fmt.Printf("- (%s) OIDC identity provider successfully created\n", greenCheck()) - } - } else { - fmt.Println("Step 6: OIDC identity provider") - fmt.Printf("- (%s) OIDC identity provider already exists\n", greenCheck()) + if err := o.ensureOIDCProvider(ctx, iamClient, providerARN, thumbprint, exists); err != nil { + return err } // Step 7: Verify configuration and update HostedCluster @@ -548,6 +493,80 @@ func (o *DrOidcIamOptions) Run(ctx context.Context) error { return nil } +func (o *DrOidcIamOptions) checkOIDCState(ctx context.Context, s3Client *s3.Client, iamClient *iam.Client) (bool, string, bool, error) { + fmt.Println("Step 1: Checking OIDC documents in S3") + oidcDocsExist := o.checkOIDCDocumentsExist(ctx, s3Client) + if oidcDocsExist && !o.ForceRecreate { + fmt.Printf("- (%s) OIDC documents found in S3\n", greenCheck()) + } else if o.ForceRecreate { + fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC documents\n", yellowForce()) + } else { + fmt.Printf("- (%s) OIDC documents missing in S3 - will create them\n", redX()) + } + + fmt.Println("Step 2: Checking OIDC identity provider in IAM") + providerARN, exists, err := o.checkOIDCProvider(ctx, iamClient) + if err != nil { + return false, "", false, fmt.Errorf("failed to check OIDC provider: %w", err) + } + + if exists && !o.ForceRecreate { + fmt.Printf("- (%s) OIDC identity provider exists\n", greenCheck()) + } else if o.ForceRecreate { + fmt.Printf("- (%s) Force recreate enabled - will regenerate OIDC provider\n", yellowForce()) + } else { + fmt.Printf("- (%s) OIDC identity provider missing - needs to be recreated\n", redX()) + } + + return oidcDocsExist, providerARN, exists, nil +} + +func (o *DrOidcIamOptions) ensureOIDCDocuments(ctx context.Context, k8sClient client.Client, s3Client *s3.Client, oidcDocsExist bool) error { + if oidcDocsExist && !o.ForceRecreate { + fmt.Println("Step 4: OIDC documents") + fmt.Printf("- (%s) OIDC documents already exist\n", greenCheck()) + return nil + } + + fmt.Println("Step 4: Generating and uploading OIDC documents") + if o.DryRun { + fmt.Printf("- (%s) DRY RUN: Would generate and upload OIDC documents using existing cluster signing key\n", yellowQuestion()) + return nil + } + + if err := o.generateAndUploadOIDCDocuments(ctx, k8sClient, s3Client); err != nil { + return fmt.Errorf("failed to generate OIDC documents: %w", err) + } + fmt.Printf("- (%s) OIDC documents generated and uploaded\n", greenCheck()) + return nil +} + +func (o *DrOidcIamOptions) ensureOIDCProvider(ctx context.Context, iamClient *iam.Client, providerARN, thumbprint string, exists bool) error { + if exists && !o.ForceRecreate { + fmt.Println("Step 6: OIDC identity provider") + fmt.Printf("- (%s) OIDC identity provider already exists\n", greenCheck()) + return nil + } + + fmt.Println("Step 6: Creating/recreating OIDC identity provider") + if o.DryRun { + fmt.Printf("- (%s) DRY RUN: Would create OIDC provider\n", yellowQuestion()) + fmt.Printf(" Issuer: %s\n", o.Issuer) + fmt.Printf(" Thumbprint: %s\n", thumbprint) + fmt.Printf(" Allowed clients: openshift, sts.amazonaws.com\n") + return nil + } + + if err := o.deleteOIDCProviderIfExists(ctx, iamClient, providerARN); err != nil { + return fmt.Errorf("failed to delete existing OIDC provider: %w", err) + } + if _, err := o.createOIDCProvider(ctx, iamClient, thumbprint); err != nil { + return fmt.Errorf("failed to create OIDC provider: %w", err) + } + fmt.Printf("- (%s) OIDC identity provider successfully created\n", greenCheck()) + return nil +} + func (o *DrOidcIamOptions) verifyAndUpdateHostedCluster(ctx context.Context, k8sClient client.Client, s3Client *s3.Client, iamClient *iam.Client) error { if o.DryRun { fmt.Printf(" - (%s) DRY RUN: Would verify OIDC configuration\n", yellowQuestion()) diff --git a/cmd/infra/aws/create.go b/cmd/infra/aws/create.go index 474736c2cb7..1c2e35eaca7 100644 --- a/cmd/infra/aws/create.go +++ b/cmd/infra/aws/create.go @@ -199,24 +199,87 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C return nil, err } - awsSession, err := o.AWSCredentialsOpts.GetSession(ctx, "cli-create-infra", o.CredentialsSecretData, o.Region) + awsSession, vpcOwnerSession, err := o.initAWSSessions(ctx) + if err != nil { + return nil, err + } + + clusterCreatorEC2Client, ec2Client, route53Client, vpcOwnerRoute53Client := o.initAWSClients(awsSession, vpcOwnerSession) + + if err := o.parseAdditionalTags(); err != nil { + return nil, err + } + + result := &CreateInfraOutput{ + InfraID: o.InfraID, + MachineCIDR: o.VPCCIDR, + Region: o.Region, + Name: o.Name, + BaseDomain: o.BaseDomain, + BaseDomainPrefix: o.BaseDomainPrefix, + PublicOnly: o.PublicOnly, + } + if len(o.Zones) == 0 { + zone, err := o.firstZone(ctx, l, ec2Client) + if err != nil { + return nil, err + } + o.Zones = append(o.Zones, zone) + } + + igwID, err := o.createVPCResources(ctx, l, ec2Client, result) + if err != nil { + return nil, err + } + + publicSubnetIDs, err := o.createPerZoneResources(ctx, l, ec2Client, result, igwID) if err != nil { return nil, err } + + result.PublicZoneID, err = o.LookupPublicZone(ctx, l, route53Client) + if err != nil { + return nil, err + } + + if vpcOwnerSession != nil { + if err := o.shareSubnets(ctx, l, vpcOwnerSession, awsSession, publicSubnetIDs, result); err != nil { + return nil, err + } + } + + if err := o.createDNSZones(ctx, l, clusterCreatorEC2Client, route53Client, vpcOwnerRoute53Client, result); err != nil { + return nil, err + } + + if err := o.createProxyResources(ctx, l, ec2Client, result, publicSubnetIDs); err != nil { + return nil, err + } + + return result, nil +} + +func (o *CreateInfraOptions) initAWSSessions(ctx context.Context) (*aws.Config, *aws.Config, error) { + awsSession, err := o.AWSCredentialsOpts.GetSession(ctx, "cli-create-infra", o.CredentialsSecretData, o.Region) + if err != nil { + return nil, nil, err + } var vpcOwnerSession *aws.Config if o.VPCOwnerCredentialOpts.AWSCredentialsFile != "" { vpcOwnerSession, err = o.VPCOwnerCredentialOpts.GetSession(ctx, "cli-create-infra", nil, o.Region) if err != nil { - return nil, err + return nil, nil, err } } + return awsSession, vpcOwnerSession, nil +} - var clusterCreatorEC2Client, ec2Client awsapi.EC2API - var vpcOwnerRoute53Client, route53Client awsapi.ROUTE53API +func (o *CreateInfraOptions) initAWSClients(awsSession, vpcOwnerSession *aws.Config) (awsapi.EC2API, awsapi.EC2API, awsapi.ROUTE53API, awsapi.ROUTE53API) { awsConfig := awsutil.NewConfig() - clusterCreatorEC2Client = ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { + clusterCreatorEC2Client := ec2.NewFromConfig(*awsSession, func(o *ec2.Options) { o.Retryer = awsConfig() }) + var ec2Client awsapi.EC2API if vpcOwnerSession != nil { ec2Client = ec2.NewFromConfig(*vpcOwnerSession, func(o *ec2.Options) { o.Retryer = awsConfig() @@ -224,9 +287,10 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C } else { ec2Client = clusterCreatorEC2Client } - route53Client = route53.NewFromConfig(*awsSession, func(o *route53.Options) { + route53Client := route53.NewFromConfig(*awsSession, func(o *route53.Options) { o.Retryer = awsutil.NewRoute53Config()() }) + var vpcOwnerRoute53Client awsapi.ROUTE53API if vpcOwnerSession != nil { vpcOwnerRoute53Client = route53.NewFromConfig(*vpcOwnerSession, func(o *route53.Options) { o.Retryer = awsutil.NewRoute53Config()() @@ -234,42 +298,26 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C } else { vpcOwnerRoute53Client = route53Client } + return clusterCreatorEC2Client, ec2Client, route53Client, vpcOwnerRoute53Client +} - if err := o.parseAdditionalTags(); err != nil { - return nil, err - } - - result := &CreateInfraOutput{ - InfraID: o.InfraID, - MachineCIDR: o.VPCCIDR, - Region: o.Region, - Name: o.Name, - BaseDomain: o.BaseDomain, - BaseDomainPrefix: o.BaseDomainPrefix, - PublicOnly: o.PublicOnly, - } - if len(o.Zones) == 0 { - zone, err := o.firstZone(ctx, l, ec2Client) - if err != nil { - return nil, err - } - o.Zones = append(o.Zones, zone) - } - - // VPC resources +func (o *CreateInfraOptions) createVPCResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput) (string, error) { + var err error result.VPCID, err = o.createVPC(ctx, l, ec2Client) if err != nil { - return nil, err + return "", err } if err = o.CreateDHCPOptions(ctx, l, ec2Client, result.VPCID); err != nil { - return nil, err + return "", err } igwID, err := o.CreateInternetGateway(ctx, l, ec2Client, result.VPCID) if err != nil { - return nil, err + return "", err } + return igwID, nil +} - // Per zone resources +func (o *CreateInfraOptions) createPerZoneResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput, igwID string) ([]string, error) { _, cidrNetwork, err := net.ParseCIDR(o.VPCCIDR) if err != nil { return nil, err @@ -285,10 +333,7 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C var publicSubnetIDs []string var natGatewayID string for _, zone := range o.Zones { - var ( - privateSubnetID string - err error - ) + var privateSubnetID string if !o.PublicOnly { privateSubnetID, err = o.CreatePrivateSubnet(ctx, l, ec2Client, result.VPCID, zone, privateNetwork.String()) if err != nil { @@ -321,7 +366,6 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C Name: zone, SubnetID: zoneSubnetID, }) - // increment each subnet by /20 privateNetwork.IP[2] = privateNetwork.IP[2] + 16 publicNetwork.IP[2] = publicNetwork.IP[2] + 16 } @@ -334,66 +378,64 @@ func (o *CreateInfraOptions) CreateInfra(ctx context.Context, l logr.Logger) (*C if err != nil { return nil, err } - result.PublicZoneID, err = o.LookupPublicZone(ctx, l, route53Client) - if err != nil { - return nil, err - } - - if vpcOwnerSession != nil { - if err := o.shareSubnets(ctx, l, vpcOwnerSession, awsSession, publicSubnetIDs, result); err != nil { - return nil, err - } - } + return publicSubnetIDs, nil +} +func (o *CreateInfraOptions) createDNSZones(ctx context.Context, l logr.Logger, clusterCreatorEC2Client awsapi.EC2API, route53Client, vpcOwnerRoute53Client awsapi.ROUTE53API, result *CreateInfraOutput) error { privateZoneClient := vpcOwnerRoute53Client var initialVPC string + var err error if o.PrivateZonesInClusterAccount { privateZoneClient = route53Client - - // Create a dummy vpc that we can use to create the private hosted zones if initialVPC, err = o.createVPC(ctx, l, clusterCreatorEC2Client); err != nil { - return nil, err + return err } + defer func() { + if initialVPC != "" { + if cleanupErr := o.deleteVPC(ctx, l, clusterCreatorEC2Client, initialVPC); cleanupErr != nil { + l.Error(cleanupErr, "Failed to clean up temporary VPC", "id", initialVPC) + } + } + }() } result.PrivateZoneID, err = o.CreatePrivateZone(ctx, l, privateZoneClient, ZoneName(o.Name, o.BaseDomainPrefix, o.BaseDomain), result.VPCID, o.PrivateZonesInClusterAccount, vpcOwnerRoute53Client, initialVPC) if err != nil { - return nil, err + return err } result.LocalZoneID, err = o.CreatePrivateZone(ctx, l, privateZoneClient, fmt.Sprintf("%s.%s", o.Name, hypershiftLocalZoneName), result.VPCID, o.PrivateZonesInClusterAccount, vpcOwnerRoute53Client, initialVPC) if err != nil { - return nil, err + return err } - if initialVPC != "" { - if err := o.deleteVPC(ctx, l, clusterCreatorEC2Client, initialVPC); err != nil { - return nil, err - } + return nil +} + +func (o *CreateInfraOptions) createProxyResources(ctx context.Context, l logr.Logger, ec2Client awsapi.EC2API, result *CreateInfraOutput, publicSubnetIDs []string) error { + if !o.EnableProxy && !o.EnableSecureProxy { + return nil + } + sgGroupID, err := o.createProxySecurityGroup(ctx, l, ec2Client, result.VPCID) + if err != nil { + return fmt.Errorf("failed to create security group for proxy: %w", err) } - if o.EnableProxy || o.EnableSecureProxy { - sgGroupID, err := o.createProxySecurityGroup(ctx, l, ec2Client, result.VPCID) + if o.ProxyVPCEndpointServiceName != "" { + result.ProxyAddr, err = o.createProxyVPCEndpoint(ctx, l, ec2Client, result.VPCID, result.Zones[0].SubnetID, sgGroupID) if err != nil { - return nil, fmt.Errorf("failed to create security group for proxy: %w", err) + return err } - - if o.ProxyVPCEndpointServiceName != "" { - result.ProxyAddr, err = o.createProxyVPCEndpoint(ctx, l, ec2Client, result.VPCID, result.Zones[0].SubnetID, sgGroupID) - if err != nil { - return nil, err - } - } else { - proxyResult, err := o.createProxyHost(ctx, l, ec2Client, result.Zones[0].SubnetID, sgGroupID, o.EnableSecureProxy) - if err != nil { - return nil, fmt.Errorf("failed to create proxy host: %w", err) - } - result.ProxyAddr = proxyResult.HTTPProxyURL - result.SecureProxyAddr = proxyResult.HTTPSProxyURL - result.ProxyCA = proxyResult.CA - result.ProxyPrivateSSHKey = proxyResult.PrivateKey + } else { + proxyResult, err := o.createProxyHost(ctx, l, ec2Client, publicSubnetIDs[0], sgGroupID, o.EnableSecureProxy) + if err != nil { + return fmt.Errorf("failed to create proxy host: %w", err) } + result.ProxyAddr = proxyResult.HTTPProxyURL + result.SecureProxyAddr = proxyResult.HTTPSProxyURL + result.ProxyCA = proxyResult.CA + result.ProxyPrivateSSHKey = proxyResult.PrivateKey } - return result, nil + return nil } func (o *CreateInfraOptions) createProxySecurityGroup(ctx context.Context, l logr.Logger, client awsapi.EC2API, vpcID string) (string, error) { diff --git a/cmd/infra/azure/create.go b/cmd/infra/azure/create.go index 4f10a4ef6cf..191ed895423 100644 --- a/cmd/infra/azure/create.go +++ b/cmd/infra/azure/create.go @@ -125,7 +125,6 @@ func BindProductFlags(opts *CreateInfraOptions, flags *pflag.FlagSet) { // Run is the main function responsible for creating the Azure infrastructure resources for a HostedCluster. func (o *CreateInfraOptions) Run(ctx context.Context, l logr.Logger) (*CreateInfraOutput, error) { - // Validate deployment model flags to prevent conflicts between ARO HCP and self-managed Azure if err := o.validateDeploymentModelFlags(); err != nil { return nil, err } @@ -136,180 +135,196 @@ func (o *CreateInfraOptions) Run(ctx context.Context, l logr.Logger) (*CreateInf BaseDomain: o.BaseDomain, } - // Setup subscription ID and Azure credential information subscriptionID, azureCreds, err := util.SetupAzureCredentials(l, o.Credentials, o.CredentialsFile) if err != nil { return nil, fmt.Errorf("failed to setup Azure credentials: %w", err) } - // Initialize managers rgMgr := NewResourceGroupManager(subscriptionID, azureCreds, o.Cloud) netMgr := NewNetworkManager(subscriptionID, azureCreds, o.Cloud) rbacMgr := NewRBACManager(subscriptionID, azureCreds) - // Create main resource group + resourceGroupName, nsgResourceGroupName, vnetResourceGroupName, err := o.createNetworkResources(ctx, l, rgMgr, netMgr, &result) + if err != nil { + return nil, err + } + + if err := o.handleIdentitiesAndRBAC(ctx, rbacMgr, &result, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { + return nil, err + } + + if err := o.createDNSAndLBResources(ctx, l, netMgr, &result, resourceGroupName); err != nil { + return nil, err + } + + if err := o.writeOutputFile(l, result); err != nil { + return nil, err + } + + return &result, nil +} + +func (o *CreateInfraOptions) createNetworkResources(ctx context.Context, l logr.Logger, rgMgr *ResourceGroupManager, netMgr *NetworkManager, result *CreateInfraOutput) (string, string, string, error) { resourceGroupName, msg, err := rgMgr.CreateOrGetResourceGroup(ctx, o, "") if err != nil { - return nil, fmt.Errorf("failed to create a resource group: %w", err) + return "", "", "", fmt.Errorf("failed to create a resource group: %w", err) } result.ResourceGroupName = resourceGroupName l.Info(msg, "name", resourceGroupName) - // Get base DNS zone ID result.PublicZoneID, err = netMgr.GetBaseDomainID(ctx, o.BaseDomain) if err != nil { - return nil, err + return "", "", "", err } - // Handle network security group nsgResourceGroupName := "" if len(o.NetworkSecurityGroupID) > 0 { result.SecurityGroupID = o.NetworkSecurityGroupID _, nsgResourceGroupName, err = azureutil.GetNameAndResourceGroupFromNetworkSecurityGroupID(o.NetworkSecurityGroupID) if err != nil { - return nil, err + return "", "", "", err } l.Info("Using existing network security group", "ID", result.SecurityGroupID) } else { nsgResourceGroupName = o.Name + "-nsg" nsgResourceGroupName, msg, err = rgMgr.CreateOrGetResourceGroup(ctx, o, nsgResourceGroupName) if err != nil { - return nil, fmt.Errorf("failed to create resource group for network security group: %w", err) + return "", "", "", fmt.Errorf("failed to create resource group for network security group: %w", err) } l.Info(msg, "name", nsgResourceGroupName) nsgID, err := netMgr.CreateSecurityGroup(ctx, nsgResourceGroupName, o.Name, o.InfraID, o.Location) if err != nil { - return nil, err + return "", "", "", err } result.SecurityGroupID = nsgID l.Info("Successfully created network security group", "ID", result.SecurityGroupID) } - // Handle subnet if len(o.SubnetID) > 0 { result.SubnetID = o.SubnetID l.Info("Using existing subnet", "ID", result.SubnetID) } - // Handle virtual network vnetResourceGroupName := "" if len(o.VnetID) > 0 { result.VNetID = o.VnetID _, vnetResourceGroupName, err = azureutil.GetVnetNameAndResourceGroupFromVnetID(o.VnetID) if err != nil { - return nil, err + return "", "", "", err } l.Info("Using existing vnet", "ID", result.VNetID) } else { vnetResourceGroupName = o.Name + "-vnet" vnetResourceGroupName, msg, err = rgMgr.CreateOrGetResourceGroup(ctx, o, vnetResourceGroupName) if err != nil { - return nil, fmt.Errorf("failed to create resource group for virtual network: %w", err) + return "", "", "", fmt.Errorf("failed to create resource group for virtual network: %w", err) } l.Info(msg, "name", vnetResourceGroupName) vnet, err := netMgr.CreateVirtualNetwork(ctx, vnetResourceGroupName, o.Name, o.InfraID, o.Location, o.SubnetID, result.SecurityGroupID) if err != nil { - return nil, err + return "", "", "", err } result.SubnetID = *vnet.Properties.Subnets[0].ID result.VNetID = *vnet.ID l.Info("Successfully created vnet", "ID", result.VNetID) } - // Handle managed identities and RBAC + return resourceGroupName, nsgResourceGroupName, vnetResourceGroupName, nil +} + +func (o *CreateInfraOptions) handleIdentitiesAndRBAC(ctx context.Context, rbacMgr *RBACManager, result *CreateInfraOutput, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName string) error { if o.ManagedIdentitiesFile != "" { result.ControlPlaneMIs = &hyperv1.AzureResourceManagedIdentities{} managedIdentitiesRaw, err := os.ReadFile(o.ManagedIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --managed-identities-file %s: %w", o.ManagedIdentitiesFile, err) + return fmt.Errorf("failed to read --managed-identities-file %s: %w", o.ManagedIdentitiesFile, err) } if err := yaml.Unmarshal(managedIdentitiesRaw, &result.ControlPlaneMIs.ControlPlane); err != nil { - return nil, fmt.Errorf("failed to unmarshal --managed-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --managed-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignControlPlaneRoles(ctx, o, result.ControlPlaneMIs, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { - return nil, err + return err } } } - // Handle data plane identities if o.DataPlaneIdentitiesFile != "" { dataPlaneIdentitiesRaw, err := os.ReadFile(o.DataPlaneIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --data-plane-identities-file %s: %w", o.DataPlaneIdentitiesFile, err) + return fmt.Errorf("failed to read --data-plane-identities-file %s: %w", o.DataPlaneIdentitiesFile, err) } if err := yaml.Unmarshal(dataPlaneIdentitiesRaw, &result.DataPlaneIdentities); err != nil { - return nil, fmt.Errorf("failed to unmarshal --data-plane-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --data-plane-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignDataPlaneRoles(ctx, o, result.DataPlaneIdentities, resourceGroupName); err != nil { - return nil, err + return err } } } - // Handle workload identities if o.WorkloadIdentitiesFile != "" { workloadIdentitiesRaw, err := os.ReadFile(o.WorkloadIdentitiesFile) if err != nil { - return nil, fmt.Errorf("failed to read --workload-identities-file %s: %w", o.WorkloadIdentitiesFile, err) + return fmt.Errorf("failed to read --workload-identities-file %s: %w", o.WorkloadIdentitiesFile, err) } if err := json.Unmarshal(workloadIdentitiesRaw, &result.WorkloadIdentities); err != nil { - return nil, fmt.Errorf("failed to unmarshal --workload-identities-file: %w", err) + return fmt.Errorf("failed to unmarshal --workload-identities-file: %w", err) } - if o.AssignServicePrincipalRoles { if err := rbacMgr.AssignWorkloadIdentities(ctx, o, result.WorkloadIdentities, resourceGroupName, nsgResourceGroupName, vnetResourceGroupName); err != nil { - return nil, err + return err } } } - // Create DNS infrastructure + return nil +} + +func (o *CreateInfraOptions) createDNSAndLBResources(ctx context.Context, l logr.Logger, netMgr *NetworkManager, result *CreateInfraOutput, resourceGroupName string) error { privateDNSZoneID, privateDNSZoneName, err := netMgr.CreatePrivateDNSZone(ctx, resourceGroupName, o.Name, o.BaseDomain) if err != nil { - return nil, err + return err } result.PrivateZoneID = privateDNSZoneID l.Info("Successfully created private DNS zone", "name", privateDNSZoneName) err = netMgr.CreatePrivateDNSZoneLink(ctx, resourceGroupName, o.Name, o.InfraID, result.VNetID, privateDNSZoneName) if err != nil { - return nil, err + return err } l.Info("Successfully created private DNS zone link") - // Create load balancer infrastructure publicIPAddress, err := netMgr.CreatePublicIPAddressForLB(ctx, resourceGroupName, o.InfraID, o.Location) if err != nil { - return nil, err + return err } l.Info("Successfully created public IP address for guest cluster egress load balancer") err = netMgr.CreateLoadBalancer(ctx, resourceGroupName, o.InfraID, o.Location, publicIPAddress) if err != nil { - return nil, err + return err } l.Info("Successfully created guest cluster egress load balancer") + return nil +} - // Serialize the result to the output file if it was provided - if o.OutputFile != "" { - resultSerialized, err := yaml.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to serialize result: %w", err) - } - if err := os.WriteFile(o.OutputFile, resultSerialized, 0644); err != nil { - l.Error(err, "Writing output file failed", "Output File", o.OutputFile, "data", string(resultSerialized)) - return nil, fmt.Errorf("failed to write result to --output-file: %w", err) - } +func (o *CreateInfraOptions) writeOutputFile(l logr.Logger, result CreateInfraOutput) error { + if o.OutputFile == "" { + return nil } - - return &result, nil + resultSerialized, err := yaml.Marshal(result) + if err != nil { + return fmt.Errorf("failed to serialize result: %w", err) + } + if err := os.WriteFile(o.OutputFile, resultSerialized, 0600); err != nil { + l.Error(err, "Writing output file failed", "Output File", o.OutputFile) + return fmt.Errorf("failed to write result to --output-file: %w", err) + } + return nil } // Validate validates the CreateInfraOptions before running the command diff --git a/cmd/install/assets/hypershift_operator.go b/cmd/install/assets/hypershift_operator.go index 3d55c46b9c4..592531ff749 100644 --- a/cmd/install/assets/hypershift_operator.go +++ b/cmd/install/assets/hypershift_operator.go @@ -530,345 +530,23 @@ type HyperShiftOperatorDeployment struct { } func (o HyperShiftOperatorDeployment) Build() *appsv1.Deployment { - args := []string{ - "run", - "--namespace=$(MY_NAMESPACE)", - "--pod-name=$(MY_NAME)", - "--metrics-addr=:9000", - fmt.Sprintf("--enable-dedicated-request-serving-isolation=%t", o.EnableDedicatedRequestServingIsolation), - fmt.Sprintf("--enable-ocp-cluster-monitoring=%t", o.EnableOCPClusterMonitoring), - fmt.Sprintf("--enable-ci-debug-output=%t", o.EnableCIDebugOutput), - fmt.Sprintf("--private-platform=%s", o.PrivatePlatform), - } - if o.RegistryOverrides != "" { - args = append(args, fmt.Sprintf("--registry-overrides=%s", o.RegistryOverrides)) - } - + args := o.buildArgs() + envVars := o.buildEnvVars() var volumeMounts []corev1.VolumeMount var initVolumeMounts []corev1.VolumeMount var volumes []corev1.Volume - envVars := []corev1.EnvVar{ - { - Name: "MY_NAMESPACE", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - FieldPath: "metadata.namespace", - }, - }, - }, - { - Name: "MY_NAME", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - FieldPath: "metadata.name", - }, - }, - }, - metrics.MetricsSetToEnv(o.MetricsSet), - { - Name: "CERT_ROTATION_SCALE", - Value: o.CertRotationScale.String(), - }, - { - Name: "KUBE_FEATURE_WatchListClient", - Value: "false", - }, - } - - // Add any additional environment variables specified if they don't already exist. - for key, value := range o.AdditionalOperatorEnvVars { - trimmedKey := strings.TrimSpace(key) - trimmedValue := strings.TrimSpace(value) - - if trimmedKey != "" && - trimmedValue != "" && - !slices.ContainsFunc(envVars, func(e corev1.EnvVar) bool { - return e.Name == trimmedKey - }) { - envVars = append(envVars, corev1.EnvVar{ - Name: trimmedKey, - Value: trimmedValue, - }) - } - } - - // Add the new HYPERSHIFT_FEATURESET env var if TPNU is set. - if o.TechPreviewNoUpgrade { - envVars = append(envVars, corev1.EnvVar{ - Name: "HYPERSHIFT_FEATURESET", - Value: string(configv1.TechPreviewNoUpgrade), - }) - } - - // Add audit log persistence env var if enabled - if o.EnableAuditLogPersistence { - envVars = append(envVars, corev1.EnvVar{ - Name: "ENABLE_AUDIT_LOG_PERSISTENCE", - Value: "true", - }) - } - if o.EnableWebhook { - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "serving-cert", - MountPath: "/var/run/secrets/serving-cert", - }) - volumes = append(volumes, corev1.Volume{ - Name: "serving-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: "manager-serving-cert", - }, - }, - }) - args = append(args, - "--cert-dir=/var/run/secrets/serving-cert", - ) - - if o.EnableValidatingWebhook { - args = append(args, "--enable-validating-webhook=true") - } - } - - if len(o.OIDCBucketName) > 0 && len(o.OIDCBucketRegion) > 0 && len(o.OIDCStorageProviderS3SecretKey) > 0 && - o.OIDCStorageProviderS3Secret != nil && len(o.OIDCStorageProviderS3Secret.Name) > 0 { - args = append(args, - "--oidc-storage-provider-s3-bucket-name="+o.OIDCBucketName, - "--oidc-storage-provider-s3-region="+o.OIDCBucketRegion, - "--oidc-storage-provider-s3-credentials=/etc/oidc-storage-provider-s3-creds/"+o.OIDCStorageProviderS3SecretKey, - ) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "oidc-storage-provider-s3-creds", - MountPath: "/etc/oidc-storage-provider-s3-creds", - }) - volumes = append(volumes, corev1.Volume{ - Name: "oidc-storage-provider-s3-creds", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.OIDCStorageProviderS3Secret.Name, - }, - }, - }) - } - - if o.ScaleFromZeroSecret != nil && len(o.ScaleFromZeroSecret.Name) > 0 && len(o.ScaleFromZeroSecretKey) > 0 && len(o.ScaleFromZeroProvider) > 0 { - args = append(args, - "--scale-from-zero-provider="+o.ScaleFromZeroProvider, - "--scale-from-zero-creds=/etc/scale-from-zero-creds/"+o.ScaleFromZeroSecretKey, - ) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "scale-from-zero-creds", - MountPath: "/etc/scale-from-zero-creds", - ReadOnly: true, - }) - volumes = append(volumes, corev1.Volume{ - Name: "scale-from-zero-creds", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.ScaleFromZeroSecret.Name, - }, - }, - }) - } + o.addWebhookResources(&args, &volumeMounts, &volumes) + o.addOIDCResources(&args, &volumeMounts, &volumes) + o.addScaleFromZeroResources(&args, &volumeMounts, &volumes) if o.UWMTelemetry { args = append(args, "--enable-uwm-telemetry-remote-write") } - if o.EnableCVOManagementClusterMetricsAccess { - envVars = append(envVars, corev1.EnvVar{ - Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, - Value: "1", - }) - } - - if len(o.ManagedService) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: "MANAGED_SERVICE", - Value: o.ManagedService, - }) - } - - if len(o.AROHCPKeyVaultUsersClientID) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: config.AROHCPKeyVaultManagedIdentityClientID, - Value: o.AROHCPKeyVaultUsersClientID, - }) - } - - if o.EnableSizeTagging { - envVars = append(envVars, corev1.EnvVar{ - Name: "ENABLE_SIZE_TAGGING", - Value: "1", - }) - } - - if o.EnableEtcdRecovery { - envVars = append(envVars, corev1.EnvVar{ - Name: config.EnableEtcdRecoveryEnvVar, - Value: "1", - }) - } - - if o.EnableCPOOverrides { - envVars = append(envVars, corev1.EnvVar{ - Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, - Value: "1", - }) - } - - if len(o.PlatformsInstalled) > 0 { - envVars = append(envVars, corev1.EnvVar{ - Name: "PLATFORMS_INSTALLED", - Value: o.PlatformsInstalled, - }) - } - - image := o.OperatorImage - - if mapImage, ok := o.Images["hypershift-operator"]; ok { - image = mapImage - } - tagMapping := images.TagMapping() - for tag, ref := range o.Images { - if envVar, exists := tagMapping[tag]; exists { - envVars = append(envVars, corev1.EnvVar{ - Name: envVar, - Value: ref, - }) - } - } - - privatePlatformType := hyperv1.PlatformType(o.PrivatePlatform) - if privatePlatformType != hyperv1.NonePlatform { - // Add platform specific settings - switch privatePlatformType { - case hyperv1.AWSPlatform: - // Add AWS credentials secret volume - volumes = append(volumes, corev1.Volume{ - Name: "credentials", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.AWSPrivateSecret.Name, - }, - }, - }) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "credentials", - MountPath: "/etc/provider", - }) - envVars = append(envVars, - corev1.EnvVar{ - Name: "AWS_SHARED_CREDENTIALS_FILE", - Value: "/etc/provider/" + o.AWSPrivateSecretKey, - }, - corev1.EnvVar{ - Name: "AWS_REGION", - Value: o.AWSPrivateRegion, - }, - corev1.EnvVar{ - Name: "AWS_SDK_LOAD_CONFIG", - Value: "1", - }) - case hyperv1.AzurePlatform: - if o.AzurePLSResourceGroup != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_RESOURCE_GROUP", - Value: o.AzurePLSResourceGroup, - }) - } - if o.AzurePLSManagedIdentityClientID != "" { - // Workload identity mode: the SA annotation triggers Azure AD Workload Identity - // webhook to inject federated tokens. Set the client ID as an env var so the - // HO platform controller can construct credentials. - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_PLS_CLIENT_ID", - Value: o.AzurePLSManagedIdentityClientID, - }) - if o.AzurePLSSubscriptionID != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_SUBSCRIPTION_ID", - Value: o.AzurePLSSubscriptionID, - }) - } - } else if o.AzurePrivateSecret != nil && len(o.AzurePrivateSecret.Name) > 0 && len(o.AzurePrivateSecretKey) > 0 { - volumes = append(volumes, corev1.Volume{ - Name: "azure-credentials", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: o.AzurePrivateSecret.Name, - }, - }, - }) - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "azure-credentials", - MountPath: "/etc/azure-provider", - ReadOnly: true, - }) - envVars = append(envVars, corev1.EnvVar{ - Name: "AZURE_CREDENTIALS_FILE", - Value: "/etc/azure-provider/" + o.AzurePrivateSecretKey, - }) - } - case hyperv1.GCPPlatform: - if o.GCPProject != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "GCP_PROJECT", - Value: o.GCPProject, - }) - } - if o.GCPRegion != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: "GCP_REGION", - Value: o.GCPRegion, - }) - } - } - - // Add AWS-specific volumes if AWS platform - if privatePlatformType == hyperv1.AWSPlatform { - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "token", - MountPath: "/var/run/secrets/openshift/serviceaccount", - }) - volumes = append(volumes, corev1.Volume{ - Name: "token", - VolumeSource: corev1.VolumeSource{ - Projected: &corev1.ProjectedVolumeSource{ - Sources: []corev1.VolumeProjection{ - { - ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ - Audience: "openshift", - Path: "token", - }, - }, - }, - }, - }, - }) - } - } - - if o.RHOBSMonitoring { - envVars = append(envVars, corev1.EnvVar{ - Name: rhobsmonitoring.EnvironmentVariable, - Value: "1", - }) - } - - if o.CVOPrometheusURL != "" { - envVars = append(envVars, corev1.EnvVar{ - Name: config.CVOPrometheusURLEnvVar, - Value: o.CVOPrometheusURL, - }) - } - - if o.MonitoringDashboards { - envVars = append(envVars, corev1.EnvVar{ - Name: "MONITORING_DASHBOARDS", - Value: "1", - }) - } + image := o.resolveImage() + o.addImageTagEnvVars(&envVars) + o.addPrivatePlatformResources(&envVars, &volumeMounts, &volumes) deployment := &appsv1.Deployment{ TypeMeta: metav1.TypeMeta{ @@ -1067,6 +745,280 @@ func (o HyperShiftOperatorDeployment) Build() *appsv1.Deployment { return deployment } +func (o HyperShiftOperatorDeployment) buildArgs() []string { + args := []string{ + "run", + "--namespace=$(MY_NAMESPACE)", + "--pod-name=$(MY_NAME)", + "--metrics-addr=:9000", + fmt.Sprintf("--enable-dedicated-request-serving-isolation=%t", o.EnableDedicatedRequestServingIsolation), + fmt.Sprintf("--enable-ocp-cluster-monitoring=%t", o.EnableOCPClusterMonitoring), + fmt.Sprintf("--enable-ci-debug-output=%t", o.EnableCIDebugOutput), + fmt.Sprintf("--private-platform=%s", o.PrivatePlatform), + } + if o.RegistryOverrides != "" { + args = append(args, fmt.Sprintf("--registry-overrides=%s", o.RegistryOverrides)) + } + return args +} + +func (o HyperShiftOperatorDeployment) buildEnvVars() []corev1.EnvVar { + envVars := []corev1.EnvVar{ + { + Name: "MY_NAMESPACE", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.namespace", + }, + }, + }, + { + Name: "MY_NAME", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.name", + }, + }, + }, + metrics.MetricsSetToEnv(o.MetricsSet), + { + Name: "CERT_ROTATION_SCALE", + Value: o.CertRotationScale.String(), + }, + { + Name: "KUBE_FEATURE_WatchListClient", + Value: "false", + }, + } + + for key, value := range o.AdditionalOperatorEnvVars { + trimmedKey := strings.TrimSpace(key) + trimmedValue := strings.TrimSpace(value) + if trimmedKey != "" && + trimmedValue != "" && + !slices.ContainsFunc(envVars, func(e corev1.EnvVar) bool { + return e.Name == trimmedKey + }) { + envVars = append(envVars, corev1.EnvVar{ + Name: trimmedKey, + Value: trimmedValue, + }) + } + } + + if o.TechPreviewNoUpgrade { + envVars = append(envVars, corev1.EnvVar{Name: "HYPERSHIFT_FEATURESET", Value: string(configv1.TechPreviewNoUpgrade)}) + } + if o.EnableAuditLogPersistence { + envVars = append(envVars, corev1.EnvVar{Name: "ENABLE_AUDIT_LOG_PERSISTENCE", Value: "true"}) + } + if o.EnableCVOManagementClusterMetricsAccess { + envVars = append(envVars, corev1.EnvVar{Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, Value: "1"}) + } + if len(o.ManagedService) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: "MANAGED_SERVICE", Value: o.ManagedService}) + } + if len(o.AROHCPKeyVaultUsersClientID) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: config.AROHCPKeyVaultManagedIdentityClientID, Value: o.AROHCPKeyVaultUsersClientID}) + } + if o.EnableSizeTagging { + envVars = append(envVars, corev1.EnvVar{Name: "ENABLE_SIZE_TAGGING", Value: "1"}) + } + if o.EnableEtcdRecovery { + envVars = append(envVars, corev1.EnvVar{Name: config.EnableEtcdRecoveryEnvVar, Value: "1"}) + } + if o.EnableCPOOverrides { + envVars = append(envVars, corev1.EnvVar{Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, Value: "1"}) + } + if len(o.PlatformsInstalled) > 0 { + envVars = append(envVars, corev1.EnvVar{Name: "PLATFORMS_INSTALLED", Value: o.PlatformsInstalled}) + } + if o.RHOBSMonitoring { + envVars = append(envVars, corev1.EnvVar{Name: rhobsmonitoring.EnvironmentVariable, Value: "1"}) + } + if o.CVOPrometheusURL != "" { + envVars = append(envVars, corev1.EnvVar{Name: config.CVOPrometheusURLEnvVar, Value: o.CVOPrometheusURL}) + } + if o.MonitoringDashboards { + envVars = append(envVars, corev1.EnvVar{Name: "MONITORING_DASHBOARDS", Value: "1"}) + } + return envVars +} + +func (o HyperShiftOperatorDeployment) addWebhookResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if !o.EnableWebhook { + return + } + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "serving-cert", + MountPath: "/var/run/secrets/serving-cert", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "serving-cert", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "manager-serving-cert", + }, + }, + }) + *args = append(*args, "--cert-dir=/var/run/secrets/serving-cert") + if o.EnableValidatingWebhook { + *args = append(*args, "--enable-validating-webhook=true") + } +} + +func (o HyperShiftOperatorDeployment) addOIDCResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if len(o.OIDCBucketName) == 0 || len(o.OIDCBucketRegion) == 0 || len(o.OIDCStorageProviderS3SecretKey) == 0 || + o.OIDCStorageProviderS3Secret == nil || len(o.OIDCStorageProviderS3Secret.Name) == 0 { + return + } + *args = append(*args, + "--oidc-storage-provider-s3-bucket-name="+o.OIDCBucketName, + "--oidc-storage-provider-s3-region="+o.OIDCBucketRegion, + "--oidc-storage-provider-s3-credentials=/etc/oidc-storage-provider-s3-creds/"+o.OIDCStorageProviderS3SecretKey, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "oidc-storage-provider-s3-creds", + MountPath: "/etc/oidc-storage-provider-s3-creds", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "oidc-storage-provider-s3-creds", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.OIDCStorageProviderS3Secret.Name, + }, + }, + }) +} + +func (o HyperShiftOperatorDeployment) addScaleFromZeroResources(args *[]string, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if o.ScaleFromZeroSecret == nil || len(o.ScaleFromZeroSecret.Name) == 0 || len(o.ScaleFromZeroSecretKey) == 0 || len(o.ScaleFromZeroProvider) == 0 { + return + } + *args = append(*args, + "--scale-from-zero-provider="+o.ScaleFromZeroProvider, + "--scale-from-zero-creds=/etc/scale-from-zero-creds/"+o.ScaleFromZeroSecretKey, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "scale-from-zero-creds", + MountPath: "/etc/scale-from-zero-creds", + ReadOnly: true, + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "scale-from-zero-creds", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.ScaleFromZeroSecret.Name, + }, + }, + }) +} + +func (o HyperShiftOperatorDeployment) resolveImage() string { + image := o.OperatorImage + if mapImage, ok := o.Images["hypershift-operator"]; ok { + image = mapImage + } + return image +} + +func (o HyperShiftOperatorDeployment) addImageTagEnvVars(envVars *[]corev1.EnvVar) { + tagMapping := images.TagMapping() + for tag, ref := range o.Images { + if envVar, exists := tagMapping[tag]; exists { + *envVars = append(*envVars, corev1.EnvVar{ + Name: envVar, + Value: ref, + }) + } + } +} + +func (o HyperShiftOperatorDeployment) addPrivatePlatformResources(envVars *[]corev1.EnvVar, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + privatePlatformType := hyperv1.PlatformType(o.PrivatePlatform) + if privatePlatformType == hyperv1.NonePlatform { + return + } + switch privatePlatformType { + case hyperv1.AWSPlatform: + *volumes = append(*volumes, corev1.Volume{ + Name: "credentials", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.AWSPrivateSecret.Name, + }, + }, + }) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "credentials", + MountPath: "/etc/provider", + }) + *envVars = append(*envVars, + corev1.EnvVar{Name: "AWS_SHARED_CREDENTIALS_FILE", Value: "/etc/provider/" + o.AWSPrivateSecretKey}, + corev1.EnvVar{Name: "AWS_REGION", Value: o.AWSPrivateRegion}, + corev1.EnvVar{Name: "AWS_SDK_LOAD_CONFIG", Value: "1"}, + ) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "token", + MountPath: "/var/run/secrets/openshift/serviceaccount", + }) + *volumes = append(*volumes, corev1.Volume{ + Name: "token", + VolumeSource: corev1.VolumeSource{ + Projected: &corev1.ProjectedVolumeSource{ + Sources: []corev1.VolumeProjection{ + { + ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ + Audience: "openshift", + Path: "token", + }, + }, + }, + }, + }, + }) + case hyperv1.AzurePlatform: + o.addAzurePlatformResources(envVars, volumeMounts, volumes) + case hyperv1.GCPPlatform: + if o.GCPProject != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "GCP_PROJECT", Value: o.GCPProject}) + } + if o.GCPRegion != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "GCP_REGION", Value: o.GCPRegion}) + } + } +} + +func (o HyperShiftOperatorDeployment) addAzurePlatformResources(envVars *[]corev1.EnvVar, volumeMounts *[]corev1.VolumeMount, volumes *[]corev1.Volume) { + if o.AzurePLSResourceGroup != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_RESOURCE_GROUP", Value: o.AzurePLSResourceGroup}) + } + if o.AzurePLSManagedIdentityClientID != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_PLS_CLIENT_ID", Value: o.AzurePLSManagedIdentityClientID}) + if o.AzurePLSSubscriptionID != "" { + *envVars = append(*envVars, corev1.EnvVar{Name: "AZURE_SUBSCRIPTION_ID", Value: o.AzurePLSSubscriptionID}) + } + } else if o.AzurePrivateSecret != nil && len(o.AzurePrivateSecret.Name) > 0 && len(o.AzurePrivateSecretKey) > 0 { + *volumes = append(*volumes, corev1.Volume{ + Name: "azure-credentials", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: o.AzurePrivateSecret.Name, + }, + }, + }) + *volumeMounts = append(*volumeMounts, corev1.VolumeMount{ + Name: "azure-credentials", + MountPath: "/etc/azure-provider", + ReadOnly: true, + }) + *envVars = append(*envVars, corev1.EnvVar{ + Name: "AZURE_CREDENTIALS_FILE", + Value: "/etc/azure-provider/" + o.AzurePrivateSecretKey, + }) + } +} + type HyperShiftOperatorService struct { Namespace *corev1.Namespace } diff --git a/cmd/install/assets/hypershift_operator_test.go b/cmd/install/assets/hypershift_operator_test.go index 5ca4665d1a9..7dcca5ccf9a 100644 --- a/cmd/install/assets/hypershift_operator_test.go +++ b/cmd/install/assets/hypershift_operator_test.go @@ -3,10 +3,16 @@ package assets import ( "fmt" "testing" + "time" . "github.com/onsi/gomega" hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + controlplaneoperatoroverrides "github.com/openshift/hypershift/hypershift-operator/controlplaneoperator-overrides" + "github.com/openshift/hypershift/support/config" + "github.com/openshift/hypershift/support/rhobsmonitoring" + + configv1 "github.com/openshift/api/config/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -799,3 +805,540 @@ func TestHyperShiftOperatorClusterRole_WebhookRBAC(t *testing.T) { }))) }) } + +func TestBuildArgs(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectContains []string + expectNotContains []string + }{ + { + name: "When registry overrides is set, it should include the flag in args", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + RegistryOverrides: "quay.io=mirror.example.com", + }, + expectContains: []string{"--registry-overrides=quay.io=mirror.example.com"}, + }, + { + name: "When registry overrides is empty, it should not include the flag", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + }, + expectNotContains: []string{"--registry-overrides"}, + }, + { + name: "When all standard options are set, it should include them in args", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.AWSPlatform), + EnableOCPClusterMonitoring: true, + EnableCIDebugOutput: true, + EnableDedicatedRequestServingIsolation: true, + }, + expectContains: []string{ + "run", + "--namespace=$(MY_NAMESPACE)", + "--pod-name=$(MY_NAME)", + "--metrics-addr=:9000", + "--enable-dedicated-request-serving-isolation=true", + "--enable-ocp-cluster-monitoring=true", + "--enable-ci-debug-output=true", + fmt.Sprintf("--private-platform=%s", hyperv1.AWSPlatform), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + args := tc.deployment.buildArgs() + for _, expected := range tc.expectContains { + g.Expect(args).To(ContainElement(expected)) + } + for _, notExpected := range tc.expectNotContains { + for _, arg := range args { + g.Expect(arg).NotTo(HavePrefix(notExpected)) + } + } + }) + } +} + +func TestBuildEnvVars(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectContains []corev1.EnvVar + expectAbsent []string + }{ + { + name: "When TechPreviewNoUpgrade is enabled, it should include HYPERSHIFT_FEATURESET env var", + deployment: HyperShiftOperatorDeployment{ + TechPreviewNoUpgrade: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "HYPERSHIFT_FEATURESET", Value: string(configv1.TechPreviewNoUpgrade)}, + }, + }, + { + name: "When EnableAuditLogPersistence is enabled, it should include ENABLE_AUDIT_LOG_PERSISTENCE env var", + deployment: HyperShiftOperatorDeployment{ + EnableAuditLogPersistence: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "ENABLE_AUDIT_LOG_PERSISTENCE", Value: "true"}, + }, + }, + { + name: "When EnableCVOManagementClusterMetricsAccess is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableCVOManagementClusterMetricsAccess: true, + }, + expectContains: []corev1.EnvVar{ + {Name: config.EnableCVOManagementClusterMetricsAccessEnvVar, Value: "1"}, + }, + }, + { + name: "When ManagedService is set, it should include MANAGED_SERVICE env var", + deployment: HyperShiftOperatorDeployment{ + ManagedService: hyperv1.AroHCP, + }, + expectContains: []corev1.EnvVar{ + {Name: "MANAGED_SERVICE", Value: hyperv1.AroHCP}, + }, + }, + { + name: "When EnableSizeTagging is enabled, it should include ENABLE_SIZE_TAGGING env var", + deployment: HyperShiftOperatorDeployment{ + EnableSizeTagging: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "ENABLE_SIZE_TAGGING", Value: "1"}, + }, + }, + { + name: "When EnableEtcdRecovery is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableEtcdRecovery: true, + }, + expectContains: []corev1.EnvVar{ + {Name: config.EnableEtcdRecoveryEnvVar, Value: "1"}, + }, + }, + { + name: "When EnableCPOOverrides is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + EnableCPOOverrides: true, + }, + expectContains: []corev1.EnvVar{ + {Name: controlplaneoperatoroverrides.CPOOverridesEnvVar, Value: "1"}, + }, + }, + { + name: "When PlatformsInstalled is set, it should include PLATFORMS_INSTALLED env var", + deployment: HyperShiftOperatorDeployment{ + PlatformsInstalled: "aws,azure", + }, + expectContains: []corev1.EnvVar{ + {Name: "PLATFORMS_INSTALLED", Value: "aws,azure"}, + }, + }, + { + name: "When RHOBSMonitoring is enabled, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + RHOBSMonitoring: true, + }, + expectContains: []corev1.EnvVar{ + {Name: rhobsmonitoring.EnvironmentVariable, Value: "1"}, + }, + }, + { + name: "When CVOPrometheusURL is set, it should include the env var", + deployment: HyperShiftOperatorDeployment{ + CVOPrometheusURL: "https://prometheus.example.com", + }, + expectContains: []corev1.EnvVar{ + {Name: config.CVOPrometheusURLEnvVar, Value: "https://prometheus.example.com"}, + }, + }, + { + name: "When MonitoringDashboards is enabled, it should include MONITORING_DASHBOARDS env var", + deployment: HyperShiftOperatorDeployment{ + MonitoringDashboards: true, + }, + expectContains: []corev1.EnvVar{ + {Name: "MONITORING_DASHBOARDS", Value: "1"}, + }, + }, + { + name: "When no optional features are enabled, it should not include optional env vars", + deployment: HyperShiftOperatorDeployment{ + CertRotationScale: 24 * time.Hour, + }, + expectAbsent: []string{ + "HYPERSHIFT_FEATURESET", + "ENABLE_AUDIT_LOG_PERSISTENCE", + "MANAGED_SERVICE", + "ENABLE_SIZE_TAGGING", + "MONITORING_DASHBOARDS", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + envVars := tc.deployment.buildEnvVars() + for _, expected := range tc.expectContains { + g.Expect(envVars).To(ContainElement(expected)) + } + for _, absent := range tc.expectAbsent { + for _, env := range envVars { + g.Expect(env.Name).NotTo(Equal(absent)) + } + } + }) + } +} + +func TestAddWebhookResources(t *testing.T) { + tests := []struct { + name string + enableWebhook bool + enableValidatingWebhook bool + expectArgs []string + expectVolumeMountCount int + expectVolumeCount int + }{ + { + name: "When webhook is disabled, it should not add any resources", + enableWebhook: false, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + }, + { + name: "When webhook is enabled without validating webhook, it should add serving-cert resources and cert-dir arg", + enableWebhook: true, + expectArgs: []string{"--cert-dir=/var/run/secrets/serving-cert"}, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + }, + { + name: "When webhook and validating webhook are both enabled, it should add cert-dir and enable-validating-webhook args", + enableWebhook: true, + enableValidatingWebhook: true, + expectArgs: []string{"--cert-dir=/var/run/secrets/serving-cert", "--enable-validating-webhook=true"}, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + d := HyperShiftOperatorDeployment{ + EnableWebhook: tc.enableWebhook, + EnableValidatingWebhook: tc.enableValidatingWebhook, + } + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + d.addWebhookResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(volumes).To(HaveLen(tc.expectVolumeCount)) + for _, expected := range tc.expectArgs { + g.Expect(args).To(ContainElement(expected)) + } + }) + } +} + +func TestAddOIDCResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectVolumeMountCount int + expectVolumeCount int + expectArgCount int + }{ + { + name: "When all OIDC parameters are set, it should add OIDC resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketName: "my-bucket", + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + OIDCStorageProviderS3Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "oidc-secret"}, + }, + }, + expectVolumeMountCount: 1, + expectVolumeCount: 1, + expectArgCount: 3, + }, + { + name: "When OIDC bucket name is empty, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + OIDCStorageProviderS3Secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "oidc-secret"}, + }, + }, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + expectArgCount: 0, + }, + { + name: "When OIDC secret is nil, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + OIDCBucketName: "my-bucket", + OIDCBucketRegion: "us-east-1", + OIDCStorageProviderS3SecretKey: "mykey", + }, + expectVolumeMountCount: 0, + expectVolumeCount: 0, + expectArgCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addOIDCResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(volumes).To(HaveLen(tc.expectVolumeCount)) + g.Expect(args).To(HaveLen(tc.expectArgCount)) + }) + } +} + +func TestAddScaleFromZeroResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectVolumeMountCount int + expectArgCount int + }{ + { + name: "When all scale-from-zero parameters are set, it should add resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "sfz-secret"}, + }, + ScaleFromZeroSecretKey: "credentials", + ScaleFromZeroProvider: "aws", + }, + expectVolumeMountCount: 1, + expectArgCount: 2, + }, + { + name: "When scale-from-zero secret is nil, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecretKey: "credentials", + ScaleFromZeroProvider: "aws", + }, + expectVolumeMountCount: 0, + expectArgCount: 0, + }, + { + name: "When scale-from-zero provider is empty, it should not add any resources", + deployment: HyperShiftOperatorDeployment{ + ScaleFromZeroSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "sfz-secret"}, + }, + ScaleFromZeroSecretKey: "credentials", + }, + expectVolumeMountCount: 0, + expectArgCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var args []string + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addScaleFromZeroResources(&args, &volumeMounts, &volumes) + + g.Expect(volumeMounts).To(HaveLen(tc.expectVolumeMountCount)) + g.Expect(args).To(HaveLen(tc.expectArgCount)) + }) + } +} + +func TestResolveImage(t *testing.T) { + tests := []struct { + name string + operatorImage string + images map[string]string + expected string + }{ + { + name: "When no image override in map, it should use OperatorImage", + operatorImage: "default-image:latest", + images: map[string]string{}, + expected: "default-image:latest", + }, + { + name: "When image override exists in map, it should use the override", + operatorImage: "default-image:latest", + images: map[string]string{"hypershift-operator": "override-image:v1"}, + expected: "override-image:v1", + }, + { + name: "When images map is nil, it should use OperatorImage", + operatorImage: "default-image:latest", + images: nil, + expected: "default-image:latest", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + d := HyperShiftOperatorDeployment{ + OperatorImage: tc.operatorImage, + Images: tc.images, + } + result := d.resolveImage() + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestAddPrivatePlatformResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectEnvNames []string + expectAbsentEnvs []string + }{ + { + name: "When private platform is None, it should not add any platform resources", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.NonePlatform), + }, + expectAbsentEnvs: []string{"AWS_SHARED_CREDENTIALS_FILE", "AZURE_CREDENTIALS_FILE", "GCP_PROJECT"}, + }, + { + name: "When private platform is AWS, it should add AWS env vars and volumes", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.AWSPlatform), + AWSPrivateRegion: "us-east-1", + AWSPrivateSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "aws-creds"}, + }, + AWSPrivateSecretKey: "credentials", + }, + expectEnvNames: []string{"AWS_SHARED_CREDENTIALS_FILE", "AWS_REGION", "AWS_SDK_LOAD_CONFIG"}, + }, + { + name: "When private platform is GCP with project and region, it should add GCP env vars", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + GCPRegion: "us-central1", + }, + expectEnvNames: []string{"GCP_PROJECT", "GCP_REGION"}, + }, + { + name: "When private platform is GCP without project, it should not add GCP_PROJECT env var", + deployment: HyperShiftOperatorDeployment{ + PrivatePlatform: string(hyperv1.GCPPlatform), + }, + expectAbsentEnvs: []string{"GCP_PROJECT", "GCP_REGION"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var envVars []corev1.EnvVar + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addPrivatePlatformResources(&envVars, &volumeMounts, &volumes) + + envNames := make([]string, 0, len(envVars)) + for _, e := range envVars { + envNames = append(envNames, e.Name) + } + for _, expected := range tc.expectEnvNames { + g.Expect(envNames).To(ContainElement(expected)) + } + for _, absent := range tc.expectAbsentEnvs { + g.Expect(envNames).NotTo(ContainElement(absent)) + } + }) + } +} + +func TestAddAzurePlatformResources(t *testing.T) { + tests := []struct { + name string + deployment HyperShiftOperatorDeployment + expectEnvNames []string + expectVolCount int + }{ + { + name: "When Azure managed identity is provided, it should set PLS client ID and subscription ID env vars", + deployment: HyperShiftOperatorDeployment{ + AzurePLSManagedIdentityClientID: "client-id", + AzurePLSSubscriptionID: "sub-id", + AzurePLSResourceGroup: "rg-mgmt", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP", "AZURE_PLS_CLIENT_ID", "AZURE_SUBSCRIPTION_ID"}, + expectVolCount: 0, + }, + { + name: "When Azure credentials file is provided, it should mount the credentials volume", + deployment: HyperShiftOperatorDeployment{ + AzurePLSResourceGroup: "rg-mgmt", + AzurePrivateSecret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "azure-creds"}, + }, + AzurePrivateSecretKey: "credentials", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP", "AZURE_CREDENTIALS_FILE"}, + expectVolCount: 1, + }, + { + name: "When neither managed identity nor credentials file is provided, it should not add env vars or volumes", + deployment: HyperShiftOperatorDeployment{ + AzurePLSResourceGroup: "rg-mgmt", + }, + expectEnvNames: []string{"AZURE_RESOURCE_GROUP"}, + expectVolCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + var envVars []corev1.EnvVar + var volumeMounts []corev1.VolumeMount + var volumes []corev1.Volume + + tc.deployment.addAzurePlatformResources(&envVars, &volumeMounts, &volumes) + + envNames := make([]string, 0, len(envVars)) + for _, e := range envVars { + envNames = append(envNames, e.Name) + } + for _, expected := range tc.expectEnvNames { + g.Expect(envNames).To(ContainElement(expected)) + } + g.Expect(volumes).To(HaveLen(tc.expectVolCount)) + }) + } +} diff --git a/cmd/install/install.go b/cmd/install/install.go index e022f1ecb7b..ed798ea8f4f 100644 --- a/cmd/install/install.go +++ b/cmd/install/install.go @@ -166,46 +166,67 @@ func (o *Options) Validate() error { o.ScaleFromZeroProvider = strings.ToLower(o.ScaleFromZeroProvider) } + errs = append(errs, o.validatePlatformConfig()...) + errs = append(errs, o.validateOIDCConfig()...) + errs = append(errs, o.validateExternalDNSConfig()...) + errs = append(errs, o.validateImageConfig()...) + errs = append(errs, o.validateScaleFromZeroConfig()...) + errs = append(errs, o.validateMonitoringConfig()...) + errs = append(errs, o.validateMiscConfig()...) + + return errors.NewAggregate(errs) +} + +func (o *Options) validatePlatformConfig() []error { + var errs []error switch hyperv1.PlatformType(o.PrivatePlatform) { case hyperv1.AWSPlatform: if (len(o.AWSPrivateCreds) == 0 && len(o.AWSPrivateCredentialsSecret) == 0) || len(o.AWSPrivateRegion) == 0 { errs = append(errs, fmt.Errorf("--aws-private-region and --aws-private-creds or --aws-private-secret are required with --private-platform=%s", hyperv1.AWSPlatform)) } case hyperv1.GCPPlatform: - // GCP uses Workload Identity Federation, no credentials required. - // However, --gcp-project and --gcp-region must be set together. if (o.GCPProject == "") != (o.GCPRegion == "") { errs = append(errs, fmt.Errorf("--gcp-project and --gcp-region must be set together when --private-platform=%s", hyperv1.GCPPlatform)) } case hyperv1.AzurePlatform: - if o.ManagedService != hyperv1.AroHCP { - hasCredFile := len(o.AzurePrivateCreds) != 0 || len(o.AzurePrivateCredentialsSecret) != 0 - hasManagedIdentity := len(o.AzurePLSManagedIdentityClientID) != 0 - if !hasCredFile && !hasManagedIdentity { - errs = append(errs, fmt.Errorf("--azure-private-creds, --azure-private-secret, or --azure-pls-managed-identity-client-id is required with --private-platform=%s", hyperv1.AzurePlatform)) - } - if hasCredFile && hasManagedIdentity { - errs = append(errs, fmt.Errorf("--azure-pls-managed-identity-client-id cannot be used with --azure-private-creds or --azure-private-secret")) - } - if hasManagedIdentity && len(o.AzurePLSSubscriptionID) == 0 { - errs = append(errs, fmt.Errorf("--azure-pls-subscription-id is required when using --azure-pls-managed-identity-client-id")) - } - if len(o.AzurePrivateCreds) != 0 && len(o.AzurePrivateCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --azure-private-creds or --azure-private-secret is supported")) - } - if len(o.AzurePLSResourceGroup) == 0 { - errs = append(errs, fmt.Errorf("--azure-pls-resource-group is required with --private-platform=%s", hyperv1.AzurePlatform)) - } - } + errs = append(errs, o.validateAzurePlatformConfig()...) case hyperv1.NonePlatform: default: errs = append(errs, fmt.Errorf("--private-platform must be either %s, %s, %s, or %s", hyperv1.AWSPlatform, hyperv1.AzurePlatform, hyperv1.GCPPlatform, hyperv1.NonePlatform)) } + return errs +} +func (o *Options) validateAzurePlatformConfig() []error { + if o.ManagedService == hyperv1.AroHCP { + return nil + } + var errs []error + hasCredFile := len(o.AzurePrivateCreds) != 0 || len(o.AzurePrivateCredentialsSecret) != 0 + hasManagedIdentity := len(o.AzurePLSManagedIdentityClientID) != 0 + if !hasCredFile && !hasManagedIdentity { + errs = append(errs, fmt.Errorf("--azure-private-creds, --azure-private-secret, or --azure-pls-managed-identity-client-id is required with --private-platform=%s", hyperv1.AzurePlatform)) + } + if hasCredFile && hasManagedIdentity { + errs = append(errs, fmt.Errorf("--azure-pls-managed-identity-client-id cannot be used with --azure-private-creds or --azure-private-secret")) + } + if hasManagedIdentity && len(o.AzurePLSSubscriptionID) == 0 { + errs = append(errs, fmt.Errorf("--azure-pls-subscription-id is required when using --azure-pls-managed-identity-client-id")) + } + if len(o.AzurePrivateCreds) != 0 && len(o.AzurePrivateCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --azure-private-creds or --azure-private-secret is supported")) + } + if len(o.AzurePLSResourceGroup) == 0 { + errs = append(errs, fmt.Errorf("--azure-pls-resource-group is required with --private-platform=%s", hyperv1.AzurePlatform)) + } + return errs +} + +func (o *Options) validateOIDCConfig() []error { + var errs []error if len(o.OIDCStorageProviderS3CredentialsSecret) > 0 && len(o.OIDCStorageProviderS3Credentials) > 0 { errs = append(errs, fmt.Errorf("only one of --oidc-storage-provider-s3-secret or --oidc-storage-provider-s3-credentials is supported")) } - if (len(o.OIDCStorageProviderS3CredentialsSecret) > 0 || len(o.OIDCStorageProviderS3Credentials) > 0) && (len(o.OIDCStorageProviderS3BucketName) == 0 || len(o.OIDCStorageProviderS3Region) == 0 || len(o.OIDCStorageProviderS3CredentialsSecretKey) == 0) { errs = append(errs, fmt.Errorf("all required oidc information is not set")) @@ -213,91 +234,102 @@ func (o *Options) Validate() error { if strings.Contains(o.OIDCStorageProviderS3BucketName, ".") { errs = append(errs, fmt.Errorf("oidc bucket name must not contain dots (.); see the notes on HTTPS at https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html")) } + return errs +} - if len(o.ExternalDNSProvider) > 0 { - // Credentials are optional for GCP when using Workload Identity - credentialsRequired := o.ExternalDNSProvider != "google" - if credentialsRequired && len(o.ExternalDNSCredentials) == 0 && len(o.ExternalDNSCredentialsSecret) == 0 { - errs = append(errs, fmt.Errorf("--external-dns-credentials or --external-dns-credentials-secret are required with --external-dns-provider")) - } - if len(o.ExternalDNSCredentials) != 0 && len(o.ExternalDNSCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --external-dns-credentials or --external-dns-credentials-secret is supported")) - } - if len(o.ExternalDNSDomainFilter) == 0 { - errs = append(errs, fmt.Errorf("--external-dns-domain-filter is required with --external-dns-provider")) +func (o *Options) validateExternalDNSConfig() []error { + if len(o.ExternalDNSProvider) == 0 { + return nil + } + var errs []error + credentialsRequired := o.ExternalDNSProvider != "google" + if credentialsRequired && len(o.ExternalDNSCredentials) == 0 && len(o.ExternalDNSCredentialsSecret) == 0 { + errs = append(errs, fmt.Errorf("--external-dns-credentials or --external-dns-credentials-secret are required with --external-dns-provider")) + } + if len(o.ExternalDNSCredentials) != 0 && len(o.ExternalDNSCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --external-dns-credentials or --external-dns-credentials-secret is supported")) + } + if len(o.ExternalDNSDomainFilter) == 0 { + errs = append(errs, fmt.Errorf("--external-dns-domain-filter is required with --external-dns-provider")) + } + if len(o.ExternalDNSInterval) > 0 { + if _, err := time.ParseDuration(o.ExternalDNSInterval); err != nil { + errs = append(errs, fmt.Errorf("--external-dns-interval is not a valid duration: %w", err)) } - if len(o.ExternalDNSInterval) > 0 { - if _, err := time.ParseDuration(o.ExternalDNSInterval); err != nil { - errs = append(errs, fmt.Errorf("--external-dns-interval is not a valid duration: %w", err)) - } + } + if len(o.ExternalDNSAWSZonesCacheDuration) > 0 { + if _, err := time.ParseDuration(o.ExternalDNSAWSZonesCacheDuration); err != nil { + errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is not a valid duration: %w", err)) } - if len(o.ExternalDNSAWSZonesCacheDuration) > 0 { - if _, err := time.ParseDuration(o.ExternalDNSAWSZonesCacheDuration); err != nil { - errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is not a valid duration: %w", err)) - } - if o.ExternalDNSProvider != "aws" { - errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is only effective with --external-dns-provider=aws")) - } + if o.ExternalDNSProvider != "aws" { + errs = append(errs, fmt.Errorf("--external-dns-aws-zones-cache-duration is only effective with --external-dns-provider=aws")) } } + return errs +} + +func (o *Options) validateImageConfig() []error { + var errs []error if o.HyperShiftImage != HyperShiftImage && len(o.ImageRefsFile) > 0 { errs = append(errs, fmt.Errorf("only one of --hypershift-image or --image-refs-file should be specified")) } if o.RHOBSMonitoring && os.Getenv(rhobsmonitoring.EnvironmentVariable) != "1" { errs = append(errs, fmt.Errorf("when invoking this command with the --rhobs-monitoring flag, the RHOBS_MONITORING environment variable must be set to \"1\"")) } - if o.CertRotationScale > 24*time.Hour { errs = append(errs, fmt.Errorf("cannot set --cert-rotation-scale longer than 24h, invalid value: %s", o.CertRotationScale.String())) } + return errs +} - // Validate scale-from-zero credentials +func (o *Options) validateScaleFromZeroConfig() []error { + if len(o.ScaleFromZeroCreds) == 0 && len(o.ScaleFromZeroCredentialsSecret) == 0 { + return nil + } + var errs []error supportedProviders := set.New("aws") - if len(o.ScaleFromZeroCreds) != 0 || len(o.ScaleFromZeroCredentialsSecret) != 0 { - // Check mutual exclusivity - only one of file or secret should be provided - if len(o.ScaleFromZeroCreds) != 0 && len(o.ScaleFromZeroCredentialsSecret) != 0 { - errs = append(errs, fmt.Errorf("only one of --scale-from-zero-creds or --scale-from-zero-secret is supported")) - } - - // Provider is required when using scale-from-zero credentials - if len(o.ScaleFromZeroProvider) == 0 { - errs = append(errs, fmt.Errorf("--scale-from-zero-provider is required when using scale-from-zero credentials")) - } else if !supportedProviders.Has(o.ScaleFromZeroProvider) { - errs = append(errs, fmt.Errorf("invalid --scale-from-zero-provider: %s (must be one of: %v)", o.ScaleFromZeroProvider, supportedProviders.UnsortedList())) - } - - // Validate credentials file exists and is accessible if provided - if len(o.ScaleFromZeroCreds) > 0 { - if _, err := os.Stat(o.ScaleFromZeroCreds); err != nil { - if os.IsNotExist(err) { - errs = append(errs, fmt.Errorf("--scale-from-zero-creds file does not exist: %s", o.ScaleFromZeroCreds)) - } else { - errs = append(errs, fmt.Errorf("--scale-from-zero-creds file is not accessible: %w", err)) - } + if len(o.ScaleFromZeroCreds) != 0 && len(o.ScaleFromZeroCredentialsSecret) != 0 { + errs = append(errs, fmt.Errorf("only one of --scale-from-zero-creds or --scale-from-zero-secret is supported")) + } + if len(o.ScaleFromZeroProvider) == 0 { + errs = append(errs, fmt.Errorf("--scale-from-zero-provider is required when using scale-from-zero credentials")) + } else if !supportedProviders.Has(o.ScaleFromZeroProvider) { + errs = append(errs, fmt.Errorf("invalid --scale-from-zero-provider: %s (must be one of: %v)", o.ScaleFromZeroProvider, supportedProviders.UnsortedList())) + } + if len(o.ScaleFromZeroCreds) > 0 { + if _, err := os.Stat(o.ScaleFromZeroCreds); err != nil { + if os.IsNotExist(err) { + errs = append(errs, fmt.Errorf("--scale-from-zero-creds file does not exist: %s", o.ScaleFromZeroCreds)) + } else { + errs = append(errs, fmt.Errorf("--scale-from-zero-creds file is not accessible: %w", err)) } } } + return errs +} +func (o *Options) validateMonitoringConfig() []error { + var errs []error if o.RHOBSMonitoring && o.EnableCVOManagementClusterMetricsAccess { errs = append(errs, fmt.Errorf("when invoking this command with the --rhobs-monitoring flag, the --enable-cvo-management-cluster-metrics-access flag is not supported ")) } - if len(o.CVOPrometheusURL) > 0 && !o.RHOBSMonitoring && !o.EnableCVOManagementClusterMetricsAccess { errs = append(errs, fmt.Errorf("--cvo-prometheus-url requires either --rhobs-monitoring or --enable-cvo-management-cluster-metrics-access to be enabled")) } + return errs +} +func (o *Options) validateMiscConfig() []error { + var errs []error if len(o.ManagedService) > 0 && o.ManagedService != hyperv1.AroHCP { errs = append(errs, fmt.Errorf("not a valid managed service type: %s", o.ManagedService)) } - - // Validate all the platforms in the list are valid for _, platform := range o.PlatformsToInstall { platformToCheck := strings.ToLower(platform) if !ValidPlatforms.Has(platformToCheck) { errs = append(errs, fmt.Errorf("not a valid platform type: %s", platform)) } } - if len(o.ImagePullPolicy) > 0 { normalized := strings.ToLower(o.ImagePullPolicy) switch normalized { @@ -306,8 +338,7 @@ func (o *Options) Validate() error { errs = append(errs, fmt.Errorf("invalid --image-pull-policy: %s (want Always|Never|IfNotPresent)", o.ImagePullPolicy)) } } - - return errors.NewAggregate(errs) + return errs } func (o *Options) ApplyDefaults() { diff --git a/cmd/install/install_test.go b/cmd/install/install_test.go index b01725f0b92..33375544104 100644 --- a/cmd/install/install_test.go +++ b/cmd/install/install_test.go @@ -13,12 +13,15 @@ import ( . "github.com/onsi/gomega" hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + "github.com/openshift/hypershift/cmd/install/assets" crdassets "github.com/openshift/hypershift/cmd/install/assets/crds" "github.com/openshift/hypershift/hypershift-operator/controllers/sharedingress" hyperapi "github.com/openshift/hypershift/support/api" + "github.com/openshift/hypershift/support/metrics" operatorv1alpha1 "github.com/openshift/api/operator/v1alpha1" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" @@ -30,6 +33,8 @@ import ( crclient "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + + prometheusoperatorv1 "github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring/v1" ) func TestOptions_Validate(t *testing.T) { @@ -968,3 +973,785 @@ func TestWaitForCAPIOperatorSync(t *testing.T) { }) } } +func TestApplyDefaults(t *testing.T) { + tests := []struct { + name string + opts Options + expectedReplicas int32 + }{ + { + name: "When Development mode is enabled, it should set replicas to 0", + opts: Options{Development: true}, + expectedReplicas: 0, + }, + { + name: "When defaulting webhook is enabled, it should set replicas to 2", + opts: Options{EnableDefaultingWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When conversion webhook is enabled, it should set replicas to 2", + opts: Options{EnableConversionWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When validating webhook is enabled, it should set replicas to 2", + opts: Options{EnableValidatingWebhook: true}, + expectedReplicas: 2, + }, + { + name: "When no special options are set, it should default replicas to 1", + opts: Options{}, + expectedReplicas: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + tc.opts.ApplyDefaults() + g.Expect(tc.opts.HyperShiftOperatorReplicas).To(Equal(tc.expectedReplicas)) + g.Expect(tc.opts.RenderNamespace).To(BeTrue()) + }) + } +} + +func TestIsAWSPlatformEnabled(t *testing.T) { + tests := []struct { + name string + platformsToInstall []string + expected bool + }{ + { + name: "When no platforms are specified, it should return true (all platforms enabled)", + platformsToInstall: nil, + expected: true, + }, + { + name: "When empty platforms list is specified, it should return true", + platformsToInstall: []string{}, + expected: true, + }, + { + name: "When AWS is in the list, it should return true", + platformsToInstall: []string{"AWS", "Azure"}, + expected: true, + }, + { + name: "When aws (lowercase) is in the list, it should return true", + platformsToInstall: []string{"aws"}, + expected: true, + }, + { + name: "When AWS is not in the list, it should return false", + platformsToInstall: []string{"Azure", "GCP"}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := isAWSPlatformEnabled(tc.platformsToInstall) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestIsAzurePlatformEnabled(t *testing.T) { + tests := []struct { + name string + platformsToInstall []string + expected bool + }{ + { + name: "When no platforms are specified, it should return true (all platforms enabled)", + platformsToInstall: nil, + expected: true, + }, + { + name: "When Azure is in the list, it should return true", + platformsToInstall: []string{"Azure", "AWS"}, + expected: true, + }, + { + name: "When azure (lowercase) is in the list, it should return true", + platformsToInstall: []string{"azure"}, + expected: true, + }, + { + name: "When Azure is not in the list, it should return false", + platformsToInstall: []string{"AWS", "GCP"}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := isAzurePlatformEnabled(tc.platformsToInstall) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} + +func TestValidatePlatformConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When private platform is None, it should pass", + opts: Options{PrivatePlatform: string(hyperv1.NonePlatform)}, + expectError: false, + }, + { + name: "When private platform is an unsupported value, it should fail", + opts: Options{PrivatePlatform: "Unsupported"}, + expectError: true, + }, + { + name: "When private platform is AWS with creds and region, it should pass", + opts: Options{ + PrivatePlatform: string(hyperv1.AWSPlatform), + AWSPrivateCreds: "/path/to/creds", + AWSPrivateRegion: "us-east-1", + }, + expectError: false, + }, + { + name: "When private platform is AWS without creds or region, it should fail", + opts: Options{ + PrivatePlatform: string(hyperv1.AWSPlatform), + }, + expectError: true, + }, + { + name: "When private platform is GCP with only project set, it should fail", + opts: Options{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + }, + expectError: true, + }, + { + name: "When private platform is GCP with both project and region, it should pass", + opts: Options{ + PrivatePlatform: string(hyperv1.GCPPlatform), + GCPProject: "my-project", + GCPRegion: "us-central1", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validatePlatformConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateOIDCConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no OIDC options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When both OIDC secret and credentials are set, it should fail", + opts: Options{ + OIDCStorageProviderS3CredentialsSecret: "my-secret", + OIDCStorageProviderS3Credentials: "my-creds", + }, + expectError: true, + }, + { + name: "When OIDC credentials are set without bucket name, it should fail", + opts: Options{ + OIDCStorageProviderS3Credentials: "my-creds", + }, + expectError: true, + }, + { + name: "When OIDC bucket name contains dots, it should fail", + opts: Options{ + OIDCStorageProviderS3BucketName: "my.bucket.name", + }, + expectError: true, + }, + { + name: "When all OIDC parameters are provided correctly, it should pass", + opts: Options{ + OIDCStorageProviderS3Credentials: "my-creds", + OIDCStorageProviderS3BucketName: "mybucket", + OIDCStorageProviderS3Region: "us-east-1", + OIDCStorageProviderS3CredentialsSecretKey: "mykey", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateOIDCConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateExternalDNSConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no external DNS provider is set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When external DNS provider is set without credentials or domain filter, it should fail", + opts: Options{ + ExternalDNSProvider: "aws", + }, + expectError: true, + }, + { + name: "When external DNS provider is set with credentials and domain filter, it should pass", + opts: Options{ + ExternalDNSProvider: "aws", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + }, + expectError: false, + }, + { + name: "When external DNS interval is an invalid duration, it should fail", + opts: Options{ + ExternalDNSProvider: "aws", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + ExternalDNSInterval: "not-a-duration", + }, + expectError: true, + }, + { + name: "When AWS zones cache duration is set with non-AWS provider, it should fail", + opts: Options{ + ExternalDNSProvider: "azure", + ExternalDNSCredentials: "/path/to/creds", + ExternalDNSDomainFilter: "example.com", + ExternalDNSAWSZonesCacheDuration: "1h", + }, + expectError: true, + }, + { + name: "When google provider is set without credentials, it should pass (Workload Identity)", + opts: Options{ + ExternalDNSProvider: "google", + ExternalDNSDomainFilter: "example.com", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateExternalDNSConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateMonitoringConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no monitoring options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When both RHOBS monitoring and CVO management cluster metrics access are set, it should fail", + opts: Options{ + RHOBSMonitoring: true, + EnableCVOManagementClusterMetricsAccess: true, + }, + expectError: true, + }, + { + name: "When CVO prometheus URL is set without RHOBS or CVO metrics, it should fail", + opts: Options{ + CVOPrometheusURL: "https://prometheus.example.com", + }, + expectError: true, + }, + { + name: "When CVO prometheus URL is set with RHOBS monitoring, it should pass", + opts: Options{ + CVOPrometheusURL: "https://prometheus.example.com", + RHOBSMonitoring: true, + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateMonitoringConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestValidateMiscConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When no misc options are set, it should pass", + opts: Options{}, + expectError: false, + }, + { + name: "When managed service is an invalid value, it should fail", + opts: Options{ + ManagedService: "INVALID", + }, + expectError: true, + }, + { + name: "When managed service is ARO-HCP, it should pass", + opts: Options{ + ManagedService: hyperv1.AroHCP, + }, + expectError: false, + }, + { + name: "When an invalid platform is specified, it should fail", + opts: Options{ + PlatformsToInstall: []string{"invalid-platform"}, + }, + expectError: true, + }, + { + name: "When valid platforms are specified, it should pass", + opts: Options{ + PlatformsToInstall: []string{"aws", "Azure"}, + }, + expectError: false, + }, + { + name: "When invalid image pull policy is specified, it should fail", + opts: Options{ + ImagePullPolicy: "WheneverYouFeel", + }, + expectError: true, + }, + { + name: "When Always image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "Always", + }, + expectError: false, + }, + { + name: "When Never image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "Never", + }, + expectError: false, + }, + { + name: "When IfNotPresent image pull policy is specified, it should pass", + opts: Options{ + ImagePullPolicy: "IfNotPresent", + }, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateMiscConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestSetupMonitoring(t *testing.T) { + tests := []struct { + name string + opts Options + expectSLOAlerts bool + expectDashboards bool + minResourceCount int + }{ + { + name: "When SLOs alerts and monitoring dashboards are disabled, it should return base monitoring resources", + opts: Options{ + PlatformMonitoring: metrics.PlatformMonitoringAll, + }, + expectSLOAlerts: false, + expectDashboards: false, + minResourceCount: 4, + }, + { + name: "When SLOs alerts are enabled, it should include alerting rule", + opts: Options{ + SLOsAlerts: true, + }, + expectSLOAlerts: true, + minResourceCount: 5, + }, + { + name: "When monitoring dashboards are enabled, it should include dashboard template", + opts: Options{ + Namespace: "hypershift", + MonitoringDashboards: true, + }, + expectDashboards: true, + minResourceCount: 5, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + objects := setupMonitoring(tc.opts, ns) + + g.Expect(len(objects)).To(BeNumerically(">=", tc.minResourceCount)) + + // Check SLO alerts + foundAlertingRule := false + for _, obj := range objects { + if rule, ok := obj.(*prometheusoperatorv1.PrometheusRule); ok && rule.Namespace == "openshift-monitoring" { + foundAlertingRule = true + break + } + } + if tc.expectSLOAlerts { + g.Expect(foundAlertingRule).To(BeTrue(), "expected SLO alerting rule to be present when SLOsAlerts is enabled") + } else { + g.Expect(foundAlertingRule).To(BeFalse(), "expected no SLO alerting rule when SLOsAlerts is disabled") + } + + // Check dashboards + foundDashboard := false + for _, obj := range objects { + if cm, ok := obj.(*corev1.ConfigMap); ok && cm.Name == "monitoring-dashboard-template" { + foundDashboard = true + break + } + } + if tc.expectDashboards { + g.Expect(foundDashboard).To(BeTrue(), "expected monitoring dashboard to be present when MonitoringDashboards is enabled") + } else { + g.Expect(foundDashboard).To(BeFalse(), "expected no monitoring dashboard when MonitoringDashboards is disabled") + } + }) + } +} + +func TestSetupRBAC(t *testing.T) { + tests := []struct { + name string + enableAdminRBAC bool + azureManagedIdentityClientID string + expectSAAnnotation bool + minObjectCount int + }{ + { + name: "When admin RBAC is disabled, it should return base RBAC resources", + minObjectCount: 6, + }, + { + name: "When admin RBAC is enabled, it should include client and reader RBAC resources", + enableAdminRBAC: true, + minObjectCount: 11, + }, + { + name: "When Azure managed identity is set, it should annotate the service account", + azureManagedIdentityClientID: "test-client-id", + expectSAAnnotation: true, + minObjectCount: 6, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + opts := Options{ + EnableAdminRBACGeneration: tc.enableAdminRBAC, + AzurePLSManagedIdentityClientID: tc.azureManagedIdentityClientID, + } + + sa, objects := setupRBAC(opts, ns) + g.Expect(len(objects)).To(BeNumerically(">=", tc.minObjectCount)) + g.Expect(sa).NotTo(BeNil()) + if tc.expectSAAnnotation { + g.Expect(sa.Annotations).To(HaveKeyWithValue("azure.workload.identity/client-id", tc.azureManagedIdentityClientID)) + } + + // When admin RBAC is enabled, verify specific admin RBAC objects exist + if tc.enableAdminRBAC { + foundClientClusterRole := false + foundReaderClusterRole := false + for _, obj := range objects { + if obj.GetObjectKind().GroupVersionKind().Kind == "ClusterRole" { + if obj.GetName() == "hypershift-client" { + foundClientClusterRole = true + } + if obj.GetName() == "hypershift-readers" { + foundReaderClusterRole = true + } + } + } + g.Expect(foundClientClusterRole).To(BeTrue(), "expected hypershift-client ClusterRole to be present when admin RBAC is enabled") + g.Expect(foundReaderClusterRole).To(BeTrue(), "expected hypershift-readers ClusterRole to be present when admin RBAC is enabled") + } + }) + } +} + +func TestSetupCA(t *testing.T) { + t.Run("When no additional trust bundle is provided, it should return only the managed trust bundle", func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + userCA, trustedCA, objects, err := setupCA(Options{}, ns) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(userCA).To(BeNil()) + g.Expect(trustedCA).NotTo(BeNil()) + g.Expect(trustedCA.Name).To(Equal("openshift-config-managed-trusted-ca-bundle")) + g.Expect(objects).To(HaveLen(1)) + }) +} + +func TestValidateImageConfig(t *testing.T) { + tests := []struct { + name string + opts Options + expectError bool + }{ + { + name: "When default image and no refs file are set, it should pass", + opts: Options{HyperShiftImage: HyperShiftImage}, + expectError: false, + }, + { + name: "When both custom image and refs file are set, it should fail", + opts: Options{ + HyperShiftImage: "custom-image:latest", + ImageRefsFile: "/path/to/refs", + }, + expectError: true, + }, + { + name: "When cert rotation scale is longer than 24h, it should fail", + opts: Options{ + HyperShiftImage: HyperShiftImage, + CertRotationScale: 48 * time.Hour, + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + errs := tc.opts.validateImageConfig() + if tc.expectError { + g.Expect(errs).NotTo(BeEmpty()) + } else { + g.Expect(errs).To(BeEmpty()) + } + }) + } +} + +func TestSetupSharedIngress(t *testing.T) { + t.Run("When setupSharedIngress is called, it should return namespace, ClusterRole, and ClusterRoleBinding", func(t *testing.T) { + g := NewGomegaWithT(t) + objects := setupSharedIngress() + g.Expect(objects).To(HaveLen(3)) + + g.Expect(objects[0].GetName()).To(Equal(sharedingress.RouterNamespace)) + g.Expect(objects[0].GetLabels()).To(HaveKeyWithValue("hypershift.openshift.io/component", "shared-ingress")) + + g.Expect(objects[1].GetName()).To(Equal(sharedingress.ConfigGeneratorName)) + g.Expect(objects[2].GetName()).To(Equal(sharedingress.ConfigGeneratorName)) + + crb := objects[2].(*rbacv1.ClusterRoleBinding) + g.Expect(crb.RoleRef.Name).To(Equal(sharedingress.ConfigGeneratorName)) + g.Expect(crb.Subjects).To(HaveLen(1)) + g.Expect(crb.Subjects[0].Name).To(Equal("router")) + g.Expect(crb.Subjects[0].Namespace).To(Equal(sharedingress.RouterNamespace)) + }) +} + +func TestSetupAdminRBAC(t *testing.T) { + t.Run("When setupAdminRBAC is called, it should return client and reader RBAC resources", func(t *testing.T) { + g := NewGomegaWithT(t) + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "hypershift"}} + objects := setupAdminRBAC(ns) + + g.Expect(objects).To(HaveLen(5)) + + objectNames := make([]string, len(objects)) + for i, obj := range objects { + objectNames[i] = obj.GetName() + } + g.Expect(objectNames).To(ContainElement("hypershift-client")) + g.Expect(objectNames).To(ContainElement("hypershift-readers")) + }) +} + +func TestGetDeploymentCondition(t *testing.T) { + tests := []struct { + name string + deployConditions []appsv1.DeploymentCondition + condType string + expectFound bool + }{ + { + name: "When the condition exists, it should return it", + deployConditions: []appsv1.DeploymentCondition{ + {Type: "Progressing", Reason: "NewReplicaSetAvailable"}, + {Type: "Available", Reason: "MinimumReplicasAvailable"}, + }, + condType: "Available", + expectFound: true, + }, + { + name: "When the condition does not exist, it should return nil", + deployConditions: []appsv1.DeploymentCondition{ + {Type: "Progressing", Reason: "NewReplicaSetAvailable"}, + }, + condType: "Available", + expectFound: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + status := appsv1.DeploymentStatus{} + for _, c := range tc.deployConditions { + status.Conditions = append(status.Conditions, appsv1.DeploymentCondition{ + Type: appsv1.DeploymentConditionType(c.Type), + Reason: c.Reason, + }) + } + cond := GetDeploymentCondition(status, appsv1.DeploymentConditionType(tc.condType)) + if tc.expectFound { + g.Expect(cond).NotTo(BeNil()) + } else { + g.Expect(cond).To(BeNil()) + } + }) + } +} + +func TestNewInstallOptionsWithDefaults(t *testing.T) { + t.Run("When NewInstallOptionsWithDefaults is called, it should set all expected defaults", func(t *testing.T) { + g := NewGomegaWithT(t) + opts := NewInstallOptionsWithDefaults() + + g.Expect(opts.Namespace).To(Equal("hypershift")) + g.Expect(opts.PrivatePlatform).To(Equal(string(hyperv1.NonePlatform))) + g.Expect(opts.HyperShiftImage).To(Equal(HyperShiftImage)) + g.Expect(opts.ExternalDNSImage).To(Equal(ExternalDNSImage)) + g.Expect(opts.CertRotationScale).To(Equal(24 * time.Hour)) + g.Expect(opts.ImagePullPolicy).To(Equal("IfNotPresent")) + g.Expect(opts.EnableConversionWebhook).To(BeTrue()) + g.Expect(opts.EnableDedicatedRequestServingIsolation).To(BeTrue()) + g.Expect(opts.EnableEtcdRecovery).To(BeTrue()) + g.Expect(opts.MetricsSet).To(Equal(metrics.DefaultMetricsSet)) + g.Expect(opts.AWSPrivateCredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.ScaleFromZeroCredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.OIDCStorageProviderS3CredentialsSecretKey).To(Equal("credentials")) + g.Expect(opts.Development).To(BeFalse()) + g.Expect(opts.EnableAdminRBACGeneration).To(BeFalse()) + g.Expect(opts.AdditionalOperatorEnvVars).NotTo(BeNil()) + }) +} + +func TestHyperShiftNamespaceBuild(t *testing.T) { + tests := []struct { + name string + nsConfig assets.HyperShiftNamespace + expectClusterMonitoringLabel bool + }{ + { + name: "When OCP cluster monitoring is enabled, it should include the monitoring label", + nsConfig: assets.HyperShiftNamespace{ + Name: "hypershift", + EnableOCPClusterMonitoring: true, + }, + expectClusterMonitoringLabel: true, + }, + { + name: "When OCP cluster monitoring is disabled, it should not include the monitoring label", + nsConfig: assets.HyperShiftNamespace{ + Name: "hypershift", + EnableOCPClusterMonitoring: false, + }, + expectClusterMonitoringLabel: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ns := tc.nsConfig.Build() + g.Expect(ns.Name).To(Equal(tc.nsConfig.Name)) + g.Expect(ns.Labels).To(HaveKeyWithValue("hypershift.openshift.io/component", "operator")) + if tc.expectClusterMonitoringLabel { + g.Expect(ns.Labels).To(HaveKeyWithValue("openshift.io/cluster-monitoring", "true")) + } else { + g.Expect(ns.Labels).NotTo(HaveKey("openshift.io/cluster-monitoring")) + } + }) + } +} diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go index 1dba7b85b32..c93e7adc6f3 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go @@ -603,173 +603,201 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C return err } + endpointID, endpointDNSEntries, err := r.ensureVPCEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) + if err != nil { + return err + } + + if len(endpointDNSEntries) == 0 { + log.Info("endpoint has no DNS entries, skipping DNS record creation", "endpointID", endpointID) + return nil + } + + fqdns, zoneID, err := r.reconcileEndpointDNSRecords(ctx, route53Client, awsEndpointService, hcp, endpointDNSEntries, log) + if err != nil { + return err + } + + awsEndpointService.Status.DNSNames = fqdns + awsEndpointService.Status.DNSZoneID = zoneID + + return r.reconcileExternalNameServices(ctx, hcp, endpointDNSEntries, log) +} + +func (r *AWSEndpointServiceReconciler) ensureVPCEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { endpointID := awsEndpointService.Status.EndpointID - var endpointDNSEntries []ec2types.DnsEntry if endpointID != "" { - // check if Endpoint exists in AWS - output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ - VpcEndpointIds: []string{endpointID}, - }) - if err != nil { - log.Error(err, "failed to describe vpc endpoint", "endpointID", endpointID) - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - if apiErr.ErrorCode() == "InvalidVpcEndpointId.NotFound" { - // clear the EndpointID so a new Endpoint is created on the requeue - awsEndpointService.Status.EndpointID = "" - return fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) - } else { - return errors.New(apiErr.ErrorCode()) - } + return r.reconcileExistingEndpoint(ctx, ec2Client, awsEndpointService, endpointID, log) + } + return r.reconcileNewEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) +} + +func (r *AWSEndpointServiceReconciler) reconcileExistingEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, endpointID string, log logr.Logger) (string, []ec2types.DnsEntry, error) { + output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ + VpcEndpointIds: []string{endpointID}, + }) + if err != nil { + log.Error(err, "failed to describe vpc endpoint", "endpointID", endpointID) + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + if apiErr.ErrorCode() == "InvalidVpcEndpointId.NotFound" { + awsEndpointService.Status.EndpointID = "" + return "", nil, fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) + } else { + return "", nil, errors.New(apiErr.ErrorCode()) } - return err } + return "", nil, err + } - if aws.ToString(output.VpcEndpoints[0].ServiceName) != awsEndpointService.Status.EndpointServiceName { - log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(output.VpcEndpoints[0].ServiceName), "WantedVPCEndpointService", awsEndpointService.Status.EndpointServiceName) - if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ - VpcEndpointIds: []string{aws.ToString(output.VpcEndpoints[0].VpcEndpointId)}, - }); err != nil { - log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(output.VpcEndpoints[0].VpcEndpointId)) - return fmt.Errorf("error deleting AWSEndpoint: %w", err) - } + if len(output.VpcEndpoints) == 0 { + awsEndpointService.Status.EndpointID = "" + return "", nil, fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) + } - // Once the VPC Endpoint is deleted, we need to send an error in order to reexecute the reconcilliation - return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(output.VpcEndpoints[0].ServiceName)) - } + if err := deleteEndpointIfWrongService(ctx, ec2Client, output.VpcEndpoints[0], awsEndpointService.Status.EndpointServiceName, log); err != nil { + return "", nil, err + } - if len(output.VpcEndpoints) == 0 { - // This should not happen but just in case - // clear the EndpointID so a new Endpoint is created on the requeue - awsEndpointService.Status.EndpointID = "" - return fmt.Errorf("endpoint with id %s not found, resetting status", endpointID) - } - log.Info("endpoint exists", "endpointID", endpointID) - endpointDNSEntries = output.VpcEndpoints[0].DnsEntries + log.Info("endpoint exists", "endpointID", endpointID) - // Ensure endpoint has the right subnets. - addedSubnet, removedSubnet := diffIDs(awsEndpointService.Spec.SubnetIDs, output.VpcEndpoints[0].SubnetIds) + if err := modifyEndpointIfNeeded(ctx, ec2Client, awsEndpointService, output.VpcEndpoints[0], endpointID, log); err != nil { + return "", nil, err + } - // Ensure endpoint has the right SG. - existingSG := make([]string, 0) - for _, group := range output.VpcEndpoints[0].Groups { - existingSG = append(existingSG, aws.ToString(group.GroupId)) - } - addedSG, _ := diffIDs([]string{awsEndpointService.Status.SecurityGroupID}, existingSG) - - if addedSubnet != nil || removedSubnet != nil || addedSG != nil { - log.Info("endpoint subnets or security groups have changed") - _, err := ec2Client.ModifyVpcEndpoint(ctx, &ec2.ModifyVpcEndpointInput{ - VpcEndpointId: aws.String(endpointID), - AddSubnetIds: addedSubnet, - RemoveSubnetIds: removedSubnet, - AddSecurityGroupIds: addedSG, - }) - if err != nil { - log.Error(err, "failed to modify vpc endpoint", "id", endpointID, "addSubnets", addedSubnet, "removeSubnets", removedSubnet, "addSG", addedSG) - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to modify vpc endpoint") - return fmt.Errorf("failed to modify vpc endpoint: %s", msg) - } - log.Info("endpoint subnets updated") - } else { - log.Info("endpoint subnets are unchanged") + return endpointID, output.VpcEndpoints[0].DnsEntries, nil +} + +func deleteEndpointIfWrongService(ctx context.Context, ec2Client awsapi.EC2API, endpoint ec2types.VpcEndpoint, expectedServiceName string, log logr.Logger) error { + if aws.ToString(endpoint.ServiceName) == expectedServiceName { + return nil + } + log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(endpoint.ServiceName), "WantedVPCEndpointService", expectedServiceName) + if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ + VpcEndpointIds: []string{aws.ToString(endpoint.VpcEndpointId)}, + }); err != nil { + log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(endpoint.VpcEndpointId)) + return fmt.Errorf("error deleting AWSEndpoint: %w", err) + } + return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(endpoint.ServiceName)) +} + +func modifyEndpointIfNeeded(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, endpoint ec2types.VpcEndpoint, endpointID string, log logr.Logger) error { + addedSubnet, removedSubnet := diffIDs(awsEndpointService.Spec.SubnetIDs, endpoint.SubnetIds) + + existingSG := make([]string, 0) + for _, group := range endpoint.Groups { + existingSG = append(existingSG, aws.ToString(group.GroupId)) + } + addedSG, _ := diffIDs([]string{awsEndpointService.Status.SecurityGroupID}, existingSG) + + if addedSubnet == nil && removedSubnet == nil && addedSG == nil { + log.Info("endpoint subnets are unchanged") + return nil + } + + log.Info("endpoint subnets or security groups have changed") + _, err := ec2Client.ModifyVpcEndpoint(ctx, &ec2.ModifyVpcEndpointInput{ + VpcEndpointId: aws.String(endpointID), + AddSubnetIds: addedSubnet, + RemoveSubnetIds: removedSubnet, + AddSecurityGroupIds: addedSG, + }) + if err != nil { + log.Error(err, "failed to modify vpc endpoint", "id", endpointID, "addSubnets", addedSubnet, "removeSubnets", removedSubnet, "addSG", addedSG) + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } - } else { - if !hasAWSConfig(&hcp.Spec.Platform) { - return fmt.Errorf("AWS platform information not provided in HostedControlPlane") + log.Error(err, "failed to modify vpc endpoint") + return fmt.Errorf("failed to modify vpc endpoint: %s", msg) + } + log.Info("endpoint subnets updated") + return nil +} + +func (r *AWSEndpointServiceReconciler) reconcileNewEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { + if !hasAWSConfig(&hcp.Spec.Platform) { + return "", nil, fmt.Errorf("AWS platform information not provided in HostedControlPlane") + } + + output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ + Filters: apiTagToEC2Filter(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), + }) + if err != nil { + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } + log.Error(err, "failed to describe vpc endpoints") + return "", nil, fmt.Errorf("failed to describe vpc endpoints: %s", msg) + } - // Verify there is not already an Endpoint that we can adopt - // This can happen if we have a stale status on AWSEndpointService or encountered - // an error updating the AWSEndpointService on the previous reconcile - output, err := ec2Client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{ - Filters: apiTagToEC2Filter(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), - }) - if err != nil { - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to describe vpc endpoints") - return fmt.Errorf("failed to describe vpc endpoints: %s", msg) + if len(output.VpcEndpoints) != 0 { + if err := deleteEndpointIfWrongService(ctx, ec2Client, output.VpcEndpoints[0], awsEndpointService.Status.EndpointServiceName, log); err != nil { + return "", nil, err } - if len(output.VpcEndpoints) != 0 { - if aws.ToString(output.VpcEndpoints[0].ServiceName) != awsEndpointService.Status.EndpointServiceName { - log.Info("endpoint links to wrong endpointservice, deleting...", "LinkedVPCEndpointServiceName", aws.ToString(output.VpcEndpoints[0].ServiceName), "WantedVPCEndpointService", awsEndpointService.Status.EndpointServiceName) - if _, err := ec2Client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ - VpcEndpointIds: []string{aws.ToString(output.VpcEndpoints[0].VpcEndpointId)}, - }); err != nil { - log.Error(err, "failed to delete vpc endpoint", "id", aws.ToString(output.VpcEndpoints[0].VpcEndpointId)) - return fmt.Errorf("error deleting AWSEndpoint: %w", err) - } + endpointID := aws.ToString(output.VpcEndpoints[0].VpcEndpointId) + log.Info("endpoint already exists, adopting", "endpointID", endpointID) + awsEndpointService.Status.EndpointID = endpointID + return endpointID, output.VpcEndpoints[0].DnsEntries, nil + } - // Once the VPC Endpoint is deleted, we need to send an error in order to reexecute the reconcilliation - return fmt.Errorf("current endpoint %s is not pointing to the existing .Status.EndpointServiceName, reconciling by deleting endpoint", aws.ToString(output.VpcEndpoints[0].ServiceName)) - } - endpointID = aws.ToString(output.VpcEndpoints[0].VpcEndpointId) - log.Info("endpoint already exists, adopting", "endpointID", endpointID) - awsEndpointService.Status.EndpointID = endpointID - endpointDNSEntries = output.VpcEndpoints[0].DnsEntries - } else { - log.Info("endpoint does not already exist") - - if awsEndpointService.Status.SecurityGroupID == "" { - return fmt.Errorf("security group ID doesn't exist yet for the endpoint to use") - } - output, err := ec2Client.CreateVpcEndpoint(ctx, &ec2.CreateVpcEndpointInput{ - SecurityGroupIds: []string{awsEndpointService.Status.SecurityGroupID}, - ServiceName: aws.String(awsEndpointService.Status.EndpointServiceName), - VpcId: aws.String(hcp.Spec.Platform.AWS.CloudProviderConfig.VPC), - VpcEndpointType: ec2types.VpcEndpointTypeInterface, - SubnetIds: awsEndpointService.Spec.SubnetIDs, - TagSpecifications: []ec2types.TagSpecification{{ - ResourceType: ec2types.ResourceTypeVpcEndpoint, - Tags: apiTagToEC2Tag(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), - }}, - }) - if err != nil { - msg := err.Error() - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - msg = apiErr.ErrorCode() - } - log.Error(err, "failed to create vpc endpoint") - return fmt.Errorf("failed to create vpc endpoint: %s", msg) - } - if output == nil || output.VpcEndpoint == nil { - return fmt.Errorf("CreateVpcEndpoint output is nil") - } + return r.createVPCEndpoint(ctx, ec2Client, awsEndpointService, hcp, log) +} - endpointID = aws.ToString(output.VpcEndpoint.VpcEndpointId) - log.Info("endpoint created", "endpointID", endpointID) - awsEndpointService.Status.EndpointID = endpointID - endpointDNSEntries = output.VpcEndpoint.DnsEntries +func (r *AWSEndpointServiceReconciler) createVPCEndpoint(ctx context.Context, ec2Client awsapi.EC2API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, log logr.Logger) (string, []ec2types.DnsEntry, error) { + log.Info("endpoint does not already exist") + + if awsEndpointService.Status.SecurityGroupID == "" { + return "", nil, fmt.Errorf("security group ID doesn't exist yet for the endpoint to use") + } + output, err := ec2Client.CreateVpcEndpoint(ctx, &ec2.CreateVpcEndpointInput{ + SecurityGroupIds: []string{awsEndpointService.Status.SecurityGroupID}, + ServiceName: aws.String(awsEndpointService.Status.EndpointServiceName), + VpcId: aws.String(hcp.Spec.Platform.AWS.CloudProviderConfig.VPC), + VpcEndpointType: ec2types.VpcEndpointTypeInterface, + SubnetIds: awsEndpointService.Spec.SubnetIDs, + TagSpecifications: []ec2types.TagSpecification{{ + ResourceType: ec2types.ResourceTypeVpcEndpoint, + Tags: apiTagToEC2Tag(awsEndpointService.Name, hcp.Spec.Platform.AWS.ResourceTags), + }}, + }) + if err != nil { + msg := err.Error() + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + msg = apiErr.ErrorCode() } + log.Error(err, "failed to create vpc endpoint") + return "", nil, fmt.Errorf("failed to create vpc endpoint: %s", msg) } - - if len(endpointDNSEntries) == 0 { - log.Info("endpoint has no DNS entries, skipping DNS record creation", "endpointID", endpointID) - return nil + if output == nil || output.VpcEndpoint == nil { + return "", nil, fmt.Errorf("CreateVpcEndpoint output is nil") } + endpointID := aws.ToString(output.VpcEndpoint.VpcEndpointId) + log.Info("endpoint created", "endpointID", endpointID) + awsEndpointService.Status.EndpointID = endpointID + return endpointID, output.VpcEndpoint.DnsEntries, nil +} + +func (r *AWSEndpointServiceReconciler) reconcileEndpointDNSRecords(ctx context.Context, route53Client awsapi.ROUTE53API, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, endpointDNSEntries []ec2types.DnsEntry, log logr.Logger) ([]string, string, error) { recordNames := recordsForService(awsEndpointService, hcp) if len(recordNames) == 0 { log.Info("WARNING: no mapping from AWSEndpointService to DNS") - return nil + return nil, "", nil } - zoneName := zoneName(hcp.Name) + zn := zoneName(hcp.Name) var zoneID string if r.awsClientBuilder.getLocalHostedZoneID() == "" { - zoneID, err = lookupZoneID(ctx, route53Client, zoneName) + var err error + zoneID, err = lookupZoneID(ctx, route53Client, zn) if err != nil { - return err + return nil, "", err } r.awsClientBuilder.setLocalHostedZoneID(zoneID) } else { @@ -778,21 +806,23 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C var fqdns []string for _, recordName := range recordNames { - fqdn := fmt.Sprintf("%s.%s", recordName, zoneName) + fqdn := fmt.Sprintf("%s.%s", recordName, zn) fqdns = append(fqdns, fqdn) - err = CreateRecord(ctx, route53Client, zoneID, fqdn, aws.ToString(endpointDNSEntries[0].DnsName), route53types.RRTypeCname) + err := CreateRecord(ctx, route53Client, zoneID, fqdn, aws.ToString(endpointDNSEntries[0].DnsName), route53types.RRTypeCname) if err != nil { - return err + return nil, "", err } log.Info("DNS record created", "fqdn", fqdn) } - awsEndpointService.Status.DNSNames = fqdns - awsEndpointService.Status.DNSZoneID = zoneID + return fqdns, zoneID, nil +} + +func (r *AWSEndpointServiceReconciler) reconcileExternalNameServices(ctx context.Context, hcp *hyperv1.HostedControlPlane, endpointDNSEntries []ec2types.DnsEntry, log logr.Logger) error { + isPublic := netutil.IsPublicHCP(hcp) + externalNames := hcpExternalNames(hcp) - if isPublic, externalNames := netutil.IsPublicHCP(hcp), hcpExternalNames(hcp); !isPublic && len(externalNames) > 0 { - // only if not public and external names are configured, create services of type ExternalName so external-dns - // can create records for them + if !isPublic && len(externalNames) > 0 { var errs []error for svcType, externalName := range externalNames { var svc *corev1.Service @@ -812,27 +842,26 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C if len(errs) > 0 { return fmt.Errorf("failed to create external services for private endpoints: %w", utilerrors.NewAggregate(errs)) } - } else { - // if the cluster is public, ensure that any ExternalName services are removed - privateExternalServices := &corev1.ServiceList{} - if err := r.List(ctx, privateExternalServices, client.HasLabels{externalPrivateServiceLabel}); err != nil { - return fmt.Errorf("cannot list private external services: %w", err) - } - if len(privateExternalServices.Items) > 0 { - log.Info("Removing private external services", "count", len(privateExternalServices.Items)) - var errs []error - for i := range privateExternalServices.Items { - svc := &privateExternalServices.Items[i] - if err := r.Delete(ctx, svc); err != nil { - errs = append(errs, fmt.Errorf("failed to delete private external service %s: %w", svc.Name, err)) - } - } - if len(errs) > 0 { - return utilerrors.NewAggregate(errs) + return nil + } + + privateExternalServices := &corev1.ServiceList{} + if err := r.List(ctx, privateExternalServices, client.InNamespace(hcp.Namespace), client.HasLabels{externalPrivateServiceLabel}); err != nil { + return fmt.Errorf("cannot list private external services: %w", err) + } + if len(privateExternalServices.Items) > 0 { + log.Info("Removing private external services", "count", len(privateExternalServices.Items)) + var errs []error + for i := range privateExternalServices.Items { + svc := &privateExternalServices.Items[i] + if err := r.Delete(ctx, svc); err != nil { + errs = append(errs, fmt.Errorf("failed to delete private external service %s: %w", svc.Name, err)) } } + if len(errs) > 0 { + return utilerrors.NewAggregate(errs) + } } - return nil } diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go index 539a96f3248..907469a5626 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go @@ -20,6 +20,7 @@ import ( route53types "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/aws/smithy-go" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -769,6 +770,716 @@ func TestDeleteSecurityGroup(t *testing.T) { // // A proper fix requires persisting the SharedVPC role ARNs in the AWSEndpointService // status so the deletion path can authenticate independently of the HCP. +func TestHasAWSConfig(t *testing.T) { + tests := []struct { + name string + platform hyperv1.PlatformSpec + expected bool + }{ + { + name: "When all AWS config fields are present, it should return true", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: &hyperv1.AWSResourceReference{ + ID: aws.String("subnet-123"), + }, + }, + }, + }, + expected: true, + }, + { + name: "When platform type is not AWS, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AzurePlatform, + }, + expected: false, + }, + { + name: "When AWS spec is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: nil, + }, + expected: false, + }, + { + name: "When CloudProviderConfig is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: nil, + }, + }, + expected: false, + }, + { + name: "When Subnet is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: nil, + }, + }, + }, + expected: false, + }, + { + name: "When Subnet.ID is nil, it should return false", + platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + CloudProviderConfig: &hyperv1.AWSCloudProviderConfig{ + Subnet: &hyperv1.AWSResourceReference{ + ID: nil, + }, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(hasAWSConfig(&tt.platform)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointPort(t *testing.T) { + tests := []struct { + name string + svcName string + expected int32 + }{ + { + name: "When service is kube-apiserver-private, it should return 6443", + svcName: "kube-apiserver-private", + expected: 6443, + }, + { + name: "When service is private-router, it should return 443", + svcName: "private-router", + expected: 443, + }, + { + name: "When service is an unknown name, it should return 443", + svcName: "some-other-service", + expected: 443, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + aes := &hyperv1.AWSEndpointService{ + ObjectMeta: metav1.ObjectMeta{Name: tt.svcName}, + } + g.Expect(vpcEndpointPort(aes)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointSecurityGroupName(t *testing.T) { + tests := []struct { + name string + infraID string + endpointName string + expected string + }{ + { + name: "When given infraID and endpoint name, it should format correctly", + infraID: "my-infra", + endpointName: "kube-apiserver-private", + expected: "my-infra-vpce-kube-apiserver-private", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(vpcEndpointSecurityGroupName(tt.infraID, tt.endpointName)).To(Equal(tt.expected)) + }) + } +} + +func TestVPCEndpointSecurityGroupFilter(t *testing.T) { + tests := []struct { + name string + infraID string + endpointName string + }{ + { + name: "When given infraID and endpoint name, it should return cluster tag and name tag filters", + infraID: "test-infra", + endpointName: "kube-apiserver-private", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + filters := vpcEndpointSecurityGroupFilter(tt.infraID, tt.endpointName) + g.Expect(filters).To(HaveLen(2)) + g.Expect(aws.ToString(filters[0].Name)).To(Equal("tag:kubernetes.io/cluster/test-infra")) + g.Expect(filters[0].Values).To(Equal([]string{"owned"})) + g.Expect(aws.ToString(filters[1].Name)).To(Equal("tag:Name")) + g.Expect(filters[1].Values).To(Equal([]string{"test-infra-vpce-kube-apiserver-private"})) + }) + } +} + +func TestApiTagToEC2Tag(t *testing.T) { + tests := []struct { + name string + svcName string + tags []hyperv1.AWSResourceTag + expected []ec2types.Tag + }{ + { + name: "When no resource tags are provided, it should return only the AWSEndpointService tag", + svcName: "my-svc", + tags: nil, + expected: []ec2types.Tag{ + {Key: aws.String("AWSEndpointService"), Value: aws.String("my-svc")}, + }, + }, + { + name: "When resource tags are provided, it should include them plus the AWSEndpointService tag", + svcName: "my-svc", + tags: []hyperv1.AWSResourceTag{ + {Key: "env", Value: "prod"}, + {Key: "team", Value: "platform"}, + }, + expected: []ec2types.Tag{ + {Key: aws.String("env"), Value: aws.String("prod")}, + {Key: aws.String("team"), Value: aws.String("platform")}, + {Key: aws.String("AWSEndpointService"), Value: aws.String("my-svc")}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := apiTagToEC2Tag(tt.svcName, tt.tags) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestApiTagToEC2Filter(t *testing.T) { + tests := []struct { + name string + svcName string + tags []hyperv1.AWSResourceTag + expected []ec2types.Filter + }{ + { + name: "When no resource tags are provided, it should return only the AWSEndpointService filter", + svcName: "my-svc", + tags: nil, + expected: []ec2types.Filter{ + {Name: aws.String("tag:AWSEndpointService"), Values: []string{"my-svc"}}, + }, + }, + { + name: "When resource tags are provided, it should include them as tag filters plus the AWSEndpointService filter", + svcName: "my-svc", + tags: []hyperv1.AWSResourceTag{ + {Key: "env", Value: "prod"}, + }, + expected: []ec2types.Filter{ + {Name: aws.String("tag:env"), Values: []string{"prod"}}, + {Name: aws.String("tag:AWSEndpointService"), Values: []string{"my-svc"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := apiTagToEC2Filter(tt.svcName, tt.tags) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestZoneName(t *testing.T) { + tests := []struct { + name string + hcpName string + expected string + }{ + { + name: "When given an HCP name, it should append the hypershift.local suffix", + hcpName: "my-cluster", + expected: "my-cluster.hypershift.local", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(zoneName(tt.hcpName)).To(Equal(tt.expected)) + }) + } +} + +func TestRouterZoneName(t *testing.T) { + tests := []struct { + name string + hcpName string + expected string + }{ + { + name: "When given an HCP name, it should prepend apps and append the hypershift.local suffix", + hcpName: "my-cluster", + expected: "apps.my-cluster.hypershift.local", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(RouterZoneName(tt.hcpName)).To(Equal(tt.expected)) + }) + } +} + +func TestControllerName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "When given a name, it should append the observer suffix", + input: "kube-apiserver", + expected: "kube-apiserver-observer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(ControllerName(tt.input)).To(Equal(tt.expected)) + }) + } +} + +func TestNameMapper(t *testing.T) { + tests := []struct { + name string + watchedNames []string + incomingName string + incomingNS string + expectRequests int + }{ + { + name: "When incoming object name matches a watched name, it should return a reconcile request", + watchedNames: []string{"kube-apiserver-private", "private-router"}, + incomingName: "kube-apiserver-private", + incomingNS: "test-ns", + expectRequests: 1, + }, + { + name: "When incoming object name does not match any watched name, it should return nil", + watchedNames: []string{"kube-apiserver-private"}, + incomingName: "other-service", + incomingNS: "test-ns", + expectRequests: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mapFn := nameMapper(tt.watchedNames) + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: tt.incomingName, + Namespace: tt.incomingNS, + }, + } + requests := mapFn(t.Context(), svc) + g.Expect(requests).To(HaveLen(tt.expectRequests)) + if tt.expectRequests > 0 { + g.Expect(requests[0].Name).To(Equal(tt.incomingName)) + g.Expect(requests[0].Namespace).To(Equal(tt.incomingNS)) + } + }) + } +} + +func TestHCPExternalNames(t *testing.T) { + tests := []struct { + name string + hcp *hyperv1.HostedControlPlane + expected map[string]string + }{ + { + name: "When API and OAuth both have Route strategies with hostnames, it should return both", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: "api.example.com"}, + }, + }, + { + Service: hyperv1.OAuthServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: "oauth.example.com"}, + }, + }, + }, + }, + }, + expected: map[string]string{ + "api": "api.example.com", + "oauth": "oauth.example.com", + }, + }, + { + name: "When no Route strategies are configured, it should return empty map", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.LoadBalancer, + }, + }, + }, + }, + }, + expected: map[string]string{}, + }, + { + name: "When Route strategy has no hostname, it should not include that entry", + hcp: &hyperv1.HostedControlPlane{ + Spec: hyperv1.HostedControlPlaneSpec{ + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.Route, + Route: &hyperv1.RoutePublishingStrategy{Hostname: ""}, + }, + }, + }, + }, + }, + expected: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + result := hcpExternalNames(tt.hcp) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestReconcileExternalService(t *testing.T) { + tests := []struct { + name string + hostName string + targetCName string + expectedLabels map[string]string + }{ + { + name: "When reconciling an external service, it should set type, labels, annotations, and external name", + hostName: "api.example.com", + targetCName: "vpce-abc.vpce-svc.us-east-1.vpce.amazonaws.com", + expectedLabels: map[string]string{ + externalPrivateServiceLabel: "true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + hcp := &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: "test-ns", + UID: "test-uid", + }, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver-private-external", + Namespace: "test-ns", + }, + } + + err := reconcileExternalService(svc, hcp, tt.hostName, tt.targetCName) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(svc.Spec.Type).To(Equal(corev1.ServiceTypeExternalName)) + g.Expect(svc.Spec.ExternalName).To(Equal(tt.targetCName)) + g.Expect(svc.Labels[externalPrivateServiceLabel]).To(Equal("true")) + g.Expect(svc.Annotations[hyperv1.ExternalDNSHostnameAnnotation]).To(Equal(tt.hostName)) + g.Expect(svc.OwnerReferences).To(HaveLen(1)) + g.Expect(svc.OwnerReferences[0].Name).To(Equal("test-hcp")) + }) + } +} + +func TestDeleteEndpointIfWrongService(t *testing.T) { + tests := []struct { + name string + endpointServiceName string + expectedServiceName string + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + expectError bool + }{ + { + name: "When endpoint points to the correct service, it should return nil", + endpointServiceName: "com.amazonaws.vpce-svc-abc", + expectedServiceName: "com.amazonaws.vpce-svc-abc", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + return awsapi.NewMockEC2API(mockCtrl) + }, + expectError: false, + }, + { + name: "When endpoint points to wrong service, it should delete and return error", + endpointServiceName: "com.amazonaws.vpce-svc-wrong", + expectedServiceName: "com.amazonaws.vpce-svc-correct", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DeleteVpcEndpoints(gomock.Any(), gomock.Any()).Return(&ec2v2.DeleteVpcEndpointsOutput{}, nil) + return m + }, + expectError: true, + }, + { + name: "When endpoint points to wrong service and delete fails, it should return the delete error", + endpointServiceName: "com.amazonaws.vpce-svc-wrong", + expectedServiceName: "com.amazonaws.vpce-svc-correct", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DeleteVpcEndpoints(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("access denied")) + return m + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + endpoint := ec2types.VpcEndpoint{ + VpcEndpointId: aws.String("vpce-123"), + ServiceName: aws.String(tt.endpointServiceName), + } + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + err := deleteEndpointIfWrongService(ctx, mockEC2, endpoint, tt.expectedServiceName, ctrl.Log.WithName("test")) + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + }) + } +} + +func TestModifyEndpointIfNeeded(t *testing.T) { + tests := []struct { + name string + specSubnetIDs []string + endpointSubnetIDs []string + specSecurityGroupID string + endpointGroups []ec2types.SecurityGroupIdentifier + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + expectError bool + }{ + { + name: "When subnets and security groups are unchanged, it should not modify", + specSubnetIDs: []string{"subnet-1", "subnet-2"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + return awsapi.NewMockEC2API(mockCtrl) + }, + expectError: false, + }, + { + name: "When subnets have changed, it should call ModifyVpcEndpoint", + specSubnetIDs: []string{"subnet-1", "subnet-3"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(&ec2v2.ModifyVpcEndpointOutput{}, nil) + return m + }, + expectError: false, + }, + { + name: "When security group needs adding, it should call ModifyVpcEndpoint", + specSubnetIDs: []string{"subnet-1"}, + endpointSubnetIDs: []string{"subnet-1"}, + specSecurityGroupID: "sg-new", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-old")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(&ec2v2.ModifyVpcEndpointOutput{}, nil) + return m + }, + expectError: false, + }, + { + name: "When ModifyVpcEndpoint fails, it should return error", + specSubnetIDs: []string{"subnet-1", "subnet-3"}, + endpointSubnetIDs: []string{"subnet-1", "subnet-2"}, + specSecurityGroupID: "sg-123", + endpointGroups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-123")}}, + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().ModifyVpcEndpoint(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("throttling")) + return m + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + aes := &hyperv1.AWSEndpointService{ + Spec: hyperv1.AWSEndpointServiceSpec{ + SubnetIDs: tt.specSubnetIDs, + }, + Status: hyperv1.AWSEndpointServiceStatus{ + SecurityGroupID: tt.specSecurityGroupID, + }, + } + + endpoint := ec2types.VpcEndpoint{ + SubnetIds: tt.endpointSubnetIDs, + Groups: tt.endpointGroups, + } + + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + err := modifyEndpointIfNeeded(ctx, mockEC2, aes, endpoint, "vpce-123", ctrl.Log.WithName("test")) + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + }) + } +} + +func TestReconcileExistingEndpoint(t *testing.T) { + tests := []struct { + name string + endpointID string + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + expectError bool + expectEndpointID string + }{ + { + name: "When DescribeVpcEndpoints returns empty results, it should reset EndpointID and return error", + endpointID: "vpce-123", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeVpcEndpoints(gomock.Any(), gomock.Any()).Return(&ec2v2.DescribeVpcEndpointsOutput{ + VpcEndpoints: []ec2types.VpcEndpoint{}, + }, nil) + return m + }, + expectError: true, + expectEndpointID: "", + }, + { + name: "When DescribeVpcEndpoints returns NotFound API error, it should reset EndpointID and return error", + endpointID: "vpce-gone", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeVpcEndpoints(gomock.Any(), gomock.Any()).Return(nil, &smithy.GenericAPIError{Code: "InvalidVpcEndpointId.NotFound", Message: "not found"}) + return m + }, + expectError: true, + expectEndpointID: "", + }, + { + name: "When endpoint exists and matches service, it should return the endpoint ID", + endpointID: "vpce-active", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeVpcEndpoints(gomock.Any(), gomock.Any()).Return(&ec2v2.DescribeVpcEndpointsOutput{ + VpcEndpoints: []ec2types.VpcEndpoint{ + { + VpcEndpointId: aws.String("vpce-active"), + ServiceName: aws.String("com.amazonaws.vpce-svc-test"), + SubnetIds: []string{"subnet-1"}, + Groups: []ec2types.SecurityGroupIdentifier{{GroupId: aws.String("sg-test")}}, + }, + }, + }, nil) + return m + }, + expectError: false, + expectEndpointID: "vpce-active", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + awsEndpointService := &hyperv1.AWSEndpointService{ + Status: hyperv1.AWSEndpointServiceStatus{ + EndpointID: tt.endpointID, + EndpointServiceName: "com.amazonaws.vpce-svc-test", + SecurityGroupID: "sg-test", + }, + Spec: hyperv1.AWSEndpointServiceSpec{ + SubnetIDs: []string{"subnet-1"}, + }, + } + + r := &AWSEndpointServiceReconciler{} + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + resultID, _, err := r.reconcileExistingEndpoint(ctx, mockEC2, awsEndpointService, tt.endpointID, ctrl.Log.WithName("test")) + if tt.expectError { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + g.Expect(resultID).To(Equal(tt.expectEndpointID)) + g.Expect(awsEndpointService.Status.EndpointID).To(Equal(tt.expectEndpointID)) + }) + } +} + func TestReconcileDeletionSharedVPC(t *testing.T) { now := metav1.NewTime(time.Now()) diff --git a/control-plane-operator/controllers/azureprivatelinkservice/controller.go b/control-plane-operator/controllers/azureprivatelinkservice/controller.go index a43b09b57be..13cd2255fdc 100644 --- a/control-plane-operator/controllers/azureprivatelinkservice/controller.go +++ b/control-plane-operator/controllers/azureprivatelinkservice/controller.go @@ -860,191 +860,182 @@ func (r *AzurePrivateLinkServiceReconciler) handleAzureError(ctx context.Context // CPO finalizer has completed PE cleanup and removed its finalizer from the CR. func (r *AzurePrivateLinkServiceReconciler) reconcileDelete(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, log logr.Logger) error { resourceGroup := azPLS.Spec.ResourceGroupName - - // 1. Delete DNS resources using the zone name persisted in status. - // This avoids a dependency on the HostedControlPlane during deletion, which may - // already be torn down or unavailable when the finalizer runs. dnsZoneName := azPLS.Status.DNSZoneName + if dnsZoneName != "" { - // Delete both A records (KAS apex and wildcard apps) - for _, recordName := range []string{kasARecordName, appsARecordName} { - log.Info("Deleting A record", "record", recordName, "zone", dnsZoneName) - deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) - if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, dnsZoneName, armprivatedns.RecordTypeA, recordName, nil); err != nil { - cancel() - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete A record %q: %w", recordName, err) - } - } - cancel() + if err := r.deleteDNSResources(ctx, resourceGroup, dnsZoneName, azPLS.Name, log); err != nil { + return err } + } else { + log.V(1).Info("DNSZoneName not set in status, skipping DNS cleanup") + } - // 2. Delete VNet link (must be deleted before zone) - linkName := vnetLinkName(azPLS.Name) - log.Info("Deleting VNet Link", "name", linkName) - deleteCtx2, cancel2 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel2() - linkPoller, err := r.VirtualNetworkLinks.BeginDelete(deleteCtx2, resourceGroup, dnsZoneName, linkName, nil) - if err != nil { + if err := r.deleteBaseDomainResources(ctx, azPLS, resourceGroup, dnsZoneName, log); err != nil { + return err + } + + return r.deletePrivateEndpoint(ctx, resourceGroup, azPLS.Name, azPLS.Status.PrivateEndpointID, log) +} + +func (r *AzurePrivateLinkServiceReconciler) deleteDNSResources(ctx context.Context, resourceGroup, dnsZoneName, crName string, log logr.Logger) error { + for _, recordName := range []string{kasARecordName, appsARecordName} { + log.Info("Deleting A record", "record", recordName, "zone", dnsZoneName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, dnsZoneName, armprivatedns.RecordTypeA, recordName, nil); err != nil { + cancel() if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting VNet Link: %w", err) - } - } else if linkPoller != nil { - linkPollCtx, linkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer linkPollCancel() - - if _, err := linkPoller.PollUntilDone(linkPollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete VNet Link: %w", err) - } + return fmt.Errorf("failed to delete A record %q: %w", recordName, err) } } + cancel() + } - // 3. Delete Private DNS Zone - log.Info("Deleting Private DNS Zone", "zone", dnsZoneName) - deleteCtx3, cancel3 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel3() - zonePoller, err := r.PrivateDNSZones.BeginDelete(deleteCtx3, resourceGroup, dnsZoneName, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting Private DNS Zone: %w", err) - } - } else if zonePoller != nil { - zonePollCtx, zonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer zonePollCancel() - - if _, err := zonePoller.PollUntilDone(zonePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete Private DNS Zone: %w", err) - } - } + if err := r.deleteVNetLink(ctx, resourceGroup, dnsZoneName, vnetLinkName(crName), log); err != nil { + return err + } + + return r.deleteDNSZone(ctx, resourceGroup, dnsZoneName, log) +} + +func (r *AzurePrivateLinkServiceReconciler) deleteVNetLink(ctx context.Context, resourceGroup, zoneName, linkName string, log logr.Logger) error { + log.Info("Deleting VNet Link", "name", linkName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + linkPoller, err := r.VirtualNetworkLinks.BeginDelete(deleteCtx, resourceGroup, zoneName, linkName, nil) + if err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to begin deleting VNet Link: %w", err) } - } else { - log.V(1).Info("DNSZoneName not set in status, skipping DNS cleanup") + return nil } + if linkPoller == nil { + return nil + } + linkPollCtx, linkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer linkPollCancel() - // 4. Delete base domain DNS resources (A records, VNet link, zone) - baseDomain := azPLS.Spec.BaseDomain - if baseDomain != "" { - // Extract cluster name from the hypershift.local zone name (format: ".hypershift.local") - clusterName := "" - if dnsZoneName != "" { - // Strip the ".hypershift.local" suffix to get the cluster name - if name, ok := strings.CutSuffix(dnsZoneName, "."+privateDNSZoneSuffix); ok { - clusterName = name - } + if _, err := linkPoller.PollUntilDone(linkPollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete VNet Link: %w", err) } + } + return nil +} - if clusterName != "" { - // Determine which base domain A records this CR owns based on its name. - // This mirrors the creation logic in reconcileBaseDomainDNS. - var baseDomainRecords []string - if azPLS.Name == privateRouterCRName { - baseDomainRecords = append(baseDomainRecords, kasBaseDomainRecordPrefix+clusterName) - // Only delete the oauth record if there is no sibling OAuth CR that - // owns it. This prevents the private-router deletion from removing - // an oauth record that now belongs to the dedicated OAuth CR. - hasSiblings, err := r.hasSiblingCR(ctx, azPLS) - if err != nil { - return fmt.Errorf("failed to check for sibling CRs during base domain cleanup: %w", err) - } - if !hasSiblings { - baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) - } - } else { - baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) - } +func (r *AzurePrivateLinkServiceReconciler) deleteDNSZone(ctx context.Context, resourceGroup, zoneName string, log logr.Logger) error { + log.Info("Deleting Private DNS Zone", "zone", zoneName) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + zonePoller, err := r.PrivateDNSZones.BeginDelete(deleteCtx, resourceGroup, zoneName, nil) + if err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to begin deleting Private DNS Zone: %w", err) + } + return nil + } + if zonePoller == nil { + return nil + } + zonePollCtx, zonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer zonePollCancel() - for _, recordName := range baseDomainRecords { - log.Info("Deleting base domain A record", "record", recordName, "zone", baseDomain) - deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) - if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, baseDomain, armprivatedns.RecordTypeA, recordName, nil); err != nil { - cancel() - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain A record %q: %w", recordName, err) - } - } - cancel() - } + if _, err := zonePoller.PollUntilDone(zonePollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete Private DNS Zone: %w", err) } + } + return nil +} - // Delete base domain VNet link - bdLinkName := baseDomainVNetLinkName(azPLS.Name) - log.Info("Deleting base domain VNet Link", "name", bdLinkName) - bdLinkCtx, bdLinkCancel := context.WithTimeout(ctx, azureAPITimeout) - defer bdLinkCancel() - bdLinkPoller, err := r.VirtualNetworkLinks.BeginDelete(bdLinkCtx, resourceGroup, baseDomain, bdLinkName, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting base domain VNet Link: %w", err) - } - } else if bdLinkPoller != nil { - bdLinkPollCtx, bdLinkPollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer bdLinkPollCancel() - - if _, err := bdLinkPoller.PollUntilDone(bdLinkPollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain VNet Link: %w", err) - } - } +func (r *AzurePrivateLinkServiceReconciler) deleteBaseDomainResources(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, resourceGroup, dnsZoneName string, log logr.Logger) error { + baseDomain := azPLS.Spec.BaseDomain + if baseDomain == "" { + return nil + } + + if err := r.deleteBaseDomainARecords(ctx, azPLS, resourceGroup, baseDomain, dnsZoneName, log); err != nil { + return err + } + + if err := r.deleteVNetLink(ctx, resourceGroup, baseDomain, baseDomainVNetLinkName(azPLS.Name), log); err != nil { + return err + } + + hasSiblings, err := r.hasSiblingCR(ctx, azPLS) + if err != nil { + return fmt.Errorf("failed to check for sibling CRs during base domain zone cleanup: %w", err) + } + + if !hasSiblings { + log.Info("Deleting base domain Private DNS Zone (last CR using this zone)", "zone", baseDomain) + return r.deleteDNSZone(ctx, resourceGroup, baseDomain, log) + } + log.Info("Skipping base domain zone deletion, other CRs still use it", "zone", baseDomain) + return nil +} + +func (r *AzurePrivateLinkServiceReconciler) deleteBaseDomainARecords(ctx context.Context, azPLS *hyperv1.AzurePrivateLinkService, resourceGroup, baseDomain, dnsZoneName string, log logr.Logger) error { + clusterName := "" + if dnsZoneName != "" { + if name, ok := strings.CutSuffix(dnsZoneName, "."+privateDNSZoneSuffix); ok { + clusterName = name } + } + if clusterName == "" { + return nil + } - // Only delete the base domain DNS zone if no other CRs share it. - // When multiple CRs (e.g., private-router and oauth-openshift) use the same - // base domain zone, the zone must not be deleted until the last CR is removed. + var baseDomainRecords []string + if azPLS.Name == privateRouterCRName { + baseDomainRecords = append(baseDomainRecords, kasBaseDomainRecordPrefix+clusterName) hasSiblings, err := r.hasSiblingCR(ctx, azPLS) if err != nil { - return fmt.Errorf("failed to check for sibling CRs during base domain zone cleanup: %w", err) + return fmt.Errorf("failed to check for sibling CRs during base domain cleanup: %w", err) } - if !hasSiblings { - log.Info("Deleting base domain Private DNS Zone (last CR using this zone)", "zone", baseDomain) - bdZoneCtx, bdZoneCancel := context.WithTimeout(ctx, azureAPITimeout) - defer bdZoneCancel() - bdZonePoller, err := r.PrivateDNSZones.BeginDelete(bdZoneCtx, resourceGroup, baseDomain, nil) - if err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to begin deleting base domain Private DNS Zone: %w", err) - } - } else if bdZonePoller != nil { - bdZonePollCtx, bdZonePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer bdZonePollCancel() - - if _, err := bdZonePoller.PollUntilDone(bdZonePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete base domain Private DNS Zone: %w", err) - } - } + baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) + } + } else { + baseDomainRecords = append(baseDomainRecords, oauthBaseDomainRecordPrefix+clusterName) + } + + for _, recordName := range baseDomainRecords { + log.Info("Deleting base domain A record", "record", recordName, "zone", baseDomain) + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + if _, err := r.RecordSets.Delete(deleteCtx, resourceGroup, baseDomain, armprivatedns.RecordTypeA, recordName, nil); err != nil { + cancel() + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete base domain A record %q: %w", recordName, err) } - } else { - log.Info("Skipping base domain zone deletion, other CRs still use it", "zone", baseDomain) } + cancel() } + return nil +} - // 5. Delete Private Endpoint - // Always attempt deletion by deterministic name, even when PrivateEndpointID is empty. - // If the status was never populated (e.g., status update failed after PE creation), - // relying solely on PrivateEndpointID would orphan the PE in the customer's subscription. - peName := privateEndpointName(azPLS.Name) - log.Info("Deleting Private Endpoint", "name", peName, "hasStatusID", azPLS.Status.PrivateEndpointID != "") - deleteCtx4, cancel4 := context.WithTimeout(ctx, azureAPITimeout) - defer cancel4() - pePoller, err := r.PrivateEndpoints.BeginDelete(deleteCtx4, resourceGroup, peName, nil) +func (r *AzurePrivateLinkServiceReconciler) deletePrivateEndpoint(ctx context.Context, resourceGroup, crName, statusPEID string, log logr.Logger) error { + peName := privateEndpointName(crName) + log.Info("Deleting Private Endpoint", "name", peName, "hasStatusID", statusPEID != "") + deleteCtx, cancel := context.WithTimeout(ctx, azureAPITimeout) + defer cancel() + pePoller, err := r.PrivateEndpoints.BeginDelete(deleteCtx, resourceGroup, peName, nil) if err != nil { if !azureutil.IsAzureNotFoundError(err) { return fmt.Errorf("failed to begin deleting Private Endpoint: %w", err) } - } else if pePoller != nil { - pePollCtx, pePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) - defer pePollCancel() + return nil + } + if pePoller == nil { + return nil + } + pePollCtx, pePollCancel := context.WithTimeout(ctx, azureutil.PollTimeout) + defer pePollCancel() - if _, err := pePoller.PollUntilDone(pePollCtx, nil); err != nil { - if !azureutil.IsAzureNotFoundError(err) { - return fmt.Errorf("failed to delete Private Endpoint: %w", err) - } + if _, err := pePoller.PollUntilDone(pePollCtx, nil); err != nil { + if !azureutil.IsAzureNotFoundError(err) { + return fmt.Errorf("failed to delete Private Endpoint: %w", err) } } - return nil } diff --git a/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go b/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go index bf377fa6c9d..3e63b043a3e 100644 --- a/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go +++ b/control-plane-operator/controllers/azureprivatelinkservice/controller_test.go @@ -1715,22 +1715,88 @@ func TestReconcileDelete_WhenNoSiblingCRs_ItShouldDeleteBaseDomainZone(t *testin "should delete both VNet links when the last CR goes away") } -func TestHasSiblingCR(t *testing.T) { +func TestBaseDomainVNetLinkName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + crName string + expected string + }{ + { + name: "When CR name is private-router, it should append basedomain VNet link suffix", + crName: "private-router", + expected: "private-router-basedomain-vnet-link", + }, + { + name: "When CR name is oauth-openshift, it should append basedomain VNet link suffix", + crName: "oauth-openshift", + expected: "oauth-openshift-basedomain-vnet-link", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + result := baseDomainVNetLinkName(tt.crName) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestDNSZoneConfigErrMsgQualifier(t *testing.T) { t.Parallel() tests := []struct { name string - azPLS *hyperv1.AzurePrivateLinkService - siblings []client.Object - expectHas bool - expectErr bool + logPrefix string + expected string }{ { - name: "When sibling with same base domain exists it should return true", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), + name: "When logPrefix is empty, it should return empty string", + logPrefix: "", + expected: "", + }, + { + name: "When logPrefix is set, it should return prefix followed by a space", + logPrefix: "base domain", + expected: "base domain ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + cfg := dnsZoneConfig{logPrefix: tt.logPrefix} + g.Expect(cfg.errMsgQualifier()).To(Equal(tt.expected)) + }) + } +} + +func TestRecordNamesForCR(t *testing.T) { + tests := []struct { + name string + crName string + clusterName string + siblings []client.Object + expected []string + }{ + { + name: "When CR is not private-router, it should return only the oauth record", + crName: "oauth-openshift", + clusterName: "my-cluster", + expected: []string{"oauth-my-cluster"}, + }, + { + name: "When CR is private-router with no sibling, it should return api and oauth records", + crName: "private-router", + clusterName: "my-cluster", + expected: []string{"api-my-cluster", "oauth-my-cluster"}, + }, + { + name: "When CR is private-router with sibling OAuth CR, it should return only api record", + crName: "private-router", + clusterName: "my-cluster", siblings: []client.Object{ func() *hyperv1.AzurePrivateLinkService { cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") @@ -1738,122 +1804,253 @@ func TestHasSiblingCR(t *testing.T) { return cr }(), }, - expectHas: true, + expected: []string{"api-my-cluster"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + scheme := newTestScheme(t, g) + + azPLS := newTestAzurePLS(t, tt.crName, "test-ns") + azPLS.Spec.BaseDomain = "example.com" + + objs := append([]client.Object{azPLS}, tt.siblings...) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(objs...). + Build() + + r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} + result, err := r.recordNamesForCR(t.Context(), azPLS, tt.clusterName, testr.New(t)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).To(Equal(tt.expected)) + }) + } +} + +func TestDeleteBaseDomainARecords(t *testing.T) { + tests := []struct { + name string + crName string + dnsZoneName string + baseDomain string + siblings []client.Object + expectedDeletedRecords []string + expectDeleteCalled bool + }{ + { + name: "When dnsZoneName is empty, it should skip deletion", + crName: "private-router", + dnsZoneName: "", + baseDomain: "example.com", + expectDeleteCalled: false, }, { - name: "When no siblings exist it should return false", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), - siblings: []client.Object{}, - expectHas: false, + name: "When dnsZoneName has wrong suffix, it should skip deletion", + crName: "private-router", + dnsZoneName: "cluster.wrong.suffix", + baseDomain: "example.com", + expectDeleteCalled: false, }, { - name: "When sibling has different base domain it should return false", - azPLS: func() *hyperv1.AzurePrivateLinkService { - cr := newTestAzurePLS(t, "private-router", "test-ns") - cr.Spec.BaseDomain = "example.com" - return cr - }(), + name: "When private-router with no siblings, it should delete api and oauth records", + crName: "private-router", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", + expectedDeletedRecords: []string{"api-my-cluster", "oauth-my-cluster"}, + expectDeleteCalled: true, + }, + { + name: "When private-router with sibling OAuth CR, it should delete only api record", + crName: "private-router", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", siblings: []client.Object{ func() *hyperv1.AzurePrivateLinkService { cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") - cr.Spec.BaseDomain = "other.com" + cr.Spec.BaseDomain = "example.com" return cr }(), }, - expectHas: false, + expectedDeletedRecords: []string{"api-my-cluster"}, + expectDeleteCalled: true, + }, + { + name: "When non-private-router CR, it should delete only oauth record", + crName: "oauth-openshift", + dnsZoneName: "my-cluster.hypershift.local", + baseDomain: "example.com", + expectedDeletedRecords: []string{"oauth-my-cluster"}, + expectDeleteCalled: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() g := NewGomegaWithT(t) scheme := newTestScheme(t, g) - objs := append([]client.Object{tt.azPLS}, tt.siblings...) + azPLS := newTestAzurePLS(t, tt.crName, "test-ns") + azPLS.Spec.BaseDomain = tt.baseDomain + azPLS.Status.DNSZoneName = tt.dnsZoneName + + objs := append([]client.Object{azPLS}, tt.siblings...) fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(objs...). Build() - r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} - has, err := r.hasSiblingCR(t.Context(), tt.azPLS) + mockRecords := &mockRecordSets{} + r := &AzurePrivateLinkServiceReconciler{ + Client: fakeClient, + RecordSets: mockRecords, + } - if tt.expectErr { - g.Expect(err).To(HaveOccurred()) - } else { - g.Expect(err).ToNot(HaveOccurred()) - g.Expect(has).To(Equal(tt.expectHas)) + err := r.deleteBaseDomainARecords(t.Context(), azPLS, "test-rg", tt.baseDomain, tt.dnsZoneName, testr.New(t)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(mockRecords.deleteCalled).To(Equal(tt.expectDeleteCalled)) + if tt.expectDeleteCalled { + g.Expect(mockRecords.deletedRecordNames).To(Equal(tt.expectedDeletedRecords)) } }) } } func TestMapHCPToAzurePLS(t *testing.T) { - t.Parallel() tests := []struct { - name string - hcp *hyperv1.HostedControlPlane - existingPLS []client.Object - expectedLen int + name string + hcp *hyperv1.HostedControlPlane + plsCRs []client.Object + expectRequests int }{ { - name: "When HCP has PLS finalizer and AzurePLS CRs exist it should return reconcile requests", + name: "When HCP has the Azure PLS finalizer and PLS CRs exist, it should return requests for all PLS CRs", hcp: func() *hyperv1.HostedControlPlane { hcp := newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com") hcp.Finalizers = []string{hcpAzurePLSFinalizerName} return hcp }(), - existingPLS: []client.Object{ - newTestAzurePLS(t, "pls-1", "test-ns"), - newTestAzurePLS(t, "pls-2", "test-ns"), + plsCRs: []client.Object{ + newTestAzurePLS(t, "private-router", "test-ns"), + newTestAzurePLS(t, "oauth-openshift", "test-ns"), }, - expectedLen: 2, + expectRequests: 2, }, { - name: "When HCP does not have PLS finalizer it should return nil", - hcp: newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com"), - expectedLen: 0, + name: "When HCP does not have the Azure PLS finalizer, it should return no requests", + hcp: newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com"), + plsCRs: []client.Object{ + newTestAzurePLS(t, "private-router", "test-ns"), + }, + expectRequests: 0, }, { - name: "When HCP has PLS finalizer but no AzurePLS CRs exist it should return empty list", + name: "When HCP has the finalizer but no PLS CRs exist, it should return no requests", hcp: func() *hyperv1.HostedControlPlane { hcp := newTestHCP(t, "test-hcp", "test-ns", "api.test.example.com") hcp.Finalizers = []string{hcpAzurePLSFinalizerName} return hcp }(), - expectedLen: 0, + plsCRs: []client.Object{}, + expectRequests: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() g := NewGomegaWithT(t) scheme := newTestScheme(t, g) - objs := []client.Object{tt.hcp} - objs = append(objs, tt.existingPLS...) - + objs := append([]client.Object{tt.hcp}, tt.plsCRs...) fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(objs...). Build() r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} - mapFunc := r.mapHCPToAzurePLS() - + mapFn := r.mapHCPToAzurePLS() ctx := log.IntoContext(t.Context(), testr.New(t)) - requests := mapFunc(ctx, tt.hcp) + requests := mapFn(ctx, tt.hcp) + g.Expect(requests).To(HaveLen(tt.expectRequests)) + }) + } +} - if tt.expectedLen == 0 { - g.Expect(requests).To(BeEmpty()) +func TestHasSiblingCR(t *testing.T) { + t.Parallel() + tests := []struct { + name string + azPLS *hyperv1.AzurePrivateLinkService + siblings []client.Object + expectHas bool + expectErr bool + }{ + { + name: "When sibling with same base domain exists it should return true", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{ + func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + }, + expectHas: true, + }, + { + name: "When no siblings exist it should return false", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{}, + expectHas: false, + }, + { + name: "When sibling has different base domain it should return false", + azPLS: func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "private-router", "test-ns") + cr.Spec.BaseDomain = "example.com" + return cr + }(), + siblings: []client.Object{ + func() *hyperv1.AzurePrivateLinkService { + cr := newTestAzurePLS(t, "oauth-openshift", "test-ns") + cr.Spec.BaseDomain = "other.com" + return cr + }(), + }, + expectHas: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewGomegaWithT(t) + scheme := newTestScheme(t, g) + + objs := append([]client.Object{tt.azPLS}, tt.siblings...) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(objs...). + Build() + + r := &AzurePrivateLinkServiceReconciler{Client: fakeClient} + has, err := r.hasSiblingCR(t.Context(), tt.azPLS) + + if tt.expectErr { + g.Expect(err).To(HaveOccurred()) } else { - g.Expect(requests).To(HaveLen(tt.expectedLen)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(has).To(Equal(tt.expectHas)) } }) } @@ -2331,7 +2528,7 @@ func TestReconcileDelete_WhenBaseDomainVNetLinkDeleteFails_ItShouldReturnError(t err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting base domain VNet Link"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting VNet Link"))) } func TestReconcileDelete_WhenBaseDomainZoneDeleteFails_ItShouldReturnError(t *testing.T) { @@ -2358,7 +2555,7 @@ func TestReconcileDelete_WhenBaseDomainZoneDeleteFails_ItShouldReturnError(t *te err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting base domain Private DNS Zone"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to begin deleting Private DNS Zone"))) } // Tests CR deletion with DNSZoneName set in status, which causes reconcileDelete @@ -2921,7 +3118,7 @@ func TestReconcileDelete_WhenBaseDomainVNetLinkPollerFails_ItShouldReturnError(t err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to delete base domain VNet Link"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to delete VNet Link"))) } func TestReconcileDelete_WhenBaseDomainZonePollerFails_ItShouldReturnError(t *testing.T) { @@ -2948,7 +3145,7 @@ func TestReconcileDelete_WhenBaseDomainZonePollerFails_ItShouldReturnError(t *te err := r.reconcileDelete(t.Context(), azPLS, testr.New(t)) g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(ContainSubstring("failed to delete base domain Private DNS Zone"))) + g.Expect(err).To(MatchError(ContainSubstring("failed to delete Private DNS Zone"))) } // Tests CR deletion without DNSZoneName, which skips DNS cleanup and only verifies diff --git a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go index 5f923e8d5d5..017ac10a8d3 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go +++ b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go @@ -367,153 +367,101 @@ func (r *HostedControlPlaneReconciler) eventHandlers(scheme *runtime.Scheme, res return handlers } -func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - r.Log = ctrl.LoggerFrom(ctx) - r.Log.Info("Reconciling") - - // Fetch the hostedControlPlane instance - hostedControlPlane := &hyperv1.HostedControlPlane{} - err := r.Client.Get(ctx, req.NamespacedName, hostedControlPlane) - if err != nil { - if apierrors.IsNotFound(err) { - return ctrl.Result{}, nil - } - return ctrl.Result{}, err - } - - originalHostedControlPlane := hostedControlPlane.DeepCopy() - - // Return early if deleted - if !hostedControlPlane.DeletionTimestamp.IsZero() { - condition := &metav1.Condition{ - Type: string(hyperv1.AWSDefaultSecurityGroupDeleted), - } - if shouldCleanupCloudResources(r.Log, hostedControlPlane) { - if code, destroyErr := r.destroyAWSDefaultSecurityGroup(ctx, hostedControlPlane); destroyErr != nil { - condition.Message = "failed to delete AWS default security group" - if code == "DependencyViolation" { - condition.Message = destroyErr.Error() - } - condition.Reason = hyperv1.AWSErrorReason - condition.Status = metav1.ConditionFalse - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) - - if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition error message: %v", err, condition.Message) - } - - if code == "UnauthorizedOperation" { - r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of unauthorized operation.") - } - if code == "DependencyViolation" { - r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of dependency violation.") - } else { - return ctrl.Result{}, fmt.Errorf("failed to delete AWS default security group: %w", destroyErr) - } - } else { - condition.Message = hyperv1.AllIsWellMessage - condition.Reason = hyperv1.AsExpectedReason - condition.Status = metav1.ConditionTrue - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) +func (r *HostedControlPlaneReconciler) reconcileDeletion(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane, originalHostedControlPlane *hyperv1.HostedControlPlane) (ctrl.Result, error) { + condition := &metav1.Condition{ + Type: string(hyperv1.AWSDefaultSecurityGroupDeleted), + } + if shouldCleanupCloudResources(r.Log, hostedControlPlane) { + if code, destroyErr := r.destroyAWSDefaultSecurityGroup(ctx, hostedControlPlane); destroyErr != nil { + condition.Message = "failed to delete AWS default security group" + if code == "DependencyViolation" { + condition.Message = destroyErr.Error() + } + condition.Reason = hyperv1.AWSErrorReason + condition.Status = metav1.ConditionFalse + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) - if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition message: %v", err, condition.Message) - } + if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition error message: %v", err, condition.Message) } - done, err := r.removeCloudResources(ctx, hostedControlPlane) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to ensure cloud resources are removed: %w", err) + switch code { + case "UnauthorizedOperation": + r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of unauthorized operation.") + case "DependencyViolation": + r.Log.Error(destroyErr, "Skipping AWS default security group deletion because of dependency violation.") + default: + return ctrl.Result{}, fmt.Errorf("failed to delete AWS default security group: %w", destroyErr) } - if !done { - return ctrl.Result{RequeueAfter: time.Minute}, nil + } else { + condition.Message = hyperv1.AllIsWellMessage + condition.Reason = hyperv1.AsExpectedReason + condition.Status = metav1.ConditionTrue + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, *condition) + + if err := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for security group deletion: %w. Condition message: %v", err, condition.Message) } } - if controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { - originalHCP := hostedControlPlane.DeepCopy() - controllerutil.RemoveFinalizer(hostedControlPlane, finalizer) - if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to remove finalizer from cluster: %w", err) - } + done, err := r.removeCloudResources(ctx, hostedControlPlane) + if err != nil { + return ctrl.Result{}, fmt.Errorf("failed to ensure cloud resources are removed: %w", err) + } + if !done { + return ctrl.Result{RequeueAfter: time.Minute}, nil } - return ctrl.Result{}, nil } - // Ensure the hostedControlPlane has a finalizer for cleanup - if !controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { + if controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { originalHCP := hostedControlPlane.DeepCopy() - controllerutil.AddFinalizer(hostedControlPlane, finalizer) + controllerutil.RemoveFinalizer(hostedControlPlane, finalizer) if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to add finalizer to hostedControlPlane: %w", err) + return ctrl.Result{}, fmt.Errorf("failed to remove finalizer from cluster: %w", err) } } + return ctrl.Result{}, nil +} - if r.OperateOnReleaseImage != "" && r.OperateOnReleaseImage != util.HCPControlPlaneReleaseImage(hostedControlPlane) { - r.Log.Info("releaseImage is " + util.HCPControlPlaneReleaseImage(hostedControlPlane) + ", but this operator is configured for " + r.OperateOnReleaseImage + ", skipping reconciliation") - return ctrl.Result{}, nil - } - - // Reconcile global configuration validation status - { - condition := metav1.Condition{ - Type: string(hyperv1.ValidHostedControlPlaneConfiguration), - ObservedGeneration: hostedControlPlane.Generation, - } - if err := r.validateConfigAndClusterCapabilities(ctx, hostedControlPlane); err != nil { - condition.Status = metav1.ConditionFalse - condition.Message = err.Error() - condition.Reason = hyperv1.InsufficientClusterCapabilitiesReason - } else { - condition.Status = metav1.ConditionTrue - condition.Message = "Configuration passes validation" - condition.Reason = hyperv1.AsExpectedReason - } - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) - } - - // Reconcile etcd cluster status - { - newCondition := metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, - } - switch hostedControlPlane.Spec.Etcd.ManagementType { - case hyperv1.Managed: - r.Log.Info("Reconciling etcd cluster status for managed strategy") - sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err != nil { - if apierrors.IsNotFound(err) { - newCondition = metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.EtcdStatefulSetNotFoundReason, - } - } else { - return ctrl.Result{}, fmt.Errorf("failed to fetch etcd statefulset %s/%s: %w", sts.Namespace, sts.Name, err) +func (r *HostedControlPlaneReconciler) reconcileEtcdStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + newCondition := metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + switch hostedControlPlane.Spec.Etcd.ManagementType { + case hyperv1.Managed: + r.Log.Info("Reconciling etcd cluster status for managed strategy") + sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err != nil { + if apierrors.IsNotFound(err) { + newCondition = metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdStatefulSetNotFoundReason, } } else { - conditionPtr, err := r.etcdStatefulSetCondition(ctx, sts) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get etcd statefulset status: %w", err) - } - newCondition = *conditionPtr + return fmt.Errorf("failed to fetch etcd statefulset %s/%s: %w", sts.Namespace, sts.Name, err) } - case hyperv1.Unmanaged: - r.Log.Info("Assuming Etcd cluster is running in unmanaged etcd strategy") - newCondition = metav1.Condition{ - Type: string(hyperv1.EtcdAvailable), - Status: metav1.ConditionTrue, - Reason: "EtcdRunning", - Message: "Etcd cluster is assumed to be running in unmanaged state", + } else { + conditionPtr, err := r.etcdStatefulSetCondition(ctx, sts) + if err != nil { + return fmt.Errorf("failed to get etcd statefulset status: %w", err) } + newCondition = *conditionPtr + } + case hyperv1.Unmanaged: + r.Log.Info("Assuming Etcd cluster is running in unmanaged etcd strategy") + newCondition = metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: "EtcdRunning", + Message: "Etcd cluster is assumed to be running in unmanaged state", } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) - // Reconcile etcd restore status if hostedControlPlane.Spec.Etcd.ManagementType == hyperv1.Managed && hostedControlPlane.Spec.Etcd.Managed != nil && len(hostedControlPlane.Spec.Etcd.Managed.Storage.RestoreSnapshotURL) > 0 { restoreCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.EtcdSnapshotRestored)) @@ -521,218 +469,176 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R r.Log.Info("Reconciling etcd cluster restore status") sts := manifests.EtcdStatefulSet(hostedControlPlane.Namespace) if err := r.Get(ctx, client.ObjectKeyFromObject(sts), sts); err == nil { - newCondition := metav1.Condition{} + rc := metav1.Condition{} conditionPtr := r.etcdRestoredCondition(ctx, sts) if conditionPtr != nil { - newCondition = *conditionPtr - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + rc = *conditionPtr + rc.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, rc) } } } } - // Validate KMS config - switch hostedControlPlane.Spec.Platform.Type { - case hyperv1.AWSPlatform: - r.validateAWSKMSConfig(ctx, hostedControlPlane) - case hyperv1.AzurePlatform: - r.validateAzureKMSConfig(ctx, hostedControlPlane) - } + return nil +} - // Reconcile Kube APIServer status - { - newCondition := metav1.Condition{ +func (r *HostedControlPlaneReconciler) reconcileKASStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + newCondition := metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + deployment := manifests.KASDeployment(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(deployment), deployment); err != nil { + if apierrors.IsNotFound(err) { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.NotFoundReason, + Message: "Kube APIServer deployment not found", + } + } else { + return fmt.Errorf("failed to fetch Kube APIServer deployment %s/%s: %w", deployment.Namespace, deployment.Name, err) + } + } else { + newCondition = metav1.Condition{ Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionUnknown, + Status: metav1.ConditionFalse, Reason: hyperv1.StatusUnknownReason, } - deployment := manifests.KASDeployment(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(deployment), deployment); err != nil { - if apierrors.IsNotFound(err) { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.NotFoundReason, - Message: "Kube APIServer deployment not found", - } - } else { - return ctrl.Result{}, fmt.Errorf("failed to fetch Kube APIServer deployment %s/%s: %w", deployment.Namespace, deployment.Name, err) - } - } else { - // Assume the deployment is unavailable until proven otherwise. - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.StatusUnknownReason, - } - for _, cond := range deployment.Status.Conditions { - if cond.Type == appsv1.DeploymentAvailable { - if cond.Status == corev1.ConditionTrue { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionTrue, - Reason: hyperv1.AsExpectedReason, - Message: "Kube APIServer deployment is available", - } - } else { - newCondition = metav1.Condition{ - Type: string(hyperv1.KubeAPIServerAvailable), - Status: metav1.ConditionFalse, - Reason: hyperv1.WaitingForAvailableReason, - Message: "Waiting for Kube APIServer deployment to become available", - } + for _, cond := range deployment.Status.Conditions { + if cond.Type == appsv1.DeploymentAvailable { + if cond.Status == corev1.ConditionTrue { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + Message: "Kube APIServer deployment is available", + } + } else { + newCondition = metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingForAvailableReason, + Message: "Waiting for Kube APIServer deployment to become available", } - break } + break } } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + return nil +} - // Reconcile Degraded status - { - condition := metav1.Condition{ - Type: string(hyperv1.HostedControlPlaneDegraded), - Status: metav1.ConditionFalse, - Reason: hyperv1.AsExpectedReason, - ObservedGeneration: hostedControlPlane.Generation, +func (r *HostedControlPlaneReconciler) reconcileDegradedStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + condition := metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: hostedControlPlane.Generation, + } + cpoManagedDeploymentList := &appsv1.DeploymentList{} + if err := r.List(ctx, cpoManagedDeploymentList, client.MatchingLabels{ + component.ManagedByLabel: "control-plane-operator", + }, client.InNamespace(hostedControlPlane.Namespace)); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to list managed deployments in namespace %s: %w", hostedControlPlane.Namespace, err) } - cpoManagedDeploymentList := &appsv1.DeploymentList{} - if err := r.List(ctx, cpoManagedDeploymentList, client.MatchingLabels{ - component.ManagedByLabel: "control-plane-operator", - }, client.InNamespace(hostedControlPlane.Namespace)); err != nil { - if !apierrors.IsNotFound(err) { - return ctrl.Result{}, fmt.Errorf("failed to list managed deployments in namespace %s: %w", hostedControlPlane.Namespace, err) - } + } + var errs []error + sort.SliceStable(cpoManagedDeploymentList.Items, func(i, j int) bool { + return cpoManagedDeploymentList.Items[i].Name < cpoManagedDeploymentList.Items[j].Name + }) + for _, deployment := range cpoManagedDeploymentList.Items { + if deployment.Status.UnavailableReplicas > 0 { + errs = append(errs, fmt.Errorf("%s deployment has %d unavailable replicas", deployment.Name, deployment.Status.UnavailableReplicas)) } - var errs []error - sort.SliceStable(cpoManagedDeploymentList.Items, func(i, j int) bool { - return cpoManagedDeploymentList.Items[i].Name < cpoManagedDeploymentList.Items[j].Name - }) - for _, deployment := range cpoManagedDeploymentList.Items { - if deployment.Status.UnavailableReplicas > 0 { - errs = append(errs, fmt.Errorf("%s deployment has %d unavailable replicas", deployment.Name, deployment.Status.UnavailableReplicas)) - } + } + err := utilerrors.NewAggregate(errs) + if err != nil { + condition.Status = metav1.ConditionTrue + condition.Reason = "UnavailableReplicas" + condition.Message = err.Error() + } + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) + return nil +} + +func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + r.Log = ctrl.LoggerFrom(ctx) + r.Log.Info("Reconciling") + + hostedControlPlane := &hyperv1.HostedControlPlane{} + err := r.Client.Get(ctx, req.NamespacedName, hostedControlPlane) + if err != nil { + if apierrors.IsNotFound(err) { + return ctrl.Result{}, nil } - err := utilerrors.NewAggregate(errs) - if err != nil { - condition.Status = metav1.ConditionTrue - condition.Reason = "UnavailableReplicas" - condition.Message = err.Error() + return ctrl.Result{}, err + } + + originalHostedControlPlane := hostedControlPlane.DeepCopy() + + if !hostedControlPlane.DeletionTimestamp.IsZero() { + return r.reconcileDeletion(ctx, hostedControlPlane, originalHostedControlPlane) + } + + if !controllerutil.ContainsFinalizer(hostedControlPlane, finalizer) { + originalHCP := hostedControlPlane.DeepCopy() + controllerutil.AddFinalizer(hostedControlPlane, finalizer) + if err := r.Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHCP, client.MergeFromWithOptimisticLock{})); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to add finalizer to hostedControlPlane: %w", err) } - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) } - // Reconcile infrastructure status + if r.OperateOnReleaseImage != "" && r.OperateOnReleaseImage != util.HCPControlPlaneReleaseImage(hostedControlPlane) { + r.Log.Info("releaseImage is " + util.HCPControlPlaneReleaseImage(hostedControlPlane) + ", but this operator is configured for " + r.OperateOnReleaseImage + ", skipping reconciliation") + return ctrl.Result{}, nil + } + { - r.Log.Info("Reconciling infrastructure status") - newCondition := metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, + condition := metav1.Condition{ + Type: string(hyperv1.ValidHostedControlPlaneConfiguration), + ObservedGeneration: hostedControlPlane.Generation, } - infraStatus, err := r.reconcileInfrastructureStatus(ctx, hostedControlPlane) - if err != nil { - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionUnknown, - Reason: hyperv1.InfraStatusFailureReason, - Message: err.Error(), - } - r.Log.Error(err, "failed to determine infrastructure status") + if err := r.validateConfigAndClusterCapabilities(ctx, hostedControlPlane); err != nil { + condition.Status = metav1.ConditionFalse + condition.Message = err.Error() + condition.Reason = hyperv1.InsufficientClusterCapabilitiesReason } else { - if infraStatus.IsReady() { - hostedControlPlane.Status.ControlPlaneEndpoint = hyperv1.APIEndpoint{ - Host: infraStatus.APIHost, - Port: infraStatus.APIPort, - } - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionTrue, - Message: hyperv1.AllIsWellMessage, - Reason: hyperv1.AsExpectedReason, - } - if util.HCPOAuthEnabled(hostedControlPlane) { - hostedControlPlane.Status.OAuthCallbackURLTemplate = fmt.Sprintf("https://%s:%d/oauth2callback/[identity-provider-name]", infraStatus.OAuthHost, infraStatus.OAuthPort) - } - } else { - message := "Cluster infrastructure is still provisioning" - if len(infraStatus.Message) > 0 { - message = infraStatus.Message - } - newCondition = metav1.Condition{ - Type: string(hyperv1.InfrastructureReady), - Status: metav1.ConditionFalse, - Reason: hyperv1.WaitingOnInfrastructureReadyReason, - Message: message, - } - r.Log.Info("Infrastructure is not yet ready") - } + condition.Status = metav1.ConditionTrue + condition.Message = "Configuration passes validation" + condition.Reason = hyperv1.AsExpectedReason } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) } - // Reconcile external DNS status - { - newCondition := metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionUnknown, - Reason: hyperv1.StatusUnknownReason, - } + if err := r.reconcileEtcdStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err + } - kasExternalHostname := netutil.ServiceExternalDNSHostname(hostedControlPlane, hyperv1.APIServer) - if kasExternalHostname != "" { - if err := netutil.ResolveDNSHostname(ctx, kasExternalHostname); err != nil { - newCondition = metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionFalse, - Reason: hyperv1.ExternalDNSHostNotReachableReason, - Message: err.Error(), - } - } else { - newCondition = metav1.Condition{ - Type: string(hyperv1.ExternalDNSReachable), - Status: metav1.ConditionTrue, - Message: hyperv1.AllIsWellMessage, - Reason: hyperv1.AsExpectedReason, - } - } - } else { - newCondition.Message = "External DNS is not configured" - } + switch hostedControlPlane.Spec.Platform.Type { + case hyperv1.AWSPlatform: + r.validateAWSKMSConfig(ctx, hostedControlPlane) + case hyperv1.AzurePlatform: + r.validateAzureKMSConfig(ctx, hostedControlPlane) + } - newCondition.ObservedGeneration = hostedControlPlane.Generation - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) + if err := r.reconcileKASStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err } - // Reconcile hostedcontrolplane availability and Ready flag - { - healthCheckErr := r.healthCheckKASLoadBalancers(ctx, hostedControlPlane) - - // Check control plane components availability only if the HCP is not already marked as available - var componentsNotAvailableMsg string - var componentsErr error - availableCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) - alreadyAvailable := availableCondition != nil && availableCondition.Status == metav1.ConditionTrue - if !alreadyAvailable { - componentsNotAvailableMsg, componentsErr = r.controlPlaneComponentsAvailable(ctx, hostedControlPlane) - } - - ready, condition := reconcileAvailabilityStatus( - hostedControlPlane.Status.Conditions, - hostedControlPlane.Status.KubeConfig != nil, - healthCheckErr, - componentsNotAvailableMsg, - componentsErr, - hostedControlPlane.Generation, - ) - hostedControlPlane.Status.Ready = ready - meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) + if err := r.reconcileDegradedStatus(ctx, hostedControlPlane); err != nil { + return ctrl.Result{}, err } + r.reconcileInfrastructureStatusCondition(ctx, hostedControlPlane) + r.reconcileExternalDNSStatusCondition(ctx, hostedControlPlane) + r.reconcileAvailabilityAndReadyStatus(ctx, hostedControlPlane) + // Admin Kubeconfig kubeconfig := manifests.KASAdminKubeconfigSecret(hostedControlPlane.Namespace, hostedControlPlane.Spec.KubeConfig) if err := r.Get(ctx, client.ObjectKeyFromObject(kubeconfig), kubeconfig); err != nil { @@ -754,20 +660,8 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R return reconcile.Result{}, err } - explicitOauthConfig := hostedControlPlane.Spec.Configuration != nil && hostedControlPlane.Spec.Configuration.OAuth != nil - if explicitOauthConfig { - hostedControlPlane.Status.KubeadminPassword = nil - } else { - kubeadminPassword := common.KubeadminPasswordSecret(hostedControlPlane.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(kubeadminPassword), kubeadminPassword); err != nil { - if !apierrors.IsNotFound(err) { - return reconcile.Result{}, fmt.Errorf("failed to get kubeadmin password: %w", err) - } - } else { - hostedControlPlane.Status.KubeadminPassword = &corev1.LocalObjectReference{ - Name: kubeadminPassword.Name, - } - } + if err := r.reconcileKubeadminPasswordStatus(ctx, hostedControlPlane); err != nil { + return reconcile.Result{}, err } // Reconcile valid release info status @@ -776,38 +670,8 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R return reconcile.Result{}, fmt.Errorf("failed to look up release image metadata: %w", err) } - // Reconcile controlPlaneVersion status. - // This runs after LookupReleaseImage so we can use the version and resolved - // digest from the release image metadata. - { - clk := r.clock - if clk == nil { - clk = clock.RealClock{} - } - // Resolve the release image to its digest so controlPlaneVersion records - // the immutable image reference, consistent with how CVO records images. - pullSecret := common.PullSecret(hostedControlPlane.Namespace) - if err := r.Client.Get(ctx, client.ObjectKeyFromObject(pullSecret), pullSecret); err != nil { - return reconcile.Result{}, fmt.Errorf("failed to get pull secret for version reconciliation: %w", err) - } - _, resolvedRef, err := r.ImageMetadataProvider.GetDigest(ctx, util.HCPControlPlaneReleaseImage(hostedControlPlane), pullSecret.Data[corev1.DockerConfigJsonKey]) - if err != nil { - return reconcile.Result{}, fmt.Errorf("failed to resolve control plane release image digest: %w", err) - } - resolvedImage := resolvedRef.String() - - componentsList := &hyperv1.ControlPlaneComponentList{} - if listErr := r.Client.List(ctx, componentsList, client.InNamespace(hostedControlPlane.Namespace)); listErr != nil { - // On list failure, ensure a Partial entry exists so consumers - // know an upgrade was attempted. Preserve observedGeneration. - hostedControlPlane.Status.ControlPlaneVersion = ensureControlPlaneVersionPartial(hostedControlPlane, clk, releaseImage.Version(), resolvedImage) - // Persist the Partial entry before returning the error. - if patchErr := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); patchErr != nil { - return reconcile.Result{}, fmt.Errorf("failed to patch status after component list failure: %w (list error: %v)", patchErr, listErr) - } - return reconcile.Result{}, fmt.Errorf("failed to list control plane components for version reconciliation: %w", listErr) - } - hostedControlPlane.Status.ControlPlaneVersion = reconcileControlPlaneVersion(hostedControlPlane, componentsList.Items, clk, releaseImage.Version(), resolvedImage) + if err := r.reconcileControlPlaneVersionStatus(ctx, hostedControlPlane, originalHostedControlPlane, releaseImage); err != nil { + return reconcile.Result{}, err } hostedControlPlane.Status.Initialized = true @@ -848,8 +712,157 @@ func (r *HostedControlPlaneReconciler) Reconcile(ctx context.Context, req ctrl.R if !hostedControlPlane.Status.Ready { return ctrl.Result{RequeueAfter: hcpNotReadyRequeueInterval}, nil } - - return ctrl.Result{RequeueAfter: hcpReadyRequeueInterval}, nil + + return ctrl.Result{RequeueAfter: hcpReadyRequeueInterval}, nil +} + +func (r *HostedControlPlaneReconciler) reconcileInfrastructureStatusCondition(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + r.Log.Info("Reconciling infrastructure status") + newCondition := metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + infraStatus, err := r.reconcileInfrastructureStatus(ctx, hostedControlPlane) + if err != nil { + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionUnknown, + Reason: hyperv1.InfraStatusFailureReason, + Message: err.Error(), + } + r.Log.Error(err, "failed to determine infrastructure status") + } else if infraStatus.IsReady() { + hostedControlPlane.Status.ControlPlaneEndpoint = hyperv1.APIEndpoint{ + Host: infraStatus.APIHost, + Port: infraStatus.APIPort, + } + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionTrue, + Message: hyperv1.AllIsWellMessage, + Reason: hyperv1.AsExpectedReason, + } + if util.HCPOAuthEnabled(hostedControlPlane) { + hostedControlPlane.Status.OAuthCallbackURLTemplate = fmt.Sprintf("https://%s:%d/oauth2callback/[identity-provider-name]", infraStatus.OAuthHost, infraStatus.OAuthPort) + } + } else { + message := "Cluster infrastructure is still provisioning" + if len(infraStatus.Message) > 0 { + message = infraStatus.Message + } + newCondition = metav1.Condition{ + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingOnInfrastructureReadyReason, + Message: message, + } + r.Log.Info("Infrastructure is not yet ready") + } + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) +} + +func (r *HostedControlPlaneReconciler) reconcileExternalDNSStatusCondition(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + newCondition := metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + } + + kasExternalHostname := netutil.ServiceExternalDNSHostname(hostedControlPlane, hyperv1.APIServer) + if kasExternalHostname != "" { + if err := netutil.ResolveDNSHostname(ctx, kasExternalHostname); err != nil { + newCondition = metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionFalse, + Reason: hyperv1.ExternalDNSHostNotReachableReason, + Message: err.Error(), + } + } else { + newCondition = metav1.Condition{ + Type: string(hyperv1.ExternalDNSReachable), + Status: metav1.ConditionTrue, + Message: hyperv1.AllIsWellMessage, + Reason: hyperv1.AsExpectedReason, + } + } + } else { + newCondition.Message = "External DNS is not configured" + } + + newCondition.ObservedGeneration = hostedControlPlane.Generation + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, newCondition) +} + +func (r *HostedControlPlaneReconciler) reconcileAvailabilityAndReadyStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) { + healthCheckErr := r.healthCheckKASLoadBalancers(ctx, hostedControlPlane) + + var componentsNotAvailableMsg string + var componentsErr error + availableCondition := meta.FindStatusCondition(hostedControlPlane.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) + alreadyAvailable := availableCondition != nil && availableCondition.Status == metav1.ConditionTrue + if !alreadyAvailable { + componentsNotAvailableMsg, componentsErr = r.controlPlaneComponentsAvailable(ctx, hostedControlPlane) + } + + ready, condition := reconcileAvailabilityStatus( + hostedControlPlane.Status.Conditions, + hostedControlPlane.Status.KubeConfig != nil, + healthCheckErr, + componentsNotAvailableMsg, + componentsErr, + hostedControlPlane.Generation, + ) + hostedControlPlane.Status.Ready = ready + meta.SetStatusCondition(&hostedControlPlane.Status.Conditions, condition) +} + +func (r *HostedControlPlaneReconciler) reconcileKubeadminPasswordStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane) error { + explicitOauthConfig := hostedControlPlane.Spec.Configuration != nil && hostedControlPlane.Spec.Configuration.OAuth != nil + if explicitOauthConfig { + hostedControlPlane.Status.KubeadminPassword = nil + return nil + } + kubeadminPassword := common.KubeadminPasswordSecret(hostedControlPlane.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(kubeadminPassword), kubeadminPassword); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to get kubeadmin password: %w", err) + } + hostedControlPlane.Status.KubeadminPassword = nil + } else { + hostedControlPlane.Status.KubeadminPassword = &corev1.LocalObjectReference{ + Name: kubeadminPassword.Name, + } + } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileControlPlaneVersionStatus(ctx context.Context, hostedControlPlane *hyperv1.HostedControlPlane, originalHostedControlPlane *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { + clk := r.clock + if clk == nil { + clk = clock.RealClock{} + } + pullSecret := common.PullSecret(hostedControlPlane.Namespace) + if err := r.Client.Get(ctx, client.ObjectKeyFromObject(pullSecret), pullSecret); err != nil { + return fmt.Errorf("failed to get pull secret for version reconciliation: %w", err) + } + _, resolvedRef, err := r.ImageMetadataProvider.GetDigest(ctx, util.HCPControlPlaneReleaseImage(hostedControlPlane), pullSecret.Data[corev1.DockerConfigJsonKey]) + if err != nil { + return fmt.Errorf("failed to resolve control plane release image digest: %w", err) + } + resolvedImage := resolvedRef.String() + + componentsList := &hyperv1.ControlPlaneComponentList{} + if listErr := r.Client.List(ctx, componentsList, client.InNamespace(hostedControlPlane.Namespace)); listErr != nil { + hostedControlPlane.Status.ControlPlaneVersion = ensureControlPlaneVersionPartial(hostedControlPlane, clk, releaseImage.Version(), resolvedImage) + if patchErr := r.Client.Status().Patch(ctx, hostedControlPlane, client.MergeFromWithOptions(originalHostedControlPlane, client.MergeFromWithOptimisticLock{})); patchErr != nil { + return fmt.Errorf("failed to patch status after component list failure: %w (list error: %v)", patchErr, listErr) + } + return fmt.Errorf("failed to list control plane components for version reconciliation: %w", listErr) + } + hostedControlPlane.Status.ControlPlaneVersion = reconcileControlPlaneVersion(hostedControlPlane, componentsList.Items, clk, releaseImage.Version(), resolvedImage) + return nil } // reconcileAvailabilityStatus determines the HostedControlPlane availability condition @@ -1293,35 +1306,7 @@ func (r *HostedControlPlaneReconciler) reconcileKubeadminPassword(ctx context.Co return nil } -func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hyperv1.HostedControlPlane, infraStatus infra.InfrastructureStatus, createOrUpdate upsert.CreateOrUpdateFN) error { - p := pki.NewPKIParams(hcp, infraStatus.APIHost, infraStatus.OAuthHost, infraStatus.KonnectivityHost) - - // Root CA - rootCASecret := manifests.RootCASecret(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, rootCASecret, func() error { - return pki.ReconcileRootCA(rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile root CA: %w", err) - } - - var observedDefaultIngressCert *corev1.ConfigMap - if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { - observedDefaultIngressCert = manifests.IngressObservedDefaultIngressCertCA(hcp.Namespace) - if err := r.Get(ctx, client.ObjectKeyFromObject(observedDefaultIngressCert), observedDefaultIngressCert); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to get observed default ingress cert: %w", err) - } - observedDefaultIngressCert = nil - } - } - rootCAConfigMap := manifests.RootCAConfigMap(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, rootCAConfigMap, func() error { - return pki.ReconcileRootCAConfigMap(rootCAConfigMap, p.OwnerRef, rootCASecret, observedDefaultIngressCert) - }); err != nil { - return fmt.Errorf("failed to reconcile root CA configmap: %w", err) - } - - // Etcd signer for all the etcd-related certs +func (r *HostedControlPlaneReconciler) reconcileEtcdCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN) error { etcdSignerSecret := manifests.EtcdSignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdSignerSecret, func() error { return pki.ReconcileEtcdSignerSecret(etcdSignerSecret, p.OwnerRef) @@ -1336,7 +1321,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd signer CA configmap: %w", err) } - // Etcd client secret etcdClientSecret := manifests.EtcdClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdClientSecret, func() error { return pki.ReconcileEtcdClientSecret(etcdClientSecret, etcdSignerSecret, p.OwnerRef) @@ -1344,7 +1328,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd client secret: %w", err) } - // Etcd server secret etcdServerSecret := manifests.EtcdServerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdServerSecret, func() error { return pki.ReconcileEtcdServerSecret(etcdServerSecret, etcdSignerSecret, p.OwnerRef) @@ -1352,7 +1335,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd server secret: %w", err) } - // Etcd peer secret etcdPeerSecret := manifests.EtcdPeerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdPeerSecret, func() error { return pki.ReconcileEtcdPeerSecret(etcdPeerSecret, etcdSignerSecret, p.OwnerRef) @@ -1360,8 +1342,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd peer secret: %w", err) } - // Etcd metrics signer - // Etcd signer for all the etcd-related certs etcdMetricsSignerSecret := manifests.EtcdMetricsSignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdMetricsSignerSecret, func() error { return pki.ReconcileEtcdMetricsSignerSecret(etcdMetricsSignerSecret, p.OwnerRef) @@ -1376,7 +1356,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd signer CA configmap: %w", err) } - // Etcd client secret etcdMetricsClientSecret := manifests.EtcdMetricsClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, etcdMetricsClientSecret, func() error { return pki.ReconcileEtcdMetricsClientSecret(etcdMetricsClientSecret, etcdMetricsSignerSecret, p.OwnerRef) @@ -1384,7 +1363,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile etcd client secret: %w", err) } - // KAS server secret + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileKASCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { kasServerSecret := manifests.KASServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, kasServerSecret, func() error { return pki.ReconcileKASServerCertSecret(kasServerSecret, rootCASecret, p.OwnerRef, p.ExternalAPIAddress, p.InternalAPIAddress, p.ServiceCIDR, p.NodeInternalAPIServerIP) @@ -1392,7 +1374,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile kas server secret: %w", err) } - // KAS server private secret kasServerPrivateSecret := manifests.KASServerPrivateCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, kasServerPrivateSecret, func() error { return pki.ReconcileKASServerPrivateCertSecret(kasServerPrivateSecret, rootCASecret, p.OwnerRef) @@ -1409,7 +1390,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return err } - // Service account signing key secret serviceAccountSigningKeySecret := manifests.ServiceAccountSigningKeySecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, serviceAccountSigningKeySecret, func() error { return pki.ReconcileServiceAccountSigningKeySecret(serviceAccountSigningKeySecret, p.OwnerRef) @@ -1417,7 +1397,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile api server service account key secret: %w", err) } - // OpenShift APIServer + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileOpenshiftCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { openshiftAPIServerCertSecret := manifests.OpenShiftAPIServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftAPIServerCertSecret, func() error { return pki.ReconcileOpenShiftAPIServerCertSecret(openshiftAPIServerCertSecret, rootCASecret, p.OwnerRef) @@ -1426,7 +1409,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } if util.HCPOAuthEnabled(hcp) { - // OpenShift OAuth APIServer openshiftOAuthAPIServerCertSecret := manifests.OpenShiftOAuthAPIServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftOAuthAPIServerCertSecret, func() error { return pki.ReconcileOpenShiftOAuthAPIServerCertSecret(openshiftOAuthAPIServerCertSecret, rootCASecret, p.OwnerRef) @@ -1435,7 +1417,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } } - // OpenShift ControllerManager Cert openshiftControllerManagerCertSecret := manifests.OpenShiftControllerManagerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftControllerManagerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(openshiftControllerManagerCertSecret, rootCASecret, p.OwnerRef) @@ -1443,7 +1424,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile openshift controller manager cert: %w", err) } - // OpenShift Route ControllerManager Cert openshiftRouteControllerManagerCertSecret := manifests.OpenShiftRouteControllerManagerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, openshiftRouteControllerManagerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(openshiftRouteControllerManagerCertSecret, rootCASecret, p.OwnerRef) @@ -1451,7 +1431,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile openshift route controller manager cert: %w", err) } - // Cluster Policy Controller Cert clusterPolicyControllerCertSecret := manifests.ClusterPolicyControllerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, clusterPolicyControllerCertSecret, func() error { return pki.ReconcileOpenShiftControllerManagerCertSecret(clusterPolicyControllerCertSecret, rootCASecret, p.OwnerRef) @@ -1459,6 +1438,10 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile cluster policy controller cert: %w", err) } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileKonnectivityCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN) error { konnectivitySigner := manifests.KonnectivitySignerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivitySigner, func() error { return pki.ReconcileKonnectivitySignerSecret(konnectivitySigner, p.OwnerRef) @@ -1473,7 +1456,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity CA config map: %v", err) } - // Konnectivity Server Cert konnectivityServerSecret := manifests.KonnectivityServerSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityServerSecret, func() error { return pki.ReconcileKonnectivityServerSecret(konnectivityServerSecret, konnectivitySigner, p.OwnerRef) @@ -1481,7 +1463,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity server cert: %w", err) } - // Konnectivity Cluster Cert konnectivityClusterSecret := manifests.KonnectivityClusterSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityClusterSecret, func() error { return pki.ReconcileKonnectivityClusterSecret(konnectivityClusterSecret, konnectivitySigner, p.OwnerRef, p.ExternalKconnectivityAddress) @@ -1489,7 +1470,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity cluster cert: %w", err) } - // Konnectivity Client Cert konnectivityClientSecret := manifests.KonnectivityClientSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityClientSecret, func() error { return pki.ReconcileKonnectivityClientSecret(konnectivityClientSecret, konnectivitySigner, p.OwnerRef) @@ -1497,7 +1477,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity client cert: %w", err) } - // Konnectivity Agent Cert konnectivityAgentSecret := manifests.KonnectivityAgentSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, konnectivityAgentSecret, func() error { return pki.ReconcileKonnectivityAgentSecret(konnectivityAgentSecret, konnectivitySigner, p.OwnerRef) @@ -1505,84 +1484,51 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile konnectivity agent cert: %w", err) } - // Reconcile ingress serving certificate only if Ingress capability is enabled. - if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { - // Ingress Cert - ingressCert := manifests.IngressCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, ingressCert, func() error { - return pki.ReconcileIngressCert(ingressCert, rootCASecret, p.OwnerRef, p.IngressSubdomain) - }); err != nil { - return fmt.Errorf("failed to reconcile ingress cert secret: %w", err) - } - } - - var userCABundles []client.ObjectKey - if hcp.Spec.AdditionalTrustBundle != nil { - userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.AdditionalTrustBundle.Name}) - } - if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.Proxy != nil && hcp.Spec.Configuration.Proxy.TrustedCA.Name != "" { - userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.Configuration.Proxy.TrustedCA.Name}) - } + return nil +} - trustedCABundle := manifests.TrustedCABundleConfigMap(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, trustedCABundle, func() error { - return r.reconcileManagedTrustedCABundle(ctx, trustedCABundle, userCABundles) +func (r *HostedControlPlaneReconciler) reconcileOAuthCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret, trustedCABundle *corev1.ConfigMap) error { + oauthServerCert := manifests.OpenShiftOAuthServerCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, oauthServerCert, func() error { + return pki.ReconcileOAuthServerCert(oauthServerCert, rootCASecret, p.OwnerRef, p.ExternalOauthAddress) }); err != nil { - return fmt.Errorf("failed to reconcile managed UserCA configMap: %w", err) + return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) } - if util.HCPOAuthEnabled(hcp) { - // OAuth server Cert - oauthServerCert := manifests.OpenShiftOAuthServerCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, oauthServerCert, func() error { - return pki.ReconcileOAuthServerCert(oauthServerCert, rootCASecret, p.OwnerRef, p.ExternalOauthAddress) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) - } - - // OAuth master CA Bundle - bundleSources := []*corev1.Secret{oauthServerCert} - if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.APIServer != nil { - for _, namedCert := range hcp.Spec.Configuration.APIServer.ServingCerts.NamedCertificates { - if namedCert.ServingCertificate.Name == "" { - continue - } - certSecret := &corev1.Secret{ObjectMeta: metav1.ObjectMeta{Name: namedCert.ServingCertificate.Name, Namespace: hcp.Namespace}} - if err := r.Get(ctx, client.ObjectKeyFromObject(certSecret), certSecret); err != nil { - return fmt.Errorf("failed to get named certificate %s: %w", certSecret.Name, err) - } - bundleSources = append(bundleSources, certSecret) + bundleSources := []*corev1.Secret{oauthServerCert} + if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.APIServer != nil { + for _, namedCert := range hcp.Spec.Configuration.APIServer.ServingCerts.NamedCertificates { + if namedCert.ServingCertificate.Name == "" { + continue } + certSecret := &corev1.Secret{ObjectMeta: metav1.ObjectMeta{Name: namedCert.ServingCertificate.Name, Namespace: hcp.Namespace}} + if err := r.Get(ctx, client.ObjectKeyFromObject(certSecret), certSecret); err != nil { + return fmt.Errorf("failed to get named certificate %s: %w", certSecret.Name, err) + } + bundleSources = append(bundleSources, certSecret) } + } - if trustedCABundle.Data[certs.UserCABundleMapKey] != "" { - bundleSources = append(bundleSources, &corev1.Secret{ - Data: map[string][]byte{ - certs.CASignerCertMapKey: []byte(trustedCABundle.Data[certs.UserCABundleMapKey]), - }, - }) - } - - oauthMasterCABundle := manifests.OpenShiftOAuthMasterCABundle(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, oauthMasterCABundle, func() error { - return pki.ReconcileOAuthMasterCABundle(oauthMasterCABundle, p.OwnerRef, bundleSources) - }); err != nil { - return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) - } + if trustedCABundle.Data[certs.UserCABundleMapKey] != "" { + bundleSources = append(bundleSources, &corev1.Secret{ + Data: map[string][]byte{ + certs.CASignerCertMapKey: []byte(trustedCABundle.Data[certs.UserCABundleMapKey]), + }, + }) } - // MCS Cert - if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { - machineConfigServerCert := manifests.MachineConfigServerCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, machineConfigServerCert, func() error { - return pki.ReconcileMachineConfigServerCert(machineConfigServerCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile machine config server cert secret: %w", err) - } + oauthMasterCABundle := manifests.OpenShiftOAuthMasterCABundle(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, oauthMasterCABundle, func() error { + return pki.ReconcileOAuthMasterCABundle(oauthMasterCABundle, p.OwnerRef, bundleSources) + }); err != nil { + return fmt.Errorf("failed to reconcile oauth cert secret: %w", err) } - var err error + + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileOLMAndMiscCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { if capabilities.IsNodeTuningCapabilityEnabled(hcp.Spec.Capabilities) { - // Cluster Node Tuning Operator metrics Serving Cert NodeTuningOperatorServingCert := manifests.ClusterNodeTuningOperatorServingCertSecret(hcp.Namespace) NodeTuningOperatorService := manifests.ClusterNodeTuningOperatorMetricsService(hcp.Namespace) err := removeServiceCAAnnotationAndSecret(ctx, r.Client, NodeTuningOperatorService, NodeTuningOperatorServingCert) @@ -1595,7 +1541,7 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile node tuning operator serving cert: %w", err) } } - // OLM PackageServer Cert + packageServerCertSecret := manifests.OLMPackageServerCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, packageServerCertSecret, func() error { return pki.ReconcileOLMPackageServerCertSecret(packageServerCertSecret, rootCASecret, p.OwnerRef) @@ -1603,7 +1549,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile packageserver cert: %w", err) } - // OLM Catalog Operator Serving Cert catalogOperatorServingCert := manifests.OLMCatalogOperatorServingCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, catalogOperatorServingCert, func() error { return pki.ReconcileOLMCatalogOperatorServingCertSecret(catalogOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1611,7 +1556,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile olm catalog operator serving cert: %w", err) } - // OLM Operator Serving Cert olmOperatorServingCert := manifests.OLMOperatorServingCertSecret(hcp.Namespace) if _, err := createOrUpdate(ctx, r, olmOperatorServingCert, func() error { return pki.ReconcileOLMOperatorServingCertSecret(olmOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1620,7 +1564,6 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } if capabilities.IsImageRegistryCapabilityEnabled(hcp.Spec.Capabilities) { - // Image Registry Operator Serving Cert imageRegistryOperatorServingCert := manifests.ImageRegistryOperatorServingCert(hcp.Namespace) if _, err := createOrUpdate(ctx, r, imageRegistryOperatorServingCert, func() error { return pki.ReconcileRegistryOperatorServingCert(imageRegistryOperatorServingCert, rootCASecret, p.OwnerRef) @@ -1643,31 +1586,26 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy return fmt.Errorf("failed to reconcile cvo serving cert: %w", err) } - // For the Multus Admission Controller, Network Node Identity, and OVN Control Plane Metrics Serving Certs: - // We want to remove the secret if there was an existing one created by service-ca; otherwise, it will cause - // issues in cases where you are upgrading an older CPO prior to us adding the feature to reconcile the serving - // cert secret ourselves. + return nil +} - // Multus Admission Controller Serving Cert - only if Multus is not disabled +func (r *HostedControlPlaneReconciler) reconcileNetworkServingCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { if !netutil.IsDisableMultiNetwork(hcp) { multusAdmissionControllerService := manifests.MultusAdmissionControllerService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(multusAdmissionControllerService), multusAdmissionControllerService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(multusAdmissionControllerService), multusAdmissionControllerService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve multus-admission-controller service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(multusAdmissionControllerService); !hasServiceCAAnnotation { multusAdmissionControllerServingCertSecret := manifests.MultusAdmissionControllerServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, multusAdmissionControllerServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, multusAdmissionControllerServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, multusAdmissionControllerServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, multusAdmissionControllerServingCertSecret, func() error { return pki.ReconcileMultusAdmissionControllerServingCertSecret(multusAdmissionControllerServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile multus admission controller serving cert: %w", err) @@ -1675,179 +1613,272 @@ func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hy } } - // Network Node Identity Serving Cert networkNodeIdentityService := manifests.NetworkNodeIdentityService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(networkNodeIdentityService), networkNodeIdentityService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(networkNodeIdentityService), networkNodeIdentityService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve network-node-identity service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(networkNodeIdentityService); !hasServiceCAAnnotation { networkNodeIdentityServingCertSecret := manifests.NetworkNodeIdentityControllerServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, networkNodeIdentityServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, networkNodeIdentityServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, networkNodeIdentityServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, networkNodeIdentityServingCertSecret, func() error { return pki.ReconcileNetworkNodeIdentityServingCertSecret(networkNodeIdentityServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile network node identity serving cert: %w", err) } } - // OVN Control Plane Metrics Serving Cert ovnControlPlaneService := manifests.OVNKubernetesControlPlaneService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(ovnControlPlaneService), ovnControlPlaneService); err != nil { + if err := r.Get(ctx, client.ObjectKeyFromObject(ovnControlPlaneService), ovnControlPlaneService); err != nil { if !apierrors.IsNotFound(err) { return fmt.Errorf("failed to retrieve ovn-kubernetes-control-plane service: %w", err) } } - // If the service doesn't have the service ca annotation, delete any previous secret with the annotation and - // reconcile the secret with our own rootCA; otherwise, skip reconciling the secret with our own rootCA. if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(ovnControlPlaneService); !hasServiceCAAnnotation { ovnControlPlaneMetricsServingCertSecret := manifests.OVNControlPlaneMetricsServingCert(hcp.Namespace) - err = removeServiceCASecret(ctx, r.Client, ovnControlPlaneMetricsServingCertSecret) - if err != nil { + if err := removeServiceCASecret(ctx, r.Client, ovnControlPlaneMetricsServingCertSecret); err != nil { return err } - if _, err = createOrUpdate(ctx, r, ovnControlPlaneMetricsServingCertSecret, func() error { + if _, err := createOrUpdate(ctx, r, ovnControlPlaneMetricsServingCertSecret, func() error { return pki.ReconcileOVNControlPlaneMetricsServingCertSecret(ovnControlPlaneMetricsServingCertSecret, rootCASecret, p.OwnerRef) }); err != nil { return fmt.Errorf("failed to reconcile OVN control plane serving cert: %w", err) } } - if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { - if hcp.Spec.Platform.Type != hyperv1.IBMCloudPlatform { - ignitionServerCert := manifests.IgnitionServerCertSecret(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, ignitionServerCert, func() error { - return pki.ReconcileIgnitionServerCertSecret(ignitionServerCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile ignition server cert: %w", err) - } + return nil +} + +func (r *HostedControlPlaneReconciler) reconcileAWSPlatformCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { + awsPodIdentityWebhookServingCert := manifests.AWSPodIdentityWebhookServingCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, awsPodIdentityWebhookServingCert, func() error { + return pki.ReconcileAWSPodIdentityWebhookServingCert(awsPodIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile %s secret: %w", awsPodIdentityWebhookServingCert.Name, err) + } + + awsEBSCsiDriverControllerMetricsService := manifests.AWSEBSCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(awsEBSCsiDriverControllerMetricsService), awsEBSCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve aws-ebs-csi-driver-controller-metrics service: %w", err) } } - // Platform specific certs - switch hcp.Spec.Platform.Type { - case hyperv1.AWSPlatform: - awsPodIdentityWebhookServingCert := manifests.AWSPodIdentityWebhookServingCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, awsPodIdentityWebhookServingCert, func() error { - return pki.ReconcileAWSPodIdentityWebhookServingCert(awsPodIdentityWebhookServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile %s secret: %w", awsPodIdentityWebhookServingCert.Name, err) + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(awsEBSCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + awsEBSCsiDriverControllerMetricsServingCert := manifests.AWSEBSCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, awsEBSCsiDriverControllerMetricsServingCert); err != nil { + return err } - awsEBSCsiDriverControllerMetricsService := manifests.AWSEBSCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(awsEBSCsiDriverControllerMetricsService), awsEBSCsiDriverControllerMetricsService); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve aws-ebs-csi-driver-controller-metrics service: %w", err) - } + if _, err := createOrUpdate(ctx, r, awsEBSCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAWSEBSCsiDriverControllerMetricsServingCertSecret(awsEBSCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile aws ebs csi driver controller metrics serving cert: %w", err) } + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(awsEBSCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - awsEBSCsiDriverControllerMetricsServingCert := manifests.AWSEBSCsiDriverControllerMetricsServingCert(hcp.Namespace) + return nil +} - err = removeServiceCASecret(ctx, r.Client, awsEBSCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } +func (r *HostedControlPlaneReconciler) reconcileAzurePlatformCerts(ctx context.Context, hcp *hyperv1.HostedControlPlane, p *pki.PKIParams, createOrUpdate upsert.CreateOrUpdateFN, rootCASecret *corev1.Secret) error { + azureWorkloadIdentityWebhookServingCert := manifests.AzureWorkloadIdentityWebhookServingCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, azureWorkloadIdentityWebhookServingCert, func() error { + return pki.ReconcileAzureWorkloadIdentityWebhookServingCert(azureWorkloadIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile %s secret: %w", azureWorkloadIdentityWebhookServingCert.Name, err) + } - if _, err = createOrUpdate(ctx, r, awsEBSCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAWSEBSCsiDriverControllerMetricsServingCertSecret(awsEBSCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile aws ebs csi driver controller metrics serving cert: %w", err) - } + AzureDiskCsiDriverOperatorServingCert := manifests.AzureDiskCSIDriverOperatorServingCertSecret(hcp.Namespace) + AzureDiskCsiDriverOperatorService := manifests.AzureDiskCSIDriverOperatorMetricsService(hcp.Namespace) + if err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureDiskCsiDriverOperatorService, AzureDiskCsiDriverOperatorServingCert); err != nil { + r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + } + if _, err := createOrUpdate(ctx, r, AzureDiskCsiDriverOperatorServingCert, func() error { + z := pki.ReconcileAzureDiskCsiDriverOperatorMetricsServingCertSecret(AzureDiskCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) + return z + }); err != nil { + return fmt.Errorf("failed to reconcile azure-disk csi driver operator serving cert: %w", err) + } + + azureDiskCsiDriverControllerMetricsService := manifests.AzureDiskCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(azureDiskCsiDriverControllerMetricsService), azureDiskCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve azure-disk-csi-driver-controller-metrics service: %w", err) } - case hyperv1.AzurePlatform: - azureWorkloadIdentityWebhookServingCert := manifests.AzureWorkloadIdentityWebhookServingCert(hcp.Namespace) - if _, err := createOrUpdate(ctx, r, azureWorkloadIdentityWebhookServingCert, func() error { - return pki.ReconcileAzureWorkloadIdentityWebhookServingCert(azureWorkloadIdentityWebhookServingCert, rootCASecret, p.OwnerRef) + } + + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureDiskCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + azureDiskCsiDriverControllerMetricsServingCert := manifests.AzureDiskCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, azureDiskCsiDriverControllerMetricsServingCert); err != nil { + return err + } + + if _, err := createOrUpdate(ctx, r, azureDiskCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAzureDiskCsiDriverControllerMetricsServingCertSecret(azureDiskCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile %s secret: %w", azureWorkloadIdentityWebhookServingCert.Name, err) + return fmt.Errorf("failed to reconcile azure disk csi driver controller metrics serving cert: %w", err) } + } - // Azure-disk CSI driver Operator metrics Serving Cert - AzureDiskCsiDriverOperatorServingCert := manifests.AzureDiskCSIDriverOperatorServingCertSecret(hcp.Namespace) - AzureDiskCsiDriverOperatorService := manifests.AzureDiskCSIDriverOperatorMetricsService(hcp.Namespace) - err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureDiskCsiDriverOperatorService, AzureDiskCsiDriverOperatorServingCert) - if err != nil { - r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + AzureFileCsiDriverOperatorServingCert := manifests.AzureFileCSIDriverOperatorServingCertSecret(hcp.Namespace) + AzureFileCsiDriverOperatorService := manifests.AzureFileCSIDriverOperatorMetricsService(hcp.Namespace) + if err := removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureFileCsiDriverOperatorService, AzureFileCsiDriverOperatorServingCert); err != nil { + r.Log.Error(err, "failed to remove service ca annotation and secret: %w") + } + if _, err := createOrUpdate(ctx, r, AzureFileCsiDriverOperatorServingCert, func() error { + z := pki.ReconcileAzureFileCsiDriverOperatorMetricsServingCertSecret(AzureFileCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) + return z + }); err != nil { + return fmt.Errorf("failed to reconcile azure-file csi driver operator serving cert: %w", err) + } + + azureFileCsiDriverControllerMetricsService := manifests.AzureFileCsiDriverControllerMetricsService(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(azureFileCsiDriverControllerMetricsService), azureFileCsiDriverControllerMetricsService); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to retrieve azure-file-csi-driver-controller-metrics service: %w", err) + } + } + + if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureFileCsiDriverControllerMetricsService); !hasServiceCAAnnotation { + azureFileCsiDriverControllerMetricsServingCert := manifests.AzureFileCsiDriverControllerMetricsServingCert(hcp.Namespace) + + if err := removeServiceCASecret(ctx, r.Client, azureFileCsiDriverControllerMetricsServingCert); err != nil { + return err } - if _, err = createOrUpdate(ctx, r, AzureDiskCsiDriverOperatorServingCert, func() error { - z := pki.ReconcileAzureDiskCsiDriverOperatorMetricsServingCertSecret(AzureDiskCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) - return z + + if _, err := createOrUpdate(ctx, r, azureFileCsiDriverControllerMetricsServingCert, func() error { + return pki.ReconcileAzureFileCsiDriverControllerMetricsServingCertSecret(azureFileCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile azure-disk csi driver operator serving cert: %w", err) + return fmt.Errorf("failed to reconcile azure file csi driver controller metrics serving cert: %w", err) } + } + + return nil +} + +func (r *HostedControlPlaneReconciler) reconcilePKI(ctx context.Context, hcp *hyperv1.HostedControlPlane, infraStatus infra.InfrastructureStatus, createOrUpdate upsert.CreateOrUpdateFN) error { + p := pki.NewPKIParams(hcp, infraStatus.APIHost, infraStatus.OAuthHost, infraStatus.KonnectivityHost) + + rootCASecret := manifests.RootCASecret(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, rootCASecret, func() error { + return pki.ReconcileRootCA(rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile root CA: %w", err) + } - azureDiskCsiDriverControllerMetricsService := manifests.AzureDiskCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(azureDiskCsiDriverControllerMetricsService), azureDiskCsiDriverControllerMetricsService); err != nil { + var observedDefaultIngressCert *corev1.ConfigMap + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { + observedDefaultIngressCert = manifests.IngressObservedDefaultIngressCertCA(hcp.Namespace) + if err := r.Get(ctx, client.ObjectKeyFromObject(observedDefaultIngressCert), observedDefaultIngressCert); err != nil { if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve azure-disk-csi-driver-controller-metrics service: %w", err) + return fmt.Errorf("failed to get observed default ingress cert: %w", err) } + observedDefaultIngressCert = nil } + } + rootCAConfigMap := manifests.RootCAConfigMap(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, rootCAConfigMap, func() error { + return pki.ReconcileRootCAConfigMap(rootCAConfigMap, p.OwnerRef, rootCASecret, observedDefaultIngressCert) + }); err != nil { + return fmt.Errorf("failed to reconcile root CA configmap: %w", err) + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureDiskCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - azureDiskCsiDriverControllerMetricsServingCert := manifests.AzureDiskCsiDriverControllerMetricsServingCert(hcp.Namespace) + if err := r.reconcileEtcdCerts(ctx, hcp, p, createOrUpdate); err != nil { + return err + } - err = removeServiceCASecret(ctx, r.Client, azureDiskCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } + if err := r.reconcileKASCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - if _, err = createOrUpdate(ctx, r, azureDiskCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAzureDiskCsiDriverControllerMetricsServingCertSecret(azureDiskCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) - }); err != nil { - return fmt.Errorf("failed to reconcile azure disk csi driver controller metrics serving cert: %w", err) - } - } + if err := r.reconcileOpenshiftCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - // Azure-file CSI driver Operator metrics Serving Cert - AzureFileCsiDriverOperatorServingCert := manifests.AzureFileCSIDriverOperatorServingCertSecret(hcp.Namespace) - AzureFileCsiDriverOperatorService := manifests.AzureFileCSIDriverOperatorMetricsService(hcp.Namespace) - err = removeServiceCAAnnotationAndSecret(ctx, r.Client, AzureFileCsiDriverOperatorService, AzureFileCsiDriverOperatorServingCert) - if err != nil { - r.Log.Error(err, "failed to remove service ca annotation and secret: %w") - } - if _, err = createOrUpdate(ctx, r, AzureFileCsiDriverOperatorServingCert, func() error { - z := pki.ReconcileAzureFileCsiDriverOperatorMetricsServingCertSecret(AzureFileCsiDriverOperatorServingCert, rootCASecret, p.OwnerRef) - return z + if err := r.reconcileKonnectivityCerts(ctx, hcp, p, createOrUpdate); err != nil { + return err + } + + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { + ingressCert := manifests.IngressCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, ingressCert, func() error { + return pki.ReconcileIngressCert(ingressCert, rootCASecret, p.OwnerRef, p.IngressSubdomain) }); err != nil { - return fmt.Errorf("failed to reconcile azure-file csi driver operator serving cert: %w", err) + return fmt.Errorf("failed to reconcile ingress cert secret: %w", err) } + } - azureFileCsiDriverControllerMetricsService := manifests.AzureFileCsiDriverControllerMetricsService(hcp.Namespace) - if err = r.Get(ctx, client.ObjectKeyFromObject(azureFileCsiDriverControllerMetricsService), azureFileCsiDriverControllerMetricsService); err != nil { - if !apierrors.IsNotFound(err) { - return fmt.Errorf("failed to retrieve azure-file-csi-driver-controller-metrics service: %w", err) - } + var userCABundles []client.ObjectKey + if hcp.Spec.AdditionalTrustBundle != nil { + userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.AdditionalTrustBundle.Name}) + } + if hcp.Spec.Configuration != nil && hcp.Spec.Configuration.Proxy != nil && hcp.Spec.Configuration.Proxy.TrustedCA.Name != "" { + userCABundles = append(userCABundles, client.ObjectKey{Namespace: hcp.Namespace, Name: hcp.Spec.Configuration.Proxy.TrustedCA.Name}) + } + + trustedCABundle := manifests.TrustedCABundleConfigMap(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, trustedCABundle, func() error { + return r.reconcileManagedTrustedCABundle(ctx, trustedCABundle, userCABundles) + }); err != nil { + return fmt.Errorf("failed to reconcile managed UserCA configMap: %w", err) + } + + if util.HCPOAuthEnabled(hcp) { + if err := r.reconcileOAuthCerts(ctx, hcp, p, createOrUpdate, rootCASecret, trustedCABundle); err != nil { + return err } + } + + if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { + machineConfigServerCert := manifests.MachineConfigServerCert(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, machineConfigServerCert, func() error { + return pki.ReconcileMachineConfigServerCert(machineConfigServerCert, rootCASecret, p.OwnerRef) + }); err != nil { + return fmt.Errorf("failed to reconcile machine config server cert secret: %w", err) + } + } - if hasServiceCAAnnotation := doesServiceHaveServiceCAAnnotation(azureFileCsiDriverControllerMetricsService); !hasServiceCAAnnotation { - azureFileCsiDriverControllerMetricsServingCert := manifests.AzureFileCsiDriverControllerMetricsServingCert(hcp.Namespace) + if err := r.reconcileOLMAndMiscCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - err = removeServiceCASecret(ctx, r.Client, azureFileCsiDriverControllerMetricsServingCert) - if err != nil { - return err - } + if err := r.reconcileNetworkServingCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } - if _, err := createOrUpdate(ctx, r, azureFileCsiDriverControllerMetricsServingCert, func() error { - return pki.ReconcileAzureFileCsiDriverControllerMetricsServingCertSecret(azureFileCsiDriverControllerMetricsServingCert, rootCASecret, p.OwnerRef) + if _, exists := hcp.Annotations[hyperv1.DisableIgnitionServerAnnotation]; !exists { + if hcp.Spec.Platform.Type != hyperv1.IBMCloudPlatform { + ignitionServerCert := manifests.IgnitionServerCertSecret(hcp.Namespace) + if _, err := createOrUpdate(ctx, r, ignitionServerCert, func() error { + return pki.ReconcileIgnitionServerCertSecret(ignitionServerCert, rootCASecret, p.OwnerRef) }); err != nil { - return fmt.Errorf("failed to reconcile azure file csi driver controller metrics serving cert: %w", err) + return fmt.Errorf("failed to reconcile ignition server cert: %w", err) } } } + switch hcp.Spec.Platform.Type { + case hyperv1.AWSPlatform: + if err := r.reconcileAWSPlatformCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } + case hyperv1.AzurePlatform: + if err := r.reconcileAzurePlatformCerts(ctx, hcp, p, createOrUpdate, rootCASecret); err != nil { + return err + } + } + return nil } diff --git a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go index f4f43ab5151..3ddf48cd6be 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller_test.go @@ -38,14 +38,19 @@ import ( "github.com/openshift/hypershift/support/releaseinfo/testutils" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/dockerv1client" + "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/reference" "github.com/openshift/hypershift/support/upsert" + "github.com/openshift/hypershift/support/util" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" configv1 "github.com/openshift/api/config/v1" "github.com/openshift/api/image/docker10" routev1 "github.com/openshift/api/route/v1" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/smithy-go" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -54,6 +59,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/util/workqueue" + "k8s.io/utils/clock" + testingclock "k8s.io/utils/clock/testing" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" @@ -67,7 +74,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/yaml" + "github.com/docker/distribution" "github.com/go-logr/zapr" + "github.com/opencontainers/go-digest" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" ) @@ -2873,3 +2882,1549 @@ func TestRemoveCloudResources(t *testing.T) { }) } } +func TestReconcileEtcdStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondType string + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When etcd management type is Unmanaged, it should set EtcdAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Unmanaged, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: "EtcdRunning", + Message: "Etcd cluster is assumed to be running in unmanaged state", + ObservedGeneration: 3, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet is not found, it should set EtcdAvailable to False with NotFound reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdStatefulSetNotFoundReason, + ObservedGeneration: 5, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet exists with quorum, it should set EtcdAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](3), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 3, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + ObservedGeneration: 2, + }, + }, + { + name: "When etcd management type is Managed and StatefulSet has no quorum, it should set EtcdAvailable to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](3), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 0, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.EtcdWaitingForQuorumReason, + ObservedGeneration: 4, + }, + }, + { + name: "When etcd management type is Managed and Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + }, + }, + }, + expectError: true, + }, + { + name: "When etcd management type is Managed with RestoreSnapshotURL and StatefulSet has ready pods, it should set EtcdSnapshotRestored condition", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: hyperv1.Managed, + Managed: &hyperv1.ManagedEtcdSpec{ + Storage: hyperv1.ManagedEtcdStorageSpec{ + RestoreSnapshotURL: []string{"https://example.com/snapshot"}, + }, + }, + }, + }, + }, + existingObjects: []client.Object{ + &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd", + Namespace: testNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To[int32](1), + }, + Status: appsv1.StatefulSetStatus{ + ReadyReplicas: 1, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "etcd-0", + Namespace: testNamespace, + Labels: map[string]string{ + "app": "etcd", + }, + }, + Status: corev1.PodStatus{ + InitContainerStatuses: []corev1.ContainerStatus{ + { + Name: "etcd-init", + Ready: true, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + ObservedGeneration: 6, + }, + }, + { + name: "When etcd management type is empty, it should set condition to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Etcd: hyperv1.EtcdSpec{ + ManagementType: "", + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionUnknown, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 1, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileEtcdStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.EtcdAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(Equal(tc.expectedCondition.Message)) + } + + // For the restore snapshot test case, also verify EtcdSnapshotRestored condition + if tc.name == "When etcd management type is Managed with RestoreSnapshotURL and StatefulSet has ready pods, it should set EtcdSnapshotRestored condition" { + restoreCond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.EtcdSnapshotRestored)) + g.Expect(restoreCond).ToNot(BeNil()) + g.Expect(restoreCond.Status).To(Equal(metav1.ConditionTrue)) + g.Expect(restoreCond.Reason).To(Equal(hyperv1.AsExpectedReason)) + } + }) + } +} + +func TestReconcileKASStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When KAS deployment is not found, it should set KubeAPIServerAvailable to False with NotFound reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.NotFoundReason, + Message: "Kube APIServer deployment not found", + ObservedGeneration: 3, + }, + }, + { + name: "When KAS deployment exists and is Available, it should set KubeAPIServerAvailable to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentAvailable, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + Message: "Kube APIServer deployment is available", + ObservedGeneration: 5, + }, + }, + { + name: "When KAS deployment exists but Available condition is False, it should set KubeAPIServerAvailable to False with WaitingForAvailable reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 7, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentAvailable, + Status: corev1.ConditionFalse, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingForAvailableReason, + Message: "Waiting for Kube APIServer deployment to become available", + ObservedGeneration: 7, + }, + }, + { + name: "When KAS deployment exists but has no Available condition, it should set KubeAPIServerAvailable to False with StatusUnknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{ + Conditions: []appsv1.DeploymentCondition{ + { + Type: appsv1.DeploymentProgressing, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 2, + }, + }, + { + name: "When KAS deployment exists with empty conditions, it should set KubeAPIServerAvailable to False with StatusUnknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + }, + Status: appsv1.DeploymentStatus{}, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionFalse, + Reason: hyperv1.StatusUnknownReason, + ObservedGeneration: 1, + }, + }, + { + name: "When Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileKASStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.KubeAPIServerAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(Equal(tc.expectedCondition.Message)) + } + }) + } +} + +func TestReconcileDegradedStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedCondition metav1.Condition + expectError bool + }{ + { + name: "When no CPO-managed deployments exist, it should set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + }, + existingObjects: []client.Object{}, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 2, + }, + }, + { + name: "When all CPO-managed deployments are fully available, it should set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 0, + }, + }, + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-controller-manager", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 0, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 3, + }, + }, + { + name: "When a single CPO-managed deployment has unavailable replicas, it should set Degraded to True", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 2, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionTrue, + Reason: "UnavailableReplicas", + Message: "kube-apiserver deployment has 2 unavailable replicas", + ObservedGeneration: 4, + }, + }, + { + name: "When multiple CPO-managed deployments have unavailable replicas, it should aggregate all errors in message", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-apiserver", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 1, + }, + }, + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kube-controller-manager", + Namespace: testNamespace, + Labels: map[string]string{ + controlplanecomponent.ManagedByLabel: "control-plane-operator", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 3, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionTrue, + Reason: "UnavailableReplicas", + ObservedGeneration: 5, + }, + }, + { + name: "When deployments exist without the CPO managed-by label, it should ignore them and set Degraded to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + }, + existingObjects: []client.Object{ + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "unmanaged-deployment", + Namespace: testNamespace, + Labels: map[string]string{ + "app": "something-else", + }, + }, + Status: appsv1.DeploymentStatus{ + UnavailableReplicas: 5, + }, + }, + }, + expectedCondition: metav1.Condition{ + Type: string(hyperv1.HostedControlPlaneDegraded), + Status: metav1.ConditionFalse, + Reason: hyperv1.AsExpectedReason, + ObservedGeneration: 6, + }, + }, + { + name: "When List returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, client client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + return fmt.Errorf("simulated list error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileDegradedStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.HostedControlPlaneDegraded)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Type).To(Equal(tc.expectedCondition.Type)) + g.Expect(cond.Status).To(Equal(tc.expectedCondition.Status)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondition.Reason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.expectedCondition.ObservedGeneration)) + if tc.expectedCondition.Message != "" { + g.Expect(cond.Message).To(ContainSubstring(tc.expectedCondition.Message)) + } + + // For the multi-deployment case, verify that both deployment names appear in the message + if tc.name == "When multiple CPO-managed deployments have unavailable replicas, it should aggregate all errors in message" { + g.Expect(cond.Message).To(ContainSubstring("kube-apiserver")) + g.Expect(cond.Message).To(ContainSubstring("kube-controller-manager")) + } + }) + } +} + +func TestReconcileInfrastructureStatusCondition(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + infraStatus infra.InfrastructureStatus + infraErr error + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + expectedEndpoint hyperv1.APIEndpoint + expectOAuthCallback bool + }{ + { + name: "When infrastructure is ready, it should set InfrastructureReady to True and populate endpoint", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "api.example.com", + APIPort: 6443, + KonnectivityHost: "konnectivity.example.com", + KonnectivityPort: 8091, + }, + expectedCondStatus: metav1.ConditionTrue, + expectedCondReason: hyperv1.AsExpectedReason, + expectedEndpoint: hyperv1.APIEndpoint{ + Host: "api.example.com", + Port: 6443, + }, + }, + { + name: "When infrastructure is not ready, it should set InfrastructureReady to False", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 4, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "", + APIPort: 0, + Message: "Load balancer pending", + }, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + name: "When infrastructure status returns error, it should set InfrastructureReady to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 5, + }, + }, + infraErr: fmt.Errorf("failed to get infrastructure status"), + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.InfraStatusFailureReason, + }, + { + name: "When infrastructure is not ready with empty message, it should use default provisioning message", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 6, + }, + }, + infraStatus: infra.InfrastructureStatus{ + APIHost: "", + APIPort: 0, + }, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + r := &HostedControlPlaneReconciler{ + Client: fake.NewClientBuilder().WithScheme(api.Scheme).Build(), + Log: zapr.NewLogger(zaptest.NewLogger(t)), + reconcileInfrastructureStatus: func(ctx context.Context, hcp *hyperv1.HostedControlPlane) (infra.InfrastructureStatus, error) { + return tc.infraStatus, tc.infraErr + }, + } + + r.reconcileInfrastructureStatusCondition(t.Context(), tc.hcp) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.InfrastructureReady)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + + if tc.expectedCondStatus == metav1.ConditionTrue { + g.Expect(tc.hcp.Status.ControlPlaneEndpoint).To(Equal(tc.expectedEndpoint)) + } + + if tc.expectedCondStatus == metav1.ConditionFalse && tc.infraStatus.Message == "" { + g.Expect(cond.Message).To(Equal("Cluster infrastructure is still provisioning")) + } + if tc.expectedCondStatus == metav1.ConditionFalse && tc.infraStatus.Message != "" { + g.Expect(cond.Message).To(Equal(tc.infraStatus.Message)) + } + }) + } +} + +func TestReconcileExternalDNSStatusCondition(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + expectedMessage string + }{ + { + name: "When no external DNS hostname is configured, it should set ExternalDNSReachable to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Networking: hyperv1.ClusterNetworking{ + APIServer: &hyperv1.APIServerNetworking{ + Port: ptr.To[int32](6443), + }, + }, + }, + }, + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.StatusUnknownReason, + expectedMessage: "External DNS is not configured", + }, + { + name: "When HCP is private (no PublicZoneID), it should set ExternalDNSReachable to Unknown", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Private, + }, + }, + Services: []hyperv1.ServicePublishingStrategyMapping{ + { + Service: hyperv1.APIServer, + ServicePublishingStrategy: hyperv1.ServicePublishingStrategy{ + Type: hyperv1.LoadBalancer, + LoadBalancer: &hyperv1.LoadBalancerPublishingStrategy{ + Hostname: "api.example.com", + }, + }, + }, + }, + }, + }, + expectedCondStatus: metav1.ConditionUnknown, + expectedCondReason: hyperv1.StatusUnknownReason, + expectedMessage: "External DNS is not configured", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + r := &HostedControlPlaneReconciler{ + Client: fake.NewClientBuilder().WithScheme(api.Scheme).Build(), + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + r.reconcileExternalDNSStatusCondition(t.Context(), tc.hcp) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.ExternalDNSReachable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + if tc.expectedMessage != "" { + g.Expect(cond.Message).To(Equal(tc.expectedMessage)) + } + }) + } +} + +func TestReconcileAvailabilityAndReadyStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedReady bool + expectedCondStatus metav1.ConditionStatus + expectedCondReason string + }{ + { + name: "When no status conditions exist and no kubeconfig, it should set not ready with Unknown reason", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.StatusUnknownReason, + }, + { + name: "When infrastructure condition is False, it should propagate infrastructure failure", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeConfig: &hyperv1.KubeconfigSecretRef{ + Name: "admin-kubeconfig", + Key: "kubeconfig", + }, + Conditions: []metav1.Condition{ + { + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionFalse, + Reason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + }, + { + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + }, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.WaitingOnInfrastructureReadyReason, + }, + { + name: "When health check fails but no conditions are set, it should report KAS LB not reachable", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 3, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + // No service strategy for APIServer - health check returns error + Services: []hyperv1.ServicePublishingStrategyMapping{}, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeConfig: &hyperv1.KubeconfigSecretRef{ + Name: "admin-kubeconfig", + Key: "kubeconfig", + }, + Conditions: []metav1.Condition{ + { + Type: string(hyperv1.InfrastructureReady), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + { + Type: string(hyperv1.EtcdAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.EtcdQuorumAvailableReason, + }, + { + Type: string(hyperv1.KubeAPIServerAvailable), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + }, + }, + }, + expectedReady: false, + expectedCondStatus: metav1.ConditionFalse, + expectedCondReason: hyperv1.KASLoadBalancerNotReachableReason, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + objs := []client.Object{} + objs = append(objs, tc.existingObjects...) + + c := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(objs...). + Build() + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + r.reconcileAvailabilityAndReadyStatus(t.Context(), tc.hcp) + + g.Expect(tc.hcp.Status.Ready).To(Equal(tc.expectedReady)) + + cond := meta.FindStatusCondition(tc.hcp.Status.Conditions, string(hyperv1.HostedControlPlaneAvailable)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tc.expectedCondStatus)) + g.Expect(cond.Reason).To(Equal(tc.expectedCondReason)) + g.Expect(cond.ObservedGeneration).To(Equal(tc.hcp.Generation)) + }) + } +} + +func TestReconcileKubeadminPasswordStatus(t *testing.T) { + testNamespace := "test-namespace" + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + expectedPasswordRef *corev1.LocalObjectReference + expectError bool + }{ + { + name: "When explicit OAuth config is specified, it should clear kubeadmin password status", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + Configuration: &hyperv1.ClusterConfiguration{ + OAuth: &configv1.OAuthSpec{ + IdentityProviders: []configv1.IdentityProvider{ + { + Name: "test-idp", + IdentityProviderConfig: configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeOpenID, + }, + }, + }, + }, + }, + }, + Status: hyperv1.HostedControlPlaneStatus{ + KubeadminPassword: &corev1.LocalObjectReference{ + Name: "old-kubeadmin-password", + }, + }, + }, + expectedPasswordRef: nil, + }, + { + name: "When no OAuth config and kubeadmin password secret exists, it should set password status", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kubeadmin-password", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + "password": []byte("test-password"), + }, + }, + }, + expectedPasswordRef: &corev1.LocalObjectReference{ + Name: "kubeadmin-password", + }, + }, + { + name: "When no OAuth config and kubeadmin password secret does not exist, it should leave password status nil", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + existingObjects: []client.Object{}, + expectedPasswordRef: nil, + }, + { + name: "When no OAuth config and Get returns unexpected error, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + var c client.Client + if tc.expectError { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return fmt.Errorf("simulated get error") + }, + }). + Build() + } else { + c = fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...). + Build() + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + } + + err := r.reconcileKubeadminPasswordStatus(t.Context(), tc.hcp) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + if tc.expectedPasswordRef == nil { + g.Expect(tc.hcp.Status.KubeadminPassword).To(BeNil()) + } else { + g.Expect(tc.hcp.Status.KubeadminPassword).ToNot(BeNil()) + g.Expect(tc.hcp.Status.KubeadminPassword.Name).To(Equal(tc.expectedPasswordRef.Name)) + } + }) + } +} + +// fakeVersionImageMetadataProvider is a simple test double for ImageMetadataProvider +// that returns deterministic results without contacting a registry. +type fakeVersionImageMetadataProvider struct { + fakeDigest string + fakeRef *reference.DockerImageReference + digestErr error +} + +func (f *fakeVersionImageMetadataProvider) ImageMetadata(_ context.Context, _ string, _ []byte) (*dockerv1client.DockerImageConfig, error) { + return &dockerv1client.DockerImageConfig{}, nil +} + +func (f *fakeVersionImageMetadataProvider) GetManifest(_ context.Context, _ string, _ []byte) (distribution.Manifest, error) { + return nil, nil +} + +func (f *fakeVersionImageMetadataProvider) GetDigest(_ context.Context, _ string, _ []byte) (digest.Digest, *reference.DockerImageReference, error) { + if f.digestErr != nil { + return "", nil, f.digestErr + } + return digest.Digest(f.fakeDigest), f.fakeRef, nil +} + +func (f *fakeVersionImageMetadataProvider) GetMetadata(_ context.Context, _ string, _ []byte) (*dockerv1client.DockerImageConfig, []distribution.Descriptor, distribution.BlobStore, error) { + return &dockerv1client.DockerImageConfig{}, nil, nil, nil +} + +func (f *fakeVersionImageMetadataProvider) GetOverride(_ context.Context, _ string, _ []byte) (*reference.DockerImageReference, error) { + return f.fakeRef, nil +} + +func TestReconcileControlPlaneVersionStatus(t *testing.T) { + testNamespace := "test-namespace" + fakeClock := testingclock.NewFakeClock(time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)) + + testCases := []struct { + name string + hcp *hyperv1.HostedControlPlane + existingObjects []client.Object + digestErr error + expectError bool + expectedDesired configv1.Release + expectedState configv1.UpdateState + }{ + { + name: "When pull secret exists and components are listed successfully with first population, it should create Partial history entry", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + expectedDesired: configv1.Release{ + Version: "4.20.0", + Image: "quay.io/openshift-release-dev/ocp-release@sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + expectedState: configv1.PartialUpdate, + }, + { + name: "When pull secret is missing, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{}, + expectError: true, + }, + { + name: "When GetDigest fails, it should return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 1, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + digestErr: fmt.Errorf("failed to resolve digest"), + expectError: true, + }, + { + name: "When component list fails, it should set partial version and return error", + hcp: &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: testNamespace, + Generation: 2, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + ReleaseImage: "quay.io/openshift-release-dev/ocp-release:4.20.0", + }, + }, + existingObjects: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret", + Namespace: testNamespace, + }, + Data: map[string][]byte{ + corev1.DockerConfigJsonKey: []byte("{}"), + }, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + releaseImage := testutils.InitReleaseImageOrDie("4.20.0") + + resolvedRef := &reference.DockerImageReference{ + Registry: "quay.io", + Namespace: "openshift-release-dev", + Name: "ocp-release", + ID: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + } + + imgProvider := &fakeVersionImageMetadataProvider{ + fakeDigest: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + fakeRef: resolvedRef, + digestErr: tc.digestErr, + } + + clientBuilder := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(tc.existingObjects...) + + // For the "component list fails" case, intercept the List call to fail + if tc.name == "When component list fails, it should set partial version and return error" { + clientBuilder = clientBuilder.WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, c client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if _, ok := list.(*hyperv1.ControlPlaneComponentList); ok { + return fmt.Errorf("simulated list error") + } + return c.List(ctx, list, opts...) + }, + }) + // Also need to add status subresource for patching + clientBuilder = clientBuilder.WithStatusSubresource(tc.hcp) + } + + c := clientBuilder.Build() + + // For the component list failure case, we need the HCP to exist in the fake client + if tc.name == "When component list fails, it should set partial version and return error" { + g.Expect(c.Create(t.Context(), tc.hcp)).To(Succeed()) + } + + r := &HostedControlPlaneReconciler{ + Client: c, + Log: zapr.NewLogger(zaptest.NewLogger(t)), + ImageMetadataProvider: imgProvider, + clock: fakeClock, + } + + originalHCP := tc.hcp.DeepCopy() + err := r.reconcileControlPlaneVersionStatus(t.Context(), tc.hcp, originalHCP, releaseImage) + if tc.expectError { + g.Expect(err).To(HaveOccurred()) + + // For the component list failure case, verify partial status was still populated + if tc.name == "When component list fails, it should set partial version and return error" { + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Version).To(Equal("4.20.0")) + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Image).ToNot(BeEmpty()) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History).ToNot(BeEmpty()) + } + return + } + g.Expect(err).ToNot(HaveOccurred()) + + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Version).To(Equal(tc.expectedDesired.Version)) + g.Expect(tc.hcp.Status.ControlPlaneVersion.Desired.Image).To(Equal(tc.expectedDesired.Image)) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History).ToNot(BeEmpty()) + g.Expect(tc.hcp.Status.ControlPlaneVersion.History[0].State).To(Equal(tc.expectedState)) + }) + } +} + +func TestReconcileDeletion(t *testing.T) { + tests := []struct { + name string + setupEC2Mock func(*gomock.Controller) *awsapi.MockEC2API + wantErr bool + wantCondStatus metav1.ConditionStatus + }{ + { + name: "When destroyAWSDefaultSecurityGroup returns UnauthorizedOperation, it should skip gracefully and not return error", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeSecurityGroups(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []ec2types.SecurityGroup{ + {GroupId: aws.String("sg-123")}, + }, + }, nil) + m.EXPECT().DeleteSecurityGroup(gomock.Any(), gomock.Any()).Return(nil, + &smithy.GenericAPIError{Code: "UnauthorizedOperation", Message: "not authorized"}) + return m + }, + wantErr: false, + wantCondStatus: metav1.ConditionFalse, + }, + { + name: "When destroyAWSDefaultSecurityGroup returns DependencyViolation, it should skip gracefully and not return error", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeSecurityGroups(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []ec2types.SecurityGroup{ + {GroupId: aws.String("sg-123")}, + }, + }, nil) + m.EXPECT().DeleteSecurityGroup(gomock.Any(), gomock.Any()).Return(nil, + &smithy.GenericAPIError{Code: "DependencyViolation", Message: "resource has dependent object"}) + return m + }, + wantErr: false, + wantCondStatus: metav1.ConditionFalse, + }, + { + name: "When destroyAWSDefaultSecurityGroup returns unexpected error, it should propagate the error", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + m.EXPECT().DescribeSecurityGroups(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []ec2types.SecurityGroup{ + {GroupId: aws.String("sg-123")}, + }, + }, nil) + m.EXPECT().DeleteSecurityGroup(gomock.Any(), gomock.Any()).Return(nil, + &smithy.GenericAPIError{Code: "InternalError", Message: "something broke"}) + return m + }, + wantErr: true, + wantCondStatus: metav1.ConditionFalse, + }, + { + name: "When destroyAWSDefaultSecurityGroup succeeds, it should set condition to true", + setupEC2Mock: func(mockCtrl *gomock.Controller) *awsapi.MockEC2API { + m := awsapi.NewMockEC2API(mockCtrl) + // First call: find the SG + m.EXPECT().DescribeSecurityGroups(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []ec2types.SecurityGroup{ + {GroupId: aws.String("sg-123")}, + }, + }, nil) + m.EXPECT().DeleteSecurityGroup(gomock.Any(), gomock.Any()).Return(&ec2.DeleteSecurityGroupOutput{}, nil) + // Second call: verify SG is gone + m.EXPECT().DescribeSecurityGroups(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []ec2types.SecurityGroup{}, + }, nil) + return m + }, + wantErr: false, + wantCondStatus: metav1.ConditionTrue, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := tt.setupEC2Mock(mockCtrl) + + hcp := &hyperv1.HostedControlPlane{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-hcp", + Namespace: "test-ns", + Annotations: map[string]string{ + hyperv1.CleanupCloudResourcesAnnotation: "true", + }, + }, + Spec: hyperv1.HostedControlPlaneSpec{ + InfraID: "test-infra", + Platform: hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{}, + }, + }, + Status: hyperv1.HostedControlPlaneStatus{ + Conditions: []metav1.Condition{ + { + Type: string(hyperv1.CloudResourcesDestroyed), + Status: metav1.ConditionTrue, + Reason: hyperv1.AsExpectedReason, + }, + }, + }, + } + fakeClient := fake.NewClientBuilder(). + WithScheme(api.Scheme). + WithObjects(hcp). + WithStatusSubresource(&hyperv1.HostedControlPlane{}). + Build() + + ctx := ctrl.LoggerInto(t.Context(), ctrl.Log.WithName("test")) + + // Re-read from fake client so the object has a ResourceVersion for OptimisticLock + g.Expect(fakeClient.Get(ctx, client.ObjectKeyFromObject(hcp), hcp)).To(Succeed()) + originalHCP := hcp.DeepCopy() + + r := &HostedControlPlaneReconciler{ + Client: fakeClient, + Log: ctrl.Log.WithName("test"), + ec2Client: mockEC2, + } + + _, err := r.reconcileDeletion(ctx, hcp, originalHCP) + if tt.wantErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + } + + cond := meta.FindStatusCondition(hcp.Status.Conditions, string(hyperv1.AWSDefaultSecurityGroupDeleted)) + g.Expect(cond).ToNot(BeNil()) + g.Expect(cond.Status).To(Equal(tt.wantCondStatus)) + }) + } +} + +// Compile-time assertion that fakeVersionImageMetadataProvider satisfies the interface. +var _ util.ImageMetadataProvider = &fakeVersionImageMetadataProvider{} + +// Compile-time assertion for clock interface used by tests. +var _ clock.Clock = &testingclock.FakeClock{} diff --git a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go index 60dd8db418a..4aeb258be96 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go +++ b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert.go @@ -137,6 +137,223 @@ func defaultIDPMappingMethods(identityProviders []configv1.IdentityProvider) []c return out } +func convertBasicAuthIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + basicAuthConfig := providerConfig.BasicAuth + if basicAuthConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.BasicAuthPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "BasicAuthPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{URL: basicAuthConfig.URL}, + } + if basicAuthConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if basicAuthConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if basicAuthConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertGitHubIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + githubConfig := providerConfig.GitHub + if githubConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GitHubIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GitHubIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: githubConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + Organizations: githubConfig.Organizations, + Teams: githubConfig.Teams, + Hostname: githubConfig.Hostname, + } + if githubConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} + +func convertGitLabIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + gitlabConfig := providerConfig.GitLab + if gitlabConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GitLabIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GitLabIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + URL: gitlabConfig.URL, + ClientID: gitlabConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + Legacy: new(bool), + } + if gitlabConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertGoogleIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + googleConfig := providerConfig.Google + if googleConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.GoogleIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "GoogleIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: googleConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + HostedDomain: googleConfig.HostedDomain, + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} + +func convertHTPasswdIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + if providerConfig.HTPasswd == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.HTPasswdPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "HTPasswdPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertKeystoneIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + keystoneConfig := providerConfig.Keystone + if keystoneConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.KeystonePasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "KeystonePasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{URL: keystoneConfig.URL}, + DomainName: keystoneConfig.DomainName, + UseKeystoneIdentity: true, + } + if keystoneConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if keystoneConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if keystoneConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertLDAPIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + ldapConfig := providerConfig.LDAP + if ldapConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.LDAPPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "LDAPPasswordIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + URL: ldapConfig.URL, + BindDN: ldapConfig.BindDN, + Insecure: ldapConfig.Insecure, + Attributes: osinv1.LDAPAttributeMapping{ + ID: ldapConfig.Attributes.ID, + PreferredUsername: ldapConfig.Attributes.PreferredUsername, + Name: ldapConfig.Attributes.Name, + Email: ldapConfig.Attributes.Email, + }, + } + if ldapConfig.BindPassword.Name != "" { + provider.BindPassword = configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), + }} + } + if ldapConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} + +func convertOpenIDIDP(ctx context.Context, providerConfig *configv1.IdentityProviderConfig, configOverride *ConfigOverride, i int, idpVolumeMounts *IDPVolumeMountInfo, kclient crclient.Client, namespace string, skipKonnectivityDialer bool) (*idpData, error) { + openIDConfig := providerConfig.OpenID + if openIDConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + openIDProvider := &osinv1.OpenIDIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "OpenIDIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + ClientID: openIDConfig.ClientID, + ClientSecret: configv1.StringSource{StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), + }}, + ExtraScopes: openIDConfig.ExtraScopes, + ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, + } + if configOverride != nil { + openIDProvider.URLs = configOverride.URLs + openIDProvider.Claims = configOverride.Claims + } else { + urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) + if err != nil { + return nil, err + } + openIDProvider.URLs = *urls + var groups []string + if len(openIDConfig.Claims.Groups) > 0 { + groups = make([]string, len(openIDConfig.Claims.Groups)) + for i, group := range openIDConfig.Claims.Groups { + groups[i] = string(group) + } + } + openIDProvider.Claims = osinv1.OpenIDClaims{ + ID: []string{configv1.UserIDClaim}, + PreferredUsername: openIDConfig.Claims.PreferredUsername, + Name: openIDConfig.Claims.Name, + Email: openIDConfig.Claims.Email, + Groups: groups, + } + } + if len(openIDConfig.CA.Name) > 0 { + openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + data := &idpData{provider: openIDProvider, login: true} + if configOverride != nil && configOverride.Challenge != nil { + data.challenge = *configOverride.Challenge + } else { + challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow(ctx, kclient, openIDProvider.URLs.Token, openIDConfig.ClientID, namespace, openIDConfig.CA, openIDConfig.ClientSecret, skipKonnectivityDialer) + if err != nil { + return nil, fmt.Errorf("error attempting password grant flow: %v", err) + } + data.challenge = challengeFlowsAllowed + } + return data, nil +} + +func convertRequestHeaderIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + requestHeaderConfig := providerConfig.RequestHeader + if requestHeaderConfig == nil { + return nil, fmt.Errorf("type %s was specified, but its configuration is missing", providerConfig.Type) + } + provider := &osinv1.RequestHeaderIdentityProvider{ + TypeMeta: metav1.TypeMeta{Kind: "RequestHeaderIdentityProvider", APIVersion: osinv1.GroupVersion.String()}, + LoginURL: requestHeaderConfig.LoginURL, + ChallengeURL: requestHeaderConfig.ChallengeURL, + ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), + ClientCommonNames: requestHeaderConfig.ClientCommonNames, + Headers: requestHeaderConfig.Headers, + PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, + NameHeaders: requestHeaderConfig.NameHeaders, + EmailHeaders: requestHeaderConfig.EmailHeaders, + } + return &idpData{ + provider: provider, + challenge: len(requestHeaderConfig.ChallengeURL) > 0, + login: len(requestHeaderConfig.LoginURL) > 0, + }, nil +} + func convertProviderConfigToIDPData( ctx context.Context, providerConfig *configv1.IdentityProviderConfig, @@ -147,289 +364,28 @@ func convertProviderConfigToIDPData( namespace string, skipKonnectivityDialer bool, ) (*idpData, error) { - const missingProviderFmt string = "type %s was specified, but its configuration is missing" - - data := &idpData{login: true} - switch providerConfig.Type { case configv1.IdentityProviderTypeBasicAuth: - basicAuthConfig := providerConfig.BasicAuth - if basicAuthConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.BasicAuthPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "BasicAuthPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: basicAuthConfig.URL, - }, - } - if basicAuthConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if basicAuthConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-key", corev1.TLSCertKey) - } - if basicAuthConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - - data.provider = provider - data.challenge = true - + return convertBasicAuthIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitHub: - githubConfig := providerConfig.GitHub - if githubConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.GitHubIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitHubIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: githubConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Organizations: githubConfig.Organizations, - Teams: githubConfig.Teams, - Hostname: githubConfig.Hostname, - } - if githubConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = false - + return convertGitHubIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitLab: - gitlabConfig := providerConfig.GitLab - if gitlabConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.GitLabIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitLabIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - - URL: gitlabConfig.URL, - ClientID: gitlabConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Legacy: new(bool), // we require OIDC for GitLab now - } - if gitlabConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - - data.provider = provider - data.challenge = true - + return convertGitLabIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGoogle: - googleConfig := providerConfig.Google - if googleConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - data.provider = &osinv1.GoogleIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GoogleIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: googleConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - HostedDomain: googleConfig.HostedDomain, - } - data.challenge = false - + return convertGoogleIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeHTPasswd: - if providerConfig.HTPasswd == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - data.provider = &osinv1.HTPasswdPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "HTPasswdPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), - } - data.challenge = true - + return convertHTPasswdIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeKeystone: - keystoneConfig := providerConfig.Keystone - if keystoneConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.KeystonePasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "KeystonePasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: keystoneConfig.URL, - }, - DomainName: keystoneConfig.DomainName, - UseKeystoneIdentity: true, // force use of keystone ID - } - if keystoneConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if keystoneConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) - } - if keystoneConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - data.provider = provider - data.challenge = true - + return convertKeystoneIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeLDAP: - ldapConfig := providerConfig.LDAP - if ldapConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - provider := &osinv1.LDAPPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "LDAPPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - URL: ldapConfig.URL, - BindDN: ldapConfig.BindDN, - Insecure: ldapConfig.Insecure, - Attributes: osinv1.LDAPAttributeMapping{ - ID: ldapConfig.Attributes.ID, - PreferredUsername: ldapConfig.Attributes.PreferredUsername, - Name: ldapConfig.Attributes.Name, - Email: ldapConfig.Attributes.Email, - }, - } - if ldapConfig.BindPassword.Name != "" { - provider.BindPassword = configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), - }, - } - } - if ldapConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = true - + return convertLDAPIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeOpenID: - openIDConfig := providerConfig.OpenID - if openIDConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - - openIDProvider := &osinv1.OpenIDIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "OpenIDIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: openIDConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - ExtraScopes: openIDConfig.ExtraScopes, - ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, - } - //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) - if configOverride != nil { - openIDProvider.URLs = configOverride.URLs - openIDProvider.Claims = configOverride.Claims - } else { - urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) - if err != nil { - return nil, err - } - openIDProvider.URLs = *urls - var groups []string - if len(openIDConfig.Claims.Groups) > 0 { - groups = make([]string, len(openIDConfig.Claims.Groups)) - for i, group := range openIDConfig.Claims.Groups { - groups[i] = string(group) - } - } - openIDProvider.Claims = osinv1.OpenIDClaims{ - // There is no longer a user-facing setting for ID as it is considered unsafe - ID: []string{configv1.UserIDClaim}, - PreferredUsername: openIDConfig.Claims.PreferredUsername, - Name: openIDConfig.Claims.Name, - Email: openIDConfig.Claims.Email, - Groups: groups, - } - } - if len(openIDConfig.CA.Name) > 0 { - openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = openIDProvider - - if configOverride != nil && configOverride.Challenge != nil { - data.challenge = *configOverride.Challenge - } else { - // openshift CR validating in kube-apiserver does not allow - // challenge-redirecting IdPs to be configured with OIDC so it is safe - // to allow challenge-issuing flow if it's available on the OIDC side - challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( - ctx, - kclient, - openIDProvider.URLs.Token, - openIDConfig.ClientID, - namespace, - openIDConfig.CA, - openIDConfig.ClientSecret, - skipKonnectivityDialer, - ) - if err != nil { - return nil, fmt.Errorf("error attempting password grant flow: %v", err) - } - data.challenge = challengeFlowsAllowed - } + return convertOpenIDIDP(ctx, providerConfig, configOverride, i, idpVolumeMounts, kclient, namespace, skipKonnectivityDialer) case configv1.IdentityProviderTypeRequestHeader: - requestHeaderConfig := providerConfig.RequestHeader - if requestHeaderConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - data.provider = &osinv1.RequestHeaderIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "RequestHeaderIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - LoginURL: requestHeaderConfig.LoginURL, - ChallengeURL: requestHeaderConfig.ChallengeURL, - ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), - ClientCommonNames: requestHeaderConfig.ClientCommonNames, - Headers: requestHeaderConfig.Headers, - PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, - NameHeaders: requestHeaderConfig.NameHeaders, - EmailHeaders: requestHeaderConfig.EmailHeaders, - } - data.challenge = len(requestHeaderConfig.ChallengeURL) > 0 - data.login = len(requestHeaderConfig.LoginURL) > 0 - + return convertRequestHeaderIDP(providerConfig, i, idpVolumeMounts) default: return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) - } // switch - - return data, nil + } } const ( diff --git a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go index f92a638fb7e..d96a6ce199e 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/oauth/idp_convert_test.go @@ -210,6 +210,721 @@ func TestOpenIDProviderConversion(t *testing.T) { } } +func TestDefaultIDPMappingMethods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []configv1.IdentityProvider + expected []configv1.MappingMethodType + }{ + { + name: "When no identity providers are given, it should return an empty slice", + input: []configv1.IdentityProvider{}, + expected: []configv1.MappingMethodType{}, + }, + { + name: "When mapping method is empty, it should default to claim", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodClaim}, + }, + { + name: "When mapping method is already set, it should preserve it", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: configv1.MappingMethodAdd}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodAdd}, + }, + { + name: "When multiple providers have mixed mapping methods, it should default only empty ones", + input: []configv1.IdentityProvider{ + {Name: "a", MappingMethod: ""}, + {Name: "b", MappingMethod: configv1.MappingMethodLookup}, + {Name: "c", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{ + configv1.MappingMethodClaim, + configv1.MappingMethodLookup, + configv1.MappingMethodClaim, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := defaultIDPMappingMethods(tc.input) + g.Expect(result).To(HaveLen(len(tc.expected))) + for i, idp := range result { + g.Expect(idp.MappingMethod).To(Equal(tc.expected[i])) + } + }) + } +} + +func newTestVolumeMountInfo() *IDPVolumeMountInfo { + return &IDPVolumeMountInfo{ + Container: oauthContainerMain().Name, + VolumeMounts: podspec.VolumeMounts{ + oauthContainerMain().Name: podspec.ContainerMounts{}, + }, + } +} + +func TestConvertBasicAuthIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When BasicAuth config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + }, + expectErr: true, + }, + { + name: "When BasicAuth config has only URL, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When BasicAuth config has CA and TLS certs, it should set volume mount paths", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + CA: configv1.ConfigMapNameReference{Name: "my-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "my-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "my-key"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertBasicAuthIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("configuration is missing")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.BasicAuthPasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://auth.example.com")) + if tc.config.BasicAuth.CA.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + } + if tc.config.BasicAuth.TLSClientCert.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + } + if tc.config.BasicAuth.TLSClientKey.Name != "" { + g.Expect(provider.RemoteConnectionInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + } + }) + } +} + +func TestConvertGitHubIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitHub config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + }, + expectErr: true, + }, + { + name: "When GitHub config is provided with organizations, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + Organizations: []string{"org1", "org2"}, + Teams: []string{"org1/team1"}, + Hostname: "github.example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + { + name: "When GitHub config has a CA, it should set the CA path", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + CA: configv1.ConfigMapNameReference{Name: "gh-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitHubIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitHubIdentityProvider) + g.Expect(provider.ClientID).To(Equal(tc.config.GitHub.ClientID)) + g.Expect(provider.Organizations).To(Equal(tc.config.GitHub.Organizations)) + if tc.config.GitHub.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertGitLabIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitLab config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + }, + expectErr: true, + }, + { + name: "When GitLab config is provided, it should produce a valid provider with legacy set", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + GitLab: &configv1.GitLabIdentityProvider{ + URL: "https://gitlab.example.com", + ClientID: "gl-client", + ClientSecret: configv1.SecretNameReference{Name: "gl-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitLabIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitLabIdentityProvider) + g.Expect(provider.URL).To(Equal("https://gitlab.example.com")) + g.Expect(provider.Legacy).ToNot(BeNil()) + g.Expect(*provider.Legacy).To(BeFalse()) + } + }) + } +} + +func TestConvertGoogleIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Google config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + }, + expectErr: true, + }, + { + name: "When Google config is provided with hosted domain, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + Google: &configv1.GoogleIdentityProvider{ + ClientID: "google-client", + ClientSecret: configv1.SecretNameReference{Name: "google-secret"}, + HostedDomain: "example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGoogleIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GoogleIdentityProvider) + g.Expect(provider.ClientID).To(Equal("google-client")) + g.Expect(provider.HostedDomain).To(Equal("example.com")) + } + }) + } +} + +func TestConvertHTPasswdIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When HTPasswd config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + }, + expectErr: true, + }, + { + name: "When HTPasswd config is provided, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertHTPasswdIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.HTPasswdPasswordIdentityProvider) + g.Expect(provider.File).To(ContainSubstring("idp_secret_0_file-data")) + } + }) + } +} + +func TestConvertKeystoneIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Keystone config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + }, + expectErr: true, + }, + { + name: "When Keystone config is provided with TLS certs, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + Keystone: &configv1.KeystoneIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://keystone.example.com", + CA: configv1.ConfigMapNameReference{Name: "ks-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "ks-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "ks-key"}, + }, + DomainName: "my-domain", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertKeystoneIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.KeystonePasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://keystone.example.com")) + g.Expect(provider.DomainName).To(Equal("my-domain")) + g.Expect(provider.UseKeystoneIdentity).To(BeTrue()) + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + }) + } +} + +func TestConvertLDAPIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When LDAP config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + }, + expectErr: true, + }, + { + name: "When LDAP config is provided with bind password, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + BindDN: "cn=admin,dc=example,dc=com", + BindPassword: configv1.SecretNameReference{Name: "ldap-bind-pw"}, + Insecure: true, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + PreferredUsername: []string{"uid"}, + Name: []string{"cn"}, + Email: []string{"mail"}, + }, + CA: configv1.ConfigMapNameReference{Name: "ldap-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When LDAP config has no bind password, it should not set the bind password field", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + Insecure: false, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertLDAPIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.LDAPPasswordIdentityProvider) + g.Expect(provider.URL).To(Equal(tc.config.LDAP.URL)) + g.Expect(provider.BindDN).To(Equal(tc.config.LDAP.BindDN)) + g.Expect(provider.Insecure).To(Equal(tc.config.LDAP.Insecure)) + g.Expect(provider.Attributes.ID).To(Equal(tc.config.LDAP.Attributes.ID)) + if tc.config.LDAP.BindPassword.Name != "" { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(ContainSubstring("idp_secret_0_bind-password")) + } else { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(BeEmpty()) + } + if tc.config.LDAP.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertRequestHeaderIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When RequestHeader config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + }, + expectErr: true, + }, + { + name: "When RequestHeader has login and challenge URLs, it should set login and challenge to true", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + LoginURL: "https://login.example.com", + ChallengeURL: "https://challenge.example.com", + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + ClientCommonNames: []string{"client1"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When RequestHeader has empty login and challenge URLs, it should set login and challenge to false", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: false, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertRequestHeaderIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.RequestHeaderIdentityProvider) + g.Expect(provider.Headers).To(Equal(tc.config.RequestHeader.Headers)) + g.Expect(provider.ClientCA).To(ContainSubstring("idp_cm_0_ca")) + } + }) + } +} + +func TestConvertProviderConfigToIDPData_UnsupportedType(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + config := &configv1.IdentityProviderConfig{ + Type: "UnsupportedType", + } + _, err := convertProviderConfigToIDPData(t.Context(), config, nil, 0, vmi, nil, "test", true) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("is not supported")) +} + +func TestConvertProviderConfigToIDPData_Routing(t *testing.T) { + t.Parallel() + tests := []struct { + name string + providerType configv1.IdentityProviderType + config *configv1.IdentityProviderConfig + expectedKind string + }{ + { + name: "When type is BasicAuth, it should route to BasicAuth converter", + providerType: configv1.IdentityProviderTypeBasicAuth, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{URL: "https://example.com"}, + }, + }, + expectedKind: "BasicAuthPasswordIdentityProvider", + }, + { + name: "When type is GitHub, it should route to GitHub converter", + providerType: configv1.IdentityProviderTypeGitHub, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "id", + ClientSecret: configv1.SecretNameReference{Name: "s"}, + }, + }, + expectedKind: "GitHubIdentityProvider", + }, + { + name: "When type is HTPasswd, it should route to HTPasswd converter", + providerType: configv1.IdentityProviderTypeHTPasswd, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd"}, + }, + }, + expectedKind: "HTPasswdPasswordIdentityProvider", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertProviderConfigToIDPData(t.Context(), tc.config, nil, 0, vmi, nil, "test", true) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.provider.GetObjectKind().GroupVersionKind().Kind).To(Equal(tc.expectedKind)) + }) + } +} + +func TestIDPVolumeMountInfo_ConfigMapPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.ConfigMapPath(2, "my-configmap", "ca", "ca.crt") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_cm_2_ca/ca.crt")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-cm-2-ca")) + g.Expect(vmi.Volumes[0].ConfigMap.Name).To(Equal("my-configmap")) +} + +func TestIDPVolumeMountInfo_SecretPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.SecretPath(3, "my-secret", "client-secret", "clientSecret") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_secret_3_client-secret/clientSecret")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-secret-3-client-secret")) + g.Expect(vmi.Volumes[0].Secret.SecretName).To(Equal("my-secret")) + g.Expect(vmi.Volumes[0].Secret.DefaultMode).ToNot(BeNil()) + g.Expect(*vmi.Volumes[0].Secret.DefaultMode).To(Equal(int32(0640))) +} + +func TestIsValidURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + rawurl string + optional bool + expected bool + }{ + { + name: "When URL is empty and optional, it should return true", + rawurl: "", + optional: true, + expected: true, + }, + { + name: "When URL is empty and required, it should return false", + rawurl: "", + optional: false, + expected: false, + }, + { + name: "When URL is a valid https URL, it should return true", + rawurl: "https://example.com/auth", + optional: false, + expected: true, + }, + { + name: "When URL uses http scheme, it should return false", + rawurl: "http://example.com/auth", + optional: false, + expected: false, + }, + { + name: "When URL has a fragment, it should return false", + rawurl: "https://example.com/auth#fragment", + optional: false, + expected: false, + }, + { + name: "When URL has no host, it should return false", + rawurl: "https://", + optional: false, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + g.Expect(isValidURL(tc.rawurl, tc.optional)).To(Equal(tc.expected)) + }) + } +} + func TestTransportForCARef(t *testing.T) { namespace := "test" diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go index a13fc371951..cfb173c7403 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/deployment_test.go @@ -49,6 +49,30 @@ func newPodMonitorWithTLS(name, namespace string, tlsCfg *prometheusoperatorv1.S } } +func newCertVolumeTestContext(namespace string, scheme *runtime.Scheme, objects ...runtime.Object) component.WorkloadContext { + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(objects...).Build() + return component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, + } +} + +func assertCertVolumeCount(t *testing.T, cpContext component.WorkloadContext, namespace string, expectedVolumes, expectedMounts int) ([]corev1.Volume, []corev1.VolumeMount) { + t.Helper() + volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(volumes) != expectedVolumes { + t.Errorf("expected %d volumes, got %d", expectedVolumes, len(volumes)) + } + if len(mounts) != expectedMounts { + t.Errorf("expected %d mounts, got %d", expectedMounts, len(mounts)) + } + return volumes, mounts +} + func TestCertVolumesFromMonitors(t *testing.T) { t.Parallel() @@ -81,7 +105,6 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }, }) - // sm2 references the same root-ca and metrics-client — should deduplicate. sm2 := newServiceMonitorWithTLS("kube-controller-manager", namespace, &prometheusoperatorv1.TLSConfig{ SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ CA: prometheusoperatorv1.SecretOrConfigMap{ @@ -103,27 +126,9 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 2, 2) - // Expect 2 unique volumes: metrics-client (Secret), root-ca (ConfigMap). - if len(volumes) != 2 { - t.Errorf("expected 2 volumes, got %d", len(volumes)) - } - if len(mounts) != 2 { - t.Errorf("expected 2 mounts, got %d", len(mounts)) - } - - // Verify sorted order: metrics-client before root-ca. if len(volumes) >= 2 { if volumes[0].Name != "metrics-client" { t.Errorf("expected first volume to be metrics-client, got %s", volumes[0].Name) @@ -158,21 +163,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -207,21 +199,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm1, sm2).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm1, sm2) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -242,45 +221,15 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, } - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When no monitors exist, it should return empty slices", func(t *testing.T) { t.Parallel() - fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When Secret volumes are generated, they should have optional=true", func(t *testing.T) { @@ -297,21 +246,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + volumes, mounts := assertCertVolumeCount(t, cpContext, namespace, 1, 1) secret := volumes[0].VolumeSource.Secret if secret == nil { @@ -321,9 +257,6 @@ func TestCertVolumesFromMonitors(t *testing.T) { t.Error("expected Secret volume to have optional=true") } - if len(mounts) != 1 { - t.Fatalf("expected 1 mount, got %d", len(mounts)) - } expectedPath := certBasePath + "/test-secret" if mounts[0].MountPath != expectedPath { t.Errorf("expected mount path %q, got %q", expectedPath, mounts[0].MountPath) @@ -344,21 +277,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) cm := volumes[0].VolumeSource.ConfigMap if cm == nil { @@ -381,27 +301,12 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + cpContext := newCertVolumeTestContext(namespace, scheme, pm) + volumes, _ := assertCertVolumeCount(t, cpContext, namespace, 1, 1) - if len(volumes) != 1 { - t.Fatalf("expected 1 volume, got %d", len(volumes)) - } if volumes[0].Name != "root-ca" { t.Errorf("expected volume name root-ca, got %s", volumes[0].Name) } - if len(mounts) != 1 { - t.Fatalf("expected 1 mount, got %d", len(mounts)) - } }) t.Run("When PodMonitor has no TLS config, it should be skipped", func(t *testing.T) { @@ -409,23 +314,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { pm := newPodMonitorWithTLS("cluster-autoscaler", namespace, nil) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, mounts, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(volumes) != 0 { - t.Errorf("expected 0 volumes, got %d", len(volumes)) - } - if len(mounts) != 0 { - t.Errorf("expected 0 mounts, got %d", len(mounts)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, pm) + assertCertVolumeCount(t, cpContext, namespace, 0, 0) }) t.Run("When ServiceMonitor and PodMonitor share the same CA ref, it should deduplicate", func(t *testing.T) { @@ -450,21 +340,8 @@ func TestCertVolumesFromMonitors(t *testing.T) { }, }) - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(sm, pm).Build() - cpContext := component.WorkloadContext{ - Context: context.Background(), - Client: fakeClient, - HCP: &hyperv1.HostedControlPlane{ObjectMeta: metav1.ObjectMeta{Namespace: namespace}}, - } - - volumes, _, err := certVolumesFromMonitors(cpContext, namespace) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(volumes) != 1 { - t.Errorf("expected 1 deduplicated volume, got %d", len(volumes)) - } + cpContext := newCertVolumeTestContext(namespace, scheme, sm, pm) + assertCertVolumeCount(t, cpContext, namespace, 1, 1) }) } diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go index 325a2d37e33..6068e51c5ef 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config.go @@ -31,7 +31,6 @@ const ( ) func adaptScrapeConfig(cpContext component.WorkloadContext, cm *corev1.ConfigMap) error { - log := logr.FromContextOrDiscard(cpContext) namespace := cpContext.HCP.Namespace endpointResolverURL := fmt.Sprintf("https://endpoint-resolver.%s.svc", namespace) @@ -43,164 +42,193 @@ func adaptScrapeConfig(cpContext component.WorkloadContext, cm *corev1.ConfigMap }, } - // Process ServiceMonitors. + smComponents, err := componentsFromServiceMonitors(cpContext, namespace) + if err != nil { + return err + } + cfg.Components = append(cfg.Components, smComponents...) + + pmComponents, err := componentsFromPodMonitors(cpContext, namespace) + if err != nil { + return err + } + cfg.Components = append(cfg.Components, pmComponents...) + + sort.Slice(cfg.Components, func(i, j int) bool { + return cfg.Components[i].Name < cfg.Components[j].Name + }) + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal scrape config: %w", err) + } + + if cm.Data == nil { + cm.Data = make(map[string]string) + } + cm.Data["config.yaml"] = string(data) + return nil +} + +func componentsFromServiceMonitors(cpContext component.WorkloadContext, namespace string) ([]metricsproxybin.ComponentFileConfig, error) { + log := logr.FromContextOrDiscard(cpContext) + smList := &prometheusoperatorv1.ServiceMonitorList{} if err := cpContext.Client.List(cpContext, smList, client.InNamespace(namespace)); err != nil { - return fmt.Errorf("failed to list ServiceMonitors: %w", err) + return nil, fmt.Errorf("failed to list ServiceMonitors: %w", err) } + var components []metricsproxybin.ComponentFileConfig for i := range smList.Items { - sm := &smList.Items[i] - if len(sm.Spec.Endpoints) == 0 { - continue - } - ep := sm.Spec.Endpoints[0] - - serviceName, podSelector, err := findServiceForMonitor(cpContext, namespace, sm.Name, sm.Spec.Selector) - if err != nil { - log.V(4).Info("skipping ServiceMonitor: service not found", "serviceMonitor", sm.Name, "error", err) - continue - } - if len(podSelector) == 0 { - log.V(4).Info("skipping ServiceMonitor: service has no pod selector", "serviceMonitor", sm.Name, "service", serviceName) - continue + comp, ok := componentFromServiceMonitor(cpContext, log, namespace, &smList.Items[i]) + if ok { + components = append(components, comp) } + } + return components, nil +} - portRef := ep.Port - if portRef == "" && ep.TargetPort != nil { - portRef = ep.TargetPort.String() - } - if portRef == "" { - log.V(4).Info("skipping ServiceMonitor: no port reference", "serviceMonitor", sm.Name) - continue - } +func componentFromServiceMonitor(cpContext component.WorkloadContext, log logr.Logger, namespace string, sm *prometheusoperatorv1.ServiceMonitor) (metricsproxybin.ComponentFileConfig, bool) { + if len(sm.Spec.Endpoints) == 0 { + return metricsproxybin.ComponentFileConfig{}, false + } + ep := sm.Spec.Endpoints[0] - port, err := resolveServicePort(cpContext, namespace, serviceName, portRef, podSelector) - if err != nil { - log.V(4).Info("skipping ServiceMonitor: port not resolvable", "serviceMonitor", sm.Name, "port", portRef, "error", err) - continue - } + serviceName, podSelector, err := findServiceForMonitor(cpContext, namespace, sm.Name, sm.Spec.Selector) + if err != nil { + log.V(4).Info("skipping ServiceMonitor: service not found", "serviceMonitor", sm.Name, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } + if len(podSelector) == 0 { + log.V(4).Info("skipping ServiceMonitor: service has no pod selector", "serviceMonitor", sm.Name, "service", serviceName) + return metricsproxybin.ComponentFileConfig{}, false + } - scheme := "http" - if ep.Scheme != nil { - scheme = ep.Scheme.String() - } + portRef := ep.Port + if portRef == "" && ep.TargetPort != nil { + portRef = ep.TargetPort.String() + } + if portRef == "" { + log.V(4).Info("skipping ServiceMonitor: no port reference", "serviceMonitor", sm.Name) + return metricsproxybin.ComponentFileConfig{}, false + } - metricsPath := "/metrics" - if ep.Path != "" { - metricsPath = ep.Path - } + port, err := resolveServicePort(cpContext, namespace, serviceName, portRef, podSelector) + if err != nil { + log.V(4).Info("skipping ServiceMonitor: port not resolvable", "serviceMonitor", sm.Name, "port", portRef, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } - var serverName string - if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { - serverName = *ep.TLSConfig.ServerName - } + comp := metricsproxybin.ComponentFileConfig{ + Selector: podSelector, + MetricsPort: port, + MetricsPath: endpointMetricsPath(ep.Path), + MetricsScheme: endpointScheme(ep.Scheme), + TLSServerName: safeTLSServerName(ep.TLSConfig), + } - comp := metricsproxybin.ComponentFileConfig{ - Selector: podSelector, - MetricsPort: port, - MetricsPath: metricsPath, - MetricsScheme: scheme, - TLSServerName: serverName, - } + if ep.TLSConfig != nil { + populateTLSFilePaths(&comp, ep.TLSConfig.CA, ep.TLSConfig.Cert, ep.TLSConfig.KeySecret) + } - if ep.TLSConfig != nil { - comp.CAFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.CA) - comp.CertFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.Cert) - if ep.TLSConfig.KeySecret != nil { - comp.KeyFile = filepath.Join(certBasePath, ep.TLSConfig.KeySecret.Name, ep.TLSConfig.KeySecret.Key) - } - } + comp.Name = sm.Name + populateMetricsLabelsFromAnnotations(&comp, sm.Annotations) + return comp, true +} - comp.Name = sm.Name - populateMetricsLabelsFromAnnotations(&comp, sm.Annotations) - cfg.Components = append(cfg.Components, comp) - } +func componentsFromPodMonitors(cpContext component.WorkloadContext, namespace string) ([]metricsproxybin.ComponentFileConfig, error) { + log := logr.FromContextOrDiscard(cpContext) - // Process PodMonitors. pmList := &prometheusoperatorv1.PodMonitorList{} if err := cpContext.Client.List(cpContext, pmList, client.InNamespace(namespace)); err != nil { - return fmt.Errorf("failed to list PodMonitors: %w", err) + return nil, fmt.Errorf("failed to list PodMonitors: %w", err) } + var components []metricsproxybin.ComponentFileConfig for i := range pmList.Items { - pm := &pmList.Items[i] - if len(pm.Spec.PodMetricsEndpoints) == 0 { - continue + comp, ok := componentFromPodMonitor(cpContext, log, namespace, &pmList.Items[i]) + if ok { + components = append(components, comp) } - ep := pm.Spec.PodMetricsEndpoints[0] + } + return components, nil +} - portName := "" - if ep.Port != nil { - portName = *ep.Port - } - if portName == "" { - log.V(4).Info("skipping PodMonitor: no port name", "podMonitor", pm.Name) - continue - } +func componentFromPodMonitor(cpContext component.WorkloadContext, log logr.Logger, namespace string, pm *prometheusoperatorv1.PodMonitor) (metricsproxybin.ComponentFileConfig, bool) { + if len(pm.Spec.PodMetricsEndpoints) == 0 { + return metricsproxybin.ComponentFileConfig{}, false + } + ep := pm.Spec.PodMetricsEndpoints[0] - // Resolve the port number from a Pod matching the PodMonitor's selector. - podSelector, err := metav1.LabelSelectorAsSelector(&pm.Spec.Selector) - if err != nil { - log.V(4).Info("skipping PodMonitor: invalid selector", "podMonitor", pm.Name, "error", err) - continue - } - port, err := resolvePodPort(cpContext, namespace, podSelector, portName) - if err != nil { - log.V(4).Info("skipping PodMonitor: port not resolvable", "podMonitor", pm.Name, "port", portName, "error", err) - continue - } + portName := "" + if ep.Port != nil { + portName = *ep.Port + } + if portName == "" { + log.V(4).Info("skipping PodMonitor: no port name", "podMonitor", pm.Name) + return metricsproxybin.ComponentFileConfig{}, false + } - scheme := "http" - if ep.Scheme != nil { - scheme = ep.Scheme.String() - } + podSelector, err := metav1.LabelSelectorAsSelector(&pm.Spec.Selector) + if err != nil { + log.V(4).Info("skipping PodMonitor: invalid selector", "podMonitor", pm.Name, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } + port, err := resolvePodPort(cpContext, namespace, podSelector, portName) + if err != nil { + log.V(4).Info("skipping PodMonitor: port not resolvable", "podMonitor", pm.Name, "port", portName, "error", err) + return metricsproxybin.ComponentFileConfig{}, false + } - metricsPath := "/metrics" - if ep.Path != "" { - metricsPath = ep.Path - } + var serverName string + if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { + serverName = *ep.TLSConfig.ServerName + } - var serverName string - if ep.TLSConfig != nil && ep.TLSConfig.ServerName != nil { - serverName = *ep.TLSConfig.ServerName - } + comp := metricsproxybin.ComponentFileConfig{ + Selector: pm.Spec.Selector.MatchLabels, + MetricsPort: port, + MetricsPath: endpointMetricsPath(ep.Path), + MetricsScheme: endpointScheme(ep.Scheme), + TLSServerName: serverName, + } - comp := metricsproxybin.ComponentFileConfig{ - Selector: pm.Spec.Selector.MatchLabels, - MetricsPort: port, - MetricsPath: metricsPath, - MetricsScheme: scheme, - TLSServerName: serverName, - } + if ep.TLSConfig != nil { + populateTLSFilePaths(&comp, ep.TLSConfig.CA, ep.TLSConfig.Cert, ep.TLSConfig.KeySecret) + } - if ep.TLSConfig != nil { - comp.CAFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.CA) - comp.CertFile = certFilePathFromSecretOrConfigMap(ep.TLSConfig.Cert) - if ep.TLSConfig.KeySecret != nil { - comp.KeyFile = filepath.Join(certBasePath, ep.TLSConfig.KeySecret.Name, ep.TLSConfig.KeySecret.Key) - } - } + comp.Name = pm.Name + populateMetricsLabelsFromAnnotations(&comp, pm.Annotations) + return comp, true +} - comp.Name = pm.Name - populateMetricsLabelsFromAnnotations(&comp, pm.Annotations) - cfg.Components = append(cfg.Components, comp) +func endpointScheme(scheme *prometheusoperatorv1.Scheme) string { + if scheme != nil { + return scheme.String() } + return "http" +} - sort.Slice(cfg.Components, func(i, j int) bool { - return cfg.Components[i].Name < cfg.Components[j].Name - }) +func endpointMetricsPath(path string) string { + if path != "" { + return path + } + return "/metrics" +} - data, err := yaml.Marshal(cfg) - if err != nil { - return fmt.Errorf("failed to marshal scrape config: %w", err) +func safeTLSServerName(tlsCfg *prometheusoperatorv1.TLSConfig) string { + if tlsCfg != nil && tlsCfg.ServerName != nil { + return *tlsCfg.ServerName } + return "" +} - if cm.Data == nil { - cm.Data = make(map[string]string) +func populateTLSFilePaths(comp *metricsproxybin.ComponentFileConfig, ca, cert prometheusoperatorv1.SecretOrConfigMap, keySecret *corev1.SecretKeySelector) { + comp.CAFile = certFilePathFromSecretOrConfigMap(ca) + comp.CertFile = certFilePathFromSecretOrConfigMap(cert) + if keySecret != nil { + comp.KeyFile = filepath.Join(certBasePath, keySecret.Name, keySecret.Key) } - cm.Data["config.yaml"] = string(data) - return nil } // certFilePathFromSecretOrConfigMap returns the file path for a volume-mounted diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go index 7652cbc8208..c613de30fa0 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/metrics_proxy/scrape_config_test.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/yaml" + "github.com/go-logr/logr" prometheusoperatorv1 "github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring/v1" ) @@ -624,6 +625,557 @@ func TestAdaptScrapeConfigPodMonitorMissingPod(t *testing.T) { }) } +func TestEndpointScheme(t *testing.T) { + t.Parallel() + tests := []struct { + name string + scheme *prometheusoperatorv1.Scheme + expected string + }{ + { + name: "When scheme is nil, it should default to http", + scheme: nil, + expected: "http", + }, + { + name: "When scheme is https, it should return https", + scheme: (*prometheusoperatorv1.Scheme)(ptr.To("https")), + expected: "https", + }, + { + name: "When scheme is http, it should return http", + scheme: (*prometheusoperatorv1.Scheme)(ptr.To("http")), + expected: "http", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(endpointScheme(tc.scheme)).To(Equal(tc.expected)) + }) + } +} + +func TestEndpointMetricsPath(t *testing.T) { + t.Parallel() + tests := []struct { + name string + path string + expected string + }{ + { + name: "When path is empty, it should default to /metrics", + path: "", + expected: "/metrics", + }, + { + name: "When path is provided, it should return that path", + path: "/custom/metrics", + expected: "/custom/metrics", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(endpointMetricsPath(tc.path)).To(Equal(tc.expected)) + }) + } +} + +func TestSafeTLSServerName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + tlsCfg *prometheusoperatorv1.TLSConfig + expected string + }{ + { + name: "When TLS config is nil, it should return empty string", + tlsCfg: nil, + expected: "", + }, + { + name: "When ServerName is nil, it should return empty string", + tlsCfg: &prometheusoperatorv1.TLSConfig{ + SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ + ServerName: nil, + }, + }, + expected: "", + }, + { + name: "When ServerName is set, it should return the server name", + tlsCfg: &prometheusoperatorv1.TLSConfig{ + SafeTLSConfig: prometheusoperatorv1.SafeTLSConfig{ + ServerName: ptr.To("kube-apiserver"), + }, + }, + expected: "kube-apiserver", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(safeTLSServerName(tc.tlsCfg)).To(Equal(tc.expected)) + }) + } +} + +func TestCertFilePathFromSecretOrConfigMap(t *testing.T) { + t.Parallel() + tests := []struct { + name string + ref prometheusoperatorv1.SecretOrConfigMap + expected string + }{ + { + name: "When both Secret and ConfigMap are nil, it should return empty string", + ref: prometheusoperatorv1.SecretOrConfigMap{}, + expected: "", + }, + { + name: "When ConfigMap is set, it should return configmap path", + ref: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + expected: certBasePath + "/root-ca/ca.crt", + }, + { + name: "When Secret is set, it should return secret path", + ref: prometheusoperatorv1.SecretOrConfigMap{ + Secret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "my-secret"}, + Key: "tls.crt", + }, + }, + expected: certBasePath + "/my-secret/tls.crt", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + g.Expect(certFilePathFromSecretOrConfigMap(tc.ref)).To(Equal(tc.expected)) + }) + } +} + +func TestPopulateTLSFilePaths(t *testing.T) { + t.Parallel() + tests := []struct { + name string + ca prometheusoperatorv1.SecretOrConfigMap + cert prometheusoperatorv1.SecretOrConfigMap + keySecret *corev1.SecretKeySelector + expectedCA string + expectedCert string + expectedKey string + }{ + { + name: "When all TLS refs are empty, it should set empty paths", + ca: prometheusoperatorv1.SecretOrConfigMap{}, + cert: prometheusoperatorv1.SecretOrConfigMap{}, + keySecret: nil, + expectedCA: "", + expectedCert: "", + expectedKey: "", + }, + { + name: "When CA, cert, and key are all provided, it should set all paths", + ca: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + cert: prometheusoperatorv1.SecretOrConfigMap{ + Secret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "metrics-client"}, + Key: "tls.crt", + }, + }, + keySecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "metrics-client"}, + Key: "tls.key", + }, + expectedCA: certBasePath + "/root-ca/ca.crt", + expectedCert: certBasePath + "/metrics-client/tls.crt", + expectedKey: certBasePath + "/metrics-client/tls.key", + }, + { + name: "When only CA is provided, it should set only CA path", + ca: prometheusoperatorv1.SecretOrConfigMap{ + ConfigMap: &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "root-ca"}, + Key: "ca.crt", + }, + }, + cert: prometheusoperatorv1.SecretOrConfigMap{}, + keySecret: nil, + expectedCA: certBasePath + "/root-ca/ca.crt", + expectedCert: "", + expectedKey: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + comp := &metricsproxybin.ComponentFileConfig{} + populateTLSFilePaths(comp, tc.ca, tc.cert, tc.keySecret) + g.Expect(comp.CAFile).To(Equal(tc.expectedCA)) + g.Expect(comp.CertFile).To(Equal(tc.expectedCert)) + g.Expect(comp.KeyFile).To(Equal(tc.expectedKey)) + }) + } +} + +func TestPopulateMetricsLabelsFromAnnotations(t *testing.T) { + t.Parallel() + tests := []struct { + name string + annotations map[string]string + expectedJob string + expectedNamespace string + expectedService string + expectedEndpoint string + }{ + { + name: "When no annotations are present, it should leave fields empty", + annotations: map[string]string{}, + expectedJob: "", + expectedNamespace: "", + expectedService: "", + expectedEndpoint: "", + }, + { + name: "When annotations is nil, it should leave fields empty", + annotations: nil, + expectedJob: "", + expectedNamespace: "", + expectedService: "", + expectedEndpoint: "", + }, + { + name: "When all metrics annotations are present, it should populate all fields", + annotations: map[string]string{ + "hypershift.openshift.io/metrics-job": "kube-apiserver", + "hypershift.openshift.io/metrics-namespace": "openshift-kube-apiserver", + "hypershift.openshift.io/metrics-service": "kube-apiserver", + "hypershift.openshift.io/metrics-endpoint": "https", + }, + expectedJob: "kube-apiserver", + expectedNamespace: "openshift-kube-apiserver", + expectedService: "kube-apiserver", + expectedEndpoint: "https", + }, + { + name: "When only some annotations are present, it should populate only matching fields", + annotations: map[string]string{ + "hypershift.openshift.io/metrics-job": "etcd", + "hypershift.openshift.io/metrics-namespace": "openshift-etcd", + }, + expectedJob: "etcd", + expectedNamespace: "openshift-etcd", + expectedService: "", + expectedEndpoint: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + comp := &metricsproxybin.ComponentFileConfig{} + populateMetricsLabelsFromAnnotations(comp, tc.annotations) + g.Expect(comp.MetricsJob).To(Equal(tc.expectedJob)) + g.Expect(comp.MetricsNamespace).To(Equal(tc.expectedNamespace)) + g.Expect(comp.MetricsService).To(Equal(tc.expectedService)) + g.Expect(comp.MetricsEndpoint).To(Equal(tc.expectedEndpoint)) + }) + } +} + +func TestFindServiceForMonitor(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + tests := []struct { + name string + objects []runtime.Object + smName string + selector metav1.LabelSelector + expectedSvc string + expectedErr bool + expectedSelKeys []string + }{ + { + name: "When service exists with same name as ServiceMonitor, it should find it by direct lookup", + objects: []runtime.Object{ + newService("kube-apiserver", namespace, "kube-apiserver", "client", 6443), + }, + smName: "kube-apiserver", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "kube-apiserver"}}, + expectedSvc: "kube-apiserver", + expectedErr: false, + expectedSelKeys: []string{"app"}, + }, + { + name: "When no service matches name but label selector matches, it should fall back to label selector", + objects: []runtime.Object{ + newService("etcd-client", namespace, "etcd", "metrics", 2381), + }, + smName: "etcd", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "etcd"}}, + expectedSvc: "etcd-client", + expectedErr: false, + expectedSelKeys: []string{"app"}, + }, + { + name: "When no service exists at all, it should return an error", + objects: []runtime.Object{}, + smName: "missing", + selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "missing"}}, + expectedErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(tc.objects...).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + svcName, podSelector, err := findServiceForMonitor(cpContext, namespace, tc.smName, tc.selector) + if tc.expectedErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(svcName).To(Equal(tc.expectedSvc)) + for _, key := range tc.expectedSelKeys { + g.Expect(podSelector).To(HaveKey(key)) + } + } + }) + } +} + +func TestResolveServicePort(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + tests := []struct { + name string + objects []runtime.Object + serviceName string + portName string + podSelector map[string]string + expectedPort int32 + expectedErr bool + }{ + { + name: "When service has a numeric targetPort, it should resolve directly", + objects: []runtime.Object{ + newService("my-svc", namespace, "app", "https", 6443), + }, + serviceName: "my-svc", + portName: "https", + podSelector: map[string]string{"app": "app"}, + expectedPort: 6443, + expectedErr: false, + }, + { + name: "When service has a named targetPort, it should resolve from a pod", + objects: []runtime.Object{ + newServiceWithTargetPort("my-svc", namespace, "app", "client", 443, intstr.FromString("https")), + newPodWithLabels("my-pod", namespace, "https", 6443, map[string]string{"app": "app"}), + }, + serviceName: "my-svc", + portName: "client", + podSelector: map[string]string{"app": "app"}, + expectedPort: 6443, + expectedErr: false, + }, + { + name: "When service port name does not match, it should return an error", + objects: []runtime.Object{ + newService("my-svc", namespace, "app", "metrics", 8080), + }, + serviceName: "my-svc", + portName: "nonexistent", + podSelector: map[string]string{"app": "app"}, + expectedErr: true, + }, + { + name: "When service does not exist, it should return an error", + objects: []runtime.Object{}, + serviceName: "missing-svc", + portName: "https", + podSelector: map[string]string{"app": "app"}, + expectedErr: true, + }, + { + name: "When service has no targetPort, it should default to the service port value", + objects: []runtime.Object{ + &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "my-svc", Namespace: namespace}, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{"app": "app"}, + Ports: []corev1.ServicePort{{Name: "web", Port: 8443}}, + }, + }, + }, + serviceName: "my-svc", + portName: "web", + podSelector: map[string]string{"app": "app"}, + expectedPort: 8443, + expectedErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(tc.objects...).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + port, err := resolveServicePort(cpContext, namespace, tc.serviceName, tc.portName, tc.podSelector) + if tc.expectedErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(port).To(Equal(tc.expectedPort)) + } + }) + } +} + +func TestComponentFromServiceMonitor_EdgeCases(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + t.Run("When ServiceMonitor has no endpoints, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + sm := &prometheusoperatorv1.ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "empty-sm", Namespace: namespace}, + Spec: prometheusoperatorv1.ServiceMonitorSpec{}, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromServiceMonitor(cpContext, logr.Discard(), namespace, sm) + g.Expect(ok).To(BeFalse()) + }) + + t.Run("When endpoint has no port reference, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + sm := &prometheusoperatorv1.ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "no-port-sm", Namespace: namespace}, + Spec: prometheusoperatorv1.ServiceMonitorSpec{ + Endpoints: []prometheusoperatorv1.Endpoint{{}}, + Selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "test"}}, + }, + } + svc := newService("no-port-sm", namespace, "test", "metrics", 8080) + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(svc).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromServiceMonitor(cpContext, logr.Discard(), namespace, sm) + g.Expect(ok).To(BeFalse()) + }) +} + +func TestComponentFromPodMonitor_EdgeCases(t *testing.T) { + t.Parallel() + + namespace := "test-ns" + + t.Run("When PodMonitor has no endpoints, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + pm := &prometheusoperatorv1.PodMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "empty-pm", Namespace: namespace}, + Spec: prometheusoperatorv1.PodMonitorSpec{}, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromPodMonitor(cpContext, logr.Discard(), namespace, pm) + g.Expect(ok).To(BeFalse()) + }) + + t.Run("When PodMonitor endpoint has nil port, it should return false", func(t *testing.T) { + g := NewGomegaWithT(t) + pm := &prometheusoperatorv1.PodMonitor{ + ObjectMeta: metav1.ObjectMeta{Name: "nil-port-pm", Namespace: namespace}, + Spec: prometheusoperatorv1.PodMonitorSpec{ + PodMetricsEndpoints: []prometheusoperatorv1.PodMetricsEndpoint{ + {Port: nil}, + }, + Selector: metav1.LabelSelector{MatchLabels: map[string]string{"app": "test"}}, + }, + } + + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = prometheusoperatorv1.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() + + cpContext := component.WorkloadContext{ + Context: context.Background(), + Client: fakeClient, + } + + _, ok := componentFromPodMonitor(cpContext, logr.Discard(), namespace, pm) + g.Expect(ok).To(BeFalse()) + }) +} + func TestResolvePodPort(t *testing.T) { t.Parallel() diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go index f41a705612e..d4e12f63669 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert.go @@ -147,291 +147,322 @@ func convertProviderConfigToIDPData( namespace string, skipKonnectivityDialer bool, ) (*idpData, error) { - const missingProviderFmt string = "type %s was specified, but its configuration is missing" - - data := &idpData{login: true} - switch providerConfig.Type { case configv1.IdentityProviderTypeBasicAuth: - basicAuthConfig := providerConfig.BasicAuth - if basicAuthConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.BasicAuthPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "BasicAuthPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: basicAuthConfig.URL, - }, - } - if basicAuthConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if basicAuthConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-key", corev1.TLSCertKey) - } - if basicAuthConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - - data.provider = provider - data.challenge = true - + return convertBasicAuthIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitHub: - githubConfig := providerConfig.GitHub - if githubConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } - provider := &osinv1.GitHubIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitHubIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: githubConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, - }, - Organizations: githubConfig.Organizations, - Teams: githubConfig.Teams, - Hostname: githubConfig.Hostname, - } - if githubConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = false - + return convertGitHubIDP(providerConfig, i, idpVolumeMounts) case configv1.IdentityProviderTypeGitLab: - gitlabConfig := providerConfig.GitLab - if gitlabConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + return convertGitLabIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeGoogle: + return convertGoogleIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeHTPasswd: + return convertHTPasswdIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeKeystone: + return convertKeystoneIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeLDAP: + return convertLDAPIDP(providerConfig, i, idpVolumeMounts) + case configv1.IdentityProviderTypeOpenID: + return convertOpenIDIDP(ctx, providerConfig, configOverride, i, idpVolumeMounts, kclient, namespace, skipKonnectivityDialer) + case configv1.IdentityProviderTypeRequestHeader: + return convertRequestHeaderIDP(providerConfig, i, idpVolumeMounts) + default: + return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) + } +} - provider := &osinv1.GitLabIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GitLabIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, +func convertBasicAuthIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + basicAuthConfig := providerConfig.BasicAuth + if basicAuthConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.BasicAuthPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "BasicAuthPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{ + URL: basicAuthConfig.URL, + }, + } + if basicAuthConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, basicAuthConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if basicAuthConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if basicAuthConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.KeyFile = idpVolumeMounts.SecretPath(i, basicAuthConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - URL: gitlabConfig.URL, - ClientID: gitlabConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, +func convertGitHubIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + githubConfig := providerConfig.GitHub + if githubConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GitHubIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GitHubIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: githubConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, githubConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - Legacy: new(bool), // we require OIDC for GitLab now - } - if gitlabConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - - data.provider = provider - data.challenge = true - - case configv1.IdentityProviderTypeGoogle: - googleConfig := providerConfig.Google - if googleConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + Organizations: githubConfig.Organizations, + Teams: githubConfig.Teams, + Hostname: githubConfig.Hostname, + } + if githubConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, githubConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} - data.provider = &osinv1.GoogleIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "GoogleIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: googleConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, +func convertGitLabIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + gitlabConfig := providerConfig.GitLab + if gitlabConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GitLabIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GitLabIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + URL: gitlabConfig.URL, + ClientID: gitlabConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, gitlabConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - HostedDomain: googleConfig.HostedDomain, - } - data.challenge = false - - case configv1.IdentityProviderTypeHTPasswd: - if providerConfig.HTPasswd == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + Legacy: new(bool), // we require OIDC for GitLab now + } + if gitlabConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, gitlabConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - data.provider = &osinv1.HTPasswdPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "HTPasswdPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), +func convertGoogleIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + googleConfig := providerConfig.Google + if googleConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.GoogleIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "GoogleIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: googleConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, googleConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), - } - data.challenge = true - - case configv1.IdentityProviderTypeKeystone: - keystoneConfig := providerConfig.Keystone - if keystoneConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } + }, + HostedDomain: googleConfig.HostedDomain, + } + return &idpData{provider: provider, challenge: false, login: true}, nil +} - provider := &osinv1.KeystonePasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "KeystonePasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - RemoteConnectionInfo: configv1.RemoteConnectionInfo{ - URL: keystoneConfig.URL, - }, - DomainName: keystoneConfig.DomainName, - UseKeystoneIdentity: true, // force use of keystone ID - } - if keystoneConfig.CA.Name != "" { - provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - if keystoneConfig.TLSClientCert.Name != "" { - provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) - } - if keystoneConfig.TLSClientKey.Name != "" { - provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) - } - data.provider = provider - data.challenge = true +func convertHTPasswdIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + if providerConfig.HTPasswd == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.HTPasswdPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "HTPasswdPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + File: idpVolumeMounts.SecretPath(i, providerConfig.HTPasswd.FileData.Name, "file-data", configv1.HTPasswdDataKey), + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - case configv1.IdentityProviderTypeLDAP: - ldapConfig := providerConfig.LDAP - if ldapConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } +func convertKeystoneIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + keystoneConfig := providerConfig.Keystone + if keystoneConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.KeystonePasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "KeystonePasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + RemoteConnectionInfo: configv1.RemoteConnectionInfo{ + URL: keystoneConfig.URL, + }, + DomainName: keystoneConfig.DomainName, + UseKeystoneIdentity: true, // force use of keystone ID + } + if keystoneConfig.CA.Name != "" { + provider.RemoteConnectionInfo.CA = idpVolumeMounts.ConfigMapPath(i, keystoneConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + if keystoneConfig.TLSClientCert.Name != "" { + provider.RemoteConnectionInfo.CertInfo.CertFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientCert.Name, "tls-client-cert", corev1.TLSCertKey) + } + if keystoneConfig.TLSClientKey.Name != "" { + provider.RemoteConnectionInfo.CertInfo.KeyFile = idpVolumeMounts.SecretPath(i, keystoneConfig.TLSClientKey.Name, "tls-client-key", corev1.TLSPrivateKeyKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - provider := &osinv1.LDAPPasswordIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "LDAPPasswordIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - URL: ldapConfig.URL, - BindDN: ldapConfig.BindDN, - Insecure: ldapConfig.Insecure, - Attributes: osinv1.LDAPAttributeMapping{ - ID: ldapConfig.Attributes.ID, - PreferredUsername: ldapConfig.Attributes.PreferredUsername, - Name: ldapConfig.Attributes.Name, - Email: ldapConfig.Attributes.Email, +func convertLDAPIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + ldapConfig := providerConfig.LDAP + if ldapConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.LDAPPasswordIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "LDAPPasswordIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + URL: ldapConfig.URL, + BindDN: ldapConfig.BindDN, + Insecure: ldapConfig.Insecure, + Attributes: osinv1.LDAPAttributeMapping{ + ID: ldapConfig.Attributes.ID, + PreferredUsername: ldapConfig.Attributes.PreferredUsername, + Name: ldapConfig.Attributes.Name, + Email: ldapConfig.Attributes.Email, + }, + } + if ldapConfig.BindPassword.Name != "" { + provider.BindPassword = configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), }, } - if ldapConfig.BindPassword.Name != "" { - provider.BindPassword = configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, ldapConfig.BindPassword.Name, "bind-password", configv1.BindPasswordKey), - }, - } - } - if ldapConfig.CA.Name != "" { - provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = provider - data.challenge = true + } + if ldapConfig.CA.Name != "" { + provider.CA = idpVolumeMounts.ConfigMapPath(i, ldapConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + return &idpData{provider: provider, challenge: true, login: true}, nil +} - case configv1.IdentityProviderTypeOpenID: - openIDConfig := providerConfig.OpenID - if openIDConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) - } +func convertOpenIDIDP( + ctx context.Context, + providerConfig *configv1.IdentityProviderConfig, + configOverride *ConfigOverride, + i int, + idpVolumeMounts *IDPVolumeMountInfo, + kclient crclient.Reader, + namespace string, + skipKonnectivityDialer bool, +) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + openIDConfig := providerConfig.OpenID + if openIDConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } - openIDProvider := &osinv1.OpenIDIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "OpenIDIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - ClientID: openIDConfig.ClientID, - ClientSecret: configv1.StringSource{ - StringSourceSpec: configv1.StringSourceSpec{ - File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), - }, + openIDProvider := &osinv1.OpenIDIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "OpenIDIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + ClientID: openIDConfig.ClientID, + ClientSecret: configv1.StringSource{ + StringSourceSpec: configv1.StringSourceSpec{ + File: idpVolumeMounts.SecretPath(i, openIDConfig.ClientSecret.Name, "client-secret", configv1.ClientSecretKey), }, - ExtraScopes: openIDConfig.ExtraScopes, - ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, - } - //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) - if configOverride != nil { - openIDProvider.URLs = configOverride.URLs - openIDProvider.Claims = configOverride.Claims - } else { - urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) - if err != nil { - return nil, err - } - openIDProvider.URLs = *urls - var groups []string - if len(openIDConfig.Claims.Groups) > 0 { - groups = make([]string, len(openIDConfig.Claims.Groups)) - for i, group := range openIDConfig.Claims.Groups { - groups[i] = string(group) - } - } - openIDProvider.Claims = osinv1.OpenIDClaims{ - // There is no longer a user-facing setting for ID as it is considered unsafe - ID: []string{configv1.UserIDClaim}, - PreferredUsername: openIDConfig.Claims.PreferredUsername, - Name: openIDConfig.Claims.Name, - Email: openIDConfig.Claims.Email, - Groups: groups, - } - } - if len(openIDConfig.CA.Name) > 0 { - openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) - } - data.provider = openIDProvider - - if configOverride != nil && configOverride.Challenge != nil { - data.challenge = *configOverride.Challenge - } else { - // openshift CR validating in kube-apiserver does not allow - // challenge-redirecting IdPs to be configured with OIDC so it is safe - // to allow challenge-issuing flow if it's available on the OIDC side - challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( - ctx, - kclient, - openIDProvider.URLs.Token, - openIDConfig.ClientID, - namespace, - openIDConfig.CA, - openIDConfig.ClientSecret, - skipKonnectivityDialer, - ) - if err != nil { - return nil, fmt.Errorf("error attempting password grant flow: %v", err) + }, + ExtraScopes: openIDConfig.ExtraScopes, + ExtraAuthorizeParameters: openIDConfig.ExtraAuthorizeParameters, + } + //Handle special case for IBM Cloud's OIDC provider (need to override some fields not available in public api) + if configOverride != nil { + openIDProvider.URLs = configOverride.URLs + openIDProvider.Claims = configOverride.Claims + } else { + urls, err := discoverOpenIDURLs(ctx, kclient, openIDConfig.Issuer, corev1.ServiceAccountRootCAKey, namespace, openIDConfig.CA, skipKonnectivityDialer) + if err != nil { + return nil, err + } + openIDProvider.URLs = *urls + var groups []string + if len(openIDConfig.Claims.Groups) > 0 { + groups = make([]string, len(openIDConfig.Claims.Groups)) + for i, group := range openIDConfig.Claims.Groups { + groups[i] = string(group) } - data.challenge = challengeFlowsAllowed } - case configv1.IdentityProviderTypeRequestHeader: - requestHeaderConfig := providerConfig.RequestHeader - if requestHeaderConfig == nil { - return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + openIDProvider.Claims = osinv1.OpenIDClaims{ + // There is no longer a user-facing setting for ID as it is considered unsafe + ID: []string{configv1.UserIDClaim}, + PreferredUsername: openIDConfig.Claims.PreferredUsername, + Name: openIDConfig.Claims.Name, + Email: openIDConfig.Claims.Email, + Groups: groups, } - data.provider = &osinv1.RequestHeaderIdentityProvider{ - TypeMeta: metav1.TypeMeta{ - Kind: "RequestHeaderIdentityProvider", - APIVersion: osinv1.GroupVersion.String(), - }, - LoginURL: requestHeaderConfig.LoginURL, - ChallengeURL: requestHeaderConfig.ChallengeURL, - ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), - ClientCommonNames: requestHeaderConfig.ClientCommonNames, - Headers: requestHeaderConfig.Headers, - PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, - NameHeaders: requestHeaderConfig.NameHeaders, - EmailHeaders: requestHeaderConfig.EmailHeaders, - } - data.challenge = len(requestHeaderConfig.ChallengeURL) > 0 - data.login = len(requestHeaderConfig.LoginURL) > 0 - - default: - return nil, fmt.Errorf("the identity provider type '%s' is not supported", providerConfig.Type) - } // switch + } + if len(openIDConfig.CA.Name) > 0 { + openIDProvider.CA = idpVolumeMounts.ConfigMapPath(i, openIDConfig.CA.Name, "ca", corev1.ServiceAccountRootCAKey) + } + data := &idpData{provider: openIDProvider, login: true} + if configOverride != nil && configOverride.Challenge != nil { + data.challenge = *configOverride.Challenge + } else { + // openshift CR validating in kube-apiserver does not allow + // challenge-redirecting IdPs to be configured with OIDC so it is safe + // to allow challenge-issuing flow if it's available on the OIDC side + challengeFlowsAllowed, err := checkOIDCPasswordGrantFlow( + ctx, + kclient, + openIDProvider.URLs.Token, + openIDConfig.ClientID, + namespace, + openIDConfig.CA, + openIDConfig.ClientSecret, + skipKonnectivityDialer, + ) + if err != nil { + return nil, fmt.Errorf("error attempting password grant flow: %v", err) + } + data.challenge = challengeFlowsAllowed + } return data, nil } +func convertRequestHeaderIDP(providerConfig *configv1.IdentityProviderConfig, i int, idpVolumeMounts *IDPVolumeMountInfo) (*idpData, error) { + const missingProviderFmt string = "type %s was specified, but its configuration is missing" + requestHeaderConfig := providerConfig.RequestHeader + if requestHeaderConfig == nil { + return nil, fmt.Errorf(missingProviderFmt, providerConfig.Type) + } + provider := &osinv1.RequestHeaderIdentityProvider{ + TypeMeta: metav1.TypeMeta{ + Kind: "RequestHeaderIdentityProvider", + APIVersion: osinv1.GroupVersion.String(), + }, + LoginURL: requestHeaderConfig.LoginURL, + ChallengeURL: requestHeaderConfig.ChallengeURL, + ClientCA: idpVolumeMounts.ConfigMapPath(i, requestHeaderConfig.ClientCA.Name, "ca", corev1.ServiceAccountRootCAKey), + ClientCommonNames: requestHeaderConfig.ClientCommonNames, + Headers: requestHeaderConfig.Headers, + PreferredUsernameHeaders: requestHeaderConfig.PreferredUsernameHeaders, + NameHeaders: requestHeaderConfig.NameHeaders, + EmailHeaders: requestHeaderConfig.EmailHeaders, + } + return &idpData{ + provider: provider, + challenge: len(requestHeaderConfig.ChallengeURL) > 0, + login: len(requestHeaderConfig.LoginURL) > 0, + }, nil +} + const ( konnectivityClientDataCertKey = "tls.crt" konnectivityClientDataKey = "tls.key" diff --git a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go index 78b458a6c96..24adb544a93 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go +++ b/control-plane-operator/controllers/hostedcontrolplane/v2/oauth/idp_convert_test.go @@ -210,6 +210,721 @@ func TestOpenIDProviderConversion(t *testing.T) { } } +func TestDefaultIDPMappingMethods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []configv1.IdentityProvider + expected []configv1.MappingMethodType + }{ + { + name: "When no identity providers are given, it should return an empty slice", + input: []configv1.IdentityProvider{}, + expected: []configv1.MappingMethodType{}, + }, + { + name: "When mapping method is empty, it should default to claim", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodClaim}, + }, + { + name: "When mapping method is already set, it should preserve it", + input: []configv1.IdentityProvider{ + {Name: "test", MappingMethod: configv1.MappingMethodAdd}, + }, + expected: []configv1.MappingMethodType{configv1.MappingMethodAdd}, + }, + { + name: "When multiple providers have mixed mapping methods, it should default only empty ones", + input: []configv1.IdentityProvider{ + {Name: "a", MappingMethod: ""}, + {Name: "b", MappingMethod: configv1.MappingMethodLookup}, + {Name: "c", MappingMethod: ""}, + }, + expected: []configv1.MappingMethodType{ + configv1.MappingMethodClaim, + configv1.MappingMethodLookup, + configv1.MappingMethodClaim, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + result := defaultIDPMappingMethods(tc.input) + g.Expect(result).To(HaveLen(len(tc.expected))) + for i, idp := range result { + g.Expect(idp.MappingMethod).To(Equal(tc.expected[i])) + } + }) + } +} + +func newTestVolumeMountInfo() *IDPVolumeMountInfo { + return &IDPVolumeMountInfo{ + Container: ComponentName, + VolumeMounts: podspec.VolumeMounts{ + ComponentName: podspec.ContainerMounts{}, + }, + } +} + +func TestConvertBasicAuthIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When BasicAuth config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + }, + expectErr: true, + }, + { + name: "When BasicAuth config has only URL, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When BasicAuth config has CA and TLS certs, it should set volume mount paths", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://auth.example.com", + CA: configv1.ConfigMapNameReference{Name: "my-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "my-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "my-key"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertBasicAuthIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("configuration is missing")) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.BasicAuthPasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://auth.example.com")) + if tc.config.BasicAuth.CA.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + } + if tc.config.BasicAuth.TLSClientCert.Name != "" { + g.Expect(provider.RemoteConnectionInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + } + if tc.config.BasicAuth.TLSClientKey.Name != "" { + g.Expect(provider.RemoteConnectionInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + } + }) + } +} + +func TestConvertGitHubIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitHub config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + }, + expectErr: true, + }, + { + name: "When GitHub config is provided with organizations, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + Organizations: []string{"org1", "org2"}, + Teams: []string{"org1/team1"}, + Hostname: "github.example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + { + name: "When GitHub config has a CA, it should set the CA path", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "my-client-id", + ClientSecret: configv1.SecretNameReference{Name: "gh-secret"}, + CA: configv1.ConfigMapNameReference{Name: "gh-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitHubIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitHubIdentityProvider) + g.Expect(provider.ClientID).To(Equal(tc.config.GitHub.ClientID)) + g.Expect(provider.Organizations).To(Equal(tc.config.GitHub.Organizations)) + if tc.config.GitHub.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertGitLabIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When GitLab config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + }, + expectErr: true, + }, + { + name: "When GitLab config is provided, it should produce a valid provider with legacy set", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitLab, + GitLab: &configv1.GitLabIdentityProvider{ + URL: "https://gitlab.example.com", + ClientID: "gl-client", + ClientSecret: configv1.SecretNameReference{Name: "gl-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGitLabIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GitLabIdentityProvider) + g.Expect(provider.URL).To(Equal("https://gitlab.example.com")) + g.Expect(provider.Legacy).ToNot(BeNil()) + g.Expect(*provider.Legacy).To(BeFalse()) + } + }) + } +} + +func TestConvertGoogleIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Google config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + }, + expectErr: true, + }, + { + name: "When Google config is provided with hosted domain, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGoogle, + Google: &configv1.GoogleIdentityProvider{ + ClientID: "google-client", + ClientSecret: configv1.SecretNameReference{Name: "google-secret"}, + HostedDomain: "example.com", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertGoogleIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.GoogleIdentityProvider) + g.Expect(provider.ClientID).To(Equal("google-client")) + g.Expect(provider.HostedDomain).To(Equal("example.com")) + } + }) + } +} + +func TestConvertHTPasswdIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When HTPasswd config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + }, + expectErr: true, + }, + { + name: "When HTPasswd config is provided, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd-secret"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertHTPasswdIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.HTPasswdPasswordIdentityProvider) + g.Expect(provider.File).To(ContainSubstring("idp_secret_0_file-data")) + } + }) + } +} + +func TestConvertKeystoneIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When Keystone config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + }, + expectErr: true, + }, + { + name: "When Keystone config is provided with TLS certs, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeKeystone, + Keystone: &configv1.KeystoneIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{ + URL: "https://keystone.example.com", + CA: configv1.ConfigMapNameReference{Name: "ks-ca"}, + TLSClientCert: configv1.SecretNameReference{Name: "ks-cert"}, + TLSClientKey: configv1.SecretNameReference{Name: "ks-key"}, + }, + DomainName: "my-domain", + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertKeystoneIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.KeystonePasswordIdentityProvider) + g.Expect(provider.RemoteConnectionInfo.URL).To(Equal("https://keystone.example.com")) + g.Expect(provider.DomainName).To(Equal("my-domain")) + g.Expect(provider.UseKeystoneIdentity).To(BeTrue()) + g.Expect(provider.RemoteConnectionInfo.CA).To(ContainSubstring("idp_cm_0_ca")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.CertFile).To(ContainSubstring("idp_secret_0_tls-client-cert")) + g.Expect(provider.RemoteConnectionInfo.CertInfo.KeyFile).To(ContainSubstring("idp_secret_0_tls-client-key")) + } + }) + } +} + +func TestConvertLDAPIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When LDAP config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + }, + expectErr: true, + }, + { + name: "When LDAP config is provided with bind password, it should produce a valid provider", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + BindDN: "cn=admin,dc=example,dc=com", + BindPassword: configv1.SecretNameReference{Name: "ldap-bind-pw"}, + Insecure: true, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + PreferredUsername: []string{"uid"}, + Name: []string{"cn"}, + Email: []string{"mail"}, + }, + CA: configv1.ConfigMapNameReference{Name: "ldap-ca"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When LDAP config has no bind password, it should not set the bind password field", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeLDAP, + LDAP: &configv1.LDAPIdentityProvider{ + URL: "ldap://ldap.example.com", + Insecure: false, + Attributes: configv1.LDAPAttributeMapping{ + ID: []string{"dn"}, + }, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertLDAPIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.LDAPPasswordIdentityProvider) + g.Expect(provider.URL).To(Equal(tc.config.LDAP.URL)) + g.Expect(provider.BindDN).To(Equal(tc.config.LDAP.BindDN)) + g.Expect(provider.Insecure).To(Equal(tc.config.LDAP.Insecure)) + g.Expect(provider.Attributes.ID).To(Equal(tc.config.LDAP.Attributes.ID)) + if tc.config.LDAP.BindPassword.Name != "" { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(ContainSubstring("idp_secret_0_bind-password")) + } else { + g.Expect(provider.BindPassword.StringSourceSpec.File).To(BeEmpty()) + } + if tc.config.LDAP.CA.Name != "" { + g.Expect(provider.CA).To(ContainSubstring("idp_cm_0_ca")) + } + } + }) + } +} + +func TestConvertRequestHeaderIDP(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config *configv1.IdentityProviderConfig + expectErr bool + expectLogin bool + expectChall bool + }{ + { + name: "When RequestHeader config is nil, it should return an error", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + }, + expectErr: true, + }, + { + name: "When RequestHeader has login and challenge URLs, it should set login and challenge to true", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + LoginURL: "https://login.example.com", + ChallengeURL: "https://challenge.example.com", + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + ClientCommonNames: []string{"client1"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: true, + expectChall: true, + }, + { + name: "When RequestHeader has empty login and challenge URLs, it should set login and challenge to false", + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeRequestHeader, + RequestHeader: &configv1.RequestHeaderIdentityProvider{ + ClientCA: configv1.ConfigMapNameReference{Name: "rh-ca"}, + Headers: []string{"X-Remote-User"}, + }, + }, + expectErr: false, + expectLogin: false, + expectChall: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertRequestHeaderIDP(tc.config, 0, vmi) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.login).To(Equal(tc.expectLogin)) + g.Expect(data.challenge).To(Equal(tc.expectChall)) + provider := data.provider.(*osinv1.RequestHeaderIdentityProvider) + g.Expect(provider.Headers).To(Equal(tc.config.RequestHeader.Headers)) + g.Expect(provider.ClientCA).To(ContainSubstring("idp_cm_0_ca")) + } + }) + } +} + +func TestConvertProviderConfigToIDPData_UnsupportedType(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + config := &configv1.IdentityProviderConfig{ + Type: "UnsupportedType", + } + _, err := convertProviderConfigToIDPData(t.Context(), config, nil, 0, vmi, nil, "test", true) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("is not supported")) +} + +func TestConvertProviderConfigToIDPData_Routing(t *testing.T) { + t.Parallel() + tests := []struct { + name string + providerType configv1.IdentityProviderType + config *configv1.IdentityProviderConfig + expectedKind string + }{ + { + name: "When type is BasicAuth, it should route to BasicAuth converter", + providerType: configv1.IdentityProviderTypeBasicAuth, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeBasicAuth, + BasicAuth: &configv1.BasicAuthIdentityProvider{ + OAuthRemoteConnectionInfo: configv1.OAuthRemoteConnectionInfo{URL: "https://example.com"}, + }, + }, + expectedKind: "BasicAuthPasswordIdentityProvider", + }, + { + name: "When type is GitHub, it should route to GitHub converter", + providerType: configv1.IdentityProviderTypeGitHub, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeGitHub, + GitHub: &configv1.GitHubIdentityProvider{ + ClientID: "id", + ClientSecret: configv1.SecretNameReference{Name: "s"}, + }, + }, + expectedKind: "GitHubIdentityProvider", + }, + { + name: "When type is HTPasswd, it should route to HTPasswd converter", + providerType: configv1.IdentityProviderTypeHTPasswd, + config: &configv1.IdentityProviderConfig{ + Type: configv1.IdentityProviderTypeHTPasswd, + HTPasswd: &configv1.HTPasswdIdentityProvider{ + FileData: configv1.SecretNameReference{Name: "htpasswd"}, + }, + }, + expectedKind: "HTPasswdPasswordIdentityProvider", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + data, err := convertProviderConfigToIDPData(t.Context(), tc.config, nil, 0, vmi, nil, "test", true) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(data.provider.GetObjectKind().GroupVersionKind().Kind).To(Equal(tc.expectedKind)) + }) + } +} + +func TestIDPVolumeMountInfo_ConfigMapPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.ConfigMapPath(2, "my-configmap", "ca", "ca.crt") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_cm_2_ca/ca.crt")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-cm-2-ca")) + g.Expect(vmi.Volumes[0].ConfigMap.Name).To(Equal("my-configmap")) +} + +func TestIDPVolumeMountInfo_SecretPath(t *testing.T) { + t.Parallel() + g := NewWithT(t) + vmi := newTestVolumeMountInfo() + result := vmi.SecretPath(3, "my-secret", "client-secret", "clientSecret") + g.Expect(result).To(Equal("/etc/oauth/idp/idp_secret_3_client-secret/clientSecret")) + g.Expect(vmi.Volumes).To(HaveLen(1)) + g.Expect(vmi.Volumes[0].Name).To(Equal("idp-secret-3-client-secret")) + g.Expect(vmi.Volumes[0].Secret.SecretName).To(Equal("my-secret")) + g.Expect(vmi.Volumes[0].Secret.DefaultMode).ToNot(BeNil()) + g.Expect(*vmi.Volumes[0].Secret.DefaultMode).To(Equal(int32(0640))) +} + +func TestIsValidURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + rawurl string + optional bool + expected bool + }{ + { + name: "When URL is empty and optional, it should return true", + rawurl: "", + optional: true, + expected: true, + }, + { + name: "When URL is empty and required, it should return false", + rawurl: "", + optional: false, + expected: false, + }, + { + name: "When URL is a valid https URL, it should return true", + rawurl: "https://example.com/auth", + optional: false, + expected: true, + }, + { + name: "When URL uses http scheme, it should return false", + rawurl: "http://example.com/auth", + optional: false, + expected: false, + }, + { + name: "When URL has a fragment, it should return false", + rawurl: "https://example.com/auth#fragment", + optional: false, + expected: false, + }, + { + name: "When URL has no host, it should return false", + rawurl: "https://", + optional: false, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + g.Expect(isValidURL(tc.rawurl, tc.optional)).To(Equal(tc.expected)) + }) + } +} + func TestTransportForCARef(t *testing.T) { namespace := "test" diff --git a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go index 89077b4daa1..3f25380fba6 100644 --- a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go +++ b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources.go @@ -343,31 +343,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } if !hcp.DeletionTimestamp.IsZero() { - // Delete admission policies during cluster deletion to allow HCCO cleanup operations for ARO HCP - if hcp.Spec.Platform.Type == hyperv1.AzurePlatform { - registryConfigManagementStateAdmissionPolicy := registry.AdmissionPolicy{Name: registry.AdmissionPolicyNameManagementState} - // During cluster deletion, delete the admission policy and its binding to allow CIRO cleanup - log.Info("Cluster is being deleted, deleting registry management state admission policy and binding to allow cleanup") - - // Delete binding first to avoid dangling reference - binding := manifests.ValidatingAdmissionPolicyBinding(fmt.Sprintf("%s-binding", registryConfigManagementStateAdmissionPolicy.Name)) - _, err := k8sutil.DeleteIfNeeded(ctx, r.client, binding) - if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicyBinding %s: %v", binding.Name, err) - } - - // Delete policy - vap := manifests.ValidatingAdmissionPolicy(registryConfigManagementStateAdmissionPolicy.Name) - if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, vap); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicy %s: %v", vap.Name, err) - } - } - - if shouldCleanupCloudResources(hcp) { - log.Info("Cleaning up hosted cluster cloud resources") - return r.destroyCloudResources(ctx, hcp) - } - return ctrl.Result{}, nil + return r.reconcileDeletion(ctx, log, hcp) } if isPaused, duration := util.IsReconciliationPaused(log, hcp.Spec.PausedUntil); isPaused { @@ -453,21 +429,184 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile rbac: %w", err)) } + errs = append(errs, r.reconcileRegistryAndIngress(ctx, hcp, log)...) + + errs = append(errs, r.reconcileAPIServicesAndOAuth(ctx, hcp, log, releaseImage)...) + + errs = append(errs, r.reconcileNetworkingAndSecrets(ctx, hcp, log, pullSecret)...) + + log.Info("reconciling olm resources") + errs = append(errs, r.reconcileOLM(ctx, hcp, pullSecret)...) + + errs = append(errs, r.reconcileStorageAndMisc(ctx, log, hcp, releaseImage)...) + r.cleanupLegacyResources(ctx, log, hcp, releaseImage, &errs) + + errs = append(errs, r.reconcilePlatformSpecificResources(ctx, log, hcp, releaseImage)...) + + if result, err := r.reconcileClusterRecovery(ctx, log, hcp, errs); err != nil || result.RequeueAfter > 0 { + return result, utilerrors.NewAggregate(append(errs, err)) + } + + return ctrl.Result{}, utilerrors.NewAggregate(errs) +} + +func (r *reconciler) reconcileStorageAndMisc(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + + log.Info("reconciling kubelet configs") + if err := r.reconcileKubeletConfig(ctx); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile kubelet config: %w", err)) + } + + if hostedcontrolplane.IsStorageAndCSIManaged(hcp) { + log.Info("reconciling storage resources") + errs = append(errs, r.reconcileStorage(ctx, hcp)...) + + log.Info("reconciling node level csi configuration") + if err := r.reconcileCSIDriver(ctx, hcp, releaseImage); err != nil { + errs = append(errs, err) + } + } + + recyclerServiceAccount := manifests.RecyclerServiceAccount() + if _, err := r.CreateOrUpdate(ctx, r.client, recyclerServiceAccount, func() error { + return nil + }); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile pv recycler service account: %w", err)) + } + + log.Info("reconciling observed configuration") + errs = append(errs, r.reconcileObservedConfiguration(ctx, hcp)...) + + errs = append(errs, r.ensureGuestAdmissionWebhooksAreValid(ctx)) + return errs +} + +func (r *reconciler) cleanupLegacyResources(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage, errs *[]error) { + if !r.isClusterVersionUpdated(ctx, releaseImage.Version()) { + return + } + deleteDNSOperatorDeploymentOnce.Do(func() { + dnsOperatorDeployment := manifests.DNSOperatorDeployment() + log.Info("removing any existing DNS operator deployment") + if err := r.uncachedClient.Delete(ctx, dnsOperatorDeployment); err != nil && !apierrors.IsNotFound(err) { + *errs = append(*errs, err) + } + }) + deleteCVORemovedResourcesOnce.Do(func() { + resources := cvo.ResourcesToRemove(hcp.Spec.Platform.Type) + for _, resource := range resources { + log.Info("removing existing resources", "resource", resource) + if err := r.uncachedClient.Delete(ctx, resource); err != nil && !apierrors.IsNotFound(err) { + *errs = append(*errs, err) + } + } + }) +} + +func (r *reconciler) reconcileDeletion(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane) (ctrl.Result, error) { + if hcp.Spec.Platform.Type == hyperv1.AzurePlatform { + registryConfigManagementStateAdmissionPolicy := registry.AdmissionPolicy{Name: registry.AdmissionPolicyNameManagementState} + log.Info("Cluster is being deleted, deleting registry management state admission policy and binding to allow cleanup") + + binding := manifests.ValidatingAdmissionPolicyBinding(fmt.Sprintf("%s-binding", registryConfigManagementStateAdmissionPolicy.Name)) + if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, binding); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicyBinding %s: %v", binding.Name, err) + } + + vap := manifests.ValidatingAdmissionPolicy(registryConfigManagementStateAdmissionPolicy.Name) + if _, err := k8sutil.DeleteIfNeeded(ctx, r.client, vap); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to delete ValidatingAdmissionPolicy %s: %v", vap.Name, err) + } + } + + if shouldCleanupCloudResources(hcp) { + log.Info("Cleaning up hosted cluster cloud resources") + return r.destroyCloudResources(ctx, hcp) + } + return ctrl.Result{}, nil +} + +func (r *reconciler) reconcilePlatformSpecificResources(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + switch hcp.Spec.Platform.Type { + case hyperv1.AWSPlatform: + log.Info("reconciling AWS specific resources") + errs = append(errs, r.reconcileAWSIdentityWebhook(ctx)...) + case hyperv1.AzurePlatform: + log.Info("reconciling Azure specific resources") + errs = append(errs, r.reconcileAzureCloudNodeManager(ctx, releaseImage.ComponentImages()["azure-cloud-node-manager"])...) + errs = append(errs, r.reconcileAzureIdentityWebhook(ctx)...) + } + return errs +} + +func (r *reconciler) reconcileClusterRecovery(ctx context.Context, log logr.Logger, hcp *hyperv1.HostedControlPlane, existingErrs []error) (ctrl.Result, error) { + if _, exists := hcp.Annotations[hyperv1.HostedClusterRestoredFromBackupAnnotation]; !exists { + return ctrl.Result{}, nil + } + + condition := &metav1.Condition{ + Type: string(hyperv1.HostedClusterRestoredFromBackup), + Reason: hyperv1.RecoveryFinishedReason, + } + + finished, err := r.reconcileRestoredCluster(ctx, hcp) + if err != nil { + log.Error(err, "failed to reconcile hosted cluster recovery") + return ctrl.Result{}, utilerrors.NewAggregate(append(existingErrs, err)) + } + + if !finished { + log.Info("hosted cluster recovery not finished yet") + condition.Status = metav1.ConditionFalse + condition.Message = "Hosted cluster recovery not finished yet" + } else { + log.Info("hosted cluster recovery finished") + condition.Status = metav1.ConditionTrue + condition.Message = "Hosted cluster recovery finished" + } + + meta.SetStatusCondition(&hcp.Status.Conditions, *condition) + log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) + if err := r.cpClient.Status().Update(ctx, hcp); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) + } + log.Info("successfully updated hcp status with recovery condition") + + if !finished { + return ctrl.Result{RequeueAfter: 120 * time.Second}, nil + } + return ctrl.Result{}, nil +} + +func (r *reconciler) reconcileCSIDriver(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { + switch hcp.Spec.Platform.Type { + case hyperv1.KubevirtPlatform: + // Most csi drivers should be laid down by the Cluster Storage Operator (CSO) instead of + // the hcco operator. Only KubeVirt is unique at the moment. + err := kubevirtcsi.ReconcileTenant(r.client, hcp, ctx, r.CreateOrUpdate, releaseImage.ComponentImages()) + if err != nil { + return err + } + } + + return nil +} + +func (r *reconciler) reconcileRegistryAndIngress(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger) []error { + var errs []error + registryConfig := manifests.Registry() var registryConfigExists bool - // Check if the registry config exists if err := r.client.Get(ctx, client.ObjectKeyFromObject(registryConfig), registryConfig); err != nil { if !apierrors.IsNotFound(err) { - return ctrl.Result{}, fmt.Errorf("failed to get registry config: %w", err) + errs = append(errs, fmt.Errorf("failed to get registry config: %w", err)) } } else { registryConfigExists = true } - // For platforms where cluster-image-registry-operator (CIRO) needs a PVC to be created, bootstrap needs to happen - // in CIRO before the registry config is created. For now, this is the case for the OpenStack platform. - // If the object exist, we reconcile the registry config for other fields as it should be fine since the PVC would - // exist at this point. if capabilities.IsImageRegistryCapabilityEnabled(hcp.Spec.Capabilities) { if imageRegistryPlatformWithPVC(hcp.Spec.Platform.Type) && (!registryConfigExists || registryConfig == nil) { log.Info("skipping registry config to let CIRO bootstrap") @@ -480,23 +619,16 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } log.Info("reconciling registry config") if _, err := r.CreateOrUpdate(ctx, r.client, registryConfig, func() error { - err = registry.ReconcileRegistryConfig(registryConfig, r.platformType, hcp.Spec.InfrastructureAvailabilityPolicy) - if err != nil { - return err - } - return nil + return registry.ReconcileRegistryConfig(registryConfig, r.platformType, hcp.Spec.InfrastructureAvailabilityPolicy) }); err != nil { errs = append(errs, fmt.Errorf("failed to reconcile imageregistry config: %w", err)) } - // TODO: remove this when ROSA HCP stops setting the managementState to Removed to disable the Image Registry if registryConfig.Spec.ManagementState == operatorv1.Removed && r.platformType != hyperv1.IBMCloudPlatform && r.platformType != hyperv1.AzurePlatform { log.Info("imageregistry operator managementstate is removed, disabling openshift-controller-manager controllers and cleaning up resources") ocmConfigMap := cpomanifests.OpenShiftControllerManagerConfig(r.hcpNamespace) if _, err := r.CreateOrUpdate(ctx, r.cpClient, ocmConfigMap, func() error { if ocmConfigMap.Data == nil { - // CPO has not created the configmap yet, wait for create - // This should not happen as we are started by the CPO after the configmap should be created return nil } config := &openshiftcpv1.OpenShiftControllerManagerConfig{} @@ -536,8 +668,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } } } - // Reconcile the IngressController resource only if the ingress capability is enabled. - // Skip this step if the user explicitly disabled ingress. + if capabilities.IsIngressCapabilityEnabled(hcp.Spec.Capabilities) { log.Info("reconciling ingress controller") if err := r.reconcileIngressController(ctx, hcp); err != nil { @@ -550,6 +681,12 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile oauth client secrets: %w", err)) } + return errs +} + +func (r *reconciler) reconcileAPIServicesAndOAuth(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger, releaseImage *releaseinfo.ReleaseImage) []error { + var errs []error + log.Info("reconciling kube control plane signer secret") kubeControlPlaneSignerSecret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -629,29 +766,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result } if util.HCPOAuthEnabled(hcp) { - log.Info("reconciling openshift oauth apiserver apiservices") - if err := r.reconcileOpenshiftOAuthAPIServerAPIServices(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift apiserver service: %w", err)) - } - - log.Info("reconciling openshift oauth apiserver service") - openshiftOAuthAPIServerService := manifests.OpenShiftOAuthAPIServerClusterService() - if _, err := r.CreateOrUpdate(ctx, r.client, openshiftOAuthAPIServerService, func() error { - oapi.ReconcileClusterService(openshiftOAuthAPIServerService) - return nil - }); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver service: %w", err)) - } - - log.Info("reconciling openshift oauth apiserver endpoints") - if err := r.reconcileOpenshiftOAuthAPIServerEndpoints(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile openshift apiserver endpoints: %w", err)) - } - - log.Info("reconciling kubeadmin password hash secret") - if err := r.reconcileKubeadminPasswordHashSecret(ctx, hcp); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile kubeadmin password hash secret: %w", err)) - } + errs = append(errs, r.reconcileOAuthAPIServerResources(ctx, hcp, log)...) } log.Info("reconciling kube apiserver service monitor") @@ -667,6 +782,42 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile metrics forwarder: %w", err)) } + return errs +} + +func (r *reconciler) reconcileOAuthAPIServerResources(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger) []error { + var errs []error + + log.Info("reconciling openshift oauth apiserver apiservices") + if err := r.reconcileOpenshiftOAuthAPIServerAPIServices(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver apiservices: %w", err)) + } + + log.Info("reconciling openshift oauth apiserver service") + openshiftOAuthAPIServerService := manifests.OpenShiftOAuthAPIServerClusterService() + if _, err := r.CreateOrUpdate(ctx, r.client, openshiftOAuthAPIServerService, func() error { + oapi.ReconcileClusterService(openshiftOAuthAPIServerService) + return nil + }); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver service: %w", err)) + } + + log.Info("reconciling openshift oauth apiserver endpoints") + if err := r.reconcileOpenshiftOAuthAPIServerEndpoints(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile openshift oauth apiserver endpoints: %w", err)) + } + + log.Info("reconciling kubeadmin password hash secret") + if err := r.reconcileKubeadminPasswordHashSecret(ctx, hcp); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile kubeadmin password hash secret: %w", err)) + } + + return errs +} + +func (r *reconciler) reconcileNetworkingAndSecrets(ctx context.Context, hcp *hyperv1.HostedControlPlane, log logr.Logger, pullSecret *corev1.Secret) []error { + var errs []error + log.Info("reconciling network operator") networkOperator := networkoperator.NetworkOperator() var ovnConfig *hyperv1.OVNKubernetesConfig @@ -679,12 +830,10 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result }); err != nil { errs = append(errs, fmt.Errorf("failed to reconcile network operator: %w", err)) } - // Detect suboptimal MTU size on kubevirt hosted cluster with ovn-k and raise a condition in such a case if err := networkoperator.DetectSuboptimalMTU(ctx, r.cpClient, networkOperator, hcp); err != nil { errs = append(errs, err) } - // this allows users to disable data collection in sensitive environments - // solves https://issues.redhat.com/browse/OCPBUGS-12208 + ensureExistsReconciliationStrategy := false if _, exists := hcp.Annotations[hyperv1.EnsureExistsPullSecretReconciliation]; exists { ensureExistsReconciliationStrategy = true @@ -779,131 +928,7 @@ func (r *reconciler) Reconcile(ctx context.Context, _ ctrl.Request) (ctrl.Result errs = append(errs, fmt.Errorf("failed to reconcile openshift controller manager service ca bundle: %w", err)) } - log.Info("reconciling olm resources") - errs = append(errs, r.reconcileOLM(ctx, hcp, pullSecret)...) - - log.Info("reconciling kubelet configs") - if err := r.reconcileKubeletConfig(ctx); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile kubelet config: %w", err)) - } - - if hostedcontrolplane.IsStorageAndCSIManaged(hcp) { - log.Info("reconciling storage resources") - errs = append(errs, r.reconcileStorage(ctx, hcp)...) - - log.Info("reconciling node level csi configuration") - if err := r.reconcileCSIDriver(ctx, hcp, releaseImage); err != nil { - errs = append(errs, err) - } - } - - recyclerServiceAccount := manifests.RecyclerServiceAccount() - if _, err := r.CreateOrUpdate(ctx, r.client, recyclerServiceAccount, func() error { - return nil - }); err != nil { - errs = append(errs, fmt.Errorf("failed to reconcile pv recycler service account: %w", err)) - } - - log.Info("reconciling observed configuration") - errs = append(errs, r.reconcileObservedConfiguration(ctx, hcp)...) - - errs = append(errs, r.ensureGuestAdmissionWebhooksAreValid(ctx)) - - // Delete the DNS operator deployment in the hosted cluster, if it is - // present there. A separate DNS operator deployment runs as part of - // the hosted control-plane, but an upgraded cluster might still have - // an old DNS operator deployment in the hosted cluster. The caching - // client has a label selector that doesn't match the deployment, - // so we must use the uncached client for this delete call. To avoid - // excessive API calls using the uncached client, the delete call is - // guarded using a sync.Once. - if r.isClusterVersionUpdated(ctx, releaseImage.Version()) { - deleteDNSOperatorDeploymentOnce.Do(func() { - dnsOperatorDeployment := manifests.DNSOperatorDeployment() - log.Info("removing any existing DNS operator deployment") - if err := r.uncachedClient.Delete(ctx, dnsOperatorDeployment); err != nil && !apierrors.IsNotFound(err) { - errs = append(errs, err) - } - }) - deleteCVORemovedResourcesOnce.Do(func() { - resources := cvo.ResourcesToRemove(hcp.Spec.Platform.Type) - for _, resource := range resources { - log.Info("removing existing resources", "resource", resource) - if err := r.uncachedClient.Delete(ctx, resource); err != nil && !apierrors.IsNotFound(err) { - errs = append(errs, err) - } - } - }) - } - - // Reconcile platform specific resources - switch hcp.Spec.Platform.Type { - case hyperv1.AWSPlatform: - log.Info("reconciling AWS specific resources") - errs = append(errs, r.reconcileAWSIdentityWebhook(ctx)...) - case hyperv1.AzurePlatform: - log.Info("reconciling Azure specific resources") - errs = append(errs, r.reconcileAzureCloudNodeManager(ctx, releaseImage.ComponentImages()["azure-cloud-node-manager"])...) - errs = append(errs, r.reconcileAzureIdentityWebhook(ctx)...) - } - - // Reconcile hostedCluster recovery if the hosted cluster was restored from backup - if _, exists := hcp.Annotations[hyperv1.HostedClusterRestoredFromBackupAnnotation]; exists { - var ( - finished bool - err error - ) - condition := &metav1.Condition{ - Type: string(hyperv1.HostedClusterRestoredFromBackup), - Reason: hyperv1.RecoveryFinishedReason, - } - - finished, err = r.reconcileRestoredCluster(ctx, hcp) - if err != nil { - log.Error(err, "failed to reconcile hosted cluster recovery") - return ctrl.Result{}, utilerrors.NewAggregate(append(errs, err)) - } - - if !finished { - log.Info("hosted cluster recovery not finished yet") - condition.Status = metav1.ConditionFalse - condition.Message = "Hosted cluster recovery not finished yet" - meta.SetStatusCondition(&hcp.Status.Conditions, *condition) - log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) - if err := r.cpClient.Status().Update(ctx, hcp); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) - } - log.Info("successfully updated hcp status with recovery not finished condition") - - return ctrl.Result{RequeueAfter: 120 * time.Second}, nil - } - - log.Info("hosted cluster recovery finished") - condition.Status = metav1.ConditionTrue - condition.Message = "Hosted cluster recovery finished" - meta.SetStatusCondition(&hcp.Status.Conditions, *condition) - log.Info("setting condition", "type", condition.Type, "status", condition.Status, "message", condition.Message) - if err := r.cpClient.Status().Update(ctx, hcp); err != nil { - return ctrl.Result{}, fmt.Errorf("failed to update status on hcp for hosted cluster recovery: %w. Condition error message: %v", err, condition.Message) - } - log.Info("successfully updated hcp status with recovery finished condition") - } - - return ctrl.Result{}, utilerrors.NewAggregate(errs) -} - -func (r *reconciler) reconcileCSIDriver(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { - switch hcp.Spec.Platform.Type { - case hyperv1.KubevirtPlatform: - // Most csi drivers should be laid down by the Cluster Storage Operator (CSO) instead of - // the hcco operator. Only KubeVirt is unique at the moment. - err := kubevirtcsi.ReconcileTenant(r.client, hcp, ctx, r.CreateOrUpdate, releaseImage.ComponentImages()) - if err != nil { - return err - } - } - - return nil + return errs } func (r *reconciler) reconcileMetricsForwarder(ctx context.Context, hcp *hyperv1.HostedControlPlane, releaseImage *releaseinfo.ReleaseImage) error { diff --git a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go index def47e1bd6e..611f39abe08 100644 --- a/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go +++ b/control-plane-operator/hostedclusterconfigoperator/controllers/resources/resources_test.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "math/rand" + "reflect" "strings" + "sync" "testing" "time" @@ -15,15 +17,18 @@ import ( "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/api" "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/kas" "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/manifests" + "github.com/openshift/hypershift/control-plane-operator/hostedclusterconfigoperator/controllers/resources/registry" "github.com/openshift/hypershift/hypershift-operator/controllers/nodepool" "github.com/openshift/hypershift/support/azureutil" "github.com/openshift/hypershift/support/globalconfig" "github.com/openshift/hypershift/support/netutil" + "github.com/openshift/hypershift/support/releaseinfo" fakereleaseprovider "github.com/openshift/hypershift/support/releaseinfo/fake" supportutil "github.com/openshift/hypershift/support/util" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" configv1 "github.com/openshift/api/config/v1" + imageapi "github.com/openshift/api/image/v1" operatorv1 "github.com/openshift/api/operator/v1" appsv1 "k8s.io/api/apps/v1" @@ -2699,6 +2704,140 @@ func Test_reconciler_reconcileControlPlaneConnectionAvailable(t *testing.T) { } } +func verifyKASCheckerLabelsAndSelectors(t *testing.T, dep *appsv1.Deployment) { + t.Helper() + if dep.Spec.Selector == nil || dep.Spec.Selector.MatchLabels["app"] != manifests.KASConnectionCheckerName { + t.Error("Selector labels not set correctly") + } + if dep.Spec.Template.ObjectMeta.Labels["app"] != manifests.KASConnectionCheckerName { + t.Error("Pod template labels not set correctly") + } +} + +func verifyKASCheckerContainerBasics(t *testing.T, dep *appsv1.Deployment, expectedImage string) corev1.Container { + t.Helper() + if len(dep.Spec.Template.Spec.Containers) != 1 { + t.Fatalf("Expected 1 container, got %d", len(dep.Spec.Template.Spec.Containers)) + } + container := dep.Spec.Template.Spec.Containers[0] + if container.Name != "connection-checker" { + t.Errorf("Expected container name 'connection-checker', got %s", container.Name) + } + if container.Image != expectedImage { + t.Errorf("Expected cli image %s, got %s", expectedImage, container.Image) + } + return container +} + +func verifyKASCheckerScript(t *testing.T, container corev1.Container) { + t.Helper() + if len(container.Command) != 3 || container.Command[0] != "/bin/sh" || container.Command[1] != "-c" { + t.Fatalf("Expected command [/bin/sh -c