Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"path"
"path/filepath"
"strings"
"sync"
"time"

cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
Expand Down Expand Up @@ -62,7 +63,12 @@ type nvidiaDevicePlugin struct {
socket string
server *grpc.Server
health chan *rm.Device
stop chan interface{}

// healthCtx and healthCancel control the health check goroutine lifecycle.
healthCtx context.Context
healthCancel context.CancelFunc
// healthWg is used to wait for the health check goroutine to complete during cleanup.
healthWg sync.WaitGroup

imexChannels imex.Channels

Expand Down Expand Up @@ -90,11 +96,6 @@ func (o *options) devicePluginForResource(ctx context.Context, resourceManager r
mps: mpsOptions,

socket: getPluginSocketPath(resourceManager.Resource()),
// These will be reinitialized every
// time the plugin server is restarted.
server: nil,
health: nil,
stop: nil,
}
return &plugin, nil
}
Expand All @@ -106,19 +107,6 @@ func getPluginSocketPath(resource spec.ResourceName) string {
return filepath.Join(pluginapi.DevicePluginPath, pluginName) + ".sock"
}

func (plugin *nvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.stop = make(chan interface{})
}

func (plugin *nvidiaDevicePlugin) cleanup() {
close(plugin.stop)
plugin.server = nil
plugin.health = nil
plugin.stop = nil
}

