diff --git a/internal/server/register.go b/internal/server/register.go index c86565f..a45b178 100644 --- a/internal/server/register.go +++ b/internal/server/register.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 The HAMi Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package server import ( @@ -19,6 +35,8 @@ import ( ) func (ps *PluginServer) watchAndRegister() { + ps.wg.Add(1) + defer ps.wg.Done() timer := time.After(1 * time.Second) for { select { @@ -110,6 +128,9 @@ func (ps *PluginServer) getDeviceNetworkID(idx int, deviceType string) (int, err } func (ps *PluginServer) registerKubelet() error { + if ps.registerKubeletFunc != nil { + return ps.registerKubeletFunc() + } conn, err := ps.dial(v1beta1.KubeletSocket, 5*time.Second) if err != nil { return err @@ -135,6 +156,9 @@ func (ps *PluginServer) registerKubelet() error { } func (ps *PluginServer) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) { + if ps.dialFunc != nil { + return ps.dialFunc(unixSocketPath, timeout) + } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() c, _ := grpc.NewClient(unixSocketPath, diff --git a/internal/server/server.go b/internal/server/server.go index da07524..4462f27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -23,6 +23,7 @@ import ( "net" "os" "path" + "sync" "time" "google.golang.org/grpc" @@ -64,6 +65,12 @@ type PluginServer struct { stopCh chan interface{} healthCh chan int32 checkIdleVNPUInterval int + wg sync.WaitGroup + + // test hooks — injected by tests to avoid real socket/kubelet dependencies + dialFunc func(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) + registerKubeletFunc func() error + prepareHostResourcesFunc func() error } type RuntimeInfo struct { @@ -82,7 +89,6 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval handshakeAnno: fmt.Sprintf("hami.io/node-handshake-%s", commonWord), allocAnno: fmt.Sprintf("huawei.com/%s", commonWord), toAllocDeviceAnno: fmt.Sprintf("hami.io/%s-devices-to-allocate", commonWord), - grpcServer: grpc.NewServer(), mgr: mgr, socket: path.Join(v1beta1.DevicePluginPath, fmt.Sprintf("%s.sock", commonWord)), stopCh: make(chan interface{}), @@ -94,14 +100,25 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval return server, nil } +// prepareHostResources wraps the package-level prepareHostResources() to +// allow test injection via the prepareHostResourcesFunc hook. +func (ps *PluginServer) prepareHostResources() error { + if ps.prepareHostResourcesFunc != nil { + return ps.prepareHostResourcesFunc() + } + return prepareHostResources() +} + func (ps *PluginServer) Start() error { // Automatically prepare host environment when the plugin starts - if err := prepareHostResources(); err != nil { + if err := ps.prepareHostResources(); err != nil { klog.Errorf("Failed to prepare host resources: %v. vNPU core functionality will be impaired.", err) return err } ps.stopCh = make(chan interface{}) + ps.grpcServer = grpc.NewServer() + err := ps.mgr.UpdateDevice() if err != nil { return err @@ -120,6 +137,8 @@ func (ps *PluginServer) Start() error { } func (ps *PluginServer) startPeriodicCheckIdleVNPUs() { + ps.wg.Add(1) + defer ps.wg.Done() ticker := time.NewTicker(time.Duration(ps.checkIdleVNPUInterval) * time.Second) defer ticker.Stop() for { @@ -137,8 +156,19 @@ func (ps *PluginServer) startPeriodicCheckIdleVNPUs() { } func (ps *PluginServer) Stop() error { - close(ps.stopCh) - ps.grpcServer.Stop() + if ps.stopCh != nil { + select { + case <-ps.stopCh: + // already closed; no-op + default: + close(ps.stopCh) + } + } + if ps.grpcServer != nil { + ps.grpcServer.Stop() + } + ps.wg.Wait() + _ = os.Remove(ps.socket) return nil } @@ -158,7 +188,9 @@ func (ps *PluginServer) serve() error { } v1beta1.RegisterDevicePluginServer(ps.grpcServer, ps) resourceName := ps.mgr.ResourceName() + ps.wg.Add(1) go func() { + defer ps.wg.Done() lastCrashTime := time.Now() restartCount := 0 for { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index c65783d..1072bf8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -20,9 +20,12 @@ import ( "context" "encoding/json" "fmt" + "os" + "path" "strings" "testing" + "google.golang.org/grpc/grpclog" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/fake" @@ -762,3 +765,182 @@ func TestCleanupIdleVNPUs(t *testing.T) { }) } } + +// ============================================================================ +// gRPC restart tests +// ============================================================================ + +// panicOnFatalLogger is a gRPC logger that converts Fatalf calls to panics. +// This allows tests to verify that gRPC does NOT call Fatalf (which would +// otherwise call os.Exit(1) and abort the test process). +// +// Usage: +// +// defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr)) +// grpclog.SetLoggerV2(newPanicOnFatalLogger()) +type panicOnFatalLogger struct { + inner grpclog.LoggerV2 +} + +func newPanicOnFatalLogger() *panicOnFatalLogger { + return &panicOnFatalLogger{ + inner: grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr), + } +} + +var _ grpclog.LoggerV2 = (*panicOnFatalLogger)(nil) + +func (l *panicOnFatalLogger) Info(args ...interface{}) { l.inner.Info(args...) } +func (l *panicOnFatalLogger) Infoln(args ...interface{}) { l.inner.Infoln(args...) } +func (l *panicOnFatalLogger) Infof(format string, args ...interface{}) { + l.inner.Infof(format, args...) +} +func (l *panicOnFatalLogger) Warning(args ...interface{}) { l.inner.Warning(args...) } +func (l *panicOnFatalLogger) Warningln(args ...interface{}) { l.inner.Warningln(args...) } +func (l *panicOnFatalLogger) Warningf(format string, args ...interface{}) { + l.inner.Warningf(format, args...) +} +func (l *panicOnFatalLogger) Error(args ...interface{}) { l.inner.Error(args...) } +func (l *panicOnFatalLogger) Errorln(args ...interface{}) { l.inner.Errorln(args...) } +func (l *panicOnFatalLogger) Errorf(format string, args ...interface{}) { + l.inner.Errorf(format, args...) +} +func (l *panicOnFatalLogger) V(level int) bool { return l.inner.V(level) } + +func (l *panicOnFatalLogger) Fatalf(format string, args ...interface{}) { + panic(fmt.Sprintf("grpc FATAL: "+format, args...)) +} + +func (l *panicOnFatalLogger) Fatalln(args ...interface{}) { + panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprintln(args...))) +} + +func (l *panicOnFatalLogger) Fatal(args ...interface{}) { + panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprint(args...))) +} + +// setupRestartablePluginServer creates a PluginServer with all test hooks +// injected so that Start()/Stop() work without real socket files or a kubelet. +func setupRestartablePluginServer(t *testing.T) *PluginServer { + t.Helper() + + ps := &PluginServer{ + commonWord: "test-ascend", + registerAnno: "hami.io/node-register-test-ascend", + handshakeAnno: "hami.io/node-handshake-test-ascend", + allocAnno: "huawei.com/test-ascend", + toAllocDeviceAnno: "hami.io/test-ascend-devices-to-allocate", + mgr: &FakeManager{ResourceNameFunc: func() string { return "test-ascend" }}, + socket: path.Join(t.TempDir(), "test-ascend.sock"), + stopCh: make(chan interface{}), + healthCh: make(chan int32), + checkIdleVNPUInterval: 3600, + dialFunc: nil, + registerKubeletFunc: func() error { + return nil + }, + prepareHostResourcesFunc: func() error { + return nil + }, + } + return ps +} + +// TestGrpcServer_RestartDoesNotPanic verifies that a single Stop+Start cycle +// does not trigger the gRPC "RegisterService after Serve" fatal error. +func TestGrpcServer_RestartDoesNotPanic(t *testing.T) { + defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr)) + grpclog.SetLoggerV2(newPanicOnFatalLogger()) + + ps := setupRestartablePluginServer(t) + + // First Start + if err := ps.Start(); err != nil { + t.Fatalf("first Start() failed: %v", err) + } + + // Stop + if err := ps.Stop(); err != nil { + t.Fatalf("Stop() failed: %v", err) + } + + // Second Start — this must not trigger grpc Fatalf + if err := ps.Start(); err != nil { + t.Fatalf("second Start() after restart failed: %v", err) + } + + // Cleanup + if err := ps.Stop(); err != nil { + t.Fatalf("final Stop() failed: %v", err) + } +} + +// TestGrpcServer_MultipleRestarts verifies that the server can survive +// multiple Stop+Start cycles without panic. +func TestGrpcServer_MultipleRestarts(t *testing.T) { + defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr)) + grpclog.SetLoggerV2(newPanicOnFatalLogger()) + + ps := setupRestartablePluginServer(t) + + for i := 0; i < 5; i++ { + if err := ps.Start(); err != nil { + t.Fatalf("Start() iteration %d failed: %v", i, err) + } + if err := ps.Stop(); err != nil { + t.Fatalf("Stop() iteration %d failed: %v", i, err) + } + } +} + +// TestGrpcServer_StopWithoutStart verifies that Stop() is safe when +// Start() was never called (no goroutines to wait for). +func TestGrpcServer_StopWithoutStart(t *testing.T) { + ps := setupRestartablePluginServer(t) + if err := ps.Stop(); err != nil { + t.Fatalf("Stop() without Start() should be safe: %v", err) + } +} + +// TestGrpcServer_DoubleStop verifies that calling Stop() twice is safe. +func TestGrpcServer_DoubleStop(t *testing.T) { + ps := setupRestartablePluginServer(t) + + if err := ps.Start(); err != nil { + t.Fatalf("Start() failed: %v", err) + } + + if err := ps.Stop(); err != nil { + t.Fatalf("first Stop() failed: %v", err) + } + + if err := ps.Stop(); err != nil { + t.Fatalf("second Stop() should be safe: %v", err) + } +} + +// TestGrpcServer_StopWaitForAllGoroutines verifies that Stop() returns +// only after all goroutines have exited. We verify this indirectly by +// checking that Start() after Stop() does not race (goroutine leak would +// manifest as stale channel reads). +func TestGrpcServer_StopWaitForAllGoroutines(t *testing.T) { + ps := setupRestartablePluginServer(t) + + if err := ps.Start(); err != nil { + t.Fatalf("Start() failed: %v", err) + } + + if err := ps.Stop(); err != nil { + t.Fatalf("Stop() failed: %v", err) + } + + // bgWG and serveWG should be zero after Stop() returns. + // A new cycle confirms there is no deadlock or hang. + if err := ps.Start(); err != nil { + t.Fatalf("Start() after Stop() failed: %v", err) + } + + if err := ps.Stop(); err != nil { + t.Fatalf("Stop() failed: %v", err) + } +}