diff --git a/controllers/object_controls.go b/controllers/object_controls.go index f26a0c063..00e5f717d 100644 --- a/controllers/object_controls.go +++ b/controllers/object_controls.go @@ -5174,12 +5174,21 @@ func RuntimeClasses(n ClusterPolicyController) (gpuv1.State, error) { return transformKataRuntimeClasses(n) } + nvidiaRuntimeClasses := n.resources[state].RuntimeClasses + if n.stateNames[state] == "pre-requisites" && !n.isStateEnabled(n.stateNames[state]) { + err := clearRuntimeClasses(n, nvidiaRuntimeClasses) + if err != nil { + return gpuv1.NotReady, fmt.Errorf("error clearing nvidia runtime classes: %w", err) + } + return gpuv1.Ready, nil + } + createRuntimeClassFunc := transformRuntimeClass if semver.Compare(n.k8sVersion, nodev1MinimumAPIVersion) <= 0 { createRuntimeClassFunc = transformRuntimeClassLegacy } - for _, obj := range n.resources[state].RuntimeClasses { + for _, obj := range nvidiaRuntimeClasses { obj := obj // When CDI is disabled, do not create the additional 'nvidia-cdi' and // 'nvidia-legacy' runtime classes. Delete these objects if they were @@ -5240,3 +5249,19 @@ func PrometheusRule(n ClusterPolicyController) (gpuv1.State, error) { } return gpuv1.Ready, nil } + +func clearRuntimeClasses(n ClusterPolicyController, runtimeClasses []nodev1.RuntimeClass) error { + for _, obj := range runtimeClasses { + // apply runtime class name as per ClusterPolicy + if obj.Name == "FILLED_BY_OPERATOR" { + obj.Name = getRuntimeClassName(&n.singleton.Spec) + } + logger := n.logger.WithValues("RuntimeClass", obj.Name) + err := n.client.Delete(n.ctx, &obj) + if err != nil && !apierrors.IsNotFound(err) { + logger.Info("Couldn't delete", "Error", err) + return err + } + } + return nil +} diff --git a/controllers/object_controls_test.go b/controllers/object_controls_test.go index 2e5ea8285..c47288079 100644 --- a/controllers/object_controls_test.go +++ b/controllers/object_controls_test.go @@ -32,6 +32,7 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" nodev1 "k8s.io/api/node/v1" + nodev1beta1 "k8s.io/api/node/v1beta1" rbacv1 "k8s.io/api/rbac/v1" schedv1 "k8s.io/api/scheduling/v1beta1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" @@ -1161,7 +1162,10 @@ func TestServiceMonitor(t *testing.T) { } // CRD object for tests that need ServiceMonitor CRD present - serviceMonitorCRD := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: ServiceMonitorCRDName}} + serviceMonitorCRD := &apiextensionsv1.CustomResourceDefinition{ + TypeMeta: metav1.TypeMeta{Kind: "CustomResourceDefinition"}, + ObjectMeta: metav1.ObjectMeta{Name: ServiceMonitorCRDName}, + } tests := []struct { description string @@ -1577,3 +1581,156 @@ func TestKernelFullVersion(t *testing.T) { require.Equal(t, test.expected["osVersionMajor"], osVersion) } } + +func TestRuntimeClasses(t *testing.T) { + const ( + testNamespace = "test-namespace" + ) + + // Create scheme with required types + scheme := runtime.NewScheme() + require.NoError(t, nodev1.AddToScheme(scheme)) + require.NoError(t, nodev1beta1.AddToScheme(scheme)) + require.NoError(t, apiextensionsv1.AddToScheme(scheme)) + require.NoError(t, gpuv1.AddToScheme(scheme)) + + // Create controller with given spec and state + newController := func(k8s client.Client, scheme *runtime.Scheme, spec gpuv1.ClusterPolicySpec, state string) ClusterPolicyController { + clusterPolicy := &gpuv1.ClusterPolicy{Spec: spec} + resources := []Resources{ + { + RuntimeClasses: []nodev1.RuntimeClass{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia", + }, + TypeMeta: metav1.TypeMeta{ + Kind: "RuntimeClass", + APIVersion: "node.k8s.io/v1", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia-cdi", + }, + TypeMeta: metav1.TypeMeta{ + Kind: "RuntimeClass", + APIVersion: "node.k8s.io/v1", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia-legacy", + }, + TypeMeta: metav1.TypeMeta{ + Kind: "RuntimeClass", + APIVersion: "node.k8s.io/v1", + }, + }, + }, + }, + } + + return ClusterPolicyController{ + client: k8s, + ctx: context.Background(), + singleton: clusterPolicy, + scheme: scheme, + operatorNamespace: testNamespace, + resources: resources, + stateNames: []string{state}, + idx: 0, + logger: ctrl.Log.WithName("test"), + } + } + + tests := []struct { + description string + stateName string + k8sVersion string + k8sObjects []client.Object + clusterPolicySpec gpuv1.ClusterPolicySpec + expectedState gpuv1.State + expectedRuntimeClasses []string + }{ + { + description: "CDI enabled", + stateName: "pre-requisites", + k8sVersion: "v1.33.0", + k8sObjects: nil, + clusterPolicySpec: gpuv1.ClusterPolicySpec{ + CDI: gpuv1.CDIConfigSpec{Enabled: ptr.To(true)}, + }, + expectedState: gpuv1.Ready, + expectedRuntimeClasses: []string{"nvidia", "nvidia-legacy", "nvidia-cdi"}, + }, + { + description: "CDI and NRI Plugin Enabled", + stateName: "pre-requisites", + k8sVersion: "v1.33.0", + k8sObjects: nil, + clusterPolicySpec: gpuv1.ClusterPolicySpec{ + CDI: gpuv1.CDIConfigSpec{ + Enabled: ptr.To(true), + NRIPluginEnabled: ptr.To(true), + }, + }, + expectedState: gpuv1.Ready, + expectedRuntimeClasses: []string{}, + }, + { + description: "CDI and NRI Plugin Enabled with pre-existing runtime class", + stateName: "pre-requisites", + k8sVersion: "v1.33.0", + k8sObjects: []client.Object{ + &nodev1.RuntimeClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia", + }, + }, + &nodev1.RuntimeClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia-legacy", + }, + }, + &nodev1.RuntimeClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nvidia-cdi", + }, + }, + }, + clusterPolicySpec: gpuv1.ClusterPolicySpec{ + CDI: gpuv1.CDIConfigSpec{ + Enabled: ptr.To(true), + NRIPluginEnabled: ptr.To(true), + }, + }, + expectedState: gpuv1.Ready, + expectedRuntimeClasses: []string{}, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(test.k8sObjects...). + Build() + + controller := newController(k8sClient, scheme, test.clusterPolicySpec, test.stateName) + controller.k8sVersion = test.k8sVersion + + state, err := RuntimeClasses(controller) + require.NoError(t, err) + require.Equal(t, test.expectedState, state) + + for _, expectedRuntimeClass := range test.expectedRuntimeClasses { + rcObject := &nodev1.RuntimeClass{} + err := k8sClient.Get(t.Context(), client.ObjectKey{Name: expectedRuntimeClass}, rcObject) + require.NoError(t, err) + require.Equal(t, expectedRuntimeClass, rcObject.Name) + } + + }) + } +} diff --git a/controllers/state_manager.go b/controllers/state_manager.go index 4ea634ebe..5969f9980 100644 --- a/controllers/state_manager.go +++ b/controllers/state_manager.go @@ -984,6 +984,8 @@ func (n ClusterPolicyController) isStateEnabled(stateName string) bool { clusterPolicySpec := &n.singleton.Spec switch stateName { + case "pre-requisites": + return !clusterPolicySpec.CDI.IsNRIPluginEnabled() case "state-driver": return clusterPolicySpec.Driver.IsEnabled() case "state-container-toolkit":