// Devices returns the full set of devices associated with the plugin.
func (plugin *nvidiaDevicePlugin) Devices() rm.Devices {
return plugin.rm.Devices()
Expand All @@ -127,16 +115,22 @@ func (plugin *nvidiaDevicePlugin) Devices() rm.Devices {
// Start starts the gRPC server, registers the device plugin with the Kubelet,
// and starts the device healthchecks.
func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
plugin.initialize()

if err := plugin.mps.waitForDaemon(); err != nil {
return fmt.Errorf("error waiting for MPS daemon: %w", err)
}

plugin.health = make(chan *rm.Device)
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugin.health is created as an unbuffered channel and the health-check goroutine can send to it even when no ListAndWatch stream is active. Since Stop() now cancels the context and then waits for the health goroutine to exit, a blocked send on plugin.health can deadlock Stop() (context cancellation won’t unblock a channel send). Consider making the channel buffered (e.g., sized to number of devices) and/or ensuring all sends in health-check code are non-blocking/select on ctx.Done().

Suggested change
plugin.health = make(chan *rm.Device)
// Use a buffered channel for health notifications to avoid blocking when
// no ListAndWatch stream is currently consuming from plugin.health.
healthBufSize := len(plugin.rm.Devices())
if healthBufSize < 1 {
healthBufSize = 1
}
plugin.health = make(chan *rm.Device, healthBufSize)

Copilot uses AI. Check for mistakes.
plugin.healthCtx, plugin.healthCancel = context.WithCancel(plugin.ctx)

err := plugin.Serve()
if err != nil {
klog.Errorf("Could not start device plugin for '%s': %s", plugin.rm.Resource(), err)
plugin.cleanup()
plugin.healthCancel()
if plugin.server != nil {
plugin.server.Stop()
plugin.server = nil
}
plugin.health = nil
return err
}
Comment on lines 125 to 135
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Serve() failure, Start() returns immediately after creating plugin.health and plugin.healthCtx but does not cancel/close them. If the same plugin instance is retried or kept around, this leaks resources and leaves fields in a partially-initialized state. Consider deferring a cleanup (cancel + close channel) on the Serve() error path.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArangoGutierrez (nit) thoughhealthCancel()properly cancels things but should these still be set to nil

plugin.health = nil
plugin.server = nil

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Now the error path also calls plugin.server.Stop() and nils both plugin.server and plugin.health to leave the struct in a clean state for potential retry. See latest commit.

klog.Infof("Starting to serve '%s' on %s", plugin.rm.Resource(), plugin.socket)
Expand All @@ -148,10 +142,17 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
}
klog.Infof("Registered device plugin for '%s' with Kubelet", plugin.rm.Resource())

plugin.healthWg.Add(1)
go func() {
defer plugin.healthWg.Done()
// TODO: add MPS health check
err := plugin.rm.CheckHealth(plugin.stop, plugin.health)
if err != nil {
err := plugin.rm.CheckHealth(plugin.healthCtx, plugin.health)
switch {
case err == nil:
klog.Infof("Health check completed successfully for '%s'", plugin.rm.Resource())
case errors.Is(err, context.Canceled):
klog.V(4).Infof("Health check canceled for '%s' (plugin shutdown)", plugin.rm.Resource())
default:
klog.Errorf("Failed to start health check: %v; continuing with health checks disabled", err)
}
}()
Expand All @@ -164,12 +165,17 @@ func (plugin *nvidiaDevicePlugin) Stop() error {
if plugin == nil || plugin.server == nil {
return nil
}
// Stop health checks if they were started.
if plugin.healthCancel != nil {
plugin.healthCancel()
plugin.healthWg.Wait()
}
klog.Infof("Stopping to serve '%s' on %s", plugin.rm.Resource(), plugin.socket)
plugin.server.Stop()
plugin.server = nil
if err := os.Remove(plugin.socket); err != nil && !os.IsNotExist(err) {
return err
}
plugin.cleanup()
return nil
}

Expand All @@ -181,6 +187,7 @@ func (plugin *nvidiaDevicePlugin) Serve() error {
return err
}

plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
pluginapi.RegisterDevicePluginServer(plugin.server, plugin)

go func() {
Expand Down Expand Up @@ -271,7 +278,8 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D

for {
select {
case <-plugin.stop:
case <-plugin.healthCtx.Done():
Comment thread
guptaNswati marked this conversation as resolved.
klog.V(4).Infof("Stopping health checks for '%s'", plugin.rm.Resource())
return nil
case d := <-plugin.health:
// FIXME: there is no way to recover from the Unhealthy state.
Expand Down Expand Up @@ -368,7 +376,7 @@ func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
// updateResponseForMPS ensures that the ContainerAllocate response contains the information required to use MPS.
// This includes per-resource pipe and log directories as well as a global daemon-specific shm
// and assumes that an MPS control daemon has already been started.
func (plugin nvidiaDevicePlugin) updateResponseForMPS(response *pluginapi.ContainerAllocateResponse) {
func (plugin *nvidiaDevicePlugin) updateResponseForMPS(response *pluginapi.ContainerAllocateResponse) {
plugin.mps.updateReponse(response)
}

Expand Down
128 changes: 102 additions & 26 deletions internal/rm/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package rm

import (
"context"
"fmt"
"os"
"strconv"
Expand All @@ -40,8 +41,44 @@ const (
envEnableHealthChecks = "DP_ENABLE_HEALTHCHECKS"
)

type nvmlDeviceHealthChecker struct {
// nvmllib is the NVML interface used to query device handles during event
// monitoring. Stored here rather than accessed via nvmlResourceManager to
// keep the health checker decoupled and independently testable.
nvmllib nvml.Interface
devices Devices
parentToDeviceMap map[string]*Device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it a type similar to type devicePlacementMap map[string]map[uint32]map[uint32]*Device
https://github.com/NVIDIA/k8s-dra-driver-gpu/blob/main/cmd/gpu-kubelet-plugin/device_health.go#L36

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea — the DRA pattern with type devicePlacementMap map[string]map[uint32]map[uint32]*Device is cleaner. That's a bigger refactor that changes how device lookup works in runEventMonitor. I'll create a follow-up PR to align DP with DRA's data model here.

deviceIDToGiMap map[string]uint32
deviceIDToCiMap map[string]uint32

xidsDisabled disabledXIDs
unhealthy chan<- *Device
}

// markUnhealthy sends a device to the unhealthy channel, respecting context cancellation.
// Returns false if the context was canceled before the send completed.
// It prefers completing the send if a receiver is immediately available,
// even if the context is already done.
func (h *nvmlDeviceHealthChecker) markUnhealthy(ctx context.Context, d *Device) bool {
// Try a non-blocking send first so that an available receiver
// is served even when the context is already cancelled.
select {
case h.unhealthy <- d:
return true
default:
}
// The channel was not ready; block until either it becomes ready
// or the context is cancelled.
select {
case h.unhealthy <- d:
return true
case <-ctx.Done():
return false
}
}

// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices
func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devices, unhealthy chan<- *Device) error {
func (r *nvmlResourceManager) checkHealth(ctx context.Context, devices Devices, unhealthy chan<- *Device) error {
xids := getDisabledHealthCheckXids()
if xids.IsAllDisabled() {
return nil
Expand Down Expand Up @@ -71,33 +108,55 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
_ = eventSet.Free()
}()

// Construct the device maps.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you already called it out. Yes, we should refractor it into helpers

// TODO: This should be factored out. The main issue is marking the devices
// unhealthy as part of this loop.
parentToDeviceMap := make(map[string]*Device)
deviceIDToGiMap := make(map[string]uint32)
deviceIDToCiMap := make(map[string]uint32)

eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError)
for _, d := range devices {
uuid, gi, ci, err := r.getDevicePlacement(d)
uuid, gi, ci, err := (&withDevicePlacements{r.nvml}).getDevicePlacement(d)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some context as to why we landed on this change / refactor?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this refactor is to take the DevicePlugin to a similar state as our DRA after @guptaNswati Health check work.
This in order to prepare both DRA and DevicePlugin for NVsentinel integration via the Device-API

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add a detailed comment on how it will help. I also had to look into it. Because rn when we are only using NVML, its not making sense to decouple DevicePlacement logic from NVMLResourceManager. We need to note how this placement code can be reused with multiple providers and not just NVML.

type nvmlResourceManager struct {
	resourceManager
	nvml nvml.Interface

to having a new wrapper

type withDevicePlacements struct {
	nvml.Interface
}

if err != nil {
klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err)
unhealthy <- d
select {
case unhealthy <- d:
case <-ctx.Done():
return ctx.Err()
}
continue
}
deviceIDToGiMap[d.ID] = gi
deviceIDToCiMap[d.ID] = ci
parentToDeviceMap[uuid] = d
Comment on lines 128 to 130
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parentToDeviceMap is keyed by uuid returned from getDevicePlacement(), which is the parent UUID for MIG devices. If multiple MIG devices share the same parent GPU, this assignment overwrites earlier entries, so events for that parent can only ever map to one MIG device. To correctly handle multiple MIG instances per parent, use a mapping that can represent multiple devices per parent (e.g., parent UUID -> list, or parent UUID -> (GI,CI)->device).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — this is a pre-existing issue. When multiple MIG devices share the same parent GPU, only the last one gets stored. I'll track this as a separate fix (needs its own tests for the multi-MIG case) in a follow-up PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. sounds good.

}

p := nvmlDeviceHealthChecker{
nvmllib: r.nvml,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed for tests?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArangoGutierrez why do we need to have the lib here? nvmlResourceManager.nvml is used everywhere

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment on the struct field explaining this. The NVML interface is stored on the health checker (rather than accessing it via nvmlResourceManager) to keep it decoupled and independently testable. The health checker methods (registerDeviceEvents, runEventMonitor) need DeviceGetHandleByUUID during event registration and processing.

devices: devices,
unhealthy: unhealthy,
parentToDeviceMap: parentToDeviceMap,
deviceIDToGiMap: deviceIDToGiMap,
deviceIDToCiMap: deviceIDToCiMap,
xidsDisabled: xids,
}
p.registerDeviceEvents(ctx, eventSet)

return p.runEventMonitor(ctx, eventSet)
}

gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid)
func (h *nvmlDeviceHealthChecker) registerDeviceEvents(ctx context.Context, eventSet nvml.EventSet) {
eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError)
for uuid, d := range h.parentToDeviceMap {
gpu, ret := h.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret)
unhealthy <- d
h.markUnhealthy(ctx, d)
continue
}

supportedEvents, ret := gpu.GetSupportedEventTypes()
if ret != nvml.SUCCESS {
klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret)
unhealthy <- d
h.markUnhealthy(ctx, d)
continue
}

Expand All @@ -106,15 +165,17 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
case ret == nvml.ERROR_NOT_SUPPORTED:
klog.Warningf("Device %v is too old to support healthchecking.", d.ID)
case ret != nvml.SUCCESS:
klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret)
unhealthy <- d
klog.Infof("Unable to register events for %v: %v; marking it as unhealthy", d.ID, ret)
h.markUnhealthy(ctx, d)
}
}
}

