Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions controllers/object_controls.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,15 @@ func kernelFullVersion(n ClusterPolicyController) (string, string, string) {
if !ok {
return kFVersion, "", ""
}
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return kFVersion, "", ""
}

if osName == "rocky" {
// If the OS is RockyLinux, we will omit the RockyLinux minor version when constructing the os image tag
osVersion = strings.Split(osVersion, ".")[0]
// If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag
if osName == "rocky" || (osName == "rhel" && osMajorNumber >= 10) {
osVersion = osMajorVersion
}

osTag := fmt.Sprintf("%s%s", osName, osVersion)
Expand Down
42 changes: 39 additions & 3 deletions controllers/object_controls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,42 @@ func TestKernelFullVersion(t *testing.T) {
node *corev1.Node
expected map[string]string
}{
{
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
Labels: map[string]string{
nfdOSReleaseIDLabelKey: "rhel",
nfdOSVersionIDLabelKey: "10.1",
nfdKernelLabelKey: "6.12.0-124.8.1.el10_1.x86_64",
commonGPULabelKey: "true",
},
},
},
expected: map[string]string{
"kernelFullVersion": "6.12.0-124.8.1.el10_1.x86_64",
"imageTagSuffix": "rhel10",
"osVersion": "10",
},
},
{
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
Labels: map[string]string{
nfdOSReleaseIDLabelKey: "rhel",
nfdOSVersionIDLabelKey: "9.6",
nfdKernelLabelKey: "5.14.0-570.78.1.el9_6.x86_64",
commonGPULabelKey: "true",
},
},
},
expected: map[string]string{
"kernelFullVersion": "5.14.0-570.78.1.el9_6.x86_64",
"imageTagSuffix": "rhel9.6",
"osVersion": "9.6",
},
},
{
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -1539,7 +1575,7 @@ func TestKernelFullVersion(t *testing.T) {
expected: map[string]string{
"kernelFullVersion": "5.14.0-611.5.1.el9_7.x86_64",
"imageTagSuffix": "rocky9",
"osVersionMajor": "9",
"osVersion": "9",
},
},
{
Expand All @@ -1557,7 +1593,7 @@ func TestKernelFullVersion(t *testing.T) {
expected: map[string]string{
"kernelFullVersion": "6.8.0-60-generic",
"imageTagSuffix": "ubuntu24.04",
"osVersionMajor": "24.04",
"osVersion": "24.04",
},
},
}
Expand All @@ -1574,6 +1610,6 @@ func TestKernelFullVersion(t *testing.T) {

require.Equal(t, test.expected["kernelFullVersion"], kFVersion)
require.Equal(t, test.expected["imageTagSuffix"], osTag)
require.Equal(t, test.expected["osVersionMajor"], osVersion)
require.Equal(t, test.expected["osVersion"], osVersion)
}
}
12 changes: 6 additions & 6 deletions internal/state/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ func getDriverAppName(cr *nvidiav1alpha1.NVIDIADriver, pool nodePool) string {

var hashBuilder strings.Builder

appNamePrefix := fmt.Sprintf(appNamePrefixFormat, cr.Spec.DriverType, pool.getOS())
appNamePrefix := fmt.Sprintf(appNamePrefixFormat, cr.Spec.DriverType, pool.osTag)
uid := string(cr.UID)

hashBuilder.WriteString(uid)
Expand Down Expand Up @@ -519,7 +519,7 @@ func getDefaultStartupProbe(spec *nvidiav1alpha1.NVIDIADriverSpec) *nvidiav1alph
}

func getDriverImagePath(spec *nvidiav1alpha1.NVIDIADriverSpec, nodePool nodePool) (string, error) {
os := nodePool.getOS()
os := nodePool.osTag

if spec.UsePrecompiledDrivers() {
return spec.GetPrecompiledImagePath(os, nodePool.kernel)
Expand Down Expand Up @@ -547,7 +547,7 @@ func getDriverSpec(cr *nvidiav1alpha1.NVIDIADriver, nodePool nodePool) (*driverS
return nil, fmt.Errorf("no NVIDIADriver CR provided")
}

nvidiaDriverName := getDriverName(cr, nodePool.getOS())
nvidiaDriverName := getDriverName(cr, nodePool.osTag)
nvidiaDriverAppName := getDriverAppName(cr, nodePool)

spec := cr.Spec.DeepCopy()
Expand Down Expand Up @@ -575,7 +575,7 @@ func getDriverSpec(cr *nvidiav1alpha1.NVIDIADriver, nodePool nodePool) (*driverS
Name: nvidiaDriverName,
ImagePath: imagePath,
ManagerImagePath: managerImagePath,
OSVersion: nodePool.getOS(),
OSVersion: nodePool.osTag,
}, nil
}

Expand All @@ -585,7 +585,7 @@ func getGDSSpec(spec *nvidiav1alpha1.NVIDIADriverSpec, pool nodePool) (*gdsDrive
return nil, nil
}
gdsSpec := spec.GPUDirectStorage
imagePath, err := gdsSpec.GetImagePath(pool.getOS())
imagePath, err := gdsSpec.GetImagePath(pool.osTag)
if err != nil {
return nil, err
}
Expand All @@ -602,7 +602,7 @@ func getGDRCopySpec(spec *nvidiav1alpha1.NVIDIADriverSpec, pool nodePool) (*gdrc
return nil, nil
}
gdrcopySpec := spec.GDRCopy
imagePath, err := gdrcopySpec.GetImagePath(pool.getOS())
imagePath, err := gdrcopySpec.GetImagePath(pool.osTag)
if err != nil {
return nil, err
}
Expand Down
31 changes: 31 additions & 0 deletions internal/state/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ func TestGetDriverAppName(t *testing.T) {
osRelease: "ubuntu",
osVersion: "20.04",
}
var err error
Comment thread
cdesiniotis marked this conversation as resolved.
pool.osTag, err = getOSTag(pool.osRelease, pool.osVersion)
assert.NoError(t, err)

actual := getDriverAppName(cr, pool)
expected := "nvidia-gpu-driver-ubuntu20.04-67cc6dbb79"
Expand All @@ -522,11 +525,29 @@ func TestGetDriverAppName(t *testing.T) {
// Now set the osVersion to a really long string
pool.osRelease = "redhatCoreOS"
pool.osVersion = "4.14-414.92.202309282257"
pool.osTag, err = getOSTag(pool.osRelease, pool.osVersion)
assert.NoError(t, err)

actual = getDriverAppName(cr, pool)
expected = "nvidia-gpu-driver-redhatCoreOS4.14-414.92.2023092822-59b779bcc5"
assert.Equal(t, expected, actual)
assert.Equal(t, 63, len(actual))

// RockyLinux
pool.osRelease = "rocky"
pool.osVersion = "9.6"
pool.osTag, err = getOSTag(pool.osRelease, pool.osVersion)
assert.NoError(t, err)
actual = getDriverAppName(cr, pool)
assert.Equal(t, "nvidia-gpu-driver-rocky9-59b779bcc5", actual)

// RHEL10
pool.osRelease = "rhel"
pool.osVersion = "10.1"
pool.osTag, err = getOSTag(pool.osRelease, pool.osVersion)
assert.NoError(t, err)
actual = getDriverAppName(cr, pool)
assert.Equal(t, "nvidia-gpu-driver-rhel10-59b779bcc5", actual)
}

func TestGetDriverAppNameRHCOS(t *testing.T) {
Expand All @@ -544,6 +565,9 @@ func TestGetDriverAppNameRHCOS(t *testing.T) {
osVersion: "4.14",
rhcosVersion: "414.92.202309282257",
}
var err error
pool.osTag, err = getOSTag(pool.osRelease, pool.osVersion)
assert.NoError(t, err)

actual := getDriverAppName(cr, pool)
expected := "nvidia-gpu-driver-rhcos4.14-6f4fc4fc6"
Expand Down Expand Up @@ -940,6 +964,10 @@ func TestGetDriverSpecMultipleNodePools(t *testing.T) {
},
}

var err error
pool1.osTag, err = getOSTag(pool1.osRelease, pool1.osVersion)
require.NoError(t, err)

pool2 := nodePool{
osRelease: "ubuntu",
osVersion: "20.04",
Expand All @@ -950,6 +978,9 @@ func TestGetDriverSpecMultipleNodePools(t *testing.T) {
},
}

pool2.osTag, err = getOSTag(pool2.osRelease, pool2.osVersion)
require.NoError(t, err)

spec1, err := getDriverSpec(cr, pool1)
require.NoError(t, err)
spec2, err := getDriverSpec(cr, pool2)
Expand Down
29 changes: 23 additions & 6 deletions internal/state/nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"maps"
"strconv"
"strings"

corev1 "k8s.io/api/core/v1"
Expand All @@ -38,6 +39,7 @@ type nodePool struct {
name string
osRelease string
osVersion string
osTag string
rhcosVersion string
kernel string
nodeSelector map[string]string
Expand Down Expand Up @@ -94,7 +96,13 @@ func getNodePools(ctx context.Context, k8sClient client.Client, selector map[str
nodePool.nodeSelector[nfdOSVersionIDLabelKey] = osVersion
nodePool.osRelease = osID
nodePool.osVersion = osVersion
nodePool.name = nodePool.getOS()
Comment thread
cdesiniotis marked this conversation as resolved.

osTag, err := getOSTag(osID, osVersion)
if err != nil {
return nil, fmt.Errorf("failed to get OS info for node %s: %w", node.Name, err)
}
nodePool.osTag = osTag
nodePool.name = osTag

if precompiled {
kernelVersion, ok := nodeLabels[nfdKernelLabelKey]
Expand Down Expand Up @@ -132,10 +140,19 @@ func getNodePools(ctx context.Context, k8sClient client.Client, selector map[str
return nodePools, nil
}

func (n nodePool) getOS() string {
if n.osRelease == "rocky" {
// If the OS is RockyLinux, we will omit the RockyLinux minor version when constructing the os image tag
n.osVersion = strings.Split(n.osVersion, ".")[0]
func getOSTag(osRelease, osVersion string) (string, error) {
Comment thread
tariq1890 marked this conversation as resolved.
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return "", fmt.Errorf("failed to parse os version: %w", err)
}

var osTagSuffix string
// If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag
if osRelease == "rocky" || (osRelease == "rhel" && osMajorNumber >= 10) {
osTagSuffix = osMajorVersion
} else {
osTagSuffix = osVersion
}
return fmt.Sprintf("%s%s", n.osRelease, n.osVersion)
return fmt.Sprintf("%s%s", osRelease, osTagSuffix), nil
}
83 changes: 83 additions & 0 deletions internal/state/nodepool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package state

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGetOSTag(t *testing.T) {
tests := []struct {
description string
osRelease string
osVersion string
expected string
expectError bool
errorMessage string
}{
{
description: "valid os release & version",
osRelease: "rhel",
osVersion: "9.4",
expected: "rhel9.4",
expectError: false,
},
{
description: "valid os release & version - ubuntu",
osRelease: "ubuntu",
osVersion: "24.04",
expected: "ubuntu24.04",
expectError: false,
},
{
description: "rocky linux",
osRelease: "rocky",
osVersion: "9.4",
expected: "rocky9",
expectError: false,
},
{
description: "RHEL 10",
osRelease: "rhel",
osVersion: "10.1",
expected: "rhel10",
expectError: false,
},
{
description: "invalid os version",
osRelease: "rhel",
osVersion: "A.10",
expectError: true,
errorMessage: "failed to parse os version: strconv.Atoi: parsing \"A\": invalid syntax",
},
}

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
actual, err := getOSTag(test.osRelease, test.osVersion)
if test.expectError {
require.Error(t, err)
require.Equal(t, test.errorMessage, err.Error())
} else {
require.NoError(t, err)
}
require.Equal(t, test.expected, actual)
})
}
}