diff --git a/vpn/ipc/endpoints.go b/vpn/ipc/endpoints.go index b55c43d2..590a5320 100644 --- a/vpn/ipc/endpoints.go +++ b/vpn/ipc/endpoints.go @@ -16,4 +16,5 @@ const ( connectionsEndpoint = "/connections" closeConnectionsEndpoint = "/connections/close" setSettingsPathEndpoint = "/set" + statusEventsEndpoint = "/status/events" ) diff --git a/vpn/ipc/events.go b/vpn/ipc/events.go new file mode 100644 index 00000000..f9ca27ad --- /dev/null +++ b/vpn/ipc/events.go @@ -0,0 +1,58 @@ +package ipc + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + + "github.com/getlantern/radiance/events" +) + +// StatusUpdateEvent is emitted when the VPN status changes. +type StatusUpdateEvent struct { + events.Event + Status VPNStatus `json:"status"` + Error string `json:"error,omitempty"` +} + +func (s *Server) statusEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + ch := make(chan StatusUpdateEvent, 8) + + // Send the current status immediately so the client doesn't have to wait for a change. + ch <- StatusUpdateEvent{Status: s.service.Status()} + + sub := events.Subscribe(func(evt StatusUpdateEvent) { + select { + case ch <- evt: + default: // drop if client is slow + } + }) + defer sub.Unsubscribe() + + for { + select { + case evt := <-ch: + buf, err := json.Marshal(evt) + if err != nil { + slog.Error("failed to marshal event", "error", err) + continue + } + fmt.Fprintf(w, "%s\r\n", buf) + flusher.Flush() + case <-r.Context().Done(): + slog.Debug("client disconnected") + return + } + } +} diff --git a/vpn/ipc/events_client.go b/vpn/ipc/events_client.go new file mode 100644 index 00000000..9af92b1a --- /dev/null +++ b/vpn/ipc/events_client.go @@ -0,0 +1,94 @@ +package ipc + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/getlantern/radiance/events" +) + +// StartStatusStream starts streaming status updates from the server and emits received +// [StatusUpdateEvent] events until the context is cancelled. If waitForConnect is true, it +// polls in a background goroutine until the server is reachable. When the stream is lost +// (server restart, network error, clean EOF), a [StatusUpdateEvent] with [Disconnected] status +// is emitted. The retry loop continues until a connection is established, the context is cancelled, +// or a non-recoverable error occurs (e.g. connection refused, invalid response). +func StartStatusStream(ctx context.Context, waitForConnect bool) error { + if !waitForConnect { + return startStream(ctx) + } + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + serverListening, err := tryDial(ctx) + if err != nil { + events.Emit(StatusUpdateEvent{ + Status: ErrorStatus, + Error: fmt.Sprintf("connection error: %v", err), + }) + return + } + if !serverListening { + continue // we started trying to connect before the server is ready + } + err = startStream(ctx) + if ctx.Err() != nil { + return + } + evt := StatusUpdateEvent{Status: Disconnected} + if err != nil { + slog.Warn("status stream disconnected", "error", err) + evt.Error = fmt.Sprintf("stream disconnected: %v", err) + } + // Stream ended cleanly (EOF) — server likely shut down. + events.Emit(evt) + return + } + } + }() + return nil +} + +func startStream(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, "GET", apiURL+statusEventsEndpoint, nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + client := &http.Client{ + Transport: &http.Transport{ + DialContext: dialContext, + Protocols: protocols, + }, + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("connecting: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status %s", resp.Status) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + var evt StatusUpdateEvent + if err := json.Unmarshal([]byte(line), &evt); err != nil { + continue + } + events.Emit(evt) + } + return scanner.Err() +} diff --git a/vpn/ipc/events_test.go b/vpn/ipc/events_test.go new file mode 100644 index 00000000..585804ff --- /dev/null +++ b/vpn/ipc/events_test.go @@ -0,0 +1,77 @@ +package ipc + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/sagernet/sing-box/experimental/clashapi" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/servers" +) + +func TestStatusEventsHandler(t *testing.T) { + svc := &mockService{status: Disconnected} + s := &Server{service: svc} + + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", statusEventsEndpoint, nil) + + done := make(chan struct{}) + go func() { + defer close(done) + s.statusEventsHandler(rec, req) + }() + + waitAssert := func(want StatusUpdateEvent, msg string) { + require.Eventually(t, func() bool { + return strings.Contains(rec.Body.String(), "\r\n") + }, time.Second, 10*time.Millisecond, msg) + evt := parseEventLine(t, rec.Body) + rec.Body.Reset() + assert.Equal(t, want, evt, msg) + } + waitAssert(StatusUpdateEvent{Status: Disconnected}, "initial event not received") + + // Emit a status change and wait for it to arrive. + evt := StatusUpdateEvent{Status: Connected} + events.Emit(evt) + waitAssert(evt, "connected event not received") + + // Emit an error status + evt = StatusUpdateEvent{Status: ErrorStatus, Error: "something went wrong"} + events.Emit(evt) + waitAssert(evt, "error event not received") +} + +func parseEventLine(t *testing.T, body *bytes.Buffer) StatusUpdateEvent { + line, err := body.ReadBytes('\n') + require.NoError(t, err) + + var evt StatusUpdateEvent + line = bytes.TrimSpace(line) + require.NoError(t, json.Unmarshal(line, &evt)) + return evt +} + +type mockService struct { + status VPNStatus +} + +func (m *mockService) Ctx() context.Context { return nil } +func (m *mockService) Status() VPNStatus { return m.status } +func (m *mockService) Start(context.Context, string) error { return nil } +func (m *mockService) Restart(context.Context, string) error { return nil } +func (m *mockService) ClashServer() *clashapi.Server { return nil } +func (m *mockService) Close() error { return nil } +func (m *mockService) UpdateOutbounds(options servers.Servers) error { return nil } +func (m *mockService) AddOutbounds(group string, options servers.Options) error { return nil } +func (m *mockService) RemoveOutbounds(group string, tags []string) error { return nil } diff --git a/vpn/ipc/http.go b/vpn/ipc/http.go index f2af307d..3167a559 100644 --- a/vpn/ipc/http.go +++ b/vpn/ipc/http.go @@ -19,6 +19,12 @@ import ( const tracerName = "github.com/getlantern/radiance/vpn/ipc" +var protocols = func() *http.Protocols { + p := &http.Protocols{} + p.SetUnencryptedHTTP2(true) + return p +}() + // empty is a placeholder type for requests that do not expect a response body. type empty struct{} @@ -40,7 +46,9 @@ func sendRequest[T any](ctx context.Context, method, endpoint string, data any) } client := &http.Client{ Transport: &http.Transport{ - DialContext: dialContext, + DialContext: dialContext, + Protocols: protocols, + ForceAttemptHTTP2: true, }, } resp, err := client.Do(req) diff --git a/vpn/ipc/server.go b/vpn/ipc/server.go index 187aa57d..534ee5ec 100644 --- a/vpn/ipc/server.go +++ b/vpn/ipc/server.go @@ -19,7 +19,6 @@ import ( "go.opentelemetry.io/otel" "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/events" "github.com/getlantern/radiance/servers" ) @@ -44,28 +43,21 @@ type Service interface { // Server represents the IPC server that communicates over a Unix domain socket for Unix-like // systems, and a named pipe for Windows. type Server struct { - svr *http.Server - service Service - router chi.Router - vpnStatus atomic.Value // string - closed atomic.Bool -} - -// StatusUpdateEvent is emitted when the VPN status changes. -type StatusUpdateEvent struct { - events.Event - Status VPNStatus - Error error + svr *http.Server + service Service + router chi.Router + closed atomic.Bool } type VPNStatus string // Possible VPN statuses const ( - Connected VPNStatus = "connected" - Disconnected VPNStatus = "disconnected" Connecting VPNStatus = "connecting" + Connected VPNStatus = "connected" Disconnecting VPNStatus = "disconnecting" + Disconnected VPNStatus = "disconnected" + Restarting VPNStatus = "restarting" ErrorStatus VPNStatus = "error" ) @@ -79,8 +71,7 @@ func NewServer(service Service) *Server { service: service, router: chi.NewMux(), } - s.vpnStatus.Store(Disconnected) - s.router.Use(log, tracer) + s.router.Use(log) // Only add auth middleware if not running on mobile, since mobile platforms have their own // sandboxing and permission models. @@ -89,30 +80,37 @@ func NewServer(service Service) *Server { s.router.Use(authPeer) } - s.router.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + // Standard routes use the tracer middleware which buffers response bodies for error recording. + s.router.Group(func(r chi.Router) { + r.Use(tracer) + r.Get("/", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }) + r.Get(statusEndpoint, s.statusHandler) + r.Get(metricsEndpoint, s.metricsHandler) + r.Get(groupsEndpoint, s.groupHandler) + r.Get(connectionsEndpoint, s.connectionsHandler) + r.Get(selectEndpoint, s.selectedHandler) + r.Get(activeEndpoint, s.activeOutboundHandler) + r.Post(selectEndpoint, s.selectHandler) + r.Get(clashModeEndpoint, s.clashModeHandler) + r.Post(clashModeEndpoint, s.clashModeHandler) + r.Post(startServiceEndpoint, s.startServiceHandler) + r.Post(stopServiceEndpoint, s.stopServiceHandler) + r.Post(restartServiceEndpoint, s.restartServiceHandler) + r.Post(closeConnectionsEndpoint, s.closeConnectionHandler) }) - s.router.Get(statusEndpoint, s.statusHandler) - s.router.Get(metricsEndpoint, s.metricsHandler) - s.router.Get(groupsEndpoint, s.groupHandler) - s.router.Get(connectionsEndpoint, s.connectionsHandler) - s.router.Get(selectEndpoint, s.selectedHandler) - s.router.Get(activeEndpoint, s.activeOutboundHandler) - s.router.Post(selectEndpoint, s.selectHandler) - s.router.Get(clashModeEndpoint, s.clashModeHandler) - s.router.Post(clashModeEndpoint, s.clashModeHandler) - s.router.Post(startServiceEndpoint, s.startServiceHandler) - s.router.Post(stopServiceEndpoint, s.stopServiceHandler) - s.router.Post(restartServiceEndpoint, s.restartServiceHandler) - s.router.Post(updateOutboundsEndpoint, s.updateOutboundsHandler) - s.router.Post(addOutboundsEndpoint, s.addOutboundsHandler) - s.router.Post(removeOutboundsEndpoint, s.removeOutboundsHandler) - s.router.Post(closeConnectionsEndpoint, s.closeConnectionHandler) + + // SSE routes skip the tracer middleware since it buffers the entire response body + // and holds the span open for the lifetime of the connection. + s.router.Get(statusEventsEndpoint, s.statusEventsHandler) svr := &http.Server{ - Handler: s.router, - ReadTimeout: time.Second * 5, - WriteTimeout: time.Second * 5, + Handler: s.router, + ReadTimeout: time.Second * 5, + // WriteTimeout is 0 (unlimited) to support long-lived SSE connections. + // Non-streaming handlers return quickly so this is safe. + Protocols: protocols, } if addAuth { svr.ConnContext = func(ctx context.Context, c net.Conn) context.Context { @@ -146,7 +144,6 @@ func (s *Server) Start() error { if s.service.Status() != Disconnected { slog.Warn("IPC server stopped unexpectedly, closing service") s.service.Close() - s.setVPNStatus(ErrorStatus, errors.New("IPC server stopped unexpectedly")) } }() @@ -200,13 +197,10 @@ func (s *Server) startServiceHandler(w http.ResponseWriter, r *http.Request) { return } - s.setVPNStatus(Connecting, nil) if err := s.service.Start(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } - s.setVPNStatus(Connected, nil) w.WriteHeader(http.StatusOK) } @@ -218,13 +212,10 @@ func StopService(ctx context.Context) error { func (s *Server) stopServiceHandler(w http.ResponseWriter, r *http.Request) { slog.Debug("Received request to stop service via IPC") - s.setVPNStatus(Disconnecting, nil) if err := s.service.Close(); err != nil { - s.setVPNStatus(ErrorStatus, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - s.setVPNStatus(Disconnected, nil) w.WriteHeader(http.StatusOK) } @@ -247,17 +238,9 @@ func (s *Server) restartServiceHandler(w http.ResponseWriter, r *http.Request) { return } - s.setVPNStatus(Disconnected, nil) if err := s.service.Restart(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - s.setVPNStatus(Connected, nil) w.WriteHeader(http.StatusOK) } - -func (s *Server) setVPNStatus(status VPNStatus, err error) { - s.vpnStatus.Store(status) - events.Emit(StatusUpdateEvent{Status: status, Error: err}) -} diff --git a/vpn/ipc/status.go b/vpn/ipc/status.go index aef35029..28f6087b 100644 --- a/vpn/ipc/status.go +++ b/vpn/ipc/status.go @@ -54,8 +54,8 @@ func (s *Server) metricsHandler(w http.ResponseWriter, r *http.Request) { } } -type state struct { - State VPNStatus `json:"state"` +type vpnStatus struct { + Status VPNStatus `json:"status"` } // GetStatus retrieves the current status of the service. @@ -65,14 +65,14 @@ func GetStatus(ctx context.Context) (VPNStatus, error) { return Disconnected, err } - res, err := sendRequest[state](ctx, "GET", statusEndpoint, nil) + res, err := sendRequest[vpnStatus](ctx, "GET", statusEndpoint, nil) if errors.Is(err, ErrIPCNotRunning) || errors.Is(err, ErrServiceIsNotReady) { return Disconnected, nil } if err != nil { - return "", fmt.Errorf("error getting status: %w", err) + return ErrorStatus, fmt.Errorf("error getting status: %w", err) } - return res.State, nil + return res.Status, nil } func tryDial(ctx context.Context) (bool, error) { @@ -90,9 +90,9 @@ func tryDial(ctx context.Context) (bool, error) { func (s *Server) statusHandler(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) status := s.service.Status() - span.SetAttributes(attribute.String("status", string(status))) + span.SetAttributes(attribute.String("status", status.String())) w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(state{status}); err != nil { + if err := json.NewEncoder(w).Encode(vpnStatus{Status: status}); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/vpn/service.go b/vpn/service.go index 4df2f133..09ad4d2f 100644 --- a/vpn/service.go +++ b/vpn/service.go @@ -133,6 +133,7 @@ func (s *TunnelService) Restart(ctx context.Context, options string) error { } s.logger.Info("Restarting tunnel") + s.tunnel.setStatus(ipc.Restarting, nil) if s.platformIfce != nil { s.mu.Unlock() if err := s.platformIfce.RestartService(); err != nil { diff --git a/vpn/tunnel.go b/vpn/tunnel.go index 3dec359e..d6f597de 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -23,6 +23,7 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/events" "github.com/getlantern/radiance/internal" "github.com/getlantern/radiance/servers" "github.com/getlantern/radiance/vpn/ipc" @@ -52,14 +53,21 @@ type tunnel struct { clientContextTracker *clientcontext.ClientContextInjector - status atomic.Value + status atomic.Value // ipc.VPNStatus cancel context.CancelFunc closers []io.Closer } -func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) error { - t.status.Store(ipc.Connecting) +func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) (err error) { + if t.status.Load() != ipc.Restarting { + t.setStatus(ipc.Connecting, nil) + } t.ctx, t.cancel = context.WithCancel(box.BaseContext()) + defer func() { + if err != nil { + t.setStatus(ipc.ErrorStatus, err) + } + }() if err := t.init(options, platformIfce); err != nil { t.close() @@ -72,7 +80,7 @@ func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) er slog.Error("Failed to connect tunnel", "error", err) return fmt.Errorf("connecting tunnel: %w", err) } - t.status.Store(ipc.Connected) + t.setStatus(ipc.Connected, nil) t.optsMap = makeOutboundOptsMap(t.ctx, options) return nil } @@ -227,6 +235,9 @@ func (t *tunnel) selectOutbound(group, tag string) error { } func (t *tunnel) close() error { + if t.status.Load() != ipc.Restarting { + t.setStatus(ipc.Disconnecting, nil) + } if t.cancel != nil { t.cancel() } @@ -249,7 +260,9 @@ func (t *tunnel) close() error { t.closers = nil t.lbService = nil - t.status.Store(ipc.Disconnected) + if t.status.Load() != ipc.Restarting { + t.setStatus(ipc.Disconnected, nil) + } return err } @@ -257,6 +270,15 @@ func (t *tunnel) Status() ipc.VPNStatus { return t.status.Load().(ipc.VPNStatus) } +func (t *tunnel) setStatus(status ipc.VPNStatus, err error) { + t.status.Store(status) + evt := ipc.StatusUpdateEvent{Status: status} + if err != nil { + evt.Error = err.Error() + } + events.Emit(evt) +} + var errLibboxClosed = errors.New("libbox closed") func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) {