Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions internal/server/register.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
* Copyright 2024 The HAMi Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package server

import (
Expand All @@ -19,6 +35,8 @@ import (
)

func (ps *PluginServer) watchAndRegister() {
ps.wg.Add(1)
defer ps.wg.Done()
timer := time.After(1 * time.Second)
for {
select {
Expand Down Expand Up @@ -110,6 +128,9 @@ func (ps *PluginServer) getDeviceNetworkID(idx int, deviceType string) (int, err
}

func (ps *PluginServer) registerKubelet() error {
if ps.registerKubeletFunc != nil {
return ps.registerKubeletFunc()
}
conn, err := ps.dial(v1beta1.KubeletSocket, 5*time.Second)
if err != nil {
return err
Expand All @@ -135,6 +156,9 @@ func (ps *PluginServer) registerKubelet() error {
}

func (ps *PluginServer) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
if ps.dialFunc != nil {
return ps.dialFunc(unixSocketPath, timeout)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
c, _ := grpc.NewClient(unixSocketPath,
Expand Down
40 changes: 36 additions & 4 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"os"
"path"
"sync"
"time"

"google.golang.org/grpc"
Expand Down Expand Up @@ -64,6 +65,12 @@ type PluginServer struct {
stopCh chan interface{}
healthCh chan int32
checkIdleVNPUInterval int
wg sync.WaitGroup

// test hooks — injected by tests to avoid real socket/kubelet dependencies
dialFunc func(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error)
registerKubeletFunc func() error
prepareHostResourcesFunc func() error
}

type RuntimeInfo struct {
Expand All @@ -82,7 +89,6 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval
handshakeAnno: fmt.Sprintf("hami.io/node-handshake-%s", commonWord),
allocAnno: fmt.Sprintf("huawei.com/%s", commonWord),
toAllocDeviceAnno: fmt.Sprintf("hami.io/%s-devices-to-allocate", commonWord),
grpcServer: grpc.NewServer(),
mgr: mgr,
socket: path.Join(v1beta1.DevicePluginPath, fmt.Sprintf("%s.sock", commonWord)),
stopCh: make(chan interface{}),
Expand All @@ -94,14 +100,25 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval
return server, nil
}

// prepareHostResources wraps the package-level prepareHostResources() to
// allow test injection via the prepareHostResourcesFunc hook.
func (ps *PluginServer) prepareHostResources() error {
if ps.prepareHostResourcesFunc != nil {
return ps.prepareHostResourcesFunc()
}
return prepareHostResources()
}

func (ps *PluginServer) Start() error {
// Automatically prepare host environment when the plugin starts
if err := prepareHostResources(); err != nil {
if err := ps.prepareHostResources(); err != nil {
klog.Errorf("Failed to prepare host resources: %v. vNPU core functionality will be impaired.", err)
return err
}

ps.stopCh = make(chan interface{})
ps.grpcServer = grpc.NewServer()

err := ps.mgr.UpdateDevice()
if err != nil {
return err
Expand All @@ -120,6 +137,8 @@ func (ps *PluginServer) Start() error {
}

func (ps *PluginServer) startPeriodicCheckIdleVNPUs() {
ps.wg.Add(1)
defer ps.wg.Done()
ticker := time.NewTicker(time.Duration(ps.checkIdleVNPUInterval) * time.Second)
defer ticker.Stop()
for {
Expand All @@ -137,8 +156,19 @@ func (ps *PluginServer) startPeriodicCheckIdleVNPUs() {
}

func (ps *PluginServer) Stop() error {
close(ps.stopCh)
ps.grpcServer.Stop()
if ps.stopCh != nil {
select {
case <-ps.stopCh:
// already closed; no-op
default:
close(ps.stopCh)
}
}
if ps.grpcServer != nil {
ps.grpcServer.Stop()
}
ps.wg.Wait()
_ = os.Remove(ps.socket)
return nil
}

Expand All @@ -158,7 +188,9 @@ func (ps *PluginServer) serve() error {
}
v1beta1.RegisterDevicePluginServer(ps.grpcServer, ps)
resourceName := ps.mgr.ResourceName()
ps.wg.Add(1)
go func() {
defer ps.wg.Done()
lastCrashTime := time.Now()
restartCount := 0
for {
Expand Down
182 changes: 182 additions & 0 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"path"
"strings"
"testing"

"google.golang.org/grpc/grpclog"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
Expand Down Expand Up @@ -762,3 +765,182 @@ func TestCleanupIdleVNPUs(t *testing.T) {
})
}
}

// ============================================================================
// gRPC restart tests
// ============================================================================

// panicOnFatalLogger is a gRPC logger that converts Fatalf calls to panics.
// This allows tests to verify that gRPC does NOT call Fatalf (which would
// otherwise call os.Exit(1) and abort the test process).
//
// Usage:
//
// defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
// grpclog.SetLoggerV2(newPanicOnFatalLogger())
type panicOnFatalLogger struct {
inner grpclog.LoggerV2
}

func newPanicOnFatalLogger() *panicOnFatalLogger {
return &panicOnFatalLogger{
inner: grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr),
}
}

var _ grpclog.LoggerV2 = (*panicOnFatalLogger)(nil)

func (l *panicOnFatalLogger) Info(args ...interface{}) { l.inner.Info(args...) }
func (l *panicOnFatalLogger) Infoln(args ...interface{}) { l.inner.Infoln(args...) }
func (l *panicOnFatalLogger) Infof(format string, args ...interface{}) {
l.inner.Infof(format, args...)
}
func (l *panicOnFatalLogger) Warning(args ...interface{}) { l.inner.Warning(args...) }
func (l *panicOnFatalLogger) Warningln(args ...interface{}) { l.inner.Warningln(args...) }
func (l *panicOnFatalLogger) Warningf(format string, args ...interface{}) {
l.inner.Warningf(format, args...)
}
func (l *panicOnFatalLogger) Error(args ...interface{}) { l.inner.Error(args...) }
func (l *panicOnFatalLogger) Errorln(args ...interface{}) { l.inner.Errorln(args...) }
func (l *panicOnFatalLogger) Errorf(format string, args ...interface{}) {
l.inner.Errorf(format, args...)
}
func (l *panicOnFatalLogger) V(level int) bool { return l.inner.V(level) }

func (l *panicOnFatalLogger) Fatalf(format string, args ...interface{}) {
panic(fmt.Sprintf("grpc FATAL: "+format, args...))
}

func (l *panicOnFatalLogger) Fatalln(args ...interface{}) {
panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprintln(args...)))
}

func (l *panicOnFatalLogger) Fatal(args ...interface{}) {
panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprint(args...)))
}

// setupRestartablePluginServer creates a PluginServer with all test hooks
// injected so that Start()/Stop() work without real socket files or a kubelet.
func setupRestartablePluginServer(t *testing.T) *PluginServer {
t.Helper()

ps := &PluginServer{
commonWord: "test-ascend",
registerAnno: "hami.io/node-register-test-ascend",
handshakeAnno: "hami.io/node-handshake-test-ascend",
allocAnno: "huawei.com/test-ascend",
toAllocDeviceAnno: "hami.io/test-ascend-devices-to-allocate",
mgr: &FakeManager{ResourceNameFunc: func() string { return "test-ascend" }},
socket: path.Join(t.TempDir(), "test-ascend.sock"),
stopCh: make(chan interface{}),
healthCh: make(chan int32),
checkIdleVNPUInterval: 3600,
dialFunc: nil,
registerKubeletFunc: func() error {
return nil
},
prepareHostResourcesFunc: func() error {
return nil
},
}
return ps
}

// TestGrpcServer_RestartDoesNotPanic verifies that a single Stop+Start cycle
// does not trigger the gRPC "RegisterService after Serve" fatal error.
func TestGrpcServer_RestartDoesNotPanic(t *testing.T) {
defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
grpclog.SetLoggerV2(newPanicOnFatalLogger())

ps := setupRestartablePluginServer(t)

// First Start
if err := ps.Start(); err != nil {
t.Fatalf("first Start() failed: %v", err)
}

// Stop
if err := ps.Stop(); err != nil {
t.Fatalf("Stop() failed: %v", err)
}

// Second Start — this must not trigger grpc Fatalf
if err := ps.Start(); err != nil {
t.Fatalf("second Start() after restart failed: %v", err)
}

// Cleanup
if err := ps.Stop(); err != nil {
t.Fatalf("final Stop() failed: %v", err)
}
}

// TestGrpcServer_MultipleRestarts verifies that the server can survive
// multiple Stop+Start cycles without panic.
func TestGrpcServer_MultipleRestarts(t *testing.T) {
defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
grpclog.SetLoggerV2(newPanicOnFatalLogger())

ps := setupRestartablePluginServer(t)

for i := 0; i < 5; i++ {
if err := ps.Start(); err != nil {
t.Fatalf("Start() iteration %d failed: %v", i, err)
}
if err := ps.Stop(); err != nil {
t.Fatalf("Stop() iteration %d failed: %v", i, err)
}
}
}

// TestGrpcServer_StopWithoutStart verifies that Stop() is safe when
// Start() was never called (no goroutines to wait for).
func TestGrpcServer_StopWithoutStart(t *testing.T) {
ps := setupRestartablePluginServer(t)
if err := ps.Stop(); err != nil {
t.Fatalf("Stop() without Start() should be safe: %v", err)
}
}

// TestGrpcServer_DoubleStop verifies that calling Stop() twice is safe.
func TestGrpcServer_DoubleStop(t *testing.T) {
ps := setupRestartablePluginServer(t)

if err := ps.Start(); err != nil {
t.Fatalf("Start() failed: %v", err)
}

if err := ps.Stop(); err != nil {
t.Fatalf("first Stop() failed: %v", err)
}

if err := ps.Stop(); err != nil {
t.Fatalf("second Stop() should be safe: %v", err)
}
}

// TestGrpcServer_StopWaitForAllGoroutines verifies that Stop() returns
// only after all goroutines have exited. We verify this indirectly by
// checking that Start() after Stop() does not race (goroutine leak would
// manifest as stale channel reads).
func TestGrpcServer_StopWaitForAllGoroutines(t *testing.T) {
ps := setupRestartablePluginServer(t)

if err := ps.Start(); err != nil {
t.Fatalf("Start() failed: %v", err)
}

if err := ps.Stop(); err != nil {
t.Fatalf("Stop() failed: %v", err)
}

// bgWG and serveWG should be zero after Stop() returns.
// A new cycle confirms there is no deadlock or hang.
if err := ps.Start(); err != nil {
t.Fatalf("Start() after Stop() failed: %v", err)
}

if err := ps.Stop(); err != nil {
t.Fatalf("Stop() failed: %v", err)
}
}
Loading