func (h *nvmlDeviceHealthChecker) runEventMonitor(ctx context.Context, eventSet nvml.EventSet) error {
for {
select {
case <-stop:
return nil
case <-ctx.Done():
return ctx.Err()
default:
}

Expand All @@ -124,18 +185,22 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
}
if ret != nvml.SUCCESS {
klog.Infof("Error waiting for event: %v; Marking all devices as unhealthy", ret)
for _, d := range devices {
unhealthy <- d
for _, d := range h.devices {
if !h.markUnhealthy(ctx, d) {
return ctx.Err()
}
}
continue
Comment on lines 186 to 193
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When eventSet.Wait() returns a non-timeout error, the loop marks all devices unhealthy and continues. If the error is persistent, this can spin forever and repeatedly attempt blocking sends to h.unhealthy, preventing shutdown if the receiver isn’t draining fast enough. Consider returning the error (or backing off) and make unhealthy notifications context-aware (select on ctx.Done()) to avoid blocking indefinitely.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pre-existing behavior — the busy loop on persistent NVML errors existed before this refactor. The new code does improve the situation: markUnhealthy is now context-aware, so the loop can be interrupted by cancellation even during sends. Adding backoff or error-count limits would be a good follow-up improvement.

}

// TODO: We create an event mask for other event types but don't handle
// them here.
if e.EventType != nvml.EventTypeXidCriticalError {
klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e)
continue
}

if xids.IsDisabled(e.EventData) {
if h.xidsDisabled.IsDisabled(e.EventData) {
klog.Infof("Skipping event %+v", e)
continue
}
Expand All @@ -145,29 +210,31 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
if ret != nvml.SUCCESS {
// If we cannot reliably determine the device UUID, we mark all devices as unhealthy.
klog.Infof("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy.", e, ret)
for _, d := range devices {
unhealthy <- d
for _, d := range h.devices {
if !h.markUnhealthy(ctx, d) {
return ctx.Err()
}
}
continue
}

d, exists := parentToDeviceMap[eventUUID]
d, exists := h.parentToDeviceMap[eventUUID]
if !exists {
klog.Infof("Ignoring event for unexpected device: %v", eventUUID)
continue
}

if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF {
gi := deviceIDToGiMap[d.ID]
ci := deviceIDToCiMap[d.ID]
gi := h.deviceIDToGiMap[d.ID]
ci := h.deviceIDToCiMap[d.ID]
if gi != e.GpuInstanceId || ci != e.ComputeInstanceId {
continue
}
klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci)
}

klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID)
unhealthy <- d
h.markUnhealthy(ctx, d)
}
}

Expand Down Expand Up @@ -276,25 +343,34 @@ func newHealthCheckXIDs(xids ...string) disabledXIDs {
return output
}

// withDevicePlacements wraps nvml.Interface to provide device placement
// resolution (parent UUID, GPU Instance, Compute Instance) independently of
// nvmlResourceManager. This decoupling allows placement logic to be reused
// across different health-check providers (e.g., NVsentinel, Device-API)
// without requiring access to the full resource manager.
type withDevicePlacements struct {
nvml.Interface
}

// getDevicePlacement returns the placement of the specified device.
// For a MIG device the placement is defined by the 3-tuple <parent UUID, GI, CI>
// For a full device the returned 3-tuple is the device's uuid and 0xFFFFFFFF for the other two elements.
func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, uint32, uint32, error) {
func (p *withDevicePlacements) getDevicePlacement(d *Device) (string, uint32, uint32, error) {
if !d.IsMigDevice() {
return d.GetUUID(), 0xFFFFFFFF, 0xFFFFFFFF, nil
}
return r.getMigDeviceParts(d)
return p.getMigDeviceParts(d)
}

// getMigDeviceParts returns the parent GI and CI ids of the MIG device.
func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, uint32, uint32, error) {
func (p *withDevicePlacements) getMigDeviceParts(d *Device) (string, uint32, uint32, error) {
if !d.IsMigDevice() {
return "", 0, 0, fmt.Errorf("cannot get GI and CI of full device")
}

uuid := d.GetUUID()
// For older driver versions, the call to DeviceGetHandleByUUID will fail for MIG devices.
mig, ret := r.nvml.DeviceGetHandleByUUID(uuid)
mig, ret := p.DeviceGetHandleByUUID(uuid)
if ret == nvml.SUCCESS {
parentHandle, ret := mig.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS {
Expand Down
Loading
Loading