diff --git a/cmd/daemon_nix.go b/cmd/daemon_nix.go index 84a53d5..34e8091 100644 --- a/cmd/daemon_nix.go +++ b/cmd/daemon_nix.go @@ -17,5 +17,8 @@ func runService(service *service.Service) error { } <-service.Done() + // Wait for the service goroutines to finish so the runner's process group is + // torn down before we exit; otherwise the runner is orphaned on shutdown. + service.Wait() return nil } diff --git a/service/service.go b/service/service.go index 22d5052..6b355f8 100644 --- a/service/service.go +++ b/service/service.go @@ -67,7 +67,15 @@ type Service struct { cliMux sync.Mutex running bool done chan struct{} - + // wg tracks the keepAliveLoop, loop and keepRunnerAlive goroutines so that + // shutdown can wait for them to finish — in particular the runner cleanup in + // keepRunnerAlive that terminates the runner's process group — before the + // process exits and orphans the runner. + wg sync.WaitGroup + + // connMux guards the connecting/connected handshake channels, which are + // closed and re-created from both keepAliveLoop and loop. + connMux sync.Mutex connecting chan struct{} connected chan struct{} @@ -78,6 +86,15 @@ func (s *Service) Done() chan struct{} { return s.done } +// Wait blocks until the service's goroutines have exited. This includes the +// runner cleanup in keepRunnerAlive that terminates the runner's process group, +// so callers should Wait after Done() fires to ensure the runner is torn down +// before the process exits. The goroutines all return on ctx cancellation or +// Stop(), so Wait does not block indefinitely. +func (s *Service) Wait() { + s.wg.Wait() +} + func (s *Service) getClient() (*garmWs.Reader, error) { s.cliMux.Lock() cli := s.cli @@ -301,6 +318,7 @@ func (s *Service) Start() error { s.running = true s.done = make(chan struct{}) + s.wg.Add(3) go s.keepAliveLoop() go s.loop() go s.keepRunnerAlive() @@ -318,9 +336,12 @@ func (s *Service) Stop() error { close(s.done) s.running = false + + s.cliMux.Lock() if s.cli != nil { s.cli.Stop() } + s.cliMux.Unlock() return nil } @@ -503,6 +524,7 @@ func (s *Service) sleepWithCancel(d time.Duration) (shouldQuit bool) { } func (s *Service) keepRunnerAlive() { + defer s.wg.Done() retryCreate: state := s.determineRunnerState(s.isRunnerAlive()) if state == params.RunnerTerminated { @@ -584,7 +606,28 @@ retryStart: } } +// connectionUp is called by keepAliveLoop once a websocket connection is +// established: it re-arms the connected channel (so the next select blocks until +// the connection drops) and closes connecting to wake loop(). +func (s *Service) connectionUp() { + s.connMux.Lock() + defer s.connMux.Unlock() + s.connected = make(chan struct{}) + close(s.connecting) +} + +// connectionDown is called by loop() when the connection drops: it re-arms the +// connecting channel and closes connected to wake keepAliveLoop() for a +// reconnect. +func (s *Service) connectionDown() { + s.connMux.Lock() + defer s.connMux.Unlock() + s.connecting = make(chan struct{}) + close(s.connected) +} + func (s *Service) keepAliveLoop() { + defer s.wg.Done() var sleepTime time.Duration retryConnecting: if sleepTime > 0 { @@ -593,12 +636,15 @@ retryConnecting: } } for { + s.connMux.Lock() + connected := s.connected + s.connMux.Unlock() select { case <-s.done: return case <-s.ctx.Done(): return - case <-s.connected: + case <-connected: slog.InfoContext(s.ctx, "attempting to connect to GARM server", "server", s.cfg.ServerURL) sleepTime = 5 * time.Second parsed, err := url.ParseRequestURI(s.cfg.ServerURL) @@ -617,18 +663,18 @@ retryConnecting: s.cli = cli s.cliMux.Unlock() - if err := s.cli.Start(); err != nil { + if err := cli.Start(); err != nil { slog.WarnContext(s.ctx, "failed to start websocket connection", "error", err) goto retryConnecting } slog.InfoContext(s.ctx, "successfully connected to GARM", "server", s.cfg.ServerURL) - s.connected = make(chan struct{}) - close(s.connecting) + s.connectionUp() } } } func (s *Service) loop() { + defer s.wg.Done() heartbeatTicker := time.NewTicker(30 * time.Second) defer func() { slog.InfoContext(s.ctx, "daemon is shutting down") @@ -638,38 +684,52 @@ func (s *Service) loop() { heartbeatTicker.Stop() }() -connecting: - select { - case <-s.done: - return - case <-s.ctx.Done(): - return - case <-s.connecting: - } - // send initial heartbeat - if id, alive, ok := s.snapshot(); ok { - if err := s.sendHeartbeat(id); err != nil { - slog.ErrorContext(s.ctx, "failed to send heartbeat", "error", err) - } - s.sendRunnerStatus(id, alive) - } - for { + // Wait until keepAliveLoop signals that a connection is up. + s.connMux.Lock() + connectingCh := s.connecting + s.connMux.Unlock() select { case <-s.done: return case <-s.ctx.Done(): return - case <-s.cli.Done(): - slog.InfoContext(s.ctx, "remote host closed WS connection") - s.connecting = make(chan struct{}) - close(s.connected) - goto connecting - case <-heartbeatTicker.C: - // send heartbeat - if id, _, ok := s.snapshot(); ok { - if err := s.sendHeartbeat(id); err != nil { - slog.ErrorContext(s.ctx, "failed to send heartbeat", "error", err) + case <-connectingCh: + } + + cli, err := s.getClient() + if err != nil { + // Signalled connected but the client is gone; ask for a reconnect. + slog.ErrorContext(s.ctx, "no websocket client after connect", "error", err) + s.connectionDown() + continue + } + + // send initial heartbeat + if id, alive, ok := s.snapshot(); ok { + if err := s.sendHeartbeat(id); err != nil { + slog.ErrorContext(s.ctx, "failed to send heartbeat", "error", err) + } + s.sendRunnerStatus(id, alive) + } + + online := true + for online { + select { + case <-s.done: + return + case <-s.ctx.Done(): + return + case <-cli.Done(): + slog.InfoContext(s.ctx, "remote host closed WS connection") + s.connectionDown() + online = false + case <-heartbeatTicker.C: + // send heartbeat + if id, _, ok := s.snapshot(); ok { + if err := s.sendHeartbeat(id); err != nil { + slog.ErrorContext(s.ctx, "failed to send heartbeat", "error", err) + } } } } diff --git a/service/service_test.go b/service/service_test.go index 4ecb16b..ca117f0 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -2,11 +2,16 @@ package service import ( "context" + "net/http" + "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" + "github.com/gorilla/websocket" + "github.com/cloudbase/garm-agent/config" "github.com/cloudbase/garm/params" ) @@ -439,3 +444,72 @@ func indexOf(s, substr string) int { } return -1 } + +// TestServiceReconnect drives the keepAliveLoop/loop connect handshake against a +// real websocket server: it connects, the server drops the connection, and the +// agent must reconnect. Run under -race, it exercises the connecting/connected +// and cli synchronization. +func TestServiceReconnect(t *testing.T) { + upgrader := websocket.Upgrader{} + conns := make(chan *websocket.Conn, 8) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + conns <- c + for { + if _, _, err := c.ReadMessage(); err != nil { + return + } + } + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := &Service{ + ctx: ctx, + cfg: &config.Agent{ServerURL: wsURL, Token: "test-token"}, + done: make(chan struct{}), + connecting: make(chan struct{}), + connected: closed, + running: true, + } + s.wg.Add(2) + go s.keepAliveLoop() + go s.loop() + + waitConn := func(what string) *websocket.Conn { + t.Helper() + select { + case c := <-conns: + return c + case <-time.After(5 * time.Second): + cancel() + t.Fatalf("timed out waiting for the agent to %s", what) + return nil + } + } + + c1 := waitConn("connect") + // Drop the connection from the server side; the agent should reconnect. + c1.Close() + waitConn("reconnect") + + // Shut down and make sure both goroutines exit. + cancel() + stopped := make(chan struct{}) + go func() { + s.wg.Wait() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(5 * time.Second): + t.Fatal("service goroutines did not exit after shutdown") + } +} diff --git a/service/service_windows.go b/service/service_windows.go index c2f170f..d6c4bc0 100644 --- a/service/service_windows.go +++ b/service/service_windows.go @@ -10,7 +10,6 @@ func (s *Service) Execute(args []string, r <-chan svc.ChangeRequest, status chan if err := s.Start(); err != nil { return false, 11 } - defer s.Stop() const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown status <- svc.Status{State: svc.StartPending} @@ -39,5 +38,11 @@ loop: } status <- svc.Status{State: svc.StopPending} + if stopErr := s.Stop(); stopErr != nil { + slog.ErrorContext(s.ctx, "failed to stop service", "error", stopErr) + } + // Wait for the service goroutines (including the runner cleanup) to finish + // before reporting the service stopped. + s.Wait() return false, 0 }