diff --git a/cmd/cluster/aws/create.go b/cmd/cluster/aws/create.go index 38c6f3623cb..83ee1f59b10 100644 --- a/cmd/cluster/aws/create.go +++ b/cmd/cluster/aws/create.go @@ -126,7 +126,7 @@ func (o *ValidatedCreateOptions) Complete(ctx context.Context, opts *core.Create opts.EtcdStorageClass = "gp3-csi" } - client, err := util.GetClient() + client, err := opts.GetClient() if err != nil { return nil, err } diff --git a/cmd/cluster/aws/create_test.go b/cmd/cluster/aws/create_test.go index 4d5b60bd522..d997bd32f93 100644 --- a/cmd/cluster/aws/create_test.go +++ b/cmd/cluster/aws/create_test.go @@ -13,12 +13,16 @@ import ( awsinfra "github.com/openshift/hypershift/cmd/infra/aws" awsutil "github.com/openshift/hypershift/cmd/infra/aws/util" "github.com/openshift/hypershift/cmd/util" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "github.com/spf13/pflag" ) @@ -90,7 +94,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") rawCreds, err := json.Marshal(&awsutil.STSCreds{ Credentials: awsutil.Credentials{ @@ -235,6 +238,9 @@ func TestCreateCluster(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() core.BindDeveloperOptions(coreOpts, flags) + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } awsOpts := DefaultOptions() BindDeveloperOptions(awsOpts, flags) if err := flags.Parse(testCase.args); err != nil { diff --git a/cmd/cluster/azure/create_test.go b/cmd/cluster/azure/create_test.go index 4f73b8d7646..1361c5fd335 100644 --- a/cmd/cluster/azure/create_test.go +++ b/cmd/cluster/azure/create_test.go @@ -14,12 +14,15 @@ import ( azureinfra "github.com/openshift/hypershift/cmd/infra/azure" azurenodepool "github.com/openshift/hypershift/cmd/nodepool/azure" "github.com/openshift/hypershift/cmd/util" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/yaml" "github.com/spf13/pflag" @@ -87,8 +90,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") - rawCreds, err := yaml.Marshal(&util.AzureCreds{ SubscriptionID: "fakeSubscriptionID", ClientID: "fakeClientID", @@ -309,6 +310,9 @@ func TestCreateCluster(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } core.BindDeveloperOptions(coreOpts, flags) azureOpts := DefaultOptions() azurenodepool.BindOptions(azureOpts.NodePoolOpts, flags) diff --git a/cmd/cluster/core/create.go b/cmd/cluster/core/create.go index 4c3ccc771e0..fbaede1fa7c 100644 --- a/cmd/cluster/core/create.go +++ b/cmd/cluster/core/create.go @@ -188,11 +188,24 @@ type RawCreateOptions struct { // This is intended primarily for e2e testing and should be used with care. BeforeApply func(crclient.Object) `json:"-"` + // ClientFn returns a Kubernetes client. When nil, util.GetClient is used. + // This enables dependency injection for testing without requiring a live cluster. + ClientFn func() (crclient.Client, error) `json:"-"` + // These fields are reverse-completed by the aws CLI since we support a flag that projects // them back up here PublicKey, PrivateKey, PullSecret []byte } +// GetClient returns a Kubernetes client using the configured ClientFn, +// falling back to util.GetClient when ClientFn is nil. +func (o *RawCreateOptions) GetClient() (crclient.Client, error) { + if o.ClientFn != nil { + return o.ClientFn() + } + return util.GetClient() +} + type resources struct { AdditionalTrustBundle *corev1.ConfigMap Namespace *corev1.Namespace @@ -358,7 +371,7 @@ func resolveReleaseImage(ctx context.Context, opts *CreateOptions) error { if len(opts.ReleaseImage) != 0 || len(opts.ReleaseStream) == 0 { return nil } - client, err := util.GetClient() + client, err := opts.GetClient() if err != nil { return fmt.Errorf("failed to get client: %w", err) } @@ -612,8 +625,8 @@ func applyFeatureSet(cluster *hyperv1.HostedCluster, opts *CreateOptions) { } } -func apply(ctx context.Context, l logr.Logger, infraID string, objects []crclient.Object, waitForRollout bool, mutate func(crclient.Object)) error { - client, err := util.GetClient() +func apply(ctx context.Context, l logr.Logger, infraID string, objects []crclient.Object, waitForRollout bool, mutate func(crclient.Object), clientFn func() (crclient.Client, error)) error { + client, err := clientFn() if err != nil { return err } @@ -753,7 +766,7 @@ func (opts *RawCreateOptions) Validate(ctx context.Context) (*ValidatedCreateOpt func (opts *RawCreateOptions) validateVersionAndWait(ctx context.Context) error { if opts.VersionCheck { versionCLI := supportedversion.GetRevision() - client, err := util.GetClient() + client, err := opts.GetClient() if err != nil { return fmt.Errorf("failed to get client: %w", err) } @@ -771,7 +784,7 @@ func (opts *RawCreateOptions) validateClusterExistence(ctx context.Context) erro if opts.Render { return nil } - client, err := util.GetClient() + client, err := opts.GetClient() if err != nil { return err } @@ -1005,7 +1018,7 @@ func CreateCluster(ctx context.Context, rawOpts *RawCreateOptions, rawPlatform P } // Otherwise, apply the objects - return apply(ctx, opts.Log, resources.Cluster.Spec.InfraID, resources.asObjects(), opts.Wait, opts.BeforeApply) + return apply(ctx, opts.Log, resources.Cluster.Spec.InfraID, resources.asObjects(), opts.Wait, opts.BeforeApply, opts.GetClient) } type DefaultNodePoolConstructor func(platformType hyperv1.PlatformType, suffix string) *hyperv1.NodePool diff --git a/cmd/cluster/core/create_test.go b/cmd/cluster/core/create_test.go index 532f77e8ed6..49f020b088c 100644 --- a/cmd/cluster/core/create_test.go +++ b/cmd/cluster/core/create_test.go @@ -10,6 +10,7 @@ import ( . "github.com/onsi/gomega" hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/config" "github.com/openshift/hypershift/support/thirdparty/library-go/pkg/image/dockerv1client" "github.com/openshift/hypershift/support/util/fakeimagemetadataprovider" @@ -708,6 +709,111 @@ func TestAllocateNodeCIDRsFlag(t *testing.T) { } } +func TestCreateOptionsGetClient(t *testing.T) { + t.Run("When ClientFn is set it should use the provided function", func(t *testing.T) { + g := NewWithT(t) + expectedClient := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build() + opts := &RawCreateOptions{ + ClientFn: func() (crclient.Client, error) { + return expectedClient, nil + }, + } + c, err := opts.GetClient() + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(c).To(Equal(expectedClient)) + }) + + t.Run("When ClientFn is set it should be accessible via completed options", func(t *testing.T) { + g := NewWithT(t) + expectedClient := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build() + opts := &CreateOptions{ + completedCreateOptions: &completedCreateOptions{ + ValidatedCreateOptions: &ValidatedCreateOptions{ + validatedCreateOptions: &validatedCreateOptions{ + RawCreateOptions: &RawCreateOptions{ + ClientFn: func() (crclient.Client, error) { + return expectedClient, nil + }, + }, + }, + }, + }, + } + c, err := opts.GetClient() + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(c).To(Equal(expectedClient)) + }) +} + +func TestValidateWithInjectedClient(t *testing.T) { + t.Run("When a HostedCluster already exists it should return an error", func(t *testing.T) { + g := NewWithT(t) + ctx := t.Context() + tempDir := t.TempDir() + + pullSecretFile := filepath.Join(tempDir, "pull-secret.json") + if err := os.WriteFile(pullSecretFile, []byte(`fake`), 0600); err != nil { + t.Fatalf("failed to write pullSecret: %v", err) + } + + existingCluster := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "existing-cluster", + Namespace: "clusters", + }, + } + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(existingCluster). + Build() + + rawOpts := &RawCreateOptions{ + Name: "existing-cluster", + Namespace: "clusters", + PullSecretFile: pullSecretFile, + Arch: "amd64", + Render: false, + ClientFn: func() (crclient.Client, error) { + return fakeClient, nil + }, + } + + _, err := rawOpts.Validate(ctx) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("already exists")) + }) + + t.Run("When no HostedCluster exists and render is true it should pass validation", func(t *testing.T) { + g := NewWithT(t) + ctx := t.Context() + tempDir := t.TempDir() + + pullSecretFile := filepath.Join(tempDir, "pull-secret.json") + if err := os.WriteFile(pullSecretFile, []byte(`fake`), 0600); err != nil { + t.Fatalf("failed to write pullSecret: %v", err) + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + Build() + + rawOpts := &RawCreateOptions{ + Name: "new-cluster", + Namespace: "clusters", + PullSecretFile: pullSecretFile, + Arch: "amd64", + Render: true, + ClientFn: func() (crclient.Client, error) { + return fakeClient, nil + }, + } + + validated, err := rawOpts.Validate(ctx) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(validated).ToNot(BeNil()) + }) +} + func TestValidateVersion(t *testing.T) { tests := []struct { name string diff --git a/cmd/cluster/core/destroy.go b/cmd/cluster/core/destroy.go index 49020dd0e80..d8bbac4dd34 100644 --- a/cmd/cluster/core/destroy.go +++ b/cmd/cluster/core/destroy.go @@ -44,6 +44,19 @@ type DestroyOptions struct { Log logr.Logger CredentialSecretName string RedactBaseDomain bool + + // ClientFn returns a Kubernetes client. When nil, util.GetClient is used. + // This enables dependency injection for testing without requiring a live cluster. + ClientFn func() (client.Client, error) `json:"-"` +} + +// GetClient returns a Kubernetes client using the configured ClientFn, +// falling back to util.GetClient when ClientFn is nil. +func (o *DestroyOptions) GetClient() (client.Client, error) { + if o.ClientFn != nil { + return o.ClientFn() + } + return util.GetClient() } type AWSPlatformDestroyOptions struct { @@ -85,7 +98,7 @@ type PowerVSPlatformDestroyOptions struct { } func GetCluster(ctx context.Context, o *DestroyOptions) (*hyperv1.HostedCluster, error) { - c, err := util.GetClient() + c, err := o.GetClient() if err != nil { return nil, err } @@ -106,7 +119,7 @@ func GetCluster(ctx context.Context, o *DestroyOptions) (*hyperv1.HostedCluster, func DestroyCluster(ctx context.Context, hostedCluster *hyperv1.HostedCluster, o *DestroyOptions, destroyPlatformSpecifics DestroyPlatformSpecifics) error { hostedClusterExists := hostedCluster != nil shouldDestroyPlatformSpecifics := destroyPlatformSpecifics != nil - c, err := util.GetClient() + c, err := o.GetClient() if err != nil { return err } diff --git a/cmd/cluster/core/destroy_test.go b/cmd/cluster/core/destroy_test.go index a813b835764..3e8896620ea 100644 --- a/cmd/cluster/core/destroy_test.go +++ b/cmd/cluster/core/destroy_test.go @@ -2,18 +2,25 @@ package core import ( "context" + "fmt" "testing" "time" . "github.com/onsi/gomega" + hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" "github.com/openshift/hypershift/cmd/log" + hyperapi "github.com/openshift/hypershift/support/api" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) func TestDestroyCluster(t *testing.T) { t.Run("When HostedCluster is nil and platform specifics provided it should call destroyPlatformSpecifics", func(t *testing.T) { - g := NewGomegaWithT(t) - t.Setenv("FAKE_CLIENT", "true") + g := NewWithT(t) platformSpecificsCalled := false var receivedOpts *DestroyOptions @@ -33,6 +40,9 @@ func TestDestroyCluster(t *testing.T) { Cloud: "AzurePublicCloud", Location: "eastus", }, + ClientFn: func() (client.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + }, } err := DestroyCluster(context.Background(), nil, opts, mockPlatformSpecifics) @@ -43,3 +53,86 @@ func TestDestroyCluster(t *testing.T) { g.Expect(receivedOpts.AzurePlatform.Cloud).To(Equal("AzurePublicCloud")) }) } + +func TestGetCluster(t *testing.T) { + t.Run("When the HostedCluster exists it should return it", func(t *testing.T) { + g := NewWithT(t) + hc := &hyperv1.HostedCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-cluster", + Namespace: "clusters", + }, + } + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + WithObjects(hc). + Build() + + opts := &DestroyOptions{ + Name: "my-cluster", + Namespace: "clusters", + Log: log.Log, + ClientFn: func() (client.Client, error) { + return fakeClient, nil + }, + } + + result, err := GetCluster(context.Background(), opts) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).ToNot(BeNil()) + g.Expect(result.Name).To(Equal("my-cluster")) + }) + + t.Run("When the HostedCluster does not exist it should return nil", func(t *testing.T) { + g := NewWithT(t) + fakeClient := fake.NewClientBuilder(). + WithScheme(hyperapi.Scheme). + Build() + + opts := &DestroyOptions{ + Name: "nonexistent", + Namespace: "clusters", + InfraID: "test-infra", + Log: log.Log, + ClientFn: func() (client.Client, error) { + return fakeClient, nil + }, + } + + result, err := GetCluster(context.Background(), opts) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(result).To(BeNil()) + }) + + t.Run("When the client factory returns an error it should propagate", func(t *testing.T) { + g := NewWithT(t) + opts := &DestroyOptions{ + Name: "test", + Namespace: "clusters", + Log: log.Log, + ClientFn: func() (client.Client, error) { + return nil, fmt.Errorf("connection refused") + }, + } + + result, err := GetCluster(context.Background(), opts) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("connection refused")) + g.Expect(result).To(BeNil()) + }) +} + +func TestDestroyOptionsGetClient(t *testing.T) { + t.Run("When ClientFn is set it should use the provided function", func(t *testing.T) { + g := NewWithT(t) + expectedClient := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build() + opts := &DestroyOptions{ + ClientFn: func() (client.Client, error) { + return expectedClient, nil + }, + } + c, err := opts.GetClient() + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(c).To(Equal(expectedClient)) + }) +} diff --git a/cmd/cluster/core/dump.go b/cmd/cluster/core/dump.go index 85f391ede3c..4a9ce8eba66 100644 --- a/cmd/cluster/core/dump.go +++ b/cmd/cluster/core/dump.go @@ -135,6 +135,19 @@ type DumpOptions struct { ImpersonateAs string Log logr.Logger + + // ClientFn returns a Kubernetes client. When nil, util.GetClient is used. + // This enables dependency injection for testing without requiring a live cluster. + ClientFn func() (client.Client, error) `json:"-"` +} + +// GetClient returns a Kubernetes client using the configured ClientFn, +// falling back to util.GetClient when ClientFn is nil. +func (o *DumpOptions) GetClient() (client.Client, error) { + if o.ClientFn != nil { + return o.ClientFn() + } + return util.GetClient() } func NewDumpCommand() *cobra.Command { @@ -184,7 +197,7 @@ func dumpGuestCluster(ctx context.Context, opts *DumpOptions) error { if len(opts.ImpersonateAs) > 0 { c, err = util.GetImpersonatedClient(opts.ImpersonateAs) } else { - c, err = util.GetClient() + c, err = opts.GetClient() } if err != nil { @@ -342,7 +355,7 @@ func DumpCluster(ctx context.Context, opts *DumpOptions) error { return err } } else { - c, err = util.GetClient() + c, err = opts.GetClient() if err != nil { return err } diff --git a/cmd/cluster/core/dump_test.go b/cmd/cluster/core/dump_test.go index 4c686e6dbff..30f21827a0f 100644 --- a/cmd/cluster/core/dump_test.go +++ b/cmd/cluster/core/dump_test.go @@ -4,10 +4,17 @@ import ( "fmt" "testing" + . "github.com/onsi/gomega" + + hyperapi "github.com/openshift/hypershift/support/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" fakediscovery "k8s.io/client-go/discovery/fake" clientgotesting "k8s.io/client-go/testing" + + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) func TestIsResourceRegistered(t *testing.T) { @@ -68,3 +75,18 @@ func TestIsResourceRegistered(t *testing.T) { }) } } + +func TestDumpOptionsGetClient(t *testing.T) { + t.Run("When ClientFn is set it should use the provided function", func(t *testing.T) { + g := NewWithT(t) + expectedClient := fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build() + opts := &DumpOptions{ + ClientFn: func() (client.Client, error) { + return expectedClient, nil + }, + } + c, err := opts.GetClient() + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(c).To(Equal(expectedClient)) + }) +} diff --git a/cmd/cluster/gcp/create_test.go b/cmd/cluster/gcp/create_test.go index 2720c446fc1..3ceaf906d6d 100644 --- a/cmd/cluster/gcp/create_test.go +++ b/cmd/cluster/gcp/create_test.go @@ -10,12 +10,16 @@ import ( hyperv1 "github.com/openshift/hypershift/api/hypershift/v1beta1" "github.com/openshift/hypershift/cmd/cluster/core" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "github.com/spf13/pflag" ) @@ -152,7 +156,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") pullSecretFile := filepath.Join(tempDir, "pull-secret.json") if err := os.WriteFile(pullSecretFile, []byte(`fake`), 0600); err != nil { @@ -188,6 +191,9 @@ func TestCreateCluster(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } core.BindDeveloperOptions(coreOpts, flags) gcpOpts := DefaultOptions() BindOptions(gcpOpts, flags) diff --git a/cmd/cluster/kubevirt/create_test.go b/cmd/cluster/kubevirt/create_test.go index 768abffbac4..87688ef9875 100644 --- a/cmd/cluster/kubevirt/create_test.go +++ b/cmd/cluster/kubevirt/create_test.go @@ -8,12 +8,16 @@ import ( . "github.com/onsi/gomega" "github.com/openshift/hypershift/cmd/cluster/core" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "github.com/google/go-cmp/cmp" "github.com/spf13/pflag" ) @@ -121,8 +125,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") - pullSecretFile := filepath.Join(tempDir, "pull-secret.json") if err := os.WriteFile(pullSecretFile, []byte(`fake`), 0600); err != nil { @@ -171,6 +173,9 @@ func TestCreateCluster(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } core.BindDeveloperOptions(coreOpts, flags) kubevirtOpts := DefaultOptions() BindDeveloperOptions(kubevirtOpts, flags) diff --git a/cmd/cluster/openstack/create_test.go b/cmd/cluster/openstack/create_test.go index 4d1aac0d6dd..ff54c229a9e 100644 --- a/cmd/cluster/openstack/create_test.go +++ b/cmd/cluster/openstack/create_test.go @@ -7,12 +7,16 @@ import ( "testing" "github.com/openshift/hypershift/cmd/cluster/core" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "github.com/spf13/pflag" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" @@ -47,7 +51,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") cloudsYAML := map[string]interface{}{ "clouds": map[string]interface{}{ @@ -110,6 +113,9 @@ func TestCreateCluster(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } core.BindDeveloperOptions(coreOpts, flags) openstackOpts := DefaultOptions() BindOptions(openstackOpts, flags) diff --git a/cmd/cluster/powervs/create_test.go b/cmd/cluster/powervs/create_test.go index f97d93a3102..cf26de13c87 100644 --- a/cmd/cluster/powervs/create_test.go +++ b/cmd/cluster/powervs/create_test.go @@ -8,6 +8,7 @@ import ( "github.com/openshift/hypershift/cmd/cluster/core" powervsinfra "github.com/openshift/hypershift/cmd/infra/powervs" + hyperapi "github.com/openshift/hypershift/support/api" "github.com/openshift/hypershift/support/certs" "github.com/openshift/hypershift/support/testutil" "github.com/openshift/hypershift/test/integration/framework" @@ -16,6 +17,9 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilrand "k8s.io/apimachinery/pkg/util/rand" + crclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "github.com/spf13/pflag" ) @@ -24,7 +28,6 @@ func TestCreateCluster(t *testing.T) { certs.UnsafeSeed(1234567890) ctx := framework.InterruptableContext(t.Context()) tempDir := t.TempDir() - t.Setenv("FAKE_CLIENT", "true") rawInfra, err := json.Marshal(&powervsinfra.Infra{ ID: "fakeID", @@ -106,6 +109,9 @@ func TestCreateCluster(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { flags := pflag.NewFlagSet(testCase.name, pflag.ContinueOnError) coreOpts := core.DefaultOptions() + coreOpts.ClientFn = func() (crclient.Client, error) { + return fake.NewClientBuilder().WithScheme(hyperapi.Scheme).Build(), nil + } core.BindDeveloperOptions(coreOpts, flags) powerVSOpts := DefaultOptions() BindOptions(powerVSOpts, flags) diff --git a/cmd/nodepool/core/create.go b/cmd/nodepool/core/create.go index 199c32c86fa..438cc5cc517 100644 --- a/cmd/nodepool/core/create.go +++ b/cmd/nodepool/core/create.go @@ -34,6 +34,19 @@ type CreateNodePoolOptions struct { NodeUpgradeType hyperv1.UpgradeType Arch string AutoRepair bool + + // ClientFn returns a Kubernetes client. When nil, util.GetClient is used. + // This enables dependency injection for testing without requiring a live cluster. + ClientFn func() (crclient.Client, error) `json:"-"` +} + +// GetClient returns a Kubernetes client using the configured ClientFn, +// falling back to util.GetClient when ClientFn is nil. +func (o *CreateNodePoolOptions) GetClient() (crclient.Client, error) { + if o.ClientFn != nil { + return o.ClientFn() + } + return util.GetClient() } type PlatformOptions interface { @@ -87,7 +100,7 @@ func (o *CreateNodePoolOptions) Validate(ctx context.Context, c crclient.Client) } func (o *CreateNodePoolOptions) CreateNodePool(ctx context.Context, platformOpts PlatformOptions) error { - client, err := util.GetClient() + client, err := o.GetClient() if err != nil { return err } diff --git a/cmd/nodepool/core/create_test.go b/cmd/nodepool/core/create_test.go index 75c048d7257..e5cf8099144 100644 --- a/cmd/nodepool/core/create_test.go +++ b/cmd/nodepool/core/create_test.go @@ -262,3 +262,18 @@ func TestValidMinorVersionCompatibility(t *testing.T) { g.Expect(err.Error()).To(Equal("NodePool minor version 4.14 is less than 4.15, which is the minimum NodePool version compatible with the 4.18 HostedCluster")) }) } + +func TestCreateNodePoolOptionsGetClient(t *testing.T) { + t.Run("When ClientFn is set it should use the provided function", func(t *testing.T) { + g := NewWithT(t) + expectedClient := fake.NewClientBuilder().WithScheme(api.Scheme).Build() + opts := &CreateNodePoolOptions{ + ClientFn: func() (client.Client, error) { + return expectedClient, nil + }, + } + c, err := opts.GetClient() + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(c).To(Equal(expectedClient)) + }) +} diff --git a/cmd/util/client.go b/cmd/util/client.go index 1f382ef692d..93b46370e11 100644 --- a/cmd/util/client.go +++ b/cmd/util/client.go @@ -2,7 +2,6 @@ package util import ( "fmt" - "os" "strings" hyperapi "github.com/openshift/hypershift/support/api" @@ -11,7 +10,6 @@ import ( cr "sigs.k8s.io/controller-runtime" crclient "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/client/fake" ) const ( @@ -33,10 +31,6 @@ func GetConfig() (*rest.Config, error) { // GetClient creates a controller-runtime client for Kubernetes func GetClient() (crclient.Client, error) { - if os.Getenv("FAKE_CLIENT") == "true" { - return fake.NewFakeClient(), nil - } - config, err := GetConfig() if err != nil { return nil, fmt.Errorf("unable to get kubernetes config: %w", err)