Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion controllers/object_controls.go
Original file line number Diff line number Diff line change
Expand Up @@ -5174,12 +5174,21 @@ func RuntimeClasses(n ClusterPolicyController) (gpuv1.State, error) {
return transformKataRuntimeClasses(n)
}

nvidiaRuntimeClasses := n.resources[state].RuntimeClasses
if n.stateNames[state] == "pre-requisites" && !n.isStateEnabled(n.stateNames[state]) {
err := clearRuntimeClasses(n, nvidiaRuntimeClasses)
if err != nil {
return gpuv1.NotReady, fmt.Errorf("error clearing nvidia runtime classes: %w", err)
}
return gpuv1.Ready, nil
}

createRuntimeClassFunc := transformRuntimeClass
if semver.Compare(n.k8sVersion, nodev1MinimumAPIVersion) <= 0 {
createRuntimeClassFunc = transformRuntimeClassLegacy
}

for _, obj := range n.resources[state].RuntimeClasses {
for _, obj := range nvidiaRuntimeClasses {
obj := obj
// When CDI is disabled, do not create the additional 'nvidia-cdi' and
// 'nvidia-legacy' runtime classes. Delete these objects if they were
Expand Down Expand Up @@ -5240,3 +5249,19 @@ func PrometheusRule(n ClusterPolicyController) (gpuv1.State, error) {
}
return gpuv1.Ready, nil
}

func clearRuntimeClasses(n ClusterPolicyController, runtimeClasses []nodev1.RuntimeClass) error {
for _, obj := range runtimeClasses {
// apply runtime class name as per ClusterPolicy
if obj.Name == "FILLED_BY_OPERATOR" {
obj.Name = getRuntimeClassName(&n.singleton.Spec)
}
logger := n.logger.WithValues("RuntimeClass", obj.Name)
err := n.client.Delete(n.ctx, &obj)
if err != nil && !apierrors.IsNotFound(err) {
logger.Info("Couldn't delete", "Error", err)
return err
}
}
return nil
}
159 changes: 158 additions & 1 deletion controllers/object_controls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
nodev1 "k8s.io/api/node/v1"
nodev1beta1 "k8s.io/api/node/v1beta1"
rbacv1 "k8s.io/api/rbac/v1"
schedv1 "k8s.io/api/scheduling/v1beta1"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
Expand Down Expand Up @@ -1161,7 +1162,10 @@ func TestServiceMonitor(t *testing.T) {
}

// CRD object for tests that need ServiceMonitor CRD present
serviceMonitorCRD := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: ServiceMonitorCRDName}}
serviceMonitorCRD := &apiextensionsv1.CustomResourceDefinition{
TypeMeta: metav1.TypeMeta{Kind: "CustomResourceDefinition"},
ObjectMeta: metav1.ObjectMeta{Name: ServiceMonitorCRDName},
}

tests := []struct {
description string
Expand Down Expand Up @@ -1577,3 +1581,156 @@ func TestKernelFullVersion(t *testing.T) {
require.Equal(t, test.expected["osVersionMajor"], osVersion)
}
}

func TestRuntimeClasses(t *testing.T) {
const (
testNamespace = "test-namespace"
)

// Create scheme with required types
scheme := runtime.NewScheme()
require.NoError(t, nodev1.AddToScheme(scheme))
require.NoError(t, nodev1beta1.AddToScheme(scheme))
require.NoError(t, apiextensionsv1.AddToScheme(scheme))
require.NoError(t, gpuv1.AddToScheme(scheme))

// Create controller with given spec and state
newController := func(k8s client.Client, scheme *runtime.Scheme, spec gpuv1.ClusterPolicySpec, state string) ClusterPolicyController {
clusterPolicy := &gpuv1.ClusterPolicy{Spec: spec}
resources := []Resources{
{
RuntimeClasses: []nodev1.RuntimeClass{
{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia",
},
TypeMeta: metav1.TypeMeta{
Kind: "RuntimeClass",
APIVersion: "node.k8s.io/v1",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia-cdi",
},
TypeMeta: metav1.TypeMeta{
Kind: "RuntimeClass",
APIVersion: "node.k8s.io/v1",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia-legacy",
},
TypeMeta: metav1.TypeMeta{
Kind: "RuntimeClass",
APIVersion: "node.k8s.io/v1",
},
},
},
},
}

return ClusterPolicyController{
client: k8s,
ctx: context.Background(),
singleton: clusterPolicy,
scheme: scheme,
operatorNamespace: testNamespace,
resources: resources,
stateNames: []string{state},
idx: 0,
logger: ctrl.Log.WithName("test"),
}
}

tests := []struct {
description string
stateName string
k8sVersion string
k8sObjects []client.Object
clusterPolicySpec gpuv1.ClusterPolicySpec
expectedState gpuv1.State
expectedRuntimeClasses []string
}{
{
description: "CDI enabled",
stateName: "pre-requisites",
k8sVersion: "v1.33.0",
k8sObjects: nil,
clusterPolicySpec: gpuv1.ClusterPolicySpec{
CDI: gpuv1.CDIConfigSpec{Enabled: ptr.To(true)},
},
expectedState: gpuv1.Ready,
expectedRuntimeClasses: []string{"nvidia", "nvidia-legacy", "nvidia-cdi"},
},
{
description: "CDI and NRI Plugin Enabled",
stateName: "pre-requisites",
k8sVersion: "v1.33.0",
k8sObjects: nil,
clusterPolicySpec: gpuv1.ClusterPolicySpec{
CDI: gpuv1.CDIConfigSpec{
Enabled: ptr.To(true),
NRIPluginEnabled: ptr.To(true),
},
},
expectedState: gpuv1.Ready,
expectedRuntimeClasses: []string{},
},
{
description: "CDI and NRI Plugin Enabled with pre-existing runtime class",
stateName: "pre-requisites",
k8sVersion: "v1.33.0",
k8sObjects: []client.Object{
&nodev1.RuntimeClass{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia",
},
},
&nodev1.RuntimeClass{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia-legacy",
},
},
&nodev1.RuntimeClass{
ObjectMeta: metav1.ObjectMeta{
Name: "nvidia-cdi",
},
},
},
clusterPolicySpec: gpuv1.ClusterPolicySpec{
CDI: gpuv1.CDIConfigSpec{
Enabled: ptr.To(true),
NRIPluginEnabled: ptr.To(true),
},
},
expectedState: gpuv1.Ready,
expectedRuntimeClasses: []string{},
},
}

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
k8sClient := fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(test.k8sObjects...).
Build()

controller := newController(k8sClient, scheme, test.clusterPolicySpec, test.stateName)
controller.k8sVersion = test.k8sVersion

state, err := RuntimeClasses(controller)
require.NoError(t, err)
require.Equal(t, test.expectedState, state)

for _, expectedRuntimeClass := range test.expectedRuntimeClasses {
rcObject := &nodev1.RuntimeClass{}
err := k8sClient.Get(t.Context(), client.ObjectKey{Name: expectedRuntimeClass}, rcObject)
require.NoError(t, err)
require.Equal(t, expectedRuntimeClass, rcObject.Name)
}

})
}
}
2 changes: 2 additions & 0 deletions controllers/state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,8 @@ func (n ClusterPolicyController) isStateEnabled(stateName string) bool {
clusterPolicySpec := &n.singleton.Spec

switch stateName {
case "pre-requisites":
return !clusterPolicySpec.CDI.IsNRIPluginEnabled()
case "state-driver":
return clusterPolicySpec.Driver.IsEnabled()
case "state-container-toolkit":
Expand Down