diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 81f903dc..208aa7e3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -21,6 +21,13 @@ jobs: with: go-version-file: "go.mod" + - name: Check for slog overwrite calls in tests + run: | + if grep -rn 'slog\.SetDefault\|slog\.SetLogLoggerLevel' --include='*_test.go' .; then + echo "::error::Test files should not upate the slog.Default logger or level. This pollutes the output." + exit 1 + fi + - name: Build run: go build -v -tags with_clash_api ./... - name: Test diff --git a/README.md b/README.md index ce688091..f49d3ed2 100644 --- a/README.md +++ b/README.md @@ -33,50 +33,105 @@ Available variables: * `RADIANCE_FEATURE_OVERRIDE`: Comma-separated list of feature flags to force-enable on the server side. If set, the value is sent as the `X-Lantern-Feature-Override` header on config requests in any environment, and it is recommended for testing/non-production use. For example, `RADIANCE_FEATURE_OVERRIDE=bandit_assignment` enables bandit-based proxy assignment during testing. -## Packages +## Architecture -Use `common.Init` to setup directories and configure loggers. -> [!note] -> This isn't necessary if `NewRadiance` was called as it will call `Init` for you. +Radiance is structured around a `LocalBackend` pattern that ties together all core functionality: configuration, servers, VPN connection, account management, issue reporting, and telemetry. The `LocalBackend` is the central coordinator and should be the primary interface for interacting with Radiance programmatically. -### `vpn` +In addition to being the core of the [Lantern client](https://github.com/getlantern/lantern), radiance also provides a standalone daemon and CLI: -The `vpn` package provides high-level functions for controlling the VPN tunnel. +- **`lanternd`** — the VPN daemon that runs the `LocalBackend` and exposes an IPC server. It can run in the foreground or be installed as a system service. +- **`lantern`** — a CLI client that communicates with the daemon over IPC. -To connect to the best available server, you can use the `QuickConnect` function. This function takes a server group (`servers.SGLantern`, `servers.SGUser`, or `"all"`) and a `PlatformInterface` as input. For example: +### Building CLI & Daemon -```go -err := vpn.QuickConnect(servers.SGLantern, platIfce) +From the `cmd/` directory: + +```sh +# Build the daemon +just build-daemon +# or +make build-daemon + +# Build the CLI +just build-cli +# or +make build-cli ``` -will connect to the best Lantern server, while: +Both binaries are output to `bin/`. You can also run the daemon directly with `make run-daemon`. -```go -err := vpn.QuickConnect("all", platIfce) +### Running + +```sh +# Start the daemon +lanternd run --data-path ~/data --log-path ~/logs + +# Install/uninstall as a system service +lanternd install --data-path ~/data --log-path ~/logs +lanternd uninstall + +# CLI commands (requires a running daemon) +lantern connect [--tag ] +lantern disconnect +lantern status +lantern servers +lantern account login +lantern subscription +lantern split-tunnel +lantern logs +lantern ip ``` -will connect to the best overall. +## Packages + +Use `common.Init` to setup directories and configure loggers. +> [!note] +> This isn't necessary if `NewLocalBackend` was called as it will call `Init` for you. + +### `backend` + +The `backend` package provides `LocalBackend`, the main entry point for all Radiance functionality. Create one with `NewLocalBackend(ctx, opts)` and call `Start()` to begin fetching configuration and serving requests. `LocalBackend` owns and coordinates the `VPNClient`, `ServerManager`, `ConfigHandler`, `AccountClient`, `IssueReporter`, and telemetry. + +### `vpn` -You can also connect to a specific server using `ConnectToServer`. This function requires a server group, a server tag, and a `PlatformInterface`. For example: +The `vpn` package provides `VPNClient`, which manages the lifecycle of the VPN tunnel. ```go -err := vpn.ConnectToServer(servers.SGUser, "my-server", platIfce) +client := vpn.NewVPNClient(dataPath, logger, platformIfce) +err := client.Connect(boxOptions) ``` -Both `QuickConnect` and `ConnectToServer` can be called without disconnecting first, allowing you to seamlessly switch between servers or connection modes. +`Connect` can be called without disconnecting first, allowing you to seamlessly switch between servers. Once connected, you can query status or view `Connections`. To stop the VPN, call `Disconnect`. -Once connected, you can check the `GetStatus` or view `ActiveConnections`. To stop the VPN, simply call `Disconnect`. The package also supports reconnecting to the last used server with `Reconnect`. +> [!note] +> In most cases, you should use the `LocalBackend` methods (`ConnectVPN`, `DisconnectVPN`, `RestartVPN`, `VPNStatus`) rather than using `VPNClient` directly. -This package also includes split tunneling capabilities, allowing you to include or exclude specific applications, domains, or IP addresses from the VPN tunnel. You can manage split tunneling by creating a `SplitTunnel` handler with `NewSplitTunnelHandler`. This handler allows you to `Enable` or `Disable` split tunneling, `AddItem` or `RemoveItem` from the filter, and view the current `Filters`. +This package also includes split tunneling capabilities via the `SplitTunnel` type, allowing you to include or exclude specific applications, domains, or IP addresses from the VPN tunnel. ### `servers` -The `servers` package is responsible for managing all VPN server configurations, separating them into two groups: `lantern` (official Lantern servers) and `user` (user-provided servers). +The `servers` package manages all VPN server configurations, separating them into two groups: `lantern` (official Lantern servers fetched from the config) and `user` (user-provided servers). -The `Manager` allows you to `AddServers` and `RemoveServer` configurations. You can retrieve the config for a specific server with `GetServerByTag` or use `Servers` to retrieve all configs. +The `Manager` allows you to `AddServers` and `RemoveServers` configurations. You can retrieve the config for a specific server with `GetServerByTag` or use `Servers` to retrieve all configs. > [!caution] -> While you can get a new `Manager` instance with `NewManager`, it is recommended to use `Radiance.ServerManager`. This will return the shared manager instance. `NewManager` can be useful for retrieving server information if you don't have access to the shared instance, but the new instance should not be kept as it won't stay in sync and adding server configs to it will overwrite existing configs if both manager instances are pointed to the same server file. +> While you can get a new `Manager` instance with `NewManager`, it is recommended to use the `LocalBackend`'s server methods (`Servers`, `AddServers`, `RemoveServers`, `GetServerByTag`). These use the shared manager instance. `NewManager` can be useful for retrieving server information if you don't have access to the shared instance, but the new instance should not be kept as it won't stay in sync. + +A key feature of this package is the ability to add private servers from a server manager via an access token using `AddPrivateServer`. This process uses Trust-on-first-use (TOFU) to securely add the server. Once a private server is added, you can invite other users with `InviteToPrivateServer` and revoke access with `RevokePrivateServerInvite`. + +### `ipc` + +The `ipc` package provides the communication layer between the `lantern` CLI and the `lanternd` daemon. The `ipc.Server` exposes an HTTP API backed by the `LocalBackend`, and the `ipc.Client` provides a typed Go client for calling it. All communication happens over a local socket. + +### `account` + +The `account` package handles user authentication (email/password and OAuth), signup, email verification, account recovery, device management, and subscription operations. It communicates with the Lantern account server and caches authentication state locally. + +### `config` + +The `config` package fetches proxy configuration from the Lantern API on a polling interval and emits `NewConfigEvent` events when the configuration changes. The `LocalBackend` subscribes to these events to update server configurations automatically. + +### `events` -A key feature of this package is the ability to add private servers from a server manager via an access token using `AddPrivateServer`. This process uses Trust-on-first-use (TOFU) to securely add the server. Once a private server is added, you can use the manager to invite other users to it with `InviteToPrivateServer` and revoke access with `RevokePrivateServerInvite`. +A generic pub-sub event system used throughout Radiance for decoupled communication between components (config changes, VPN status updates, log entries, etc.). diff --git a/account/auth.go b/account/auth.go new file mode 100644 index 00000000..3f9a7a81 --- /dev/null +++ b/account/auth.go @@ -0,0 +1,117 @@ +package account + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "math/big" + + "github.com/1Password/srp" + "golang.org/x/crypto/pbkdf2" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" +) + +func (a *Client) fetchSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { + query := map[string]string{"email": email} + resp, err := a.sendRequest(ctx, "GET", "/users/salt", query, nil, nil) + if err != nil { + return nil, err + } + var salt protos.GetSaltResponse + if err := proto.Unmarshal(resp, &salt); err != nil { + return nil, fmt.Errorf("unmarshaling salt response: %w", err) + } + return &salt, nil +} + +// clientProof performs the SRP authentication flow to generate the client proof for the given email and password. +func (a *Client) clientProof(ctx context.Context, email, password string, salt []byte) ([]byte, error) { + srpClient, err := newSRPClient(email, password, salt) + if err != nil { + return nil, err + } + + A := srpClient.EphemeralPublic() + data := &protos.PrepareRequest{ + Email: email, + A: A.Bytes(), + } + resp, err := a.sendRequest(ctx, "POST", "/users/prepare", nil, nil, data) + if err != nil { + return nil, err + } + + var srpB protos.PrepareResponse + if err := proto.Unmarshal(resp, &srpB); err != nil { + return nil, fmt.Errorf("unmarshaling prepare response: %w", err) + } + B := big.NewInt(0).SetBytes(srpB.B) + if err = srpClient.SetOthersPublic(B); err != nil { + return nil, err + } + + key, err := srpClient.Key() + if err != nil || key == nil { + return nil, fmt.Errorf("generating Client key %w", err) + } + if !srpClient.GoodServerProof(salt, email, srpB.Proof) { + return nil, fmt.Errorf("checking server proof %w", err) + } + + proof, err := srpClient.ClientProof() + if err != nil { + return nil, fmt.Errorf("generating client proof %w", err) + } + return proof, nil +} + +// getSalt retrieves the salt for the given email address or it's cached value. +func (a *Client) getSalt(ctx context.Context, email string) ([]byte, error) { + if cached := a.getSaltCached(); cached != nil { + return cached, nil + } + resp, err := a.fetchSalt(ctx, email) + if err != nil { + return nil, err + } + return resp.Salt, nil +} + +const group = srp.RFC5054Group3072 + +func newSRPClient(email, password string, salt []byte) (*srp.SRP, error) { + if len(salt) == 0 || len(password) == 0 || len(email) == 0 { + return nil, errors.New("salt, password and email should not be empty") + } + + encryptedKey, err := generateEncryptedKey(password, email, salt) + if err != nil { + return nil, fmt.Errorf("failed to generate encrypted key: %w", err) + } + + return srp.NewSRPClient(srp.KnownGroups[group], encryptedKey, nil), nil +} + +func generateEncryptedKey(password, email string, salt []byte) (*big.Int, error) { + if len(salt) == 0 || len(password) == 0 || len(email) == 0 { + return nil, errors.New("salt or password or email is empty") + } + combinedInput := password + email + encryptedKey := pbkdf2.Key([]byte(combinedInput), salt, 4096, 32, sha256.New) + encryptedKeyBigInt := big.NewInt(0).SetBytes(encryptedKey) + return encryptedKeyBigInt, nil +} + +func generateSalt() ([]byte, error) { + salt := make([]byte, 16) + if n, err := rand.Read(salt); err != nil { + return nil, err + } else if n != 16 { + return nil, errors.New("failed to generate 16 byte salt") + } + return salt, nil +} diff --git a/account/client.go b/account/client.go new file mode 100644 index 00000000..bfb674c3 --- /dev/null +++ b/account/client.go @@ -0,0 +1,241 @@ +// Package account provides a client for communicating with the account server to perform operations +// such as user authentication, subscription management, and account information retrieval. +package account + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "path/filepath" + "sort" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/settings" +) + +const tracerName = "github.com/getlantern/radiance/account" + +// Client is an account client that communicates with the account server to perform operations such as +// user authentication, subscription management, and account information retrieval. +type Client struct { + httpClient *http.Client + // proURL and authURL override the default server URLs. Used for testing. + proURL string + authURL string + + salt []byte + saltPath string + mu sync.RWMutex +} + +// NewClient creates a new account client with the given HTTP client and data directory for caching +// the salt value. +func NewClient(httpClient *http.Client, dataDir string) *Client { + path := filepath.Join(dataDir, saltFileName) + salt, err := readSalt(path) + if err != nil { + slog.Warn("failed to read salt", "error", err) + } + return &Client{ + httpClient: httpClient, + salt: salt, + saltPath: path, + } +} + +func (a *Client) getSaltCached() []byte { + a.mu.RLock() + defer a.mu.RUnlock() + return a.salt +} + +func (a *Client) setSalt(salt []byte) { + a.mu.Lock() + defer a.mu.Unlock() + a.salt = salt +} + +func (a *Client) proBaseURL() string { + if a.proURL != "" { + return a.proURL + } + return common.GetProServerURL() +} + +func (a *Client) baseURL() string { + if a.authURL != "" { + return a.authURL + } + return common.GetBaseURL() +} + +// sendRequest sends an HTTP request to the specified URL with the given method, query parameters, +// headers, and body. If the URL is relative, the base URL will be prepended. +func (a *Client) sendRequest( + ctx context.Context, + method, url string, + queryParams, headers map[string]string, + body any, +) ([]byte, error) { + // check if url is absolute, if not prepend base URL + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = a.baseURL() + url + } + + var bodyReader io.Reader + contentType := "" + if body != nil { + if pb, ok := body.(proto.Message); ok { + data, err := proto.Marshal(pb) + if err != nil { + return nil, fmt.Errorf("marshaling protobuf request: %w", err) + } + bodyReader = bytes.NewReader(data) + contentType = "application/x-protobuf" + } else { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshaling JSON request: %w", err) + } + bodyReader = bytes.NewReader(data) + contentType = "application/json" + } + } + req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + req.Header.Set(common.AppNameHeader, common.Name) + req.Header.Set(common.VersionHeader, common.Version) + req.Header.Set(common.PlatformHeader, common.Platform) + if contentType != "" { + req.Header.Set("Content-Type", contentType) + req.Header.Set("Accept", contentType) + } + if len(queryParams) > 0 { + q := req.URL.Query() + for k, v := range queryParams { + q.Set(k, v) + } + req.URL.RawQuery = q.Encode() + } + + if env.GetBool(env.PrintCurl) { + slog.Debug("CURL command", "curl", curlFromRequest(req)) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + sanitized := sanitizeResponseBody(respBody) + slog.Debug("error response", "path", req.URL.Path, "status", resp.StatusCode, "body", string(sanitized)) + return nil, fmt.Errorf("unexpected status %v body %s", resp.StatusCode, sanitized) + } + + if len(respBody) == 0 { + return nil, nil + } + if contentType := resp.Header.Get("Content-Type"); strings.Contains(contentType, "application/json") { + return sanitizeResponseBody(respBody), nil + } + return respBody, nil +} + +// sendProRequest sends a request to the Pro server, automatically adding the required headers, +// including the device ID, user ID, and Pro token from settings, if available. If the URL is relative, +// the Pro server base URL will be prepended. +func (a *Client) sendProRequest( + ctx context.Context, + method, url string, + queryParams, additionalheaders map[string]string, + body any, +) ([]byte, error) { + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = a.proBaseURL() + url + } + headers := map[string]string{ + common.DeviceIDHeader: settings.GetString(settings.DeviceIDKey), + } + if tok := settings.GetString(settings.TokenKey); tok != "" { + headers[common.ProTokenHeader] = tok + } + if uid := settings.GetString(settings.UserIDKey); uid != "" { + headers[common.UserIDHeader] = uid + } + maps.Copy(headers, additionalheaders) + return a.sendRequest(ctx, method, url, queryParams, headers, body) +} + +// curlFromRequest generates a curl command string from an [http.Request]. +func curlFromRequest(req *http.Request) string { + var b strings.Builder + fmt.Fprintf(&b, "curl -X %s", req.Method) + + keys := make([]string, 0, len(req.Header)) + for k := range req.Header { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + for _, v := range req.Header[k] { + fmt.Fprintf(&b, " -H '%s: %s'", k, v) + } + } + + if req.Body != nil { + buf, _ := io.ReadAll(req.Body) + // Important! we need to reset the body since it can only be read once. + req.Body = io.NopCloser(bytes.NewBuffer(buf)) + fmt.Fprintf(&b, " -d '%s'", shellEscape(string(buf))) + } + + fmt.Fprintf(&b, " '%s'", req.URL.String()) + return b.String() +} + +func shellEscape(s string) string { + return strings.ReplaceAll(s, "'", "'\\''") +} + +func sanitizeResponseBody(data []byte) []byte { + var out bytes.Buffer + r := bytes.NewReader(data) + for { + ch, size, err := r.ReadRune() + if err != nil { + break + } + if ch == utf8.RuneError && size == 1 { + continue + } + if unicode.IsControl(ch) && ch != '\n' && ch != '\r' && ch != '\t' { + continue + } + out.WriteRune(ch) + } + return out.Bytes() +} diff --git a/api/jwt.go b/account/jwt.go similarity index 79% rename from api/jwt.go rename to account/jwt.go index cb14f482..c381243a 100644 --- a/api/jwt.go +++ b/account/jwt.go @@ -1,4 +1,4 @@ -package api +package account import ( "encoding/json" @@ -10,13 +10,15 @@ import ( type JWTUserInfo struct { UserID string `json:"user_id"` Email string `json:"email"` - DeviceId string `json:"device_id"` + DeviceID string `json:"device_id"` LegacyUserID int64 `json:"legacy_user_id"` LegacyToken string `json:"legacy_token"` } func decodeJWT(tokenStr string) (*JWTUserInfo, error) { claims := jwt.MapClaims{} + // ParseUnverified is used intentionally: the JWT has already been validated + // server-side and the client only needs to extract claims for local use. token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, &claims) if err != nil { return nil, err diff --git a/api/protos/auth.pb.go b/account/protos/auth.pb.go similarity index 100% rename from api/protos/auth.pb.go rename to account/protos/auth.pb.go diff --git a/api/protos/auth.proto b/account/protos/auth.proto similarity index 100% rename from api/protos/auth.proto rename to account/protos/auth.proto diff --git a/api/protos/subscription.pb.go b/account/protos/subscription.pb.go similarity index 100% rename from api/protos/subscription.pb.go rename to account/protos/subscription.pb.go diff --git a/api/protos/subscription.proto b/account/protos/subscription.proto similarity index 100% rename from api/protos/subscription.proto rename to account/protos/subscription.proto diff --git a/api/subscription.go b/account/subscription.go similarity index 60% rename from api/subscription.go rename to account/subscription.go index a4639529..6fdc2f95 100644 --- a/api/subscription.go +++ b/account/subscription.go @@ -1,19 +1,20 @@ -package api +package account import ( "context" + "encoding/json" "fmt" "log/slog" "net/url" "strconv" "time" - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" + "go.opentelemetry.io/otel" + + "github.com/getlantern/radiance/account/protos" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/traces" - "go.opentelemetry.io/otel" ) type ( @@ -48,39 +49,40 @@ type SubscriptionPlans struct { // SubscriptionResponse contains information about a created subscription. type SubscriptionResponse struct { - CustomerId string `json:"customerId"` - SubscriptionId string `json:"subscriptionId"` + CustomerID string `json:"customerId"` + SubscriptionID string `json:"subscriptionId"` ClientSecret string `json:"clientSecret"` PublishableKey string `json:"publishableKey"` } // SubscriptionPlans retrieves available subscription plans for a given channel. -func (ac *APIClient) SubscriptionPlans(ctx context.Context, channel string) (string, error) { +func (a *Client) SubscriptionPlans(ctx context.Context, channel string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_plans") defer span.End() - var resp SubscriptionPlans params := map[string]string{ "locale": settings.GetString(settings.LocaleKey), "distributionChannel": channel, } - proWC := ac.proWebClient() - req := proWC.NewRequest(params, nil, nil) - err := proWC.Get(ctx, "/plans-v5", req, &resp) + resp, err := a.sendProRequest(ctx, "GET", "/plans-v5", params, nil, nil) if err != nil { slog.Error("retrieving plans", "error", err) return "", traces.RecordError(ctx, err) } - if resp.BaseResponse != nil && resp.Error != "" { - err = fmt.Errorf("received bad response: %s", resp.Error) + var plans SubscriptionPlans + if err := json.Unmarshal(resp, &plans); err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("unmarshaling plans response: %w", err)) + } + if plans.BaseResponse != nil && plans.Error != "" { + err = fmt.Errorf("received bad response: %s", plans.Error) slog.Error("retrieving plans", "error", err) return "", traces.RecordError(ctx, err) } - return withMarshalJsonString(resp, nil) + return string(resp), nil } // NewStripeSubscription creates a new Stripe subscription for the given email and plan ID. -func (ac *APIClient) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { +func (a *Client) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "new_stripe_subscription") defer span.End() @@ -88,25 +90,25 @@ func (ac *APIClient) NewStripeSubscription(ctx context.Context, email, planID st "email": email, "planId": planID, } - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - var resp SubscriptionResponse - err := proWC.Post(ctx, "/stripe-subscription", req, &resp) - return withMarshalJsonString(resp, err) + resp, err := a.sendProRequest(ctx, "POST", "/stripe-subscription", nil, nil, data) + if err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("creating stripe subscription: %w", err)) + } + return string(resp), nil } type VerifySubscriptionResponse struct { Status string `json:"status"` - SubscriptionId string `json:"subscriptionId"` - ActualUserId int64 `json:"actualUserId" json:",omitempty"` - ActualUserToken string `json:"actualUserToken" json:",omitempty"` + SubscriptionID string `json:"subscriptionId"` + ActualUserID int64 `json:"actualUserId,omitempty"` + ActualUserToken string `json:"actualUserToken,omitempty"` } // VerifySubscription verifies a subscription for a given service (Google or Apple). data // should contain the information required by service to verify the subscription, such as the // purchase token for Google Play or the receipt for Apple. The status and subscription ID are returned // along with any error that occurred during the verification process. -func (ac *APIClient) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (string, error) { +func (a *Client) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "verify_subscription") defer span.End() @@ -121,46 +123,57 @@ func (ac *APIClient) VerifySubscription(ctx context.Context, service Subscriptio return "", traces.RecordError(ctx, fmt.Errorf("unsupported service: %s", service)) } - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - var resp VerifySubscriptionResponse - err := proWC.Post(ctx, path, req, &resp) + resp, err := a.sendProRequest(ctx, "POST", path, nil, nil, data) if err != nil { slog.Error("verifying subscription", "error", err) return "", traces.RecordError(ctx, fmt.Errorf("verifying subscription: %w", err)) } - return withMarshalJsonString(resp, nil) + return string(resp), nil + } -// StripeBillingPortalUrl generates the Stripe billing portal URL for the given user ID. -func (ac *APIClient) StripeBillingPortalUrl(ctx context.Context) (string, error) { +// StripeBillingPortalURL generates the Stripe billing portal URL for the given user ID. +// baseURL = common.GetProServerURL +func (a *Client) StripeBillingPortalURL(ctx context.Context, baseURL, userID, proToken string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "stripe_billing_portal_url") defer span.End() - portalURL, err := url.Parse(fmt.Sprintf("%s/%s", common.GetProServerURL(), "stripe-billing-portal")) + portalURL, err := url.Parse(baseURL + "/stripe-billing-portal") if err != nil { slog.Error("parsing portal URL", "error", err) return "", traces.RecordError(ctx, fmt.Errorf("parsing portal URL: %w", err)) } query := portalURL.Query() query.Set("referer", "https://lantern.io/") - query.Set("userId", strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - query.Set("proToken", settings.GetString(settings.TokenKey)) + query.Set("userId", userID) + query.Set("proToken", proToken) portalURL.RawQuery = query.Encode() return portalURL.String(), nil } -// SubscriptionPaymentRedirectURL generates a redirect URL for subscription payment. -func (ac *APIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_payment_redirect_url") - defer span.End() +type redirect struct { + Redirect string +} - type response struct { - Redirect string - } - var resp response +func (a *Client) paymentRedirect(ctx context.Context, path string, params map[string]string) (string, error) { headers := map[string]string{ - backend.RefererHeader: "https://lantern.io/", + common.RefererHeader: "https://lantern.io/", } + resp, err := a.sendProRequest(ctx, "GET", path, params, headers, nil) + if err != nil { + slog.Error("payment redirect", "error", err) + return "", traces.RecordError(ctx, fmt.Errorf("payment redirect: %w", err)) + } + var r redirect + if err := json.Unmarshal(resp, &r); err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("unmarshaling payment redirect response: %w", err)) + } + return r.Redirect, nil +} + +// SubscriptionPaymentRedirectURL generates a redirect URL for subscription payment. +func (a *Client) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_payment_redirect_url") + defer span.End() params := map[string]string{ "provider": data.Provider, "plan": data.Plan, @@ -168,43 +181,21 @@ func (ac *APIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data Pa "email": data.Email, "billingType": string(data.BillingType), } - proWC := ac.proWebClient() - req := proWC.NewRequest(params, headers, nil) - err := proWC.Get(ctx, "/subscription-payment-redirect", req, &resp) - if err != nil { - slog.Error("subscription payment redirect", "error", err) - return "", traces.RecordError(ctx, fmt.Errorf("subscription payment redirect: %w", err)) - } - return resp.Redirect, traces.RecordError(ctx, err) + return a.paymentRedirect(ctx, "/subscription-payment-redirect", params) } -// PaymentRedirect is used to get the payment redirect URL with PaymentRedirectData -// this is used in desktop app and android app -func (ac *APIClient) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { +// PaymentRedirect is used to get the payment redirect URL with PaymentRedirectData. +// This is used in the desktop and android apps. +func (a *Client) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "payment_redirect") defer span.End() - - type response struct { - Redirect string - } - var resp response - headers := map[string]string{ - backend.RefererHeader: "https://lantern.io/", - } - mapping := map[string]string{ + params := map[string]string{ "provider": data.Provider, "plan": data.Plan, "deviceName": data.DeviceName, "email": data.Email, } - proWC := ac.proWebClient() - req := proWC.NewRequest(mapping, headers, nil) - err := proWC.Get(ctx, "/payment-redirect", req, &resp) - if err != nil { - slog.Error("subscription payment redirect", "error", err) - return "", traces.RecordError(ctx, fmt.Errorf("subscription payment redirect: %w", err)) - } - return resp.Redirect, traces.RecordError(ctx, err) + return a.paymentRedirect(ctx, "/payment-redirect", params) } type PurchaseResponse struct { @@ -215,28 +206,29 @@ type PurchaseResponse struct { } // ActivationCode is used to purchase a subscription using a reseller code. -func (ac *APIClient) ActivationCode(ctx context.Context, email, resellerCode string) (*PurchaseResponse, error) { +func (a *Client) ActivationCode(ctx context.Context, email, resellerCode string) (*PurchaseResponse, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "activation_code") defer span.End() - data := map[string]interface{}{ + data := map[string]any{ "idempotencyKey": strconv.FormatInt(time.Now().UnixNano(), 10), "provider": "reseller-code", "email": email, "deviceName": settings.GetString(settings.DeviceIDKey), "resellerCode": resellerCode, } - var resp PurchaseResponse - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - err := proWC.Post(ctx, "/purchase", req, &resp) + resp, err := a.sendProRequest(ctx, "POST", "/purchase", nil, nil, data) if err != nil { slog.Error("retrieving subscription status", "error", err) return nil, traces.RecordError(ctx, fmt.Errorf("retrieving subscription status: %w", err)) } - if resp.BaseResponse != nil && resp.Error != "" { - slog.Error("retrieving subscription status", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("received bad response: %s", resp.Error)) + var purchase PurchaseResponse + if err := json.Unmarshal(resp, &purchase); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("unmarshaling purchase response: %w", err)) + } + if purchase.BaseResponse != nil && purchase.Error != "" { + slog.Error("retrieving subscription status", "error", purchase.Error) + return nil, traces.RecordError(ctx, fmt.Errorf("received bad response: %s", purchase.Error)) } - return &resp, nil + return &purchase, nil } diff --git a/account/subscription_test.go b/account/subscription_test.go new file mode 100644 index 00000000..cedd3ee3 --- /dev/null +++ b/account/subscription_test.go @@ -0,0 +1,61 @@ +package account + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubscriptionPaymentRedirect(t *testing.T) { + ac, _ := newTestClient(t) + data := PaymentRedirectData{ + Provider: "stripe", + Plan: "pro", + DeviceName: "test-device", + Email: "", + BillingType: SubscriptionTypeOneTime, + } + url, err := ac.SubscriptionPaymentRedirectURL(context.Background(), data) + require.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestPaymentRedirect(t *testing.T) { + ac, _ := newTestClient(t) + data := PaymentRedirectData{ + Provider: "stripe", + Plan: "pro", + DeviceName: "test-device", + Email: "", + } + url, err := ac.PaymentRedirect(context.Background(), data) + require.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestNewUser(t *testing.T) { + ac, _ := newTestClient(t) + resp, err := ac.NewUser(context.Background()) + require.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestVerifySubscription(t *testing.T) { + ac, _ := newTestClient(t) + data := map[string]string{ + "email": "test@getlantern.org", + "planID": "1y-usd-10", + } + resp, err := ac.VerifySubscription(context.Background(), AppleService, data) + require.NoError(t, err) + assert.NotEmpty(t, resp) +} + +func TestPlans(t *testing.T) { + ac, _ := newTestClient(t) + resp, err := ac.SubscriptionPlans(context.Background(), "store") + require.NoError(t, err) + assert.NotEmpty(t, resp) +} diff --git a/account/user.go b/account/user.go new file mode 100644 index 00000000..f0a685c7 --- /dev/null +++ b/account/user.go @@ -0,0 +1,672 @@ +package account + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/url" + "os" + "strings" + + "go.opentelemetry.io/otel" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/traces" +) + +const saltFileName = ".salt" + +// UserDataResponse represents the response from pro server +type UserDataResponse struct { + *protos.BaseResponse `json:",inline"` + *protos.LoginResponse_UserData `json:",inline"` +} + +type SignupResponse = protos.SignupResponse +type UserData = protos.LoginResponse + +// NewUser creates a new user account +func (a *Client) NewUser(ctx context.Context) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "new_user") + defer span.End() + + resp, err := a.sendProRequest(ctx, "POST", "/user-create", nil, nil, nil) + if err != nil { + slog.Error("creating new user", "error", err) + return nil, traces.RecordError(ctx, err) + } + var userResp UserDataResponse + if err := json.Unmarshal(resp, &userResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling new user response: %w", err)) + } + userData, err := a.storeData(ctx, userResp) + if err != nil { + return nil, err + } + return userData, nil +} + +// FetchUserData fetches user data from the server. +func (a *Client) FetchUserData(ctx context.Context) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "fetch_user_data") + defer span.End() + return a.fetchUserData(ctx) +} + +// fetchUserData calls the /user-data endpoint and stores the result via storeData. +func (a *Client) fetchUserData(ctx context.Context) (*UserData, error) { + resp, err := a.sendProRequest(ctx, "GET", "/user-data", nil, nil, nil) + if err != nil { + slog.Error("user data", "error", err) + return nil, traces.RecordError(ctx, fmt.Errorf("getting user data: %w", err)) + } + var userResp UserDataResponse + if err := json.Unmarshal(resp, &userResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling new user response: %w", err)) + } + return a.storeData(ctx, userResp) +} + +func (a *Client) storeData(ctx context.Context, resp UserDataResponse) (*UserData, error) { + if resp.BaseResponse != nil && resp.Error != "" { + err := fmt.Errorf("received bad response: %s", resp.Error) + slog.Error("user data", "error", err) + return nil, traces.RecordError(ctx, err) + } + if resp.LoginResponse_UserData == nil { + slog.Error("user data", "error", "no user data in response") + return nil, traces.RecordError(ctx, fmt.Errorf("no user data in response")) + } + resp.DeviceID = settings.GetString(settings.DeviceIDKey) + login := &UserData{ + LegacyID: resp.UserId, + LegacyToken: resp.Token, + LegacyUserData: resp.LoginResponse_UserData, + } + a.setData(login) + return login, nil +} + +// DataCapInfo represents the data cap info +type DataCapInfo struct { + // Whether data cap is enabled for this device/user + Enabled bool `json:"enabled"` + // Data cap usage details (only populated if enabled is true) + Usage *DataCapUsageDetails `json:"usage,omitempty"` +} + +// DataCapUsageDetails contains details of the data cap usage +type DataCapUsageDetails struct { + BytesAllotted string `json:"bytesAllotted"` + BytesUsed string `json:"bytesUsed"` + AllotmentStartTime string `json:"allotmentStartTime"` + AllotmentEndTime string `json:"allotmentEndTime"` +} + +// DataCapInfo returns information about this user's data cap +func (a *Client) DataCapInfo(ctx context.Context) (*DataCapInfo, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "data_cap_info") + defer span.End() + + getURL := "/datacap/" + settings.GetString(settings.DeviceIDKey) + resp, err := a.sendRequest(ctx, "GET", getURL, nil, nil, nil) + if err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("getting datacap info: %w", err)) + } + var usage *DataCapInfo + if err := json.Unmarshal(resp, &usage); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling datacap info response: %w", err)) + } + return usage, nil +} + +// SignUp signs the user up for an account. +func (a *Client) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + salt, err := generateSalt() + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + srpClient, err := newSRPClient(lowerCaseEmail, password, salt) + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + data := &protos.SignupRequest{ + Email: lowerCaseEmail, + Salt: salt, + Verifier: verifierKey.Bytes(), + SkipEmailConfirmation: true, + // Set temp always to true for now + // If new user faces any issue while sign up user can sign up again + Temp: true, + } + + resp, err := a.sendRequest(ctx, "POST", "/users/signup", nil, nil, data) + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + a.setSalt(salt) + + var signupData protos.SignupResponse + if err := proto.Unmarshal(resp, &signupData); err != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling sign up response: %w", err)) + } + idErr := settings.Set(settings.UserIDKey, signupData.LegacyID) + if idErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save user id: %w", idErr)) + } + proTokenErr := settings.Set(settings.TokenKey, signupData.ProToken) + if proTokenErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save token: %w", proTokenErr)) + } + jwtTokenErr := settings.Set(settings.JwtTokenKey, signupData.Token) + if jwtTokenErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save JWT token: %w", jwtTokenErr)) + } + + return salt, &signupData, nil +} + +var ErrNoSalt = errors.New("no salt available") +var ErrNotLoggedIn = errors.New("not logged in") +var ErrInvalidCode = errors.New("invalid code") + +// SignupEmailResendCode requests that the sign-up code be resent via email. +func (a *Client) SignupEmailResendCode(ctx context.Context, email string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_resend_code") + defer span.End() + + salt := a.getSaltCached() + if salt == nil { + return traces.RecordError(ctx, ErrNoSalt) + } + data := &protos.SignupEmailResendRequest{ + Email: email, + Salt: salt, + } + _, err := a.sendRequest(ctx, "POST", "/users/signup/resend/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// SignupEmailConfirmation confirms the new account using the sign-up code received via email. +func (a *Client) SignupEmailConfirmation(ctx context.Context, email, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_confirmation") + defer span.End() + + data := &protos.ConfirmSignupRequest{ + Email: email, + Code: code, + } + _, err := a.sendRequest(ctx, "POST", "/users/signup/complete/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +func writeSalt(salt []byte, path string) error { + if err := os.WriteFile(path, salt, 0600); err != nil { + return fmt.Errorf("writing salt to %s: %w", path, err) + } + return nil +} + +func readSalt(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("reading salt from %s: %w", path, err) + } + if len(buf) == 0 { + return nil, nil + } + return buf, nil +} + +// Login logs the user in. +func (a *Client) Login(ctx context.Context, email, password string) (*UserData, error) { + // clear any previous salt value + a.setSalt(nil) + ctx, span := otel.Tracer(tracerName).Start(ctx, "login") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + deviceID := settings.GetString(settings.DeviceIDKey) + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return nil, err + } + + loginData := &protos.LoginRequest{ + Email: lowerCaseEmail, + DeviceId: deviceID, + Proof: proof, + } + resp, err := a.sendRequest(ctx, "POST", "/users/login", nil, nil, loginData) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + var loginResp UserData + if err := proto.Unmarshal(resp, &loginResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling login response: %w", err)) + } + //this can be nil if the user has reached the device limit + if loginResp.LegacyUserData != nil { + loginResp.LegacyUserData.DeviceID = deviceID + } + + // regardless of state we need to save login information + // We have device flow limit on login + a.setData(&loginResp) + a.setSalt(salt) + if saltErr := writeSalt(salt, a.saltPath); saltErr != nil { + return nil, traces.RecordError(ctx, saltErr) + } + settings.Set(settings.OAuthLoginKey, false) + return &loginResp, nil +} + +// Logout logs the user out. No-op if there is no user account logged in. +func (a *Client) Logout(ctx context.Context, email string) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "logout") + defer span.End() + logout := &protos.LogoutRequest{ + Email: email, + DeviceId: settings.GetString(settings.DeviceIDKey), + LegacyUserID: settings.GetInt64(settings.UserIDKey), + LegacyToken: settings.GetString(settings.TokenKey), + Token: settings.GetString(settings.JwtTokenKey), + } + _, err := a.sendRequest(ctx, "POST", "/users/logout", nil, nil, logout) + if err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("logging out: %w", err)) + } + a.ClearUser() + a.setSalt(nil) + if err := writeSalt(nil, a.saltPath); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("writing salt after logout: %w", err)) + } + return a.NewUser(ctx) +} + +// StartRecoveryByEmail initializes the account recovery process for the provided email. +func (a *Client) StartRecoveryByEmail(ctx context.Context, email string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "start_recovery_by_email") + defer span.End() + + data := &protos.StartRecoveryByEmailRequest{Email: email} + _, err := a.sendRequest(ctx, "POST", "/users/recovery/start/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// CompleteRecoveryByEmail completes account recovery using the code received via email. +func (a *Client) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_recovery_by_email") + defer span.End() + lowerCaseEmail := strings.ToLower(email) + newSalt, err := generateSalt() + if err != nil { + return traces.RecordError(ctx, err) + } + srpClient, err := newSRPClient(lowerCaseEmail, newPassword, newSalt) + if err != nil { + return traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.CompleteRecoveryByEmailRequest{ + Email: email, + Code: code, + NewSalt: newSalt, + NewVerifier: verifierKey.Bytes(), + } + _, err = a.sendRequest(ctx, "POST", "/users/recovery/complete/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to complete recovery by email: %w", err)) + } + if err = writeSalt(newSalt, a.saltPath); err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to write new salt: %w", err)) + } + return nil +} + +// ValidateEmailRecoveryCode validates the recovery code received via email. +func (a *Client) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "validate_email_recovery_code") + defer span.End() + + data := &protos.ValidateRecoveryCodeRequest{ + Email: email, + Code: code, + } + resp, err := a.sendRequest(ctx, "POST", "/users/recovery/validate/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, err) + } + var codeResp protos.ValidateRecoveryCodeResponse + if err := proto.Unmarshal(resp, &codeResp); err != nil { + return traces.RecordError(ctx, fmt.Errorf("error unmarshalling validate recovery code response: %w", err)) + } + if !codeResp.Valid { + return traces.RecordError(ctx, ErrInvalidCode) + } + return nil +} + +// StartChangeEmail initializes a change of the email address associated with this user account. +func (a *Client) StartChangeEmail(ctx context.Context, newEmail, password string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "start_change_email") + defer span.End() + + lowerCaseEmail := strings.ToLower(settings.GetString(settings.EmailKey)) + lowerCaseNewEmail := strings.ToLower(newEmail) + + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return traces.RecordError(ctx, err) + } + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.ChangeEmailRequest{ + OldEmail: lowerCaseEmail, + NewEmail: lowerCaseNewEmail, + Proof: proof, + } + _, err = a.sendRequest(ctx, "POST", "/users/change_email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// CompleteChangeEmail completes a change of the email address associated with this user account, +// using the code received via email. +func (a *Client) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_change_email") + defer span.End() + + newSalt, err := generateSalt() + if err != nil { + return traces.RecordError(ctx, err) + } + + srpClient, err := newSRPClient(newEmail, password, newSalt) + if err != nil { + return traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.CompleteChangeEmailRequest{ + OldEmail: settings.GetString(settings.EmailKey), + NewEmail: newEmail, + Code: code, + NewSalt: newSalt, + NewVerifier: verifierKey.Bytes(), + } + _, err = a.sendRequest(ctx, "POST", "/users/change_email/complete/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, err) + } + if err := writeSalt(newSalt, a.saltPath); err != nil { + return traces.RecordError(ctx, err) + } + if err := settings.Set(settings.EmailKey, newEmail); err != nil { + return traces.RecordError(ctx, err) + } + + a.setSalt(newSalt) + return nil +} + +// DeleteAccount deletes this user account. +func (a *Client) DeleteAccount(ctx context.Context, email, password string) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "delete_account") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + data := &protos.DeleteUserRequest{ + Email: lowerCaseEmail, + Permanent: true, + DeviceId: settings.GetString(settings.DeviceIDKey), + Token: settings.GetString(settings.JwtTokenKey), + } + if !settings.GetBool(settings.OAuthLoginKey) { + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return nil, err + } + data.Proof = proof + } else { + if data.Token == "" { + return nil, traces.RecordError(ctx, errors.New("jwt token is required for OAuth account deletion")) + } + } + + _, err := a.sendRequest(ctx, "POST", "/users/delete", nil, nil, data) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + a.ClearUser() + a.setSalt(nil) + if err := writeSalt(nil, a.saltPath); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to write salt during account deletion cleanup: %w", err)) + } + + return a.NewUser(ctx) +} + +// OAuthLoginURL initiates the OAuth login process for the specified provider. +func (a *Client) OAuthLoginURL(ctx context.Context, provider string) (string, error) { + authURL := a.authURL + if authURL == "" { + authURL = common.GetBaseURL() + } + loginURL, err := url.Parse(authURL + "/users/oauth2/" + provider) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + query := loginURL.Query() + query.Set("deviceId", settings.GetString(settings.DeviceIDKey)) + query.Set("userId", settings.GetString(settings.UserIDKey)) + query.Set("proToken", settings.GetString(settings.TokenKey)) + query.Set("returnTo", "lantern://auth") + loginURL.RawQuery = query.Encode() + return loginURL.String(), nil +} + +func (a *Client) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*UserData, error) { + slog.Debug("Getting OAuth login callback") + jwtUserInfo, err := decodeJWT(oAuthToken) + if err != nil { + return nil, fmt.Errorf("error decoding JWT: %w", err) + } + + // Temporary set user data to so api can read it + login := &UserData{ + LegacyID: jwtUserInfo.LegacyUserID, + LegacyToken: jwtUserInfo.LegacyToken, + LegacyUserData: &protos.LoginResponse_UserData{ + UserId: jwtUserInfo.LegacyUserID, + Token: jwtUserInfo.LegacyToken, + DeviceID: jwtUserInfo.DeviceID, + Email: jwtUserInfo.Email, + }, + } + a.setData(login) + // Get user data from api this will also save data in user config + user, err := a.fetchUserData(ctx) + if err != nil { + return nil, fmt.Errorf("error getting user data: %w", err) + } + + if err := settings.Set(settings.JwtTokenKey, oAuthToken); err != nil { + slog.Error("Failed to persist JWT token", "error", err) + return nil, fmt.Errorf("failed to persist JWT token: %w", err) + } + settings.Set(settings.OAuthLoginKey, true) + user.Id = jwtUserInfo.Email + user.EmailConfirmed = true + a.setData(user) + return user, nil +} + +type LinkResponse struct { + *protos.BaseResponse `json:",inline"` + UserID int `json:"userID"` + ProToken string `json:"token"` +} + +// RemoveDevice removes a device from the user's account. +func (a *Client) RemoveDevice(ctx context.Context, deviceID string) (*LinkResponse, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "remove_device") + defer span.End() + + data := map[string]string{ + "deviceId": deviceID, + } + resp, err := a.sendProRequest(ctx, "POST", "/user-link-remove", nil, nil, data) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + var link LinkResponse + if err := json.Unmarshal(resp, &link); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling remove device response: %w", err)) + } + if link.BaseResponse != nil && link.BaseResponse.Error != "" { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to remove device: %s", link.BaseResponse.Error)) + } + return &link, nil +} + +func (a *Client) ReferralAttach(ctx context.Context, code string) (bool, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "referral_attach") + defer span.End() + + data := map[string]string{ + "code": code, + } + resp, err := a.sendProRequest(ctx, "POST", "/referral-attach", nil, nil, data) + if err != nil { + return false, traces.RecordError(ctx, err) + } + var baseResp protos.BaseResponse + if err := proto.Unmarshal(resp, &baseResp); err != nil { + return false, traces.RecordError(ctx, fmt.Errorf("error unmarshalling referral attach response: %w", err)) + } + if baseResp.Error != "" { + return false, traces.RecordError(ctx, errors.New(baseResp.Error)) + } + return true, nil +} + +type UserChangeEvent struct { + events.Event +} + +func (a *Client) setData(data *UserData) { + a.mu.Lock() + defer a.mu.Unlock() + if data == nil { + a.ClearUser() + return + } + if data.LegacyUserData == nil { + slog.Info("no user data to set") + return + } + + existingUser := settings.GetInt64(settings.UserIDKey) != 0 + + var changed bool + if data.LegacyUserData.UserLevel != "" { + oldUserLevel := settings.GetString(settings.UserLevelKey) + changed = changed || oldUserLevel != data.LegacyUserData.UserLevel + if err := settings.Set(settings.UserLevelKey, data.LegacyUserData.UserLevel); err != nil { + slog.Error("failed to set user level in settings", "error", err) + } + } + if data.LegacyUserData.Email != "" { + oldEmail := settings.GetString(settings.EmailKey) + changed = changed || oldEmail != data.LegacyUserData.Email + if err := settings.Set(settings.EmailKey, data.LegacyUserData.Email); err != nil { + slog.Error("failed to set email in settings", "error", err) + } + } + if data.LegacyID != 0 { + oldUserID := settings.GetInt64(settings.UserIDKey) + changed = changed || oldUserID != data.LegacyID + if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { + slog.Error("failed to set user ID in settings", "error", err) + } + } + if data.LegacyToken != "" { + oldToken := settings.GetString(settings.TokenKey) + changed = changed || oldToken != data.LegacyToken + if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { + slog.Error("failed to set token in settings", "error", err) + } + } + if data.Token != "" { + oldJwtToken := settings.GetString(settings.JwtTokenKey) + changed = changed || oldJwtToken != data.Token + if err := settings.Set(settings.JwtTokenKey, data.Token); err != nil { + slog.Error("failed to set JWT token in settings", "error", err) + } + } + + devices := []settings.Device{} + for _, d := range data.Devices { + devices = append(devices, settings.Device{ + Name: d.Name, + ID: d.Id, + }) + } + if err := settings.Set(settings.DevicesKey, devices); err != nil { + slog.Error("failed to set devices in settings", "error", err) + } + + if err := settings.Set(settings.UserDataKey, data); err != nil { + slog.Error("failed to set login response in settings", "error", err) + } + + // We only consider the user to have changed if there was a previous user. + if existingUser && changed { + events.Emit(UserChangeEvent{}) + } +} + +func (a *Client) ClearUser() { + settings.Clear(settings.UserIDKey) + settings.Clear(settings.TokenKey) + settings.Clear(settings.UserLevelKey) + settings.Clear(settings.EmailKey) + settings.Clear(settings.DevicesKey) + settings.Clear(settings.JwtTokenKey) + settings.Clear(settings.UserDataKey) +} diff --git a/account/user_test.go b/account/user_test.go new file mode 100644 index 00000000..87ee1f1c --- /dev/null +++ b/account/user_test.go @@ -0,0 +1,353 @@ +package account + +import ( + "context" + "encoding/hex" + "encoding/json" + "io" + "math/big" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/1Password/srp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" + "github.com/getlantern/radiance/common/settings" +) + +// testServer holds server-side SRP state for the mock auth server. +type testServer struct { + salt map[string][]byte + verifier []byte + cache map[string]string +} + +func writeProtoResponse(w http.ResponseWriter, msg proto.Message) { + data, err := proto.Marshal(msg) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.Header().Set("Content-Type", "application/x-protobuf") + w.Write(data) +} + +func readProtoRequest(r *http.Request, msg proto.Message) error { + data, err := io.ReadAll(r.Body) + if err != nil { + return err + } + return proto.Unmarshal(data, msg) +} + +func writeJSONResponse(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func newTestServer(t *testing.T) (*httptest.Server, *testServer) { + state := &testServer{ + salt: make(map[string][]byte), + cache: make(map[string]string), + } + mux := http.NewServeMux() + + // Auth endpoints + mux.HandleFunc("/users/salt", func(w http.ResponseWriter, r *http.Request) { + email := r.URL.Query().Get("email") + salt := state.salt[email] + if salt == nil { + salt = []byte("salt") + } + writeProtoResponse(w, &protos.GetSaltResponse{Salt: salt}) + }) + + mux.HandleFunc("/users/signup", func(w http.ResponseWriter, r *http.Request) { + var req protos.SignupRequest + if err := readProtoRequest(r, &req); err != nil { + http.Error(w, err.Error(), 500) + return + } + state.salt[req.Email] = req.Salt + state.verifier = req.Verifier + writeProtoResponse(w, &protos.SignupResponse{}) + }) + + mux.HandleFunc("/users/prepare", func(w http.ResponseWriter, r *http.Request) { + var req protos.PrepareRequest + if err := readProtoRequest(r, &req); err != nil { + http.Error(w, err.Error(), 500) + return + } + A := big.NewInt(0).SetBytes(req.A) + verifier := big.NewInt(0).SetBytes(state.verifier) + server := srp.NewSRPServer(srp.KnownGroups[srp.RFC5054Group3072], verifier, nil) + if err := server.SetOthersPublic(A); err != nil { + http.Error(w, err.Error(), 500) + return + } + B := server.EphemeralPublic() + if B == nil { + http.Error(w, "cannot generate B", 500) + return + } + if _, err := server.Key(); err != nil { + http.Error(w, "cannot generate key", 500) + return + } + proof, err := server.M(state.salt[req.Email], req.Email) + if err != nil { + http.Error(w, "cannot generate proof", 500) + return + } + serverState, _ := server.MarshalBinary() + state.cache[req.Email] = hex.EncodeToString(serverState) + writeProtoResponse(w, &protos.PrepareResponse{B: B.Bytes(), Proof: proof}) + }) + + mux.HandleFunc("/users/login", func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.LoginResponse{ + LegacyUserData: &protos.LoginResponse_UserData{ + DeviceID: "deviceId", + }, + }) + }) + + // Simple auth endpoints that return empty responses + for _, path := range []string{ + "/users/signup/resend/email", + "/users/signup/complete/email", + "/users/recovery/start/email", + "/users/recovery/complete/email", + "/users/change_email", + "/users/change_email/complete/email", + "/users/delete", + "/users/logout", + } { + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.EmptyResponse{}) + }) + } + + mux.HandleFunc("/users/recovery/validate/email", func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.ValidateRecoveryCodeResponse{Valid: true}) + }) + + // Pro server endpoints + mux.HandleFunc("/user-create", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, UserDataResponse{ + BaseResponse: &protos.BaseResponse{}, + LoginResponse_UserData: &protos.LoginResponse_UserData{ + UserId: 123, + Token: "test-token", + }, + }) + }) + + mux.HandleFunc("/user-data", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, UserDataResponse{ + BaseResponse: &protos.BaseResponse{}, + LoginResponse_UserData: &protos.LoginResponse_UserData{ + UserId: 123, + Token: "test-token", + }, + }) + }) + + mux.HandleFunc("/user-link-remove", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, LinkResponse{ + BaseResponse: &protos.BaseResponse{}, + UserID: 123, + ProToken: "token", + }) + }) + + mux.HandleFunc("/referral-attach", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, protos.BaseResponse{}) + }) + + // Subscription endpoints + mux.HandleFunc("/plans-v5", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, SubscriptionPlans{ + BaseResponse: &protos.BaseResponse{}, + Plans: []*protos.Plan{{Id: "1y-usd-10", Description: "Pro Plan"}}, + }) + }) + + mux.HandleFunc("/subscription-payment-redirect", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, map[string]string{"Redirect": "https://example.com/redirect"}) + }) + + mux.HandleFunc("/payment-redirect", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, map[string]string{"Redirect": "https://example.com/redirect"}) + }) + + mux.HandleFunc("/stripe-subscription", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, SubscriptionResponse{ + CustomerID: "cus_123", + SubscriptionID: "sub_123", + ClientSecret: "secret", + }) + }) + + mux.HandleFunc("/purchase-apple-subscription-v2", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, VerifySubscriptionResponse{ + Status: "active", + SubscriptionID: "sub_1234567890", + }) + }) + + mux.HandleFunc("/purchase", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, PurchaseResponse{ + BaseResponse: &protos.BaseResponse{}, + PaymentStatus: "completed", + }) + }) + + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + return ts, state +} + +func newTestClient(t *testing.T) (*Client, *testServer) { + ts, state := newTestServer(t) + settings.InitSettings(t.TempDir()) + t.Cleanup(settings.Reset) + return &Client{ + httpClient: ts.Client(), + proURL: ts.URL, + authURL: ts.URL, + saltPath: filepath.Join(t.TempDir(), saltFileName), + }, state +} + +// newTestClientWithSRP creates a test client and pre-registers an email/password on the mock server. +func newTestClientWithSRP(t *testing.T, email, password string) (*Client, *testServer) { + ac, state := newTestClient(t) + + salt, err := generateSalt() + require.NoError(t, err) + + encKey, err := generateEncryptedKey(password, email, salt) + require.NoError(t, err) + + srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) + verifierKey, err := srpClient.Verifier() + require.NoError(t, err) + + state.salt[email] = salt + state.verifier = verifierKey.Bytes() + ac.salt = salt + + return ac, state +} + +func TestSignUp(t *testing.T) { + ac, _ := newTestClient(t) + salt, signupResponse, err := ac.SignUp(context.Background(), "test@example.com", "password") + assert.NoError(t, err) + assert.NotNil(t, salt) + assert.NotNil(t, signupResponse) +} + +func TestSignupEmailResendCode(t *testing.T) { + ac, _ := newTestClient(t) + ac.salt = []byte("salt") + err := ac.SignupEmailResendCode(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestSignupEmailConfirmation(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.SignupEmailConfirmation(context.Background(), "test@example.com", "code") + assert.NoError(t, err) +} + +func TestLogin(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + // Clear cached salt to test the full flow (getSalt → srpLogin) + ac.salt = nil + _, err := ac.Login(context.Background(), email, "password") + assert.NoError(t, err) +} + +func TestLogout(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.DeviceIDKey, "deviceId") + _, err := ac.Logout(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestStartRecoveryByEmail(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.StartRecoveryByEmail(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestCompleteRecoveryByEmail(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.CompleteRecoveryByEmail(context.Background(), "test@example.com", "newPassword", "code") + assert.NoError(t, err) +} + +func TestValidateEmailRecoveryCode(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.ValidateEmailRecoveryCode(context.Background(), "test@example.com", "code") + assert.NoError(t, err) +} + +func TestStartChangeEmail(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + settings.Set(settings.EmailKey, email) + err := ac.StartChangeEmail(context.Background(), "new@example.com", "password") + assert.NoError(t, err) +} + +func TestCompleteChangeEmail(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.EmailKey, "old@example.com") + err := ac.CompleteChangeEmail(context.Background(), "new@example.com", "password", "code") + assert.NoError(t, err) +} + +func TestDeleteAccount(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + settings.Set(settings.DeviceIDKey, "deviceId") + _, err := ac.DeleteAccount(context.Background(), email, "password") + assert.NoError(t, err) +} + +func TestOAuthLoginUrl(t *testing.T) { + ac, _ := newTestClient(t) + url, err := ac.OAuthLoginURL(context.Background(), "google") + assert.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestOAuthLoginCallback(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.DeviceIDKey, "deviceId") + + // Mock JWT with unverified signature — decodeJWT uses ParseUnverified so this succeeds. + mockToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20iLCJsZWdhY3lfdXNlcl9pZCI6MTIzNDUsImxlZ2FjeV90b2tlbiI6InRlc3QtdG9rZW4ifQ.test" + + data, err := ac.OAuthLoginCallback(context.Background(), mockToken) + assert.NoError(t, err) + assert.NotEmpty(t, data) +} + +func TestOAuthLoginCallback_InvalidToken(t *testing.T) { + ac, _ := newTestClient(t) + + _, err := ac.OAuthLoginCallback(context.Background(), "invalid-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "error decoding JWT") +} diff --git a/api/api.go b/api/api.go deleted file mode 100644 index 5245dca6..00000000 --- a/api/api.go +++ /dev/null @@ -1,59 +0,0 @@ -package api - -import ( - "log/slog" - "path/filepath" - "strconv" - "sync" - - "github.com/go-resty/resty/v2" - - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" -) - -const tracerName = "github.com/getlantern/radiance/api" - -type APIClient struct { - salt []byte - saltPath string - authClient AuthClient - mu sync.RWMutex -} - -func NewAPIClient(dataDir string) *APIClient { - path := filepath.Join(dataDir, saltFileName) - salt, err := readSalt(path) - if err != nil { - slog.Warn("failed to read salt", "error", err) - } - - cli := &APIClient{ - salt: salt, - saltPath: path, - authClient: &authClient{}, - } - return cli -} - -func (a *APIClient) proWebClient() *webClient { - httpClient := kindling.HTTPClient() - proWC := newWebClient(httpClient, common.GetProServerURL()) - proWC.client.OnBeforeRequest(func(client *resty.Client, req *resty.Request) error { - req.Header.Set(backend.DeviceIDHeader, settings.GetString(settings.DeviceIDKey)) - if settings.GetString(settings.TokenKey) != "" { - req.Header.Set(backend.ProTokenHeader, settings.GetString(settings.TokenKey)) - } - if settings.GetInt64(settings.UserIDKey) != 0 { - req.Header.Set(backend.UserIDHeader, strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - } - return nil - }) - return proWC -} - -func authWebClient() *webClient { - return newWebClient(kindling.HTTPClient(), common.GetBaseURL()) -} diff --git a/api/auth.go b/api/auth.go deleted file mode 100644 index 0949c1a2..00000000 --- a/api/auth.go +++ /dev/null @@ -1,178 +0,0 @@ -package api - -import ( - "context" - "fmt" - "strconv" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common/settings" -) - -type AuthClient interface { - // Sign up methods - SignUp(ctx context.Context, email string, password string) ([]byte, *protos.SignupResponse, error) - SignupEmailResendCode(ctx context.Context, data *protos.SignupEmailResendRequest) error - SignupEmailConfirmation(ctx context.Context, data *protos.ConfirmSignupRequest) error - // Login methods - GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) - LoginPrepare(ctx context.Context, loginData *protos.PrepareRequest) (*protos.PrepareResponse, error) - Login(ctx context.Context, email, password, deviceID string, salt []byte) (*protos.LoginResponse, error) - // Recovery methods - StartRecoveryByEmail(ctx context.Context, loginData *protos.StartRecoveryByEmailRequest) error - CompleteRecoveryByEmail(ctx context.Context, loginData *protos.CompleteRecoveryByEmailRequest) error - ValidateEmailRecoveryCode(ctx context.Context, loginData *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) - // Change email methods - ChangeEmail(ctx context.Context, loginData *protos.ChangeEmailRequest) error - // Complete change email methods - CompleteChangeEmail(ctx context.Context, loginData *protos.CompleteChangeEmailRequest) error - DeleteAccount(ctc context.Context, loginData *protos.DeleteUserRequest) error - // Logout - SignOut(ctx context.Context, logoutData *protos.LogoutRequest) error -} - -type authClient struct{} - -// Auth APIS -// GetSalt is used to get the salt for a given email address -func (c *authClient) GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { - var resp protos.GetSaltResponse - query := map[string]string{ - "email": email, - } - header := map[string]string{ - "Content-Type": "application/x-protobuf", - "Accept": "application/x-protobuf", - } - wc := authWebClient() - req := wc.NewRequest(query, header, nil) - if err := wc.Get(ctx, "/users/salt", req, &resp); err != nil { - return nil, err - } - return &resp, nil -} - -// Sign up API -// SignUp is used to sign up a new user with the SignupRequest -func (c *authClient) signUp(ctx context.Context, signupData *protos.SignupRequest) (*protos.SignupResponse, error) { - var resp protos.SignupResponse - header := map[string]string{ - backend.DeviceIDHeader: settings.GetString(settings.DeviceIDKey), - backend.UserIDHeader: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), - backend.ProTokenHeader: settings.GetString(settings.TokenKey), - } - wc := authWebClient() - req := wc.NewRequest(nil, header, signupData) - if err := wc.Post(ctx, "/users/signup", req, &resp); err != nil { - return nil, err - } - return &resp, nil -} - -// SignupEmailResendCode is used to resend the email confirmation code -// Params: ctx context.Context, data *SignupEmailResendRequest -func (c *authClient) SignupEmailResendCode(ctx context.Context, data *protos.SignupEmailResendRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, data) - return wc.Post(ctx, "/users/signup/resend/email", req, &resp) -} - -// SignupEmailConfirmation is used to confirm the email address once user enter code -// Params: ctx context.Context, data *ConfirmSignupRequest -func (c *authClient) SignupEmailConfirmation(ctx context.Context, data *protos.ConfirmSignupRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, data) - return wc.Post(ctx, "/users/signup/complete/email", req, &resp) -} - -// LoginPrepare does the initial login preparation with come make sure the user exists and match user salt -func (c *authClient) LoginPrepare(ctx context.Context, loginData *protos.PrepareRequest) (*protos.PrepareResponse, error) { - var model protos.PrepareResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - if err := wc.Post(ctx, "/users/prepare", req, &model); err != nil { - // Send custom error to show error on client side - return nil, fmt.Errorf("user_not_found %w", err) - } - return &model, nil -} - -// Login is used to login a user with the LoginRequest -func (c *authClient) login(ctx context.Context, loginData *protos.LoginRequest) (*protos.LoginResponse, error) { - var resp protos.LoginResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - if err := wc.Post(ctx, "/users/login", req, &resp); err != nil { - return nil, err - } - - return &resp, nil -} - -// StartRecoveryByEmail is used to start the recovery process by sending a recovery code to the user's email -func (c *authClient) StartRecoveryByEmail(ctx context.Context, loginData *protos.StartRecoveryByEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/recovery/start/email", req, &resp) -} - -// CompleteRecoveryByEmail is used to complete the recovery process by validating the recovery code -func (c *authClient) CompleteRecoveryByEmail(ctx context.Context, loginData *protos.CompleteRecoveryByEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/recovery/complete/email", req, &resp) -} - -// // ValidateEmailRecoveryCode is used to validate the recovery code -func (c *authClient) ValidateEmailRecoveryCode(ctx context.Context, recoveryData *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) { - var resp protos.ValidateRecoveryCodeResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, recoveryData) - err := wc.Post(ctx, "/users/recovery/validate/email", req, &resp) - if err != nil { - return nil, err - } - if !resp.Valid { - return nil, fmt.Errorf("invalid_code Error decoding response body: %w", err) - } - return &resp, nil -} - -// ChangeEmail is used to change the email address of a user -func (c *authClient) ChangeEmail(ctx context.Context, loginData *protos.ChangeEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/change_email", req, &resp) -} - -// CompleteChangeEmail is used to complete the email change process -func (c *authClient) CompleteChangeEmail(ctx context.Context, loginData *protos.CompleteChangeEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/change_email/complete/email", req, &resp) -} - -// DeleteAccount is used to delete the account of a user -// Once account is delete make sure to create new account -func (c *authClient) DeleteAccount(ctx context.Context, accountData *protos.DeleteUserRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, accountData) - return wc.Post(ctx, "/users/delete", req, &resp) -} - -// DeleteAccount is used to delete the account of a user -// Once account is delete make sure to create new account -func (c *authClient) SignOut(ctx context.Context, logoutData *protos.LogoutRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, logoutData) - return wc.Post(ctx, "/users/logout", req, &resp) -} diff --git a/api/srp.go b/api/srp.go deleted file mode 100644 index 8d0d60c7..00000000 --- a/api/srp.go +++ /dev/null @@ -1,138 +0,0 @@ -package api - -import ( - "context" - "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "math/big" - "strings" - - "github.com/1Password/srp" - "golang.org/x/crypto/pbkdf2" - - "github.com/getlantern/radiance/api/protos" -) - -func newSRPClient(email string, password string, salt []byte) (*srp.SRP, error) { - if len(salt) == 0 || len(password) == 0 || len(email) == 0 { - return nil, errors.New("salt, password and email should not be empty") - } - - lowerCaseEmail := strings.ToLower(email) - encryptedKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return nil, fmt.Errorf("failed to generate encrypted key: %w", err) - } - - return srp.NewSRPClient(srp.KnownGroups[group], encryptedKey, nil), nil -} - -// Takes password and email, salt and returns encrypted key -func generateEncryptedKey(password string, email string, salt []byte) (*big.Int, error) { - if len(salt) == 0 || len(password) == 0 || len(email) == 0 { - return nil, errors.New("salt or password or email is empty") - } - lowerCaseEmail := strings.ToLower(email) - combinedInput := password + lowerCaseEmail - encryptedKey := pbkdf2.Key([]byte(combinedInput), salt, 4096, 32, sha256.New) - encryptedKeyBigInt := big.NewInt(0).SetBytes(encryptedKey) - return encryptedKeyBigInt, nil -} - -func generateSalt() ([]byte, error) { - salt := make([]byte, 16) - if n, err := rand.Read(salt); err != nil { - return nil, err - } else if n != 16 { - return nil, errors.New("failed to generate 16 byte salt") - } - return salt, nil -} - -func (c *authClient) SignUp(ctx context.Context, email string, password string) ([]byte, *protos.SignupResponse, error) { - lowerCaseEmail := strings.ToLower(email) - salt, err := generateSalt() - if err != nil { - return nil, nil, err - } - srpClient, err := newSRPClient(lowerCaseEmail, password, salt) - if err != nil { - return nil, nil, err - } - verifierKey, err := srpClient.Verifier() - if err != nil { - return nil, nil, err - } - signUpRequestBody := &protos.SignupRequest{ - Email: lowerCaseEmail, - Salt: salt, - Verifier: verifierKey.Bytes(), - SkipEmailConfirmation: true, - // Set temp always to true for now - // If new user faces any issue while sign up user can sign up again - Temp: true, - } - - body, err := c.signUp(ctx, signUpRequestBody) - if err != nil { - return salt, nil, err - } - return salt, body, nil -} - -// Todo find way to optimize this method -func (c *authClient) Login(ctx context.Context, email string, password string, deviceId string, salt []byte) (*protos.LoginResponse, error) { - lowerCaseEmail := strings.ToLower(email) - - // Prepare login request body - client, err := newSRPClient(lowerCaseEmail, password, salt) - if err != nil { - return nil, err - } - //Send this key to client - A := client.EphemeralPublic() - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := c.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return nil, err - } - - // // Once the client receives B from the server Client should check error status here as defense against - // // a malicious B sent from server - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return nil, err - } - - // client can now make the session key - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return nil, fmt.Errorf("user_not_found error while generating Client key %w", err) - } - - // Step 3 - - // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return nil, fmt.Errorf("user_not_found error while checking server proof %w", err) - } - - clientProof, err := client.ClientProof() - if err != nil { - return nil, fmt.Errorf("user_not_found error while generating client proof %w", err) - } - loginRequestBody := &protos.LoginRequest{ - Email: lowerCaseEmail, - Proof: clientProof, - DeviceId: deviceId, - } - return c.login(ctx, loginRequestBody) -} diff --git a/api/subscription_test.go b/api/subscription_test.go deleted file mode 100644 index 0fa2c052..00000000 --- a/api/subscription_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package api - -import ( - "context" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/api/protos" -) - -func TestSubscriptionPaymentRedirect(t *testing.T) { - ac := mockAPIClient(t) - data := PaymentRedirectData{ - Provider: "stripe", - Plan: "pro", - DeviceName: "test-device", - Email: "", - BillingType: SubscriptionTypeOneTime, - } - url, err := ac.SubscriptionPaymentRedirectURL(context.Background(), data) - require.NoError(t, err) - assert.NotEmpty(t, url) -} -func TestPaymentRedirect(t *testing.T) { - ac := mockAPIClient(t) - data := PaymentRedirectData{ - Provider: "stripe", - Plan: "pro", - DeviceName: "test-device", - Email: "", - } - url, err := ac.PaymentRedirect(context.Background(), data) - require.NoError(t, err) - assert.NotEmpty(t, url) -} - -func TestNewUser(t *testing.T) { - ac := mockAPIClient(t) - resp, err := ac.NewUser(context.Background()) - require.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestVerifySubscription(t *testing.T) { - ac := mockAPIClient(t) - email := "test@getlantern.org" - planID := "1y-usd-10" - data := map[string]string{ - "email": email, - "planID": planID, - } - status, subID, err := ac.VerifySubscription(context.Background(), AppleService, data) - require.NoError(t, err) - assert.NotEmpty(t, status) - assert.NotEmpty(t, subID) -} - -func TestPlans(t *testing.T) { - ac := mockAPIClient(t) - resp, err := ac.SubscriptionPlans(context.Background(), "store") - require.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Plans) -} - -type MockAPIClient struct { - *APIClient -} - -func mockAPIClient(t *testing.T) *MockAPIClient { - return &MockAPIClient{ - APIClient: &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - salt: []byte{1, 2, 3, 4, 5}, - }, - } -} - -func (m *MockAPIClient) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (status, subID string, err error) { - return "active", "sub_1234567890", nil -} - -func (m *MockAPIClient) SubscriptionPlans(ctx context.Context, channel string) (*SubscriptionPlans, error) { - resp := &SubscriptionPlans{ - BaseResponse: &protos.BaseResponse{}, - Plans: []*protos.Plan{ - {Id: "1y-usd-10", Description: "Pro Plan", Price: map[string]int64{}}, - }, - } - return resp, nil -} -func (m *MockAPIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { - return "https://example.com/redirect", nil -} - -func (m *MockAPIClient) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { - return "https://example.com/redirect", nil -} -func (m *MockAPIClient) NewUser(ctx context.Context) (*protos.LoginResponse, error) { - return &protos.LoginResponse{}, nil -} diff --git a/api/user.go b/api/user.go deleted file mode 100644 index e3094d76..00000000 --- a/api/user.go +++ /dev/null @@ -1,818 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "math/big" - "net/url" - "os" - "strconv" - "strings" - "time" - - "github.com/1Password/srp" - - "go.opentelemetry.io/otel" - "google.golang.org/protobuf/proto" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/traces" -) - -// The main output of this file is Radiance.GetUser, which provides a hook into all user account -// functionality. - -const saltFileName = ".salt" - -// pro-server requests -type UserDataResponse struct { - *protos.BaseResponse `json:",inline"` - *protos.LoginResponse_UserData `json:",inline"` -} - -// NewUser creates a new user account -func (ac *APIClient) NewUser(ctx context.Context) (*protos.LoginResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "new_user") - defer span.End() - - var resp UserDataResponse - header := map[string]string{ - backend.ContentTypeHeader: "application/json", - } - req := ac.proWebClient().NewRequest(nil, header, nil) - err := ac.proWebClient().Post(ctx, "/user-create", req, &resp) - if err != nil { - slog.Error("creating new user", "error", err) - return nil, traces.RecordError(ctx, err) - } - loginResponse, err := ac.storeData(ctx, resp) - if err != nil { - return nil, err - } - return loginResponse, nil -} - -func (ac *APIClient) UserData() ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - slog.Debug("Getting user data") - user := &protos.LoginResponse{} - err := settings.GetStruct(settings.LoginResponseKey, user) - return withMarshalProto(user, err) - }) -} - -// FetchUserData fetches user data from the server. -func (ac *APIClient) FetchUserData(ctx context.Context) ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "fetch_user_data") - defer span.End() - return withMarshalProto(ac.fetchUserData(ctx)) - }) -} - -// fetchUserData calls the /user-data endpoint and stores the result via storeData. -func (ac *APIClient) fetchUserData(ctx context.Context) (*protos.LoginResponse, error) { - var resp UserDataResponse - err := ac.proWebClient().Get(ctx, "/user-data", nil, &resp) - if err != nil { - slog.Error("user data", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("getting user data: %w", err)) - } - return ac.storeData(ctx, resp) -} - -func (a *APIClient) storeData(ctx context.Context, resp UserDataResponse) (*protos.LoginResponse, error) { - if resp.BaseResponse != nil && resp.Error != "" { - err := fmt.Errorf("received bad response: %s", resp.Error) - slog.Error("user data", "error", err) - return nil, traces.RecordError(ctx, err) - } - if resp.LoginResponse_UserData == nil { - slog.Error("user data", "error", "no user data in response") - return nil, traces.RecordError(ctx, fmt.Errorf("no user data in response")) - } - // Append device ID to user data - resp.LoginResponse_UserData.DeviceID = settings.GetString(settings.DeviceIDKey) - login := &protos.LoginResponse{ - LegacyID: resp.UserId, - LegacyToken: resp.Token, - LegacyUserData: resp.LoginResponse_UserData, - } - a.setData(login) - return login, nil -} - -// user-server requests - -// Devices returns a list of devices associated with this user account. -func (a *APIClient) Devices() ([]settings.Device, error) { - return settings.Devices() -} - -// DataCapUsageResponse represents the data cap usage response -type DataCapUsageResponse struct { - // Whether data cap is enabled for this device/user - Enabled bool `json:"enabled"` - // Data cap usage details (only populated if enabled is true) - Usage *DataCapUsageDetails `json:"usage,omitempty"` -} - -// DataCapUsageDetails contains details of the data cap usage -type DataCapUsageDetails struct { - BytesAllotted string `json:"bytesAllotted"` - BytesUsed string `json:"bytesUsed"` - AllotmentStartTime string `json:"allotmentStartTime"` - AllotmentEndTime string `json:"allotmentEndTime"` -} - -// fetchDataCap fetches the current datacap usage from the server. -func (a *APIClient) fetchDataCap(ctx context.Context) (*DataCapUsageResponse, error) { - datacap := &DataCapUsageResponse{} - headers := map[string]string{ - backend.ContentTypeHeader: "application/json", - } - getURL := fmt.Sprintf("/datacap/%s", settings.GetString(settings.DeviceIDKey)) - authWc := authWebClient() - newReq := authWc.NewRequest(nil, headers, nil) - err := authWc.Get(ctx, getURL, newReq, datacap) - return datacap, err -} - -// DataCapInfo returns information about this user's data cap -func (a *APIClient) DataCapInfo(ctx context.Context) (string, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "data_cap_info") - defer span.End() - datacap, err := a.fetchDataCap(ctx) - return withMarshalJsonString(datacap, err) -} - -type DataCapChangeEvent struct { - events.Event - *DataCapUsageResponse -} - -// DataCapPollInterval controls how often the datacap polling loop fetches updated usage. -// The server previously sent updates every 30s via SSE, so we match that cadence. -var DataCapPollInterval = 30 * time.Second - -// DataCapStream polls the datacap endpoint periodically and emits DataCapChangeEvent -// whenever the usage data changes. It blocks until ctx is cancelled. -// -// This replaces the previous SSE-based implementation which was incompatible with -// domain-fronted connections (CDNs buffer SSE responses, causing 60s timeouts). -func (a *APIClient) DataCapStream(ctx context.Context) error { - ticker := time.NewTicker(DataCapPollInterval) - defer ticker.Stop() - var last string - // Perform an initial poll before entering the ticker loop. - a.pollDataCap(ctx, &last) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - a.pollDataCap(ctx, &last) - } - } -} - -func (a *APIClient) pollDataCap(ctx context.Context, last *string) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "data_cap_poll") - defer span.End() - - datacap, err := a.fetchDataCap(ctx) - if err != nil { - slog.Debug("datacap poll error", "error", err) - traces.RecordError(ctx, err) - return - } - - jsonBytes, err := json.Marshal(datacap) - if err != nil { - slog.Debug("datacap poll marshal error", "error", err) - traces.RecordError(ctx, err) - return - } - current := string(jsonBytes) - if current != *last { - *last = current - events.Emit(DataCapChangeEvent{DataCapUsageResponse: datacap}) - if datacap.Usage != nil { - slog.Debug("datacap updated", "bytesUsed", datacap.Usage.BytesUsed) - } - } -} - -// SignUp signs the user up for an account. -func (a *APIClient) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up") - defer span.End() - - salt, signupResponse, err := a.authClient.SignUp(ctx, email, password) - if err != nil { - return nil, nil, traces.RecordError(ctx, err) - } - a.salt = salt - - idErr := settings.Set(settings.UserIDKey, signupResponse.LegacyID) - if idErr != nil { - return nil, nil, fmt.Errorf("could not save user id: %w", idErr) - } - proTokenErr := settings.Set(settings.TokenKey, signupResponse.ProToken) - if proTokenErr != nil { - return nil, nil, fmt.Errorf("could not save token: %w", proTokenErr) - } - jwtTokenErr := settings.Set(settings.JwtTokenKey, signupResponse.Token) - if jwtTokenErr != nil { - return nil, nil, fmt.Errorf("could not save JWT token: %w", jwtTokenErr) - } - - return salt, signupResponse, nil -} - -var ErrNoSalt = errors.New("not salt available, call GetSalt/Signup first") -var ErrNotLoggedIn = errors.New("not logged in") -var ErrInvalidCode = errors.New("invalid code") - -// SignupEmailResendCode requests that the sign-up code be resent via email. -func (a *APIClient) SignupEmailResendCode(ctx context.Context, email string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_resend_code") - defer span.End() - - if a.salt == nil { - return traces.RecordError(ctx, ErrNoSalt) - } - return traces.RecordError(ctx, a.authClient.SignupEmailResendCode(ctx, &protos.SignupEmailResendRequest{ - Email: email, - Salt: a.salt, - })) -} - -// SignupEmailConfirmation confirms the new account using the sign-up code received via email. -func (a *APIClient) SignupEmailConfirmation(ctx context.Context, email, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_confirmation") - defer span.End() - - return traces.RecordError(ctx, a.authClient.SignupEmailConfirmation(ctx, &protos.ConfirmSignupRequest{ - Email: email, - Code: code, - })) -} - -func writeSalt(salt []byte, path string) error { - if err := os.WriteFile(path, salt, 0600); err != nil { - return fmt.Errorf("writing salt to %s: %w", path, err) - } - return nil -} - -func readSalt(path string) ([]byte, error) { - buf, err := os.ReadFile(path) - if err != nil && !os.IsNotExist(err) { - return nil, fmt.Errorf("reading salt from %s: %w", path, err) - } - if len(buf) == 0 { - return nil, nil - } - return buf, nil -} - -// getSalt retrieves the salt for the given email address or it's cached value. -func (a *APIClient) getSalt(ctx context.Context, email string) ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "get_salt") - defer span.End() - - if a.salt != nil { - return a.salt, nil // use cached value - } - resp, err := a.authClient.GetSalt(ctx, email) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - return resp.Salt, nil -} - -// Login logs the user in. -func (a *APIClient) Login(ctx context.Context, email string, password string) ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - // clear any previous salt value - a.salt = nil - ctx, span := otel.Tracer(tracerName).Start(ctx, "login") - defer span.End() - - salt, err := a.getSalt(ctx, email) - if err != nil { - return nil, err - } - - deviceId := settings.GetString(settings.DeviceIDKey) - resp, err := a.authClient.Login(ctx, email, password, deviceId, salt) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - //this can be nil if the user has reached the device limit - if resp.LegacyUserData != nil { - // Append device ID to user data - resp.LegacyUserData.DeviceID = deviceId - } - - // regardless of state we need to save login information - // We have device flow limit on login - a.setData(resp) - a.salt = salt - if saltErr := writeSalt(salt, a.saltPath); saltErr != nil { - return nil, traces.RecordError(ctx, saltErr) - } - return withMarshalProto(resp, nil) - }) -} - -// Logout logs the user out. No-op if there is no user account logged in. -func (a *APIClient) Logout(ctx context.Context, email string) ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "logout") - defer span.End() - logout := &protos.LogoutRequest{ - Email: email, - DeviceId: settings.GetString(settings.DeviceIDKey), - LegacyUserID: settings.GetInt64(settings.UserIDKey), - LegacyToken: settings.GetString(settings.TokenKey), - Token: settings.GetString(settings.JwtTokenKey), - } - if err := a.authClient.SignOut(ctx, logout); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("logging out: %w", err)) - } - a.Reset() - a.salt = nil - if err := writeSalt(nil, a.saltPath); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("writing salt after logout: %w", err)) - } - return withMarshalProto(a.NewUser(context.Background())) - }) -} - -func withMarshalProto(resp *protos.LoginResponse, err error) ([]byte, error) { - if err != nil { - return nil, err - } - protoUserData, err := proto.Marshal(resp) - if err != nil { - return nil, fmt.Errorf("error marshalling login response: %w", err) - } - return protoUserData, nil -} - -func withMarshalJson(data any, err error) ([]byte, error) { - if err != nil { - return nil, err - } - jsonData, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("error marshalling user data: %w", err) - } - return jsonData, nil -} - -func withMarshalJsonString(data any, err error) (string, error) { - raw, err := withMarshalJson(data, err) - return string(raw), err -} - -// StartRecoveryByEmail initializes the account recovery process for the provided email. -func (a *APIClient) StartRecoveryByEmail(ctx context.Context, email string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "start_recovery_by_email") - defer span.End() - - return traces.RecordError(ctx, a.authClient.StartRecoveryByEmail(ctx, &protos.StartRecoveryByEmailRequest{ - Email: email, - })) -} - -// CompleteRecoveryByEmail completes account recovery using the code received via email. -func (a *APIClient) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_recovery_by_email") - defer span.End() - lowerCaseEmail := strings.ToLower(email) - newSalt, err := generateSalt() - if err != nil { - return traces.RecordError(ctx, err) - } - srpClient, err := newSRPClient(lowerCaseEmail, newPassword, newSalt) - if err != nil { - return traces.RecordError(ctx, err) - } - verifierKey, err := srpClient.Verifier() - if err != nil { - return traces.RecordError(ctx, err) - } - - err = a.authClient.CompleteRecoveryByEmail(ctx, &protos.CompleteRecoveryByEmailRequest{ - Email: email, - Code: code, - NewSalt: newSalt, - NewVerifier: verifierKey.Bytes(), - }) - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to complete recovery by email: %w", err)) - } - if err = writeSalt(newSalt, a.saltPath); err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to write new salt: %w", err)) - } - return nil -} - -// ValidateEmailRecoveryCode validates the recovery code received via email. -func (a *APIClient) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "validate_email_recovery_code") - defer span.End() - - resp, err := a.authClient.ValidateEmailRecoveryCode(ctx, &protos.ValidateRecoveryCodeRequest{ - Email: email, - Code: code, - }) - if err != nil { - return traces.RecordError(ctx, err) - } - if !resp.Valid { - return traces.RecordError(ctx, ErrInvalidCode) - } - return nil -} - -const group = srp.RFC5054Group3072 - -// StartChangeEmail initializes a change of the email address associated with this user account. -func (a *APIClient) StartChangeEmail(ctx context.Context, newEmail string, password string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "start_change_email") - defer span.End() - lowerCaseEmail := strings.ToLower(settings.GetString(settings.EmailKey)) - lowerCaseNewEmail := strings.ToLower(newEmail) - salt, err := a.getSalt(ctx, lowerCaseEmail) - if err != nil { - return traces.RecordError(ctx, err) - } - // Prepare login request body - encKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return traces.RecordError(ctx, err) - } - client := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - - //Send this key to client - A := client.EphemeralPublic() - - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := a.authClient.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return traces.RecordError(ctx, err) - } - // Once the client receives B from the server Client should check error status here as defense against - // a malicious B sent from server - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return traces.RecordError(ctx, err) - } - - // client can now make the session key - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating Client key %w", err)) - } - - // // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while checking server proof %w", err)) - } - - clientProof, err := client.ClientProof() - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating client proof %w", err)) - } - - changeEmailRequestBody := &protos.ChangeEmailRequest{ - OldEmail: lowerCaseEmail, - NewEmail: lowerCaseNewEmail, - Proof: clientProof, - } - - return traces.RecordError(ctx, a.authClient.ChangeEmail(ctx, changeEmailRequestBody)) -} - -// CompleteChangeEmail completes a change of the email address associated with this user account, -// using the code recieved via email. -func (a *APIClient) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_change_email") - defer span.End() - newSalt, err := generateSalt() - if err != nil { - return traces.RecordError(ctx, err) - } - - encKey, err := generateEncryptedKey(password, newEmail, newSalt) - if err != nil { - return traces.RecordError(ctx, err) - } - - srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - verifierKey, err := srpClient.Verifier() - if err != nil { - return traces.RecordError(ctx, err) - } - if err := a.authClient.CompleteChangeEmail(ctx, &protos.CompleteChangeEmailRequest{ - OldEmail: settings.GetString(settings.EmailKey), - NewEmail: newEmail, - Code: code, - NewSalt: newSalt, - NewVerifier: verifierKey.Bytes(), - }); err != nil { - return traces.RecordError(ctx, err) - } - if err := writeSalt(newSalt, a.saltPath); err != nil { - return traces.RecordError(ctx, err) - } - if err := settings.Set(settings.EmailKey, newEmail); err != nil { - return traces.RecordError(ctx, err) - } - - a.salt = newSalt - return nil -} - -// DeleteAccount deletes this user account. -func (a *APIClient) DeleteAccount(ctx context.Context, email, password string, isOAuthUser bool) ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "delete_account") - defer span.End() - var deleteRequestBody *protos.DeleteUserRequest - lowerCaseEmail := strings.ToLower(email) - if !isOAuthUser { - salt, err := a.getSalt(ctx, lowerCaseEmail) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - // Prepare login request body - encKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - client := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - - //Send this key to client - A := client.EphemeralPublic() - - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := a.authClient.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return nil, traces.RecordError(ctx, err) - } - - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return nil, traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating Client key %w", err)) - } - - // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return nil, traces.RecordError(ctx, errors.New("user_not_found error while checking server proof")) - } - - clientProof, err := client.ClientProof() - if err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating client proof %w", err)) - } - deleteRequestBody = &protos.DeleteUserRequest{ - Email: lowerCaseEmail, - Proof: clientProof, - Permanent: true, - DeviceId: settings.GetString(settings.DeviceIDKey), - Token: settings.GetString(settings.JwtTokenKey), - } - } else { - jwtToken := settings.GetString(settings.JwtTokenKey) - if jwtToken == "" { - return nil, traces.RecordError(ctx, errors.New("jwt token is required for OAuth account deletion")) - } - deleteRequestBody = &protos.DeleteUserRequest{ - Email: lowerCaseEmail, - Permanent: true, - Token: jwtToken, - DeviceId: settings.GetString(settings.DeviceIDKey), - } - } - if err := a.authClient.DeleteAccount(ctx, deleteRequestBody); err != nil { - return nil, traces.RecordError(ctx, err) - } - // clean up local data - a.Reset() - a.salt = nil - if err := writeSalt(nil, a.saltPath); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to write salt during account deletion cleanup: %w", err)) - } - - return withMarshalProto(a.NewUser(context.Background())) - }) -} - -// OAuthLoginUrl initiates the OAuth login process for the specified provider. -func (a *APIClient) OAuthLoginUrl(ctx context.Context, provider string) (string, error) { - loginURL, err := url.Parse(fmt.Sprintf("%s/%s/%s", common.GetBaseURL(), "users/oauth2", provider)) - if err != nil { - return "", fmt.Errorf("failed to parse URL: %w", err) - } - query := loginURL.Query() - query.Set("deviceId", settings.GetString(settings.DeviceIDKey)) - query.Set("userId", strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - query.Set("proToken", settings.GetString(settings.TokenKey)) - query.Set("returnTo", "lantern://auth") - loginURL.RawQuery = query.Encode() - return loginURL.String(), nil -} - -func (a *APIClient) OAuthLoginCallback(ctx context.Context, oAuthToken string) ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - slog.Debug("Getting OAuth login callback") - jwtUserInfo, err := decodeJWT(oAuthToken) - if err != nil { - return nil, fmt.Errorf("error decoding JWT: %w", err) - } - - // Temporary set user data to so api can read it - login := &protos.LoginResponse{ - LegacyID: jwtUserInfo.LegacyUserID, - LegacyToken: jwtUserInfo.LegacyToken, - LegacyUserData: &protos.LoginResponse_UserData{ - UserId: jwtUserInfo.LegacyUserID, - Token: jwtUserInfo.LegacyToken, - DeviceID: jwtUserInfo.DeviceId, - Email: jwtUserInfo.Email, - }, - } - a.setData(login) - // Get user data from api this will also save data in user config - user, err := a.fetchUserData(context.Background()) - if err != nil { - return nil, fmt.Errorf("error getting user data: %w", err) - } - - if err := settings.Set(settings.JwtTokenKey, oAuthToken); err != nil { - slog.Error("Failed to persist JWT token", "error", err) - return nil, fmt.Errorf("failed to persist JWT token: %w", err) - } - user.Id = jwtUserInfo.Email - user.EmailConfirmed = true - a.setData(user) - return withMarshalProto(user, nil) - }) -} - -type LinkResponse struct { - *protos.BaseResponse `json:",inline"` - UserID int `json:"userID"` - ProToken string `json:"token"` -} - -// RemoveDevice removes a device from the user's account. -func (a *APIClient) RemoveDevice(ctx context.Context, deviceID string) (*LinkResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "remove_device") - defer span.End() - - data := map[string]string{ - "deviceId": deviceID, - } - proWC := a.proWebClient() - req := proWC.NewRequest(nil, nil, data) - resp := &LinkResponse{} - if err := proWC.Post(ctx, "/user-link-remove", req, resp); err != nil { - return nil, traces.RecordError(ctx, err) - } - if resp.BaseResponse != nil && resp.BaseResponse.Error != "" { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to remove device: %s", resp.BaseResponse.Error)) - } - return resp, nil -} - -func (a *APIClient) ReferralAttach(ctx context.Context, code string) (bool, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "referral_attach") - defer span.End() - - data := map[string]string{ - "code": code, - } - proWC := a.proWebClient() - req := proWC.NewRequest(nil, nil, data) - resp := &protos.BaseResponse{} - if err := proWC.Post(ctx, "/referral-attach", req, resp); err != nil { - return false, traces.RecordError(ctx, err) - } - if resp.Error != "" { - return false, traces.RecordError(ctx, fmt.Errorf("%s", resp.Error)) - } - return true, nil -} - -func (a *APIClient) setData(data *protos.LoginResponse) { - a.mu.Lock() - defer a.mu.Unlock() - if data == nil { - a.Reset() - return - } - var changed bool - if data.LegacyUserData == nil { - slog.Info("no user data to set") - return - } - - existingUser := settings.GetInt64(settings.UserIDKey) != 0 - - if data.LegacyUserData.UserLevel != "" { - oldUserLevel := settings.GetString(settings.UserLevelKey) - changed = changed || oldUserLevel != data.LegacyUserData.UserLevel - if err := settings.Set(settings.UserLevelKey, data.LegacyUserData.UserLevel); err != nil { - slog.Error("failed to set user level in settings", "error", err) - } - } - if data.LegacyUserData.Email != "" { - oldEmail := settings.GetString(settings.EmailKey) - changed = changed && oldEmail != data.LegacyUserData.Email - if err := settings.Set(settings.EmailKey, data.LegacyUserData.Email); err != nil { - slog.Error("failed to set email in settings", "error", err) - } - } - if data.LegacyID != 0 { - oldUserID := settings.GetInt64(settings.UserIDKey) - changed = changed && oldUserID != data.LegacyID - if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { - slog.Error("failed to set user ID in settings", "error", err) - } - } - if data.LegacyToken != "" { - oldToken := settings.GetString(settings.TokenKey) - changed = changed && oldToken != data.LegacyToken - if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { - slog.Error("failed to set token in settings", "error", err) - } - } - if data.Token != "" { - oldJwtToken := settings.GetString(settings.JwtTokenKey) - changed = changed && oldJwtToken != data.Token - if err := settings.Set(settings.JwtTokenKey, data.Token); err != nil { - slog.Error("failed to set JWT token in settings", "error", err) - } - } - - devices := []settings.Device{} - for _, d := range data.Devices { - devices = append(devices, settings.Device{ - Name: d.Name, - ID: d.Id, - }) - } - if err := settings.Set(settings.DevicesKey, devices); err != nil { - slog.Error("failed to set devices in settings", "error", err) - } - - if err := settings.Set(settings.LoginResponseKey, data); err != nil { - slog.Error("failed to set login response in settings", "error", err) - } - - // We only consider the user to have changed if there was a previous user. - if existingUser && changed { - events.Emit(settings.UserChangeEvent{}) - } -} - -func (a *APIClient) Reset() { - // Clear user data - settings.Set(settings.UserIDKey, int64(0)) - settings.Set(settings.TokenKey, "") - settings.Set(settings.UserLevelKey, "") - settings.Set(settings.EmailKey, "") - settings.Set(settings.DevicesKey, []settings.Device{}) -} diff --git a/api/user_test.go b/api/user_test.go deleted file mode 100644 index f688be36..00000000 --- a/api/user_test.go +++ /dev/null @@ -1,314 +0,0 @@ -package api - -import ( - "context" - "encoding/hex" - "errors" - "math/big" - "path/filepath" - "testing" - - "github.com/1Password/srp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/common/settings" -) - -func TestSignUp(t *testing.T) { - settings.InitSettings(t.TempDir()) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - salt, signupResponse, err := ac.SignUp(context.Background(), "test@example.com", "password") - assert.NoError(t, err) - assert.NotNil(t, salt) - assert.NotNil(t, signupResponse) -} - -func TestSignupEmailResendCode(t *testing.T) { - ac := &APIClient{ - salt: []byte("salt"), - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.SignupEmailResendCode(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestSignupEmailConfirmation(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.SignupEmailConfirmation(context.Background(), "test@example.com", "code") - assert.NoError(t, err) -} - -func TestLogin(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - _, err := ac.Login(context.Background(), "test@example.com", "password") - assert.NoError(t, err) -} - -func TestLogout(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - _, err := ac.Logout(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestStartRecoveryByEmail(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.StartRecoveryByEmail(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestCompleteRecoveryByEmail(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.CompleteRecoveryByEmail(context.Background(), "test@example.com", "newPassword", "code") - assert.NoError(t, err) -} - -func TestValidateEmailRecoveryCode(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.ValidateEmailRecoveryCode(context.Background(), "test@example.com", "code") - assert.NoError(t, err) -} - -func TestStartChangeEmail(t *testing.T) { - email := "test@example.com" - settings.Set(settings.EmailKey, email) - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - err := ac.StartChangeEmail(context.Background(), "new@example.com", "password") - assert.NoError(t, err) -} - -func TestCompleteChangeEmail(t *testing.T) { - old := "old@example.com" - tmp := t.TempDir() - err := settings.InitSettings(tmp) - require.NoError(t, err) - settings.Set(settings.EmailKey, old) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err = ac.CompleteChangeEmail(context.Background(), "new@example.com", "password", "code") - assert.NoError(t, err) -} - -func TestDeleteAccount(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", false) - assert.NoError(t, err) -} - -func TestDeleteAccount_OAuthUser(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - settings.Set(settings.JwtTokenKey, "jwt-token") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", true) - assert.NoError(t, err) -} -func TestDeleteAccount_OAuthUser_MissingJwtToken(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", true) - assert.Error(t, err) -} - -func TestOAuthLoginUrl(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - } - url, err := ac.OAuthLoginUrl(context.Background(), "google") - assert.NoError(t, err) - assert.NotEmpty(t, url) -} - -func TestOAuthLoginCallback(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - - // Create a mock JWT token - mockToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20iLCJsZWdhY3lVc2VySUQiOjEyMzQ1LCJsZWdhY3lUb2tlbiI6InRlc3QtdG9rZW4ifQ.test" - - _, err := ac.OAuthLoginCallback(context.Background(), mockToken) - // This will fail because decodeJWT is not mocked, but demonstrates the test structure - assert.Error(t, err) -} - -func TestOAuthLoginCallback_InvalidToken(t *testing.T) { - settings.InitSettings(t.TempDir()) - t.Cleanup(settings.Reset) - - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - - _, err := ac.OAuthLoginCallback(context.Background(), "invalid-token") - assert.Error(t, err) - assert.Contains(t, err.Error(), "error decoding JWT") -} - -// Mock implementation of AuthClient for testing purposes -type mockAuthClient struct { - cache map[string]string - salt map[string][]byte - verifier []byte -} - -func mockAuthClientNew(t *testing.T, email, password string) *mockAuthClient { - salt, err := generateSalt() - require.NoError(t, err) - - encKey, err := generateEncryptedKey(password, email, salt) - require.NoError(t, err) - - srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - verifierKey, err := srpClient.Verifier() - require.NoError(t, err) - - m := &mockAuthClient{ - salt: map[string][]byte{email: salt}, - verifier: verifierKey.Bytes(), - cache: make(map[string]string), - } - return m -} - -func (m *mockAuthClient) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { - return []byte("salt"), &protos.SignupResponse{}, nil -} - -func (m *mockAuthClient) SignupEmailResendCode(ctx context.Context, req *protos.SignupEmailResendRequest) error { - return nil -} - -func (m *mockAuthClient) SignupEmailConfirmation(ctx context.Context, req *protos.ConfirmSignupRequest) error { - return nil -} - -func (m *mockAuthClient) GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { - return &protos.GetSaltResponse{Salt: []byte("salt")}, nil -} - -func (m *mockAuthClient) Login(ctx context.Context, email, password, deviceId string, salt []byte) (*protos.LoginResponse, error) { - return &protos.LoginResponse{ - LegacyUserData: &protos.LoginResponse_UserData{ - DeviceID: "deviceId", - }, - }, nil -} - -func (m *mockAuthClient) SignOut(ctx context.Context, req *protos.LogoutRequest) error { - return nil -} - -func (m *mockAuthClient) StartRecoveryByEmail(ctx context.Context, req *protos.StartRecoveryByEmailRequest) error { - return nil -} - -func (m *mockAuthClient) CompleteRecoveryByEmail(ctx context.Context, req *protos.CompleteRecoveryByEmailRequest) error { - return nil -} - -func (m *mockAuthClient) ValidateEmailRecoveryCode(ctx context.Context, req *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) { - return &protos.ValidateRecoveryCodeResponse{Valid: true}, nil -} - -func (m *mockAuthClient) ChangeEmail(ctx context.Context, req *protos.ChangeEmailRequest) error { - return nil -} - -func (m *mockAuthClient) CompleteChangeEmail(ctx context.Context, req *protos.CompleteChangeEmailRequest) error { - return nil -} - -func (m *mockAuthClient) DeleteAccount(ctx context.Context, req *protos.DeleteUserRequest) error { - return nil -} - -func (m *mockAuthClient) LoginPrepare(ctx context.Context, req *protos.PrepareRequest) (*protos.PrepareResponse, error) { - A := big.NewInt(0).SetBytes(req.A) - verifier := big.NewInt(0).SetBytes(m.verifier) - - server := srp.NewSRPServer(srp.KnownGroups[srp.RFC5054Group3072], verifier, nil) - if err := server.SetOthersPublic(A); err != nil { - return nil, err - } - B := server.EphemeralPublic() - if B == nil { - return nil, errors.New("cannot generate B") - } - if _, err := server.Key(); err != nil { - return nil, errors.New("cannot generate key") - } - proof, err := server.M(m.salt[req.Email], req.Email) - if err != nil { - return nil, errors.New("cannot generate Proof") - } - state, err := server.MarshalBinary() - if err != nil { - return nil, err - } - m.cache[req.Email] = hex.EncodeToString(state) - return &protos.PrepareResponse{B: B.Bytes(), Proof: proof}, nil -} diff --git a/api/webclient.go b/api/webclient.go deleted file mode 100644 index 2d00213e..00000000 --- a/api/webclient.go +++ /dev/null @@ -1,148 +0,0 @@ -package api - -import ( - "bytes" - "context" - "encoding/json" - "log/slog" - "unicode" - "unicode/utf8" - - "fmt" - "net/http" - - "github.com/go-resty/resty/v2" - "google.golang.org/protobuf/proto" - - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/env" -) - -type webClient struct { - client *resty.Client -} - -func newWebClient(httpClient *http.Client, baseURL string) *webClient { - if httpClient == nil { - httpClient = &http.Client{ - Timeout: common.DefaultHTTPTimeout, - } - } - client := resty.NewWithClient(httpClient) - if baseURL != "" { - client.SetBaseURL(baseURL) - } - - client.SetHeaders(map[string]string{ - backend.AppNameHeader: common.Name, - backend.VersionHeader: common.Version, - backend.PlatformHeader: common.Platform, - }) - - // Add a request middleware to marshal the request body to protobuf or JSON - client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { - if req.Body == nil { - return nil - } - if pb, ok := req.Body.(proto.Message); ok { - data, err := proto.Marshal(pb) - if err != nil { - return err - } - req.Body = data - req.Header.Set("Content-Type", "application/x-protobuf") - req.Header.Set("Accept", "application/x-protobuf") - } else { - data, err := json.Marshal(req.Body) - if err != nil { - return err - } - req.Body = data - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - } - - return nil - }) - - // Add a response middleware to unmarshal the response body from protobuf or JSON - client.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { - if len(resp.Body()) == 0 || resp.Request.Result == nil { - return nil - } - switch ct := resp.RawResponse.Header.Get("Content-Type"); ct { - case "application/x-protobuf": - pb, ok := resp.Request.Result.(proto.Message) - if !ok { - return fmt.Errorf("response body is not a protobuf message") - } - return proto.Unmarshal(resp.Body(), pb) - case "application/json": - body := sanitizeResponseBody(resp.Body()) - return json.Unmarshal(body, resp.Request.Result) - } - return nil - }) - return &webClient{client: client} -} - -func (wc *webClient) NewRequest(queryParams, headers map[string]string, body any) *resty.Request { - req := wc.client.NewRequest().SetQueryParams(queryParams).SetHeaders(headers).SetBody(body) - if curl, _ := env.Get[bool](env.PrintCurl); curl { - req = req.SetDebug(true).EnableGenerateCurlOnDebug() - } - return req -} - -func (wc *webClient) Get(ctx context.Context, path string, req *resty.Request, res any) error { - return wc.send(ctx, resty.MethodGet, path, req, res) -} - -func (wc *webClient) Post(ctx context.Context, path string, req *resty.Request, res any) error { - return wc.send(ctx, resty.MethodPost, path, req, res) -} - -func (wc *webClient) send(ctx context.Context, method, path string, req *resty.Request, res any) error { - if req == nil { - req = wc.client.NewRequest() - } - req.SetContext(ctx) - if res != nil { - req.SetResult(res) - } - - resp, err := req.Execute(method, path) - if err != nil { - return fmt.Errorf("error sending request: %w", err) - } - // print curl command for debugging - slog.Debug("CURL command", "curl", req.GenerateCurlCommand()) - if resp.StatusCode() < 200 || resp.StatusCode() >= 300 { - sanitizedBody := sanitizeResponseBody(resp.Body()) - slog.Debug("error sending request", "path", path, "status", resp.StatusCode(), "body", string(sanitizedBody)) - return fmt.Errorf("unexpected status %v body %s ", resp.StatusCode(), sanitizedBody) - } - return nil -} - -func sanitizeResponseBody(data []byte) []byte { - var out bytes.Buffer - r := bytes.NewReader(data) - for { - ch, size, err := r.ReadRune() - if err != nil { - break - } - // Skip invalid UTF-8 sequences - if ch == utf8.RuneError && size == 1 { - continue - } - // Skip control characters (optional) - if unicode.IsControl(ch) && ch != '\n' && ch != '\r' && ch != '\t' { - continue - } - out.WriteRune(ch) - } - return out.Bytes() -} diff --git a/backend/radiance.go b/backend/radiance.go new file mode 100644 index 00000000..0ccf93b0 --- /dev/null +++ b/backend/radiance.go @@ -0,0 +1,820 @@ +// Package backend provides the main interface for all the major components of Radiance. +package backend + +import ( + "context" + "errors" + "fmt" + "log/slog" + "maps" + "path/filepath" + "reflect" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Xuanwo/go-locale" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + C "github.com/getlantern/common" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/deviceid" + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/issue" + "github.com/getlantern/radiance/kindling" + "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/telemetry" + "github.com/getlantern/radiance/traces" + "github.com/getlantern/radiance/vpn" +) + +const tracerName = "github.com/getlantern/backend" + +// LocalBackend ties all the core functionality of Radiance together. It manages the configuration, +// servers, VPN connection, account management, issue reporting, and telemetry for the application. +type LocalBackend struct { + ctx context.Context + confHandler *config.ConfigHandler + issueReporter *issue.IssueReporter + accountClient *account.Client + + srvManager *servers.Manager + vpnClient *vpn.VPNClient + splitTunnelMgr *vpn.SplitTunnel + + shutdownFuncs []func() error + closeOnce sync.Once + stopChan chan struct{} + + deviceID string + + telemetryCfgSub atomic.Pointer[events.Subscription[config.NewConfigEvent]] + stopConnMetrics func() + connMetricsMu sync.Mutex + vpnStatusSub *events.Subscription[vpn.StatusUpdateEvent] +} + +type Options struct { + DataDir string + LogDir string + Locale string + LogLevel string + // this should be the platform device ID on mobile devices, desktop platforms will generate their + // own device ID and ignore this value + DeviceID string + // User choice for telemetry consent + TelemetryConsent bool + PlatformInterface vpn.PlatformInterface +} + +// NewLocalBackend performs global initialization and returns a new LocalBackend instance. +// It should be called once at the start of the application. +func NewLocalBackend(ctx context.Context, opts Options) (*LocalBackend, error) { + if err := common.Init(opts.DataDir, opts.LogDir, opts.LogLevel); err != nil { + return nil, fmt.Errorf("failed to initialize common components: %w", err) + } + if opts.Locale == "" { + if tag, err := locale.Detect(); err != nil { + opts.Locale = "en-US" + } else { + opts.Locale = tag.String() + } + } + + var platformDeviceID string + switch common.Platform { + case "ios", "android": + platformDeviceID = opts.DeviceID + default: + platformDeviceID = deviceid.Get() + } + + dataDir := settings.GetString(settings.DataPathKey) + disableFetch := env.GetBool(env.DisableFetch) + settings.Patch(settings.Settings{ + settings.LocaleKey: opts.Locale, + settings.DeviceIDKey: platformDeviceID, + settings.ConfigFetchDisabledKey: disableFetch, + settings.TelemetryKey: opts.TelemetryConsent, + }) + + kindling.SetKindling(kindling.NewKindling(dataDir)) + accountClient := account.NewClient(kindling.HTTPClient(), dataDir) + + svrMgr, err := servers.NewManager( + dataDir, slog.Default().With("service", "server_manager"), + ) + if err != nil { + return nil, fmt.Errorf("failed to create server manager: %w", err) + } + + splitTunnelMgr, err := vpn.NewSplitTunnelHandler( + dataDir, slog.Default().With("service", "split_tunnel"), + ) + if err != nil { + return nil, fmt.Errorf("failed to create split tunnel manager: %w", err) + } + + cOpts := config.Options{ + DataPath: dataDir, + Locale: opts.Locale, + AccountClient: accountClient, + HTTPClient: kindling.HTTPClient(), + Logger: slog.Default().With("service", "config_handler"), + } + if disableFetch { + cOpts.PollInterval = -1 + slog.Info("Config fetch disabled via environment variable", "env_var", env.DisableFetch) + } + + vpnClient := vpn.NewVPNClient(dataDir, slog.Default().With("service", "vpn"), opts.PlatformInterface) + r := &LocalBackend{ + ctx: ctx, + issueReporter: issue.NewIssueReporter(kindling.HTTPClient()), + accountClient: accountClient, + confHandler: config.NewConfigHandler(ctx, cOpts), + srvManager: svrMgr, + vpnClient: vpnClient, + splitTunnelMgr: splitTunnelMgr, + shutdownFuncs: []func() error{ + telemetry.Close, kindling.Close, vpnClient.Close, + }, + stopChan: make(chan struct{}), + closeOnce: sync.Once{}, + deviceID: platformDeviceID, + } + return r, nil +} + +func (r *LocalBackend) Start() { + // set country code in settings when new config is received so it can be included in issue reports + events.SubscribeOnce(func(evt config.NewConfigEvent) { + if evt.New != nil && evt.New.Country != "" { + if err := settings.Set(settings.CountryCodeKey, evt.New.Country); err != nil { + slog.Error("failed to set country code in settings", "error", err) + } + slog.Info("Set country code from config response", "country_code", evt.New.Country) + } + }) + // update VPN outbounds when new config is received + events.Subscribe(func(evt config.NewConfigEvent) { + if evt.New == nil { + return + } + cfg := evt.New + locs := make(map[string]C.ServerLocation, len(cfg.OutboundLocations)) + // Track which cities are already covered by active outbounds. + coveredCities := make(map[string]bool, len(cfg.OutboundLocations)) + for k, v := range cfg.OutboundLocations { + if v == nil { + slog.Warn("Server location is nil, skipping", "tag", k) + continue + } + locs[k] = *v + coveredCities[v.City+"|"+v.CountryCode] = true + } + // Include available server locations not already covered by active + // outbounds so the client's location picker shows every location. + for _, sl := range cfg.Servers { + if coveredCities[sl.City+"|"+sl.CountryCode] { + continue + } + key := strings.ToLower(strings.ReplaceAll(sl.City, " ", "-") + "-" + sl.CountryCode) + locs[key] = sl + } + opts := servers.Options{ + Outbounds: cfg.Options.Outbounds, + Endpoints: cfg.Options.Endpoints, + Locations: locs, + URLOverrides: cfg.BanditURLOverrides, + } + if len(cfg.BanditURLOverrides) > 0 { + // Create a marker span linked to the API's bandit trace so the + // config fetch appears in the same distributed trace as the callback. + if ctx, ok := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides); ok { + _, span := otel.Tracer(tracerName).Start(ctx, "radiance.config_received", + trace.WithAttributes( + attribute.Int("bandit.override_count", len(cfg.BanditURLOverrides)), + attribute.Int("bandit.outbound_count", len(cfg.Options.Outbounds)), + ), + ) + span.End() // point-in-time marker — config was received at this timestamp + } + } + if err := r.setServers(servers.SGLantern, opts); err != nil { + slog.Error("setting servers in manager", "error", err) + } + if err := r.RunOfflineURLTests(); err != nil { + slog.Error("Failed to run offline URL tests after config update", "error", err) + } + }) + r.confHandler.Start() +} + +// addShutdownFunc adds a shutdown function(s) to the Radiance instance. +// This function is called when the Radiance instance is closed to ensure that all +// resources are cleaned up properly. +func (r *LocalBackend) addShutdownFunc(fns ...func() error) { + for _, fn := range fns { + if fn != nil { + r.shutdownFuncs = append(r.shutdownFuncs, fn) + } + } +} + +func (r *LocalBackend) Close() { + r.closeOnce.Do(func() { + slog.Debug("Closing Radiance") + r.confHandler.Stop() + close(r.stopChan) + for _, shutdown := range r.shutdownFuncs { + if err := shutdown(); err != nil { + slog.Error("Failed to shutdown", "error", err) + } + } + }) + <-r.stopChan +} + +////////////////// +// Issue Report // +////////////////// + +// ReportIssue allows the user to report an issue with the application. It collects relevant +// information about the user's environment such as country, device ID, user ID, subscription level, +// and locale, and log files to include in the report. The additionalAttachments parameter allows +// the caller to include any extra files they want to attach to the issue report. +func (r *LocalBackend) ReportIssue(issueType issue.IssueType, description, email string, additionalAttachments []string) error { + ctx, span := otel.Tracer(tracerName).Start(context.Background(), "report_issue") + defer span.End() + // get country from the config returned by the backend + var country string + cfg, err := r.confHandler.GetConfig() + if err != nil { + slog.Warn("Failed to get config", "error", err) + } else { + country = cfg.Country + } + + report := issue.IssueReport{ + Type: issueType, + Description: description, + Email: email, + CountryCode: country, + DeviceID: r.deviceID, + UserID: settings.GetString(settings.UserIDKey), + SubscriptionLevel: settings.GetString(settings.UserLevelKey), + Locale: settings.GetString(settings.LocaleKey), + AdditionalAttachments: append(baseIssueAttachments(), additionalAttachments...), + } + err = r.issueReporter.Report(ctx, report) + if err != nil { + slog.Error("Failed to report issue", "error", err) + return traces.RecordError(ctx, fmt.Errorf("failed to report issue: %w", err)) + } + slog.Info("Issue reported successfully") + return nil +} + +// baseIssueAttachments returns a list of file paths to include as attachments in every issue report +// in order of importance. +func baseIssueAttachments() []string { + logPath := settings.GetString(settings.LogPathKey) + dataPath := settings.GetString(settings.DataPathKey) + // TODO: any other files we want to include?? split-tunnel config? + return []string{ + filepath.Join(logPath, internal.CrashLogFileName), + filepath.Join(dataPath, internal.ConfigFileName), + filepath.Join(dataPath, internal.ServersFileName), + filepath.Join(dataPath, internal.DebugBoxOptionsFileName), + } +} + +///////////////// +// Settings // +///////////////// + +// Features returns the features available in the current configuration, returned from the server in the +// config response. +func (r *LocalBackend) Features() map[string]bool { + _, span := otel.Tracer(tracerName).Start(context.Background(), "features") + defer span.End() + cfg, err := r.confHandler.GetConfig() + if err != nil { + slog.Info("Failed to get config for features", "error", err) + return map[string]bool{} + } + if cfg == nil { + slog.Info("No config available for features, returning empty map") + return map[string]bool{} + } + slog.Debug("Returning features from config", "features", cfg.Features) + // Return the features from the config + if cfg.Features == nil { + slog.Info("No features available in config, returning empty map") + return map[string]bool{} + } + return cfg.Features +} + +func (r *LocalBackend) PatchSettings(updates settings.Settings) error { + curr := settings.GetAllFor(slices.Collect(maps.Keys(updates))...) + diff := updates.Diff(curr) + slog.Log(nil, log.LevelTrace, "Patching settings", "updates", updates, "current", curr, "diff", diff) + if len(diff) == 0 { + return nil + } + if err := settings.Patch(diff); err != nil { + return fmt.Errorf("failed to update settings: %w", err) + } + // telemetry settings + if _, ok := diff[settings.TelemetryKey]; ok { + if settings.GetBool(settings.TelemetryKey) { + if err := r.startTelemetry(); err != nil { + slog.Error("Failed to start telemetry", "error", err) + } + } else { + r.stopTelemetry() + } + } + + // vpn settings + k := settings.SplitTunnelKey + if _, ok := diff[k]; ok { + r.splitTunnelMgr.SetEnabled(settings.GetBool(k)) + } + r.maybeRestartVPN(diff) + + return nil +} + +// maybeRestartVPN restarts the VPN connection if either the ad block or smart routing settings +// were changed and the VPN is currently connected. +func (r *LocalBackend) maybeRestartVPN(updates settings.Settings) { + _, adBlockChanged := updates[settings.AdBlockKey] + _, smartRoutingChanged := updates[settings.SmartRoutingKey] + if (adBlockChanged || smartRoutingChanged) && r.vpnClient.Status() == vpn.Connected { + bOptions := r.getBoxOptions() + go r.vpnClient.Restart(bOptions) + } +} + +///////////////// +// telemetry // +///////////////// + +func (r *LocalBackend) startTelemetry() error { + cfg, err := r.confHandler.GetConfig() + if err == nil { + if err := telemetry.Initialize(r.deviceID, *cfg, settings.IsPro()); err != nil { + return fmt.Errorf("failed to initialize telemetry: %w", err) + } + } + if r.telemetryCfgSub.Load() != nil { + return nil + } + // subscribe to config changes to update telemetry config + sub := events.Subscribe(func(evt config.NewConfigEvent) { + if !settings.GetBool(settings.TelemetryKey) { + return + } + if evt.Old != nil && reflect.DeepEqual(evt.Old.OTEL, evt.New.OTEL) { + // no changes to telemetry config, no need to update + return + } + if err := telemetry.Initialize(r.deviceID, *evt.New, settings.IsPro()); err != nil { + slog.Error("Failed to update telemetry config", "error", err) + } + }) + r.telemetryCfgSub.Store(sub) + + // subscribe to VPN status events to start/stop connection metrics collection + r.vpnStatusSub = events.Subscribe(func(evt vpn.StatusUpdateEvent) { + r.updateConnMetrics(evt.Status) + }) + return nil +} + +func (r *LocalBackend) stopTelemetry() { + if sub := r.telemetryCfgSub.Swap(nil); sub != nil { + sub.Unsubscribe() + } + if r.vpnStatusSub != nil { + r.vpnStatusSub.Unsubscribe() + r.vpnStatusSub = nil + } + r.stopConnMetricsIfRunning() + telemetry.Close() +} + +// updateConnMetrics starts or stops connection metrics collection based on VPN status. +// Metrics are only collected when the VPN is connected and telemetry is enabled. +func (r *LocalBackend) updateConnMetrics(status vpn.VPNStatus) { + if status == vpn.Connected { + r.startConnMetrics() + } else { + r.stopConnMetricsIfRunning() + } +} + +func (r *LocalBackend) startConnMetrics() { + r.connMetricsMu.Lock() + defer r.connMetricsMu.Unlock() + if r.stopConnMetrics != nil { + return // already running + } + r.stopConnMetrics = telemetry.StartConnectionMetrics(r.ctx, r.vpnClient, 1*time.Minute) + slog.Debug("Started connection metrics collection") +} + +func (r *LocalBackend) stopConnMetricsIfRunning() { + r.connMetricsMu.Lock() + defer r.connMetricsMu.Unlock() + if r.stopConnMetrics != nil { + r.stopConnMetrics() + r.stopConnMetrics = nil + slog.Debug("Stopped connection metrics collection") + } +} + +/////////////////////// +// Server management // +/////////////////////// + +func (r *LocalBackend) Servers() servers.Servers { + return r.srvManager.Servers() +} + +func (r *LocalBackend) GetServerByTag(tag string) (servers.Server, bool) { + return r.srvManager.GetServerByTag(tag) +} + +func (r *LocalBackend) AddServers(group servers.ServerGroup, options servers.Options) error { + if err := r.srvManager.AddServers(group, options, true); err != nil { + return fmt.Errorf("failed to add servers to ServerManager: %w", err) + } + if err := r.vpnClient.AddOutbounds(group, options); err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + return fmt.Errorf("failed to add outbounds to VPN client: %w", err) + } + return nil +} + +func (r *LocalBackend) RemoveServers(tags []string) error { + removed, err := r.srvManager.RemoveServers(tags) + if err != nil { + return fmt.Errorf("failed to remove servers from ServerManager: %w", err) + } + servers := make(map[string][]string) + for _, srv := range removed { + servers[srv.Group] = append(servers[srv.Group], srv.Tag) + } + for group, tags := range servers { + if err := r.vpnClient.RemoveOutbounds(group, tags); err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + return fmt.Errorf("failed to remove outbounds from VPN client: %w", err) + } + } + return nil +} + +func (r *LocalBackend) setServers(group servers.ServerGroup, options servers.Options) error { + if err := r.srvManager.SetServers(group, options); err != nil { + return fmt.Errorf("failed to set servers in ServerManager: %w", err) + } + err := r.vpnClient.UpdateOutbounds(group, options) + if err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + slog.Error("Failed to update VPN outbounds after config change", "error", err) + } + return nil +} + +func (r *LocalBackend) AddServersByJSON(config string) error { + return r.srvManager.AddServersByJSON(context.Background(), []byte(config)) +} + +func (r *LocalBackend) AddServersByURL(urls []string, skipCertVerification bool) error { + return r.srvManager.AddServersByURL(context.Background(), urls, skipCertVerification) +} + +func (r *LocalBackend) AddPrivateServer(tag, ip string, port int, accessToken string, loc C.ServerLocation, joined bool) error { + return r.srvManager.AddPrivateServer(tag, ip, port, accessToken, loc, joined) +} + +func (r *LocalBackend) InviteToPrivateServer(ip string, port int, accessToken string, inviteName string) (string, error) { + return r.srvManager.InviteToPrivateServer(ip, port, accessToken, inviteName) +} + +func (r *LocalBackend) RevokePrivateServerInvite(ip string, port int, accessToken string, inviteName string) error { + return r.srvManager.RevokePrivateServerInvite(ip, port, accessToken, inviteName) +} + +///////////////// +// VPN // +///////////////// + +func (r *LocalBackend) VPNStatus() vpn.VPNStatus { + return r.vpnClient.Status() +} + +func (r *LocalBackend) ConnectVPN(tag string) error { + if tag == "" { + tag = vpn.AutoSelectTag + } + if tag != vpn.AutoSelectTag { + if _, found := r.srvManager.GetServerByTag(tag); !found { + return fmt.Errorf("no server found with tag %s", tag) + } + } + bOptions := r.getBoxOptions() + if err := r.vpnClient.Connect(bOptions); err != nil { + return fmt.Errorf("failed to connect VPN: %w", err) + } + if err := r.selectServer(tag); err != nil { + return fmt.Errorf("failed to select server: %w", err) + } + return nil +} + +func (r *LocalBackend) getBoxOptions() vpn.BoxOptions { + // ignore error, we can still connect with default options if config is not available for some reason + cfg, _ := r.confHandler.GetConfig() + bOptions := vpn.BoxOptions{ + BasePath: settings.GetString(settings.DataPathKey), + } + if cfg != nil { + bOptions.Options = cfg.Options + bOptions.BanditURLOverrides = cfg.BanditURLOverrides + bOptions.BanditThroughputURL = cfg.BanditThroughputURL + if settings.GetBool(settings.SmartRoutingKey) { + bOptions.SmartRouting = cfg.SmartRouting + } + if settings.GetBool(settings.AdBlockKey) { + bOptions.AdBlock = cfg.AdBlock + } + } + if userServers, ok := r.srvManager.Servers()[servers.SGUser]; ok { + bOptions.Options.Outbounds = append(bOptions.Options.Outbounds, userServers.Outbounds...) + bOptions.Options.Endpoints = append(bOptions.Options.Endpoints, userServers.Endpoints...) + } + return bOptions +} + +func (r *LocalBackend) DisconnectVPN() error { + return r.vpnClient.Disconnect() +} + +func (r *LocalBackend) RestartVPN() error { + bOptions := r.getBoxOptions() + return r.vpnClient.Restart(bOptions) +} + +func (r *LocalBackend) SelectServer(tag string) error { + return r.selectServer(tag) +} + +func (r *LocalBackend) selectServer(tag string) error { + if err := r.vpnClient.SelectServer(tag); err != nil { + return fmt.Errorf("failed to select server: %w", err) + } + if tag == vpn.AutoSelectTag { + err := settings.Patch(settings.Settings{ + settings.AutoConnectKey: true, + settings.SelectedServerKey: nil, + }) + if err != nil { + slog.Warn("failed to update settings", "error", err) + } + return nil + } + + server, found := r.srvManager.GetServerByTag(tag) + if !found { // sanity check, the vpn should have errored if this were the case + return fmt.Errorf("no server found with tag %s", tag) + } + server.Options = nil + err := settings.Patch(settings.Settings{ + settings.AutoConnectKey: false, + settings.SelectedServerKey: server, + }) + if err != nil { + slog.Warn("Failed to save selected server in settings", "error", err) + } + slog.Info("Selected server", "tag", tag, "type", server.Type) + return nil +} + +// Connections returns a list of all connections, both active and recently closed. If there are no +// connections and the tunnel is open, an empty slice is returned without an error. +func (r *LocalBackend) VPNConnections() ([]vpn.Connection, error) { + return r.vpnClient.Connections() +} + +// ActiveConnections returns a list of currently active connections, ordered from newest to oldest. +func (r *LocalBackend) ActiveVPNConnections() ([]vpn.Connection, error) { + connections, err := r.vpnClient.Connections() + if err != nil { + return nil, fmt.Errorf("failed to get VPN connections: %w", err) + } + connections = slices.DeleteFunc(connections, func(conn vpn.Connection) bool { + return conn.ClosedAt != 0 + }) + slices.SortFunc(connections, func(a, b vpn.Connection) int { + return int(b.CreatedAt - a.CreatedAt) + }) + return connections, nil +} + +// TODO: handle case where selected server is no longer available (e.g. removed from manager) more +// gracefully, currently we just return that the server is no longer available but maybe we should +// also clear the selected server from settings and select a new server in the VPN client. +// should we not remove a lantern server if it's currently selected in the VPN client and instead +// mark it as unavailable in the manager until it's no longer selected in the VPN client? + +// SelectedServer returns the currently selected server and whether the server is still available. +// The server may no longer be available if it was removed from the manager since it was selected. +func (r *LocalBackend) SelectedServer() (servers.Server, bool, error) { + if !settings.Exists(settings.SelectedServerKey) { + return servers.Server{}, false, fmt.Errorf("no selected server") + } + var selected servers.Server + if err := settings.GetStruct(settings.SelectedServerKey, &selected); err != nil { + return servers.Server{}, false, fmt.Errorf("failed to get selected server from settings: %w", err) + } + server, found := r.srvManager.GetServerByTag(selected.Tag) + stillExists := found && + server.Group == selected.Group && + server.Type == selected.Type && + server.Location == selected.Location + return selected, stillExists, nil +} + +// CurrentAutoSelectedServer returns the tag of the server that is currently auto-selected by the +// VPN client. +func (r *LocalBackend) CurrentAutoSelectedServer() (string, error) { + return r.vpnClient.CurrentAutoSelectedServer() +} + +// StartAutoSelectionsListener starts polling for auto-selection changes and emitting events. +func (r *LocalBackend) StartAutoSelectedListener() { + r.vpnClient.AutoSelectedChangeListener(r.ctx) +} + +func (r *LocalBackend) RunOfflineURLTests() error { + cfg, err := r.confHandler.GetConfig() + if err != nil { + return fmt.Errorf("no config available: %w", err) + } + return r.vpnClient.RunOfflineURLTests( + settings.GetString(settings.DataPathKey), + cfg.Options.Outbounds, + cfg.BanditURLOverrides, + ) +} + +////////////////// +// Split Tunnel // +///////////////// + +func (r *LocalBackend) SplitTunnelFilters() vpn.SplitTunnelFilter { + return r.splitTunnelMgr.Filters() +} + +func (r *LocalBackend) AddSplitTunnelItems(items vpn.SplitTunnelFilter) error { + return r.splitTunnelMgr.AddItems(items) +} + +func (r *LocalBackend) RemoveSplitTunnelItems(items vpn.SplitTunnelFilter) error { + return r.splitTunnelMgr.RemoveItems(items) +} + +///////////// +// Account // +///////////// + +func (r *LocalBackend) NewUser(ctx context.Context) (*account.UserData, error) { + return r.accountClient.NewUser(ctx) +} + +func (r *LocalBackend) Login(ctx context.Context, email, password string) (*account.UserData, error) { + return r.accountClient.Login(ctx, email, password) +} + +func (r *LocalBackend) Logout(ctx context.Context, email string) (*account.UserData, error) { + return r.accountClient.Logout(ctx, email) +} + +func (r *LocalBackend) FetchUserData(ctx context.Context) (*account.UserData, error) { + return r.accountClient.FetchUserData(ctx) +} + +func (r *LocalBackend) StartChangeEmail(ctx context.Context, newEmail, password string) error { + return r.accountClient.StartChangeEmail(ctx, newEmail, password) +} + +func (r *LocalBackend) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + return r.accountClient.CompleteChangeEmail(ctx, newEmail, password, code) +} + +func (r *LocalBackend) StartRecoveryByEmail(ctx context.Context, email string) error { + return r.accountClient.StartRecoveryByEmail(ctx, email) +} + +func (r *LocalBackend) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + return r.accountClient.CompleteRecoveryByEmail(ctx, email, newPassword, code) +} + +func (r *LocalBackend) DeleteAccount(ctx context.Context, email, password string) (*account.UserData, error) { + return r.accountClient.DeleteAccount(ctx, email, password) +} + +func (r *LocalBackend) SignUp(ctx context.Context, email, password string) ([]byte, *account.SignupResponse, error) { + return r.accountClient.SignUp(ctx, email, password) +} + +func (r *LocalBackend) SignupEmailConfirmation(ctx context.Context, email, code string) error { + return r.accountClient.SignupEmailConfirmation(ctx, email, code) +} + +func (r *LocalBackend) SignupEmailResendCode(ctx context.Context, email string) error { + return r.accountClient.SignupEmailResendCode(ctx, email) +} + +func (r *LocalBackend) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + return r.accountClient.ValidateEmailRecoveryCode(ctx, email, code) +} + +func (r *LocalBackend) DataCapInfo(ctx context.Context) (*account.DataCapInfo, error) { + return r.accountClient.DataCapInfo(ctx) +} + +func (r *LocalBackend) RemoveDevice(ctx context.Context, deviceID string) (*account.LinkResponse, error) { + return r.accountClient.RemoveDevice(ctx, deviceID) +} + +func (r *LocalBackend) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*account.UserData, error) { + return r.accountClient.OAuthLoginCallback(ctx, oAuthToken) +} + +func (r *LocalBackend) OAuthLoginUrl(ctx context.Context, provider string) (string, error) { + return r.accountClient.OAuthLoginURL(ctx, provider) +} + +func (r *LocalBackend) UserDevices() ([]settings.Device, error) { + return settings.Devices() +} + +func (r *LocalBackend) UserData() (*account.UserData, error) { + var userData account.UserData + if err := settings.GetStruct(settings.UserDataKey, &userData); err != nil { + return nil, fmt.Errorf("failed to get user data from settings: %w", err) + } + return &userData, nil +} + +/////////////////// +// Subscriptions // +/////////////////// + +func (r *LocalBackend) ActivationCode(ctx context.Context, email, resellerCode string) (*account.PurchaseResponse, error) { + return r.accountClient.ActivationCode(ctx, email, resellerCode) +} + +func (r *LocalBackend) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { + return r.accountClient.NewStripeSubscription(ctx, email, planID) +} + +func (r *LocalBackend) PaymentRedirect(ctx context.Context, data account.PaymentRedirectData) (string, error) { + return r.accountClient.PaymentRedirect(ctx, data) +} + +func (r *LocalBackend) ReferralAttach(ctx context.Context, code string) (bool, error) { + return r.accountClient.ReferralAttach(ctx, code) +} + +func (r *LocalBackend) StripeBillingPortalURL(ctx context.Context) (string, error) { + return r.accountClient.StripeBillingPortalURL(ctx, + common.GetProServerURL(), settings.GetString(settings.UserIDKey), settings.GetString(settings.TokenKey), + ) +} + +func (r *LocalBackend) SubscriptionPaymentRedirectURL(ctx context.Context, data account.PaymentRedirectData) (string, error) { + return r.accountClient.SubscriptionPaymentRedirectURL(ctx, data) +} + +func (r *LocalBackend) SubscriptionPlans(ctx context.Context, channel string) (string, error) { + return r.accountClient.SubscriptionPlans(ctx, channel) +} + +func (r *LocalBackend) VerifySubscription(ctx context.Context, service account.SubscriptionService, data map[string]string) (string, error) { + return r.accountClient.VerifySubscription(ctx, service, data) +} diff --git a/backend/radiance_test.go b/backend/radiance_test.go new file mode 100644 index 00000000..dd6eaa62 --- /dev/null +++ b/backend/radiance_test.go @@ -0,0 +1,8 @@ +package backend + +import ( + "testing" +) + +func TestBackend(t *testing.T) { +} diff --git a/cmd/Makefile b/cmd/Makefile index e39104dd..599f5d91 100644 --- a/cmd/Makefile +++ b/cmd/Makefile @@ -1,7 +1,33 @@ -TAGS=with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale +TAGS=with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_conntrack +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + TAGS := standalone,$(TAGS) +endif + +ifeq ($(OS),Windows_NT) + LANTERND := lanternd.exe + LANTERN := lantern.exe +else + LANTERND := lanternd + LANTERN := lantern +endif + +.PHONY: build-daemon build-daemon: - go build -tags "$(TAGS)" -o ../bin/lanternd ./lanternd/lanternd.go + go build -tags "$(TAGS)" -o ../bin/$(LANTERND) ./lanternd +.PHONY: run-daemon run-daemon: - go run -tags=$(TAGS) ./lanternd/lanternd.go $(args) + go run -tags=$(TAGS) ./lanternd run \ + $(if $(data-path),--data-path=$(data-path)) \ + $(if $(log-path),--log-path=$(log-path)) \ + $(if $(log-level),--log-level=$(log-level)) + +.PHONY: build-cli +build-cli: +ifeq ($(UNAME_S),Darwin) + go build -tags "standalone" -o ../bin/$(LANTERN) ./lantern +else + go build -o ../bin/$(LANTERN) ./lantern +endif diff --git a/cmd/justfile b/cmd/justfile new file mode 100644 index 00000000..a4a4a1e2 --- /dev/null +++ b/cmd/justfile @@ -0,0 +1,15 @@ +base_tags := "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_conntrack" +tags := if os() == "macos" { "standalone," + base_tags } else { base_tags } +lanternd := if os() == "windows" { "lanternd.exe" } else { "lanternd" } +lantern := if os() == "windows" { "lantern.exe" } else { "lantern" } + +build-daemon: + go build -tags "{{tags}}" -o ../bin/{{lanternd}} ./lanternd + +run-daemon *args: + go run -tags={{tags}} ./lanternd run {{args}} + +cli_tags := if os() == "macos" { "standalone" } else { "" } + +build-cli: + go build {{ if cli_tags != "" { "-tags " + cli_tags } else { "" } }} -o ../bin/{{lantern}} ./lantern diff --git a/cmd/kindling-tester/main.go b/cmd/kindling-tester/main.go index 39675f31..e910e4fa 100644 --- a/cmd/kindling-tester/main.go +++ b/cmd/kindling-tester/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "io" "log/slog" @@ -13,7 +12,7 @@ import ( "github.com/getlantern/radiance/kindling" ) -func performKindlingPing(ctx context.Context, urlToHit string, runID string, deviceID string, userID int64, token string, dataDir string) error { +func performKindlingPing(urlToHit string, runID string, deviceID string, userID int64, token string, dataDir string) error { os.MkdirAll(dataDir, 0o755) settings.Set(settings.DataPathKey, dataDir) settings.Set(settings.UserIDKey, userID) @@ -28,8 +27,8 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev }) t1 := time.Now() - kindling.SetKindling(kindling.NewKindling()) - defer kindling.Close(ctx) + kindling.SetKindling(kindling.NewKindling(dataDir)) + defer kindling.Close() cli := kindling.HTTPClient() t2 := time.Now() @@ -55,7 +54,7 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev if err := os.WriteFile(dataDir+"/output.txt", responseBody, 0o644); err != nil { slog.Error("failed to write output file", slog.Any("error", err), slog.String("path", dataDir+"/output.txt")) } - return os.WriteFile(dataDir+"/timing.txt", []byte(fmt.Sprintf(` + return os.WriteFile(dataDir+"/timing.txt", fmt.Appendf([]byte{}, ` result: %v run-id: %s err: %v @@ -63,7 +62,7 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev connected: %d fetched: %d url: %s`, - true, runID, nil, t1, int32(t2.Sub(t1).Milliseconds()), int32(t3.Sub(t1).Milliseconds()), urlToHit)), 0o644) + true, runID, nil, t1, int32(t2.Sub(t1).Milliseconds()), int32(t3.Sub(t1).Milliseconds()), urlToHit), 0o644) } func main() { @@ -95,8 +94,6 @@ func main() { } } - ctx := context.Background() - // disabling all other transports before enabling the selected for name := range kindling.EnabledTransports { kindling.EnabledTransports[name] = false @@ -104,7 +101,7 @@ func main() { kindling.EnabledTransports[transport] = true slog.Debug("enabled transports", slog.Any("enabled_transports", kindling.EnabledTransports)) - if err := performKindlingPing(ctx, targetURL, runID, deviceID, uid, token, data); err != nil { + if err := performKindlingPing(targetURL, runID, deviceID, uid, token, data); err != nil { slog.Error("failed to perform kindling ping", slog.Any("error", err)) os.Exit(1) } diff --git a/cmd/lantern/account.go b/cmd/lantern/account.go new file mode 100644 index 00000000..2c0fe63f --- /dev/null +++ b/cmd/lantern/account.go @@ -0,0 +1,328 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + "syscall" + + "golang.org/x/term" + + "github.com/getlantern/radiance/ipc" +) + +type AccountCmd struct { + Login *LoginCmd `arg:"subcommand:login" help:"log in to your account"` + Logout *LogoutCmd `arg:"subcommand:logout" help:"log out of your account"` + Signup *SignupCmd `arg:"subcommand:signup" help:"create a new account"` + Recover *RecoverAccountCmd `arg:"subcommand:recover" help:"recover existing account"` + + Usage *UsageCmd `arg:"subcommand:usage" help:"view data usage"` + Devices *DevicesCmd `arg:"subcommand:devices" help:"manage user devices"` + SetEmail *SetEmailCmd `arg:"subcommand:set-email" help:"change account email"` +} + +type LoginCmd struct { + OAuth bool `arg:"--oauth" help:"log in with OAuth provider"` + Provider string `arg:"--provider" help:"OAuth provider"` +} + +type LogoutCmd struct{} + +type SignupCmd struct{} + +type RecoverAccountCmd struct{} + +type SetEmailCmd struct{} + +type UsageCmd struct{} + +type DevicesCmd struct { + List bool `arg:"--list" help:"list user devices"` + Remove string `arg:"--remove" help:"remove a device by ID"` +} + +func runAccount(ctx context.Context, c *ipc.Client, cmd *AccountCmd) error { + switch { + case cmd.Login != nil: + return accountLogin(ctx, c, cmd.Login) + case cmd.Logout != nil: + return accountLogout(ctx, c) + case cmd.Signup != nil: + return accountSignup(ctx, c) + case cmd.Recover != nil: + return accountRecover(ctx, c) + case cmd.Usage != nil: + return accountDataUsage(ctx, c) + case cmd.Devices != nil: + return accountDevices(ctx, c, cmd.Devices) + case cmd.SetEmail != nil: + return accountSetEmail(ctx, c) + default: + return fmt.Errorf("no subcommand specified") + } +} + +// isLoggedIn returns the current user's email if logged in, or empty string if not. +func isLoggedIn(ctx context.Context, c *ipc.Client) (string, error) { + userData, err := c.UserData(ctx) + if err != nil { + return "", err + } + return userData.GetLegacyUserData().GetEmail(), nil +} + +func requireLoggedOut(ctx context.Context, c *ipc.Client) error { + email, err := isLoggedIn(ctx, c) + if err != nil { + return fmt.Errorf("failed to check login status: %w", err) + } + if email != "" { + return fmt.Errorf("already logged in as %s — log out first", email) + } + return nil +} + +func requireLoggedIn(ctx context.Context, c *ipc.Client) (string, error) { + email, err := isLoggedIn(ctx, c) + if err != nil { + return "", fmt.Errorf("failed to check login status: %w", err) + } + if email == "" { + return "", fmt.Errorf("no user is currently logged in") + } + return email, nil +} + +func accountLogin(ctx context.Context, c *ipc.Client, cmd *LoginCmd) error { + if err := requireLoggedOut(ctx, c); err != nil { + return err + } + + if cmd.OAuth { + provider := cmd.Provider + if provider == "" { + provider = "google" + } + url, err := c.OAuthLoginUrl(ctx, provider) + if err != nil { + return err + } + fmt.Println("Open this URL in your browser to log in:") + fmt.Println(url) + fmt.Print("Enter OAuth token: ") + token, err := readLine() + if err != nil { + return err + } + userData, err := c.OAuthLoginCallback(ctx, token) + if err != nil { + return err + } + return printJSON(userData) + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + + userData, err := c.Login(ctx, email, password) + if err != nil { + return err + } + fmt.Println("Logged in successfully.") + return printJSON(userData) +} + +func accountLogout(ctx context.Context, c *ipc.Client) error { + email, err := requireLoggedIn(ctx, c) + if err != nil { + return err + } + _, err = c.Logout(ctx, email) + if err != nil { + return err + } + fmt.Println("Logged out successfully.") + return nil +} + +func accountSignup(ctx context.Context, c *ipc.Client) error { + if err := requireLoggedOut(ctx, c); err != nil { + return err + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + confirm, err := promptPassword("Confirm password: ") + if err != nil { + return err + } + if password != confirm { + return fmt.Errorf("passwords do not match") + } + + _, resp, err := c.SignUp(ctx, email, password) + if err != nil { + return err + } + fmt.Println("Account created successfully.") + + fmt.Println("A confirmation code has been sent to your email.") + code, err := prompt("Confirmation code: ") + if err != nil { + return err + } + if err := c.SignupEmailConfirmation(ctx, email, code); err != nil { + return fmt.Errorf("email confirmation failed: %w", err) + } + fmt.Println("Email confirmed.") + _ = resp + return nil +} + +func accountRecover(ctx context.Context, c *ipc.Client) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + + if err := c.StartRecoveryByEmail(ctx, email); err != nil { + return err + } + fmt.Println("A recovery code has been sent to your email.") + + code, err := prompt("Recovery code: ") + if err != nil { + return err + } + if err := c.ValidateEmailRecoveryCode(ctx, email, code); err != nil { + return fmt.Errorf("invalid recovery code: %w", err) + } + + newPassword, err := promptPassword("New password: ") + if err != nil { + return err + } + confirm, err := promptPassword("Confirm new password: ") + if err != nil { + return err + } + if newPassword != confirm { + return fmt.Errorf("passwords do not match") + } + + if err := c.CompleteRecoveryByEmail(ctx, email, newPassword, code); err != nil { + return err + } + fmt.Println("Account recovered successfully. You can now log in with your new password.") + return nil +} + +func accountSetEmail(ctx context.Context, c *ipc.Client) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + newEmail, err := prompt("New email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + + if err := c.StartChangeEmail(ctx, newEmail, password); err != nil { + return err + } + fmt.Println("A confirmation code has been sent to your new email.") + + code, err := prompt("Confirmation code: ") + if err != nil { + return err + } + if err := c.CompleteChangeEmail(ctx, newEmail, password, code); err != nil { + return err + } + fmt.Println("Email changed successfully.") + return nil +} + +func accountDataUsage(ctx context.Context, c *ipc.Client) error { + info, err := c.DataCapInfo(ctx) + if err != nil { + return err + } + fmt.Println(info) + return nil +} + +func accountDevices(ctx context.Context, c *ipc.Client, cmd *DevicesCmd) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + switch { + case cmd.Remove != "": + resp, err := c.RemoveDevice(ctx, cmd.Remove) + if err != nil { + return err + } + fmt.Println("Device removed.") + return printJSON(resp) + default: + // Default to listing devices + devices, err := c.UserDevices(ctx) + if err != nil { + return err + } + return printJSON(devices) + } +} + +// prompt prints a prompt and reads a line of input from stdin. +func prompt(label string) (string, error) { + fmt.Print(label) + return readLine() +} + +// readLine reads a single line from stdin, trimming the trailing newline. +func readLine() (string, error) { + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("unexpected end of input") + } + return strings.TrimSpace(scanner.Text()), nil +} + +// promptPassword prints a prompt and reads a password without echoing it. +func promptPassword(label string) (string, error) { + fmt.Print(label) + password, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() // newline after hidden input + if err != nil { + return "", fmt.Errorf("failed to read password: %w", err) + } + return string(password), nil +} diff --git a/cmd/lantern/ip.go b/cmd/lantern/ip.go new file mode 100644 index 00000000..c352fc73 --- /dev/null +++ b/cmd/lantern/ip.go @@ -0,0 +1,110 @@ +package main + +import ( + "context" + "fmt" + "io" + "net/http" + "net/netip" + "strings" + "time" +) + +// list of URLs to fetch the public IP address, just in case one is down or blocked +var ipURLs = []string{ + "https://ip.me", + "https://ifconfig.me/ip", + "https://checkip.amazonaws.com", + "https://ifconfig.io/ip", + "https://ident.me", + "https://ipinfo.io/ip", + "https://api.ipify.org", +} + +// GetPublicIP fetches the public IP address +func GetPublicIP(ctx context.Context) (string, error) { + return getPublicIP(ctx, ipURLs) +} + +func getPublicIP(ctx context.Context, urls []string) (string, error) { + if len(urls) == 0 { + urls = ipURLs + } + type result struct { + ip string + err error + } + results := make(chan result, len(urls)) + sem := make(chan struct{}, 3) + + client := &http.Client{} + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, url := range urls { + go func() { + // limit number of concurrent requests + sem <- struct{}{} + defer func() { <-sem }() + ip, err := fetchIP(ctx, client, url) + results <- result{ip, err} + }() + } + + var lastErr error + for i := 0; i < len(urls); i++ { + res := <-results + if res.err == nil { + return res.ip, nil + } + lastErr = res.err + } + return "", fmt.Errorf("failed to get public IP, error: %w", lastErr) +} + +// fetchIP performs an HTTP GET request to the given URL and returns the trimmed response body as the IP. +func fetchIP(ctx context.Context, client *http.Client, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", "curl/8.14.1") // some services return the entire HTML page for non-curl user agents + req.Header.Set("Connection", "close") + req.Close = true + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + ip := strings.TrimSpace(string(body)) + if ip == "" { + return "", fmt.Errorf("empty response from %s", url) + } + if _, err := netip.ParseAddr(ip); err != nil { + return "", fmt.Errorf("response is not a valid IP: %s -> %s...", url, ip[:min(len(ip), 7)]) + } + return ip, nil +} + +// WaitForIPChange polls the public IP address every interval until it changes from the current value. +func WaitForIPChange(ctx context.Context, current string, interval time.Duration) (string, error) { + urls := ipURLs + for { + select { + case <-ctx.Done(): + return "", nil + case <-time.After(interval): + ip, err := getPublicIP(ctx, urls) + if err != nil { + return "", nil + } else if ip != current { + return ip, nil + } + urls = append(urls[3:], urls[:3]...) // rotate URLs to avoid hitting the same ones repeatedly + } + } +} diff --git a/cmd/lantern/lantern.go b/cmd/lantern/lantern.go new file mode 100644 index 00000000..dda2f713 --- /dev/null +++ b/cmd/lantern/lantern.go @@ -0,0 +1,141 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "context" + + "github.com/alexflint/go-arg" + + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/issue" + rlog "github.com/getlantern/radiance/log" +) + +type args struct { + Connect *ConnectCmd `arg:"subcommand:connect" help:"connect to VPN"` + Disconnect *DisconnectCmd `arg:"subcommand:disconnect" help:"disconnect VPN"` + Status *StatusCmd `arg:"subcommand:status" help:"show VPN status"` + Servers *ServersCmd `arg:"subcommand:servers" help:"manage servers"` + Features *FeaturesCmd `arg:"subcommand:features" help:"list available features and their status"` + SmartRouting *SmartRoutingCmd `arg:"subcommand:smart-routing" help:"show or set smart routing"` + AdBlock *AdBlockCmd `arg:"subcommand:ad-block" help:"show or set ad blocking"` + Telemetry *TelemetryCmd `arg:"subcommand:telemetry" help:"show or set telemetry"` + SplitTunnel *SplitTunnelCmd `arg:"subcommand:split-tunnel" help:"split-tunnel settings and filters"` + Account *AccountCmd `arg:"subcommand:account" help:"login, signup, user data, devices, recovery"` + Subscription *SubscriptionCmd `arg:"subcommand:subscription" help:"plans, payments, and billing"` + ReportIssue *ReportIssueCmd `arg:"subcommand:report-issue" help:"report an issue"` + Logs *LogsCmd `arg:"subcommand:logs" help:"tail daemon logs"` + IP *IPCmd `arg:"subcommand:ip" help:"show public IP address"` +} + +func (args) Description() string { + return "Radiance CLI — command-line interface for the Radiance VPN daemon" +} + +type ReportIssueCmd struct { + Type int `arg:"--type,required" help:"0=purchase 1=signin 2=spinner 3=blocked-sites 4=slow 5=link-device 6=crash 9=other 10=update"` + Description string `arg:"--desc,required" help:"issue description"` + Email string `arg:"--email" help:"email address"` + Attachments []string `arg:"--attach" help:"additional attachment paths"` +} + +func runReportIssue(ctx context.Context, c *ipc.Client, cmd *ReportIssueCmd) error { + return c.ReportIssue(ctx, issue.IssueType(cmd.Type), cmd.Description, cmd.Email, cmd.Attachments) +} + +type LogsCmd struct{} + +func tailLogs(ctx context.Context, c *ipc.Client) error { + err := c.TailLogs(ctx, func(entry rlog.LogEntry) { + if entry.Source != "" { + fmt.Printf("%s [%s] %s: %s\n", entry.Time, entry.Level, entry.Source, entry.Message) + } else { + fmt.Printf("%s [%s] %s\n", entry.Time, entry.Level, entry.Message) + } + }) + if ctx.Err() != nil { + fmt.Fprintln(os.Stderr, "\nStopped tailing logs.") + return nil + } + return err +} + +type IPCmd struct{} + +func runIP(ctx context.Context) error { + tctx, tcancel := context.WithTimeout(ctx, 10*time.Second) + defer tcancel() + ip, err := GetPublicIP(tctx) + if err != nil { + return err + } + fmt.Println(ip) + return nil +} + +func main() { + var a args + p := arg.MustParse(&a) + if p.Subcommand() == nil { + p.WriteHelp(os.Stdout) + os.Exit(1) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + client := ipc.NewClient() + defer client.Close() + + if err := run(ctx, client, &a); err != nil { + fmt.Fprintln(os.Stderr, "error:", err) + os.Exit(1) + } +} + +func run(ctx context.Context, c *ipc.Client, a *args) error { + switch { + case a.Connect != nil: + return vpnConnect(ctx, c, a.Connect.Name) + case a.Disconnect != nil: + return c.DisconnectVPN(ctx) + case a.Status != nil: + return vpnStatus(ctx, c) + case a.Servers != nil: + return runServers(ctx, c, a.Servers) + case a.Features != nil: + return runFeatures(ctx, c) + case a.SmartRouting != nil: + return runSmartRouting(ctx, c, a.SmartRouting) + case a.AdBlock != nil: + return runAdBlock(ctx, c, a.AdBlock) + case a.Telemetry != nil: + return runTelemetry(ctx, c, a.Telemetry) + case a.SplitTunnel != nil: + return runSplitTunnel(ctx, c, a.SplitTunnel) + case a.Account != nil: + return runAccount(ctx, c, a.Account) + case a.Subscription != nil: + return runSubscription(ctx, c, a.Subscription) + case a.ReportIssue != nil: + return runReportIssue(ctx, c, a.ReportIssue) + case a.Logs != nil: + return tailLogs(ctx, c) + case a.IP != nil: + return runIP(ctx) + default: + return fmt.Errorf("no subcommand specified") + } +} + +func printJSON(v any) error { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(v) +} diff --git a/cmd/lantern/servers.go b/cmd/lantern/servers.go new file mode 100644 index 00000000..2650b2d2 --- /dev/null +++ b/cmd/lantern/servers.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "strings" + + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/vpn" +) + +type ServersCmd struct { + Show string `arg:"--show" help:"display server by tag"` + AddJSON string `arg:"--add-json" help:"add servers from JSON config"` + AddURL string `arg:"--add-url" help:"add servers from comma-separated URLs"` + SkipCertVerify bool `arg:"--skip-cert-verify" help:"skip cert verification (with --add-url)"` + Remove string `arg:"--remove" help:"comma-separated list of servers to remove"` + List bool `arg:"--list" help:"list servers"` + + PrivateServer *PrivateServerCmd `arg:"subcommand:private" help:"private server operations"` +} + +type PrivateServerCmd struct { + Add string `arg:"--add" help:"add private server with given tag"` + Invite string `arg:"--invite" help:"invite to private server"` + RevokeInvite string `arg:"--revoke-invite" help:"revoke invite"` + IP string `arg:"--ip" help:"server IP"` + Port int `arg:"--port" help:"server port"` + Token string `arg:"--token" help:"access token"` +} + +func runServers(ctx context.Context, c *ipc.Client, cmd *ServersCmd) error { + switch { + case cmd.Show != "": + return serversGet(ctx, c, cmd.Show) + case cmd.AddJSON != "": + return c.AddServersByJSON(ctx, cmd.AddJSON) + case cmd.AddURL != "": + urls := strings.Split(cmd.AddURL, ",") + return c.AddServersByURL(ctx, urls, cmd.SkipCertVerify) + case cmd.Remove != "": + return serversRemove(ctx, c, cmd.Remove) + case cmd.List: + return serversList(ctx, c) + case cmd.PrivateServer != nil: + return runPrivateServer(ctx, c, cmd.PrivateServer) + default: + return fmt.Errorf("must specify one of --get, --add-json, --add-url, --remove, or --list") + } +} + +func runPrivateServer(ctx context.Context, c *ipc.Client, cmd *PrivateServerCmd) error { + switch { + case cmd.Add != "": + return c.AddPrivateServer(ctx, cmd.Add, cmd.IP, cmd.Port, cmd.Token) + case cmd.Invite != "": + code, err := c.InviteToPrivateServer(ctx, cmd.IP, cmd.Port, cmd.Token, cmd.Invite) + if err != nil { + return err + } + fmt.Println(code) + return nil + case cmd.RevokeInvite != "": + return c.RevokePrivateServerInvite(ctx, cmd.IP, cmd.Port, cmd.Token, cmd.RevokeInvite) + default: + return fmt.Errorf("must specify one of --add, --invite, or --revoke-invite") + } +} + +func serversList(ctx context.Context, c *ipc.Client) error { + srvs, err := c.Servers(ctx) + if err != nil { + return err + } + found := false + for group, opts := range srvs { + if len(opts.Outbounds) == 0 && len(opts.Endpoints) == 0 { + continue + } + found = true + fmt.Println(group) + for _, s := range opts.Outbounds { + printServerEntry(s.Tag, s.Type, opts) + } + for _, s := range opts.Endpoints { + printServerEntry(s.Tag, s.Type, opts) + } + } + if !found { + fmt.Println("No servers available") + } + return nil +} + +func printServerEntry(tag, typ string, opts servers.Options) { + fmt.Printf(" %s [%s]", tag, typ) + if loc, ok := opts.Locations[tag]; ok { + fmt.Printf(" — %s, %s", loc.City, loc.Country) + } + fmt.Println() +} + +func serversGet(ctx context.Context, c *ipc.Client, tag string) error { + svr, exists, err := c.GetServerByTag(ctx, tag) + if err != nil { + return err + } + if !exists { + fmt.Println("Server not found") + return nil + } + return printJSON(svr) +} + +func serversSelected(ctx context.Context, c *ipc.Client) error { + svr, exists, err := c.SelectedServer(ctx) + if err != nil { + return err + } + if !exists { + fmt.Println("No server selected") + return nil + } + return printJSON(svr) +} + +func serversAutoSelections(ctx context.Context, c *ipc.Client, watch bool) error { + if watch { + return c.AutoSelectedEvents(ctx, func(ev vpn.AutoSelectedEvent) { + s := ev.Selected + fmt.Printf("Selected: %s\n", s) + }) + } + sel, err := c.AutoSelected(ctx) + if err != nil { + return err + } + fmt.Printf("Selected: %s\n", sel.Tag) + return nil +} + +func serversRemove(ctx context.Context, c *ipc.Client, tags string) error { + tagList := strings.Split(tags, ",") + return c.RemoveServers(ctx, tagList) +} diff --git a/cmd/lantern/settings.go b/cmd/lantern/settings.go new file mode 100644 index 00000000..512cfb3f --- /dev/null +++ b/cmd/lantern/settings.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "fmt" + + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/ipc" +) + +type FeaturesCmd struct{} + +func runFeatures(ctx context.Context, c *ipc.Client) error { + f, err := c.Features(ctx) + if err != nil { + return err + } + for k, v := range f { + fmt.Printf("%s: %v\n", k, v) + } + return nil +} + +type SmartRoutingCmd struct { + Enable *bool `arg:"positional" help:"enable or disable smart routing (true|false)"` +} + +func runSmartRouting(ctx context.Context, c *ipc.Client, cmd *SmartRoutingCmd) error { + if cmd.Enable == nil { + s, err := c.Settings(ctx) + if err != nil { + return err + } + fmt.Printf("Smart routing: %v\n", toBool(s[settings.SmartRoutingKey])) + return nil + } + if err := c.EnableSmartRouting(ctx, *cmd.Enable); err != nil { + return err + } + fmt.Printf("Smart routing set to %v\n", *cmd.Enable) + return nil +} + +type AdBlockCmd struct { + Enable *bool `arg:"positional" help:"enable or disable ad blocking (true|false)"` +} + +func runAdBlock(ctx context.Context, c *ipc.Client, cmd *AdBlockCmd) error { + if cmd.Enable == nil { + s, err := c.Settings(ctx) + if err != nil { + return err + } + fmt.Printf("Ad blocking: %v\n", toBool(s[settings.AdBlockKey])) + return nil + } + if err := c.EnableAdBlocking(ctx, *cmd.Enable); err != nil { + return err + } + fmt.Printf("Ad blocking set to %v\n", *cmd.Enable) + return nil +} + +type TelemetryCmd struct { + Enable *bool `arg:"positional" help:"enable or disable telemetry (true|false)"` +} + +func runTelemetry(ctx context.Context, c *ipc.Client, cmd *TelemetryCmd) error { + if cmd.Enable == nil { + s, err := c.Settings(ctx) + if err != nil { + return err + } + fmt.Printf("Telemetry: %v\n", toBool(s[settings.TelemetryKey])) + return nil + } + if err := c.EnableTelemetry(ctx, *cmd.Enable); err != nil { + return err + } + fmt.Printf("Telemetry set to %v\n", *cmd.Enable) + return nil +} + +func toBool(v any) bool { + if v == nil { + return false + } + return fmt.Sprintf("%v", v) == "true" +} diff --git a/cmd/lantern/split_tunnel.go b/cmd/lantern/split_tunnel.go new file mode 100644 index 00000000..15e98734 --- /dev/null +++ b/cmd/lantern/split_tunnel.go @@ -0,0 +1,113 @@ +package main + +import ( + "context" + "fmt" + "strings" + + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/vpn" +) + +type SplitTunnelCmd struct { + Enable *bool `arg:"positional" help:"enable or disable split tunneling (true|false)"` + List bool `arg:"-l,--list" help:"list current filters"` + Add string `arg:"--add" help:"add filter (TYPE:VALUE, e.g. domain-suffix:example.com)"` + Remove string `arg:"--remove" help:"remove filter (TYPE:VALUE)"` +} + +func runSplitTunnel(ctx context.Context, c *ipc.Client, cmd *SplitTunnelCmd) error { + switch { + case cmd.Add != "": + typ, val, err := parseFilter(cmd.Add) + if err != nil { + return err + } + return c.AddSplitTunnelItems(ctx, buildFilter(typ, val)) + case cmd.Remove != "": + typ, val, err := parseFilter(cmd.Remove) + if err != nil { + return err + } + return c.RemoveSplitTunnelItems(ctx, buildFilter(typ, val)) + case cmd.List: + return splitTunnelList(ctx, c) + case cmd.Enable != nil: + if err := c.EnableSplitTunneling(ctx, *cmd.Enable); err != nil { + return err + } + fmt.Printf("Split tunneling set to %v\n", *cmd.Enable) + return nil + default: + return splitTunnelStatus(ctx, c) + } +} + +func splitTunnelStatus(ctx context.Context, c *ipc.Client) error { + s, err := c.Settings(ctx) + if err != nil { + return err + } + v := s[settings.SplitTunnelKey] + if v == nil { + v = false + } + fmt.Printf("Split tunneling: %v\n", v) + return nil +} + +func splitTunnelList(ctx context.Context, c *ipc.Client) error { + s, err := c.Settings(ctx) + if err != nil { + return err + } + fmt.Println("Enabled:", s[settings.SplitTunnelKey]) + filters, err := c.SplitTunnelFilters(ctx) + if err != nil { + return err + } + fmt.Println(filters.String()) + return nil +} + +// parseFilter splits "TYPE:VALUE" into the internal filter type and value. +func parseFilter(spec string) (string, string, error) { + typ, val, ok := strings.Cut(spec, ":") + if !ok || val == "" { + return "", "", fmt.Errorf("filter format: TYPE:VALUE (e.g. domain-suffix:example.com)") + } + return filterTypeFromArg(typ), val, nil +} + +// filterTypeFromArg converts a CLI arg like "domain-suffix" to the internal type "domainSuffix". +func filterTypeFromArg(a string) string { + s, rest, _ := strings.Cut(a, "-") + if rest != "" { + s += strings.ToUpper(rest[:1]) + rest[1:] + } + return s +} + +func buildFilter(filterType, value string) vpn.SplitTunnelFilter { + var f vpn.SplitTunnelFilter + switch filterType { + case vpn.TypeDomain: + f.Domain = []string{value} + case vpn.TypeDomainSuffix: + f.DomainSuffix = []string{value} + case vpn.TypeDomainKeyword: + f.DomainKeyword = []string{value} + case vpn.TypeDomainRegex: + f.DomainRegex = []string{value} + case vpn.TypeProcessName: + f.ProcessName = []string{value} + case vpn.TypeProcessPath: + f.ProcessPath = []string{value} + case vpn.TypeProcessPathRegex: + f.ProcessPathRegex = []string{value} + case vpn.TypePackageName: + f.PackageName = []string{value} + } + return f +} diff --git a/cmd/lantern/subscription.go b/cmd/lantern/subscription.go new file mode 100644 index 00000000..769a141a --- /dev/null +++ b/cmd/lantern/subscription.go @@ -0,0 +1,271 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/ipc" +) + +type SubscriptionCmd struct { + Plans *SubscriptionPlansCmd `arg:"subcommand:plans" help:"list subscription plans for a channel"` + Activate *ActivateCmd `arg:"subcommand:activate" help:"activate with reseller code"` + StripeSub *StripeSubCmd `arg:"subcommand:stripe-sub" help:"create Stripe subscription"` + Redirect *PaymentRedirectCmd `arg:"subcommand:redirect" help:"get payment redirect URL"` + SubRedirect *SubPaymentRedirectCmd `arg:"subcommand:sub-redirect" help:"get subscription payment redirect URL"` + Referral *ReferralCmd `arg:"subcommand:referral" help:"attach referral code"` + StripeBilling *StripeBillingCmd `arg:"subcommand:stripe-billing" help:"get Stripe billing portal URL"` + Verify *VerifySubscriptionCmd `arg:"subcommand:verify" help:"verify subscription"` +} + +type SubscriptionPlansCmd struct { + Channel string `arg:"--channel" help:"subscription channel"` +} + +type ActivateCmd struct { + Email string `arg:"--email" help:"email address"` + Code string `arg:"--code" help:"reseller code"` +} + +type StripeSubCmd struct { + Email string `arg:"--email" help:"email address"` + PlanID string `arg:"--plan" help:"plan ID"` +} + +type PaymentRedirectCmd struct { + PlanID string `arg:"--plan" help:"plan ID"` + Provider string `arg:"--provider" help:"payment provider"` + Email string `arg:"--email" help:"email address"` + DeviceName string `arg:"--device" help:"device name"` + BillingType string `arg:"--billing-type" default:"subscription" help:"one_time or subscription"` +} + +type SubPaymentRedirectCmd struct { + PlanID string `arg:"--plan" help:"plan ID"` + Provider string `arg:"--provider" help:"payment provider"` + Email string `arg:"--email" help:"email address"` + DeviceName string `arg:"--device" help:"device name"` + BillingType string `arg:"--billing-type" default:"subscription" help:"one_time or subscription"` +} + +type ReferralCmd struct { + Code string `arg:"--code" help:"referral code"` +} + +type StripeBillingCmd struct{} + +type VerifySubscriptionCmd struct { + Service string `arg:"--service" help:"stripe, apple, or google"` + VerifyData string `arg:"--data" help:"verification data as JSON"` +} + +func runSubscription(ctx context.Context, c *ipc.Client, cmd *SubscriptionCmd) error { + switch { + case cmd.Plans != nil: + return subPlans(ctx, c, cmd.Plans) + case cmd.Activate != nil: + return subActivate(ctx, c, cmd.Activate) + case cmd.StripeSub != nil: + return subStripeSub(ctx, c, cmd.StripeSub) + case cmd.Redirect != nil: + return subRedirect(ctx, c, cmd.Redirect) + case cmd.SubRedirect != nil: + return subSubRedirect(ctx, c, cmd.SubRedirect) + case cmd.Referral != nil: + return subReferral(ctx, c, cmd.Referral) + case cmd.StripeBilling != nil: + return subStripeBilling(ctx, c, cmd.StripeBilling) + case cmd.Verify != nil: + return subVerify(ctx, c, cmd.Verify) + default: + return fmt.Errorf("no subcommand specified") + } +} + +func subPlans(ctx context.Context, c *ipc.Client, cmd *SubscriptionPlansCmd) error { + channel := cmd.Channel + if channel == "" { + var err error + channel, err = prompt("Channel: ") + if err != nil { + return err + } + } + plans, err := c.SubscriptionPlans(ctx, channel) + if err != nil { + return err + } + fmt.Println(plans) + return nil +} + +func subActivate(ctx context.Context, c *ipc.Client, cmd *ActivateCmd) error { + email := cmd.Email + code := cmd.Code + var err error + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return err + } + } + if code == "" { + code, err = prompt("Reseller code: ") + if err != nil { + return err + } + } + resp, err := c.ActivationCode(ctx, email, code) + if err != nil { + return err + } + return printJSON(resp) +} + +func subStripeSub(ctx context.Context, c *ipc.Client, cmd *StripeSubCmd) error { + email := cmd.Email + planID := cmd.PlanID + var err error + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return err + } + } + if planID == "" { + planID, err = prompt("Plan ID: ") + if err != nil { + return err + } + } + secret, err := c.NewStripeSubscription(ctx, email, planID) + if err != nil { + return err + } + fmt.Println(secret) + return nil +} + +func promptRedirectData(planID, provider, email, deviceName, billingType string) (account.PaymentRedirectData, error) { + var err error + if planID == "" { + planID, err = prompt("Plan ID: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if provider == "" { + provider, err = prompt("Provider: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if deviceName == "" { + deviceName, err = prompt("Device name: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if billingType == "" { + billingType = "subscription" + } + return account.PaymentRedirectData{ + Plan: planID, + Provider: provider, + Email: email, + DeviceName: deviceName, + BillingType: account.SubscriptionType(billingType), + }, nil +} + +func subRedirect(ctx context.Context, c *ipc.Client, cmd *PaymentRedirectCmd) error { + data, err := promptRedirectData(cmd.PlanID, cmd.Provider, cmd.Email, cmd.DeviceName, cmd.BillingType) + if err != nil { + return err + } + url, err := c.PaymentRedirect(ctx, data) + if err != nil { + return err + } + fmt.Println(url) + return nil +} + +func subSubRedirect(ctx context.Context, c *ipc.Client, cmd *SubPaymentRedirectCmd) error { + data, err := promptRedirectData(cmd.PlanID, cmd.Provider, cmd.Email, cmd.DeviceName, cmd.BillingType) + if err != nil { + return err + } + url, err := c.SubscriptionPaymentRedirectURL(ctx, data) + if err != nil { + return err + } + fmt.Println(url) + return nil +} + +func subReferral(ctx context.Context, c *ipc.Client, cmd *ReferralCmd) error { + code := cmd.Code + if code == "" { + var err error + code, err = prompt("Referral code: ") + if err != nil { + return err + } + } + ok, err := c.ReferralAttach(ctx, code) + if err != nil { + return err + } + if ok { + fmt.Println("Referral attached successfully") + } else { + fmt.Println("Referral was not attached") + } + return nil +} + +func subStripeBilling(ctx context.Context, c *ipc.Client, cmd *StripeBillingCmd) error { + url, err := c.StripeBillingPortalURL(ctx) + if err != nil { + return err + } + fmt.Println(url) + return nil +} + +func subVerify(ctx context.Context, c *ipc.Client, cmd *VerifySubscriptionCmd) error { + service := cmd.Service + verifyData := cmd.VerifyData + var err error + if service == "" { + service, err = prompt("Service (stripe, apple, or google): ") + if err != nil { + return err + } + } + if verifyData == "" { + verifyData, err = prompt("Verification data (JSON): ") + if err != nil { + return err + } + } + var data map[string]string + if err := json.Unmarshal([]byte(verifyData), &data); err != nil { + return fmt.Errorf("invalid JSON for verification data: %w", err) + } + result, err := c.VerifySubscription(ctx, account.SubscriptionService(service), data) + if err != nil { + return err + } + fmt.Println(result) + return nil +} diff --git a/cmd/lantern/vpn.go b/cmd/lantern/vpn.go new file mode 100644 index 00000000..42efe464 --- /dev/null +++ b/cmd/lantern/vpn.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/vpn" +) + +type ConnectCmd struct { + Name string `arg:"-n,--name" default:"auto" help:"server name to connect to"` +} + +type DisconnectCmd struct{} + +type StatusCmd struct{} + +func vpnConnect(ctx context.Context, c *ipc.Client, tag string) error { + tctx, tcancel := context.WithTimeout(ctx, 5*time.Second) + prevIP, _ := GetPublicIP(tctx) + tcancel() + + if err := c.ConnectVPN(ctx, tag); err != nil { + return err + } + fmt.Printf("Connected (tag: %s)\n", tag) + + start := time.Now() + waitCtx, waitCancel := context.WithTimeout(ctx, 30*time.Second) + defer waitCancel() + ip, err := WaitForIPChange(waitCtx, prevIP, 500*time.Millisecond) + if err == nil && ip != "" { + fmt.Printf("Public IP: %s (took %v)\n", ip, time.Since(start).Truncate(time.Millisecond)) + } + return nil +} + +func vpnStatus(ctx context.Context, c *ipc.Client) error { + status, err := c.VPNStatus(ctx) + if err != nil { + return err + } + line := string(status) + if status == vpn.Connected { + if sel, exists, err := c.SelectedServer(ctx); err == nil && exists { + line += " server=" + sel.Tag + } + } + tctx, tcancel := context.WithTimeout(ctx, 5*time.Second) + if ip, err := GetPublicIP(tctx); err == nil { + line += " ip=" + ip + } + tcancel() + fmt.Println(line) + return nil +} diff --git a/cmd/lanternd/lanternd.go b/cmd/lanternd/lanternd.go index ba159528..6b278329 100644 --- a/cmd/lanternd/lanternd.go +++ b/cmd/lanternd/lanternd.go @@ -1,78 +1,332 @@ package main import ( + "bufio" "context" - "flag" + "errors" "fmt" + "io" "log" "log/slog" "os" + "os/exec" "os/signal" + "path/filepath" "syscall" "time" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" + "github.com/alexflint/go-arg" + "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/traces" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/ipc" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/vpn" - "github.com/getlantern/radiance/vpn/ipc" ) -const tracerName = "github.com/getlantern/radiance/cmd/lanternd" +type runCmd struct { + DataPath string `arg:"--data-path" help:"path to store data"` + LogPath string `arg:"--log-path" help:"path to store logs"` + LogLevel string `arg:"--log-level" default:"info" help:"logging level (trace, debug, info, warn, error)"` +} -var ( - dataPath = flag.String("data-path", "$HOME/.lantern", "Path to store data") - logPath = flag.String("log-path", "$HOME/.lantern", "Path to store logs") - logLevel = flag.String("log-level", "info", "Logging level (trace, debug, info, warn, error)") -) +type installCmd struct { + DataPath string `arg:"--data-path" help:"path to store data"` + LogPath string `arg:"--log-path" help:"path to store logs"` + LogLevel string `arg:"--log-level" default:"info" help:"logging level (trace, debug, info, warn, error)"` +} + +type uninstallCmd struct{} + +type daemonArgs struct { + Run *runCmd `arg:"subcommand:run" help:"run the daemon"` + Install *installCmd `arg:"subcommand:install" help:"install as system service"` + Uninstall *uninstallCmd `arg:"subcommand:uninstall" help:"uninstall system service"` +} + +func (daemonArgs) Description() string { + return "lanternd — Lantern VPN daemon" +} func main() { - flag.Parse() + if maybePlatformService() { + return + } - dataPath := os.ExpandEnv(*dataPath) - logPath := os.ExpandEnv(*logPath) - logLevel := *logLevel + var a daemonArgs + p := arg.MustParse(&a) + if p.Subcommand() == nil { + p.WriteHelp(os.Stdout) + os.Exit(1) + } - slog.Info("Starting lanternd", "version", common.Version, "dataPath", dataPath) - if err := common.Init(dataPath, logPath, logLevel); err != nil { - log.Fatalf("Failed to initialize common: %v\n", err) + var err error + switch { + case a.Run != nil: + dataPath := os.ExpandEnv(withDefault(a.Run.DataPath, defaultDataPath)) + logPath := os.ExpandEnv(withDefault(a.Run.LogPath, defaultLogPath)) + if os.Getenv("_LANTERND_CHILD") != "1" { + err = babysit(os.Args[1:], dataPath, logPath, a.Run.LogLevel) + break + } + ctx, cancel := context.WithCancel(context.Background()) + // Shut down on stdin closure (babysit parent signals us) or OS signal. + go func() { + io.Copy(io.Discard, os.Stdin) + cancel() + }() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + cancel() + // Restore default signal behavior so a second signal terminates immediately. + signal.Reset(syscall.SIGINT, syscall.SIGTERM) + }() + err = runDaemon(ctx, dataPath, logPath, a.Run.LogLevel) + case a.Install != nil: + err = install( + os.ExpandEnv(withDefault(a.Install.DataPath, defaultDataPath)), + os.ExpandEnv(withDefault(a.Install.LogPath, defaultLogPath)), + a.Install.LogLevel, + ) + case a.Uninstall != nil: + err = uninstall() + } + if err != nil { + log.Fatalf("Error: %v\n", err) + } +} + +func withDefault(val, def string) string { + if val == "" { + return def + } + return val +} + +// copyBin copies the current executable to binPath, creating parent directories +// as needed. It returns the destination path. +func copyBin() (string, error) { + src, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get executable path: %w", err) + } + src, err = filepath.EvalSymlinks(src) + if err != nil { + return "", fmt.Errorf("failed to resolve executable path: %w", err) + } + + dst := binPath + if src == dst { + return dst, nil + } + + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return "", fmt.Errorf("failed to create directory for %s: %w", dst, err) + } + + sf, err := os.Open(src) + if err != nil { + return "", fmt.Errorf("failed to open source binary: %w", err) + } + defer sf.Close() + + df, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return "", fmt.Errorf("failed to create %s: %w", dst, err) + } + defer df.Close() + + if _, err := io.Copy(df, sf); err != nil { + return "", fmt.Errorf("failed to copy binary to %s: %w", dst, err) + } + + slog.Info("Copied binary", "src", src, "dst", dst) + return dst, nil +} + +// childProcess manages a daemon child process. The parent spawns the child, drains its output, +// and can signal graceful shutdown by closing its stdin pipe. If the child crashes, the parent +// cleans up stale VPN network state immediately. +type childProcess struct { + cmd *exec.Cmd + stdin io.Closer + done chan error + dataPath string + logger *slog.Logger +} + +// spawnChild creates and starts a daemon child process with piped I/O. The child's stdout and +// stderr are merged and drained through the provided logger (or os.Stdout as fallback). +func spawnChild(args []string, dataPath, logPath, logLevel string) (*childProcess, error) { + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("failed to get executable path: %w", err) + } + + cmd := exec.Command(exe, args...) + cmd.Env = append(os.Environ(), "_LANTERND_CHILD=1") + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + cmd.Stderr = cmd.Stdout // merge stderr into the same pipe + + logger := rlog.NewLogger(rlog.Config{ + LogPath: filepath.Join(logPath, internal.LogFileName), + Level: logLevel, + Prod: true, + DisablePublisher: true, + }) + + go func() { + defer stdoutPipe.Close() + var w io.Writer = os.Stdout + if h, ok := logger.Handler().(rlog.Handler); ok { + w = h.Writer() + } + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + if s := scanner.Text(); s != "" { + w.Write([]byte(s + "\n")) + } + } + if err := scanner.Err(); err != nil { + logger.Error("Error reading child process output", "error", err) + } + }() + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start daemon process: %w", err) + } + logger.Info("Started daemon process", "pid", cmd.Process.Pid) + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + return &childProcess{ + cmd: cmd, + stdin: stdinPipe, + done: done, + dataPath: dataPath, + logger: logger, + }, nil +} + +// RequestShutdown signals the child to shut down gracefully by closing its stdin pipe. +func (c *childProcess) RequestShutdown() { + c.logger.Info("Requesting child process shutdown") + c.stdin.Close() +} + +// Done returns a channel that receives the child's exit error (nil on clean exit). +func (c *childProcess) Done() <-chan error { + return c.done +} + +// WaitOrKill waits for the child to exit, killing it if it doesn't exit within the timeout. +func (c *childProcess) WaitOrKill(timeout time.Duration) error { + select { + case err := <-c.done: + return err + case <-time.After(timeout): + c.logger.Warn("Child did not exit in time, killing") + c.cmd.Process.Kill() + return <-c.done } +} + +// HandleCrash cleans up stale VPN network state left by a crashed child. +func (c *childProcess) HandleCrash(err error) { + c.logger.Warn("Daemon process exited unexpectedly, cleaning up network state", "error", err) + vpn.ClearNetErrorState() +} - ipcServer, err := initIPC(dataPath, logPath, logLevel) +// babysit runs the daemon as a child process and monitors it. If the child exits unexpectedly +// (crash, panic, etc.), the parent immediately cleans up any stale VPN network state so the OS +// network remains usable without requiring a reboot or manual intervention. +// +// Graceful shutdown is signaled by closing the child's stdin pipe — this works cross-platform, +// including inside a Windows service where there is no console for signal delivery. +func babysit(args []string, dataPath, logPath, logLevel string) error { + child, err := spawnChild(args, dataPath, logPath, logLevel) if err != nil { - log.Fatalf("Failed to initialize IPC: %v\n", err) + return err } + child.logger.Info("Monitoring daemon process") - // Wait for a signal to gracefully shut down. + // On termination signal, close the child's stdin pipe to trigger graceful shutdown. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh + go func() { + <-sigCh + child.RequestShutdown() + }() - slog.Info("Shutting down...") - time.AfterFunc(15*time.Second, func() { - log.Fatal("Failed to shut down in time, forcing exit.") - }) - ipcServer.Close() + err = <-child.Done() + signal.Stop(sigCh) + + if err != nil { + child.HandleCrash(err) + } + + // Propagate the child's exit code. + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitCode()) + } + return err } -func initIPC(dataPath, logPath, logLevel string) (*ipc.Server, error) { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "initIPC", - trace.WithAttributes(attribute.String("dataPath", dataPath)), - ) - defer span.End() +func runDaemon(ctx context.Context, dataPath, logPath, logLevel string) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - span.AddEvent("initializing IPC server") + slog.Info("Starting lanternd", "version", common.Version, "dataPath", dataPath) + be, err := backend.NewLocalBackend(ctx, backend.Options{ + DataDir: dataPath, + LogDir: logPath, + LogLevel: logLevel, + }) + if err != nil { + return fmt.Errorf("failed to create backend: %w", err) + } + user, err := be.UserData() + if err != nil { + return fmt.Errorf("failed to get current data: %w", err) + } + if user == nil { + if _, err := be.NewUser(ctx); err != nil { + return fmt.Errorf("failed to create new user: %w", err) + } + } - server := ipc.NewServer(vpn.NewTunnelService(dataPath, slog.Default().With("service", "ipc"), nil)) - slog.Debug("starting IPC server") + be.Start() + server := ipc.NewServer(be, !common.IsMobile()) if err := server.Start(); err != nil { - slog.Error("failed to start IPC server", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("start IPC server: %w", err)) + return fmt.Errorf("failed to start IPC server: %w", err) + } + + // Wait for context cancellation to gracefully shut down. + <-ctx.Done() + + slog.Info("Shutting down...") + + time.AfterFunc(15*time.Second, func() { + slog.Error("Failed to shut down in time, forcing exit") + os.Exit(1) + }) + + be.Close() + if err := server.Close(); err != nil { + slog.Error("Error closing IPC server", "error", err) } - return server, nil + slog.Info("Shutdown complete") + return nil } diff --git a/cmd/lanternd/lanternd.service b/cmd/lanternd/lanternd.service index de147401..66d98fa3 100644 --- a/cmd/lanternd/lanternd.service +++ b/cmd/lanternd/lanternd.service @@ -5,7 +5,7 @@ After=network-online.target [Service] Type=simple -ExecStart=/usr/sbin/lanternd -data-path /var/lib/lantern -log-path /var/log/lantern -log-level trace +ExecStart=/usr/sbin/lanternd run -data-path /var/lib/lantern -log-path /var/log/lantern -log-level trace Restart=on-failure RestartSec=5s diff --git a/cmd/lanternd/lanternd_darwin.go b/cmd/lanternd/lanternd_darwin.go new file mode 100644 index 00000000..216e2e25 --- /dev/null +++ b/cmd/lanternd/lanternd_darwin.go @@ -0,0 +1,103 @@ +//go:build darwin && !ios + +package main + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "text/template" +) + +const ( + serviceName = "com.lantern.lanternd" + defaultDataPath = "/Library/Application Support/Lantern" + defaultLogPath = "/Library/Logs/Lantern" + binPath = "/usr/local/bin/" + serviceName +) + +func maybePlatformService() bool { + return false +} + +var launchdPlistTmpl = template.Must(template.New("plist").Parse(` + + + + Label + {{.ServiceName}} + ProgramArguments + + {{.ExePath}} + run + --data-path + {{.DataPath}} + --log-path + {{.LogPath}} + --log-level + {{.LogLevel}} + + RunAtLoad + + KeepAlive + + StandardOutPath + {{.LogPath}}/lanternd.stdout.log + StandardErrorPath + {{.LogPath}}/lanternd.stderr.log + + +`)) + +func plistPath() string { + return fmt.Sprintf("/Library/LaunchDaemons/%s.plist", serviceName) +} + +func install(dataPath, logPath, logLevel string) error { + exe, err := copyBin() + if err != nil { + return err + } + + plist := plistPath() + f, err := os.Create(plist) + if err != nil { + return fmt.Errorf("failed to create plist %s: %w", plist, err) + } + defer f.Close() + + err = launchdPlistTmpl.Execute(f, struct { + ServiceName, ExePath, DataPath, LogPath, LogLevel string + }{serviceName, exe, dataPath, logPath, logLevel}) + if err != nil { + return fmt.Errorf("failed to write plist: %w", err) + } + + slog.Info("Installing launchd service", "plist", plist) + if out, err := exec.Command("launchctl", "load", "-w", plist).CombinedOutput(); err != nil { + return fmt.Errorf("launchctl load: %w\n%s", err, out) + } + + slog.Info("Launchd service installed and started") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling launchd service") + plist := plistPath() + + if out, err := exec.Command("launchctl", "unload", "-w", plist).CombinedOutput(); err != nil { + slog.Warn("Failed to unload service", "error", err, "output", string(out)) + } + + if err := os.Remove(plist); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove plist: %w", err) + } + + slog.Info("Launchd service uninstalled") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil +} diff --git a/cmd/lanternd/lanternd_linux.go b/cmd/lanternd/lanternd_linux.go new file mode 100644 index 00000000..bd88aee1 --- /dev/null +++ b/cmd/lanternd/lanternd_linux.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "text/template" +) + +const ( + serviceName = "lanternd" + defaultDataPath = "/var/lib/lantern" + defaultLogPath = "/var/log/lantern" + binPath = "/usr/bin/" + serviceName +) + +func maybePlatformService() bool { + return false +} + +var systemdUnitTmpl = template.Must(template.New("unit").Parse(`[Unit] +Description=Lantern VPN Daemon +Wants=network-online.target +After=network-online.target + +[Service] +Type=simple +ExecStart={{.ExePath}} run --data-path {{.DataPath}} --log-path {{.LogPath}} --log-level {{.LogLevel}} +Restart=on-failure +RestartSec=5s + +RuntimeDirectory=lantern +RuntimeDirectoryMode=0755 +StateDirectory=lantern +CacheDirectory=lantern +LogsDirectory=lantern + +[Install] +WantedBy=multi-user.target +`)) + +func install(dataPath, logPath, logLevel string) error { + exe, err := copyBin() + if err != nil { + return err + } + + unitPath := fmt.Sprintf("/etc/systemd/system/%s.service", serviceName) + f, err := os.Create(unitPath) + if err != nil { + return fmt.Errorf("failed to create unit file %s: %w", unitPath, err) + } + defer f.Close() + + err = systemdUnitTmpl.Execute(f, struct { + ExePath, DataPath, LogPath, LogLevel string + }{exe, dataPath, logPath, logLevel}) + if err != nil { + return fmt.Errorf("failed to write unit file: %w", err) + } + + slog.Info("Installing systemd service", "unit", unitPath) + for _, args := range [][]string{ + {"systemctl", "daemon-reload"}, + {"systemctl", "enable", serviceName}, + {"systemctl", "start", serviceName}, + } { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + return fmt.Errorf("%v: %w\n%s", args, err, out) + } + } + + slog.Info("Systemd service installed and started") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling systemd service") + for _, args := range [][]string{ + {"systemctl", "stop", serviceName}, + {"systemctl", "disable", serviceName}, + } { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + slog.Warn("Command failed", "cmd", args, "error", err, "output", string(out)) + } + } + + unitPath := fmt.Sprintf("/etc/systemd/system/%s.service", serviceName) + if err := os.Remove(unitPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove unit file: %w", err) + } + + if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil { + return fmt.Errorf("systemctl daemon-reload: %w\n%s", err, out) + } + + slog.Info("Systemd service uninstalled") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil +} diff --git a/cmd/lanternd/lanternd_windows.go b/cmd/lanternd/lanternd_windows.go new file mode 100644 index 00000000..e98dfaf5 --- /dev/null +++ b/cmd/lanternd/lanternd_windows.go @@ -0,0 +1,227 @@ +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + serviceName = "lanternd" + defaultDataPath = "$PROGRAMDATA\\lantern" + defaultLogPath = "$PROGRAMDATA\\lantern" + binPath = "C:\\Program Files\\Lantern\\" + serviceName + ".exe" +) + +var isWindowsService bool + +func init() { + isSvc, err := svc.IsWindowsService() + if err != nil { + log.Fatalf("Failed to determine if running as Windows service: %v\n", err) + } + isWindowsService = isSvc +} + +func install(dataPath, logPath, logLevel string) error { + dataPath = os.ExpandEnv(dataPath) + logPath = os.ExpandEnv(logPath) + + slog.Info("Installing Windows service..") + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %w", err) + } + + if service, err := m.OpenService(serviceName); err == nil { + service.Close() + return fmt.Errorf("service %q is already installed", serviceName) + } + + exe, err := copyBin() + if err != nil { + return err + } + + config := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceName, + Description: "Lantern Daemon Service", + } + + args := []string{ + "run", + "--data-path", dataPath, + "--log-path", logPath, + "--log-level", logLevel, + } + + slog.Info("Creating Windows service", "exe", exe, "args", args) + service, err := m.CreateService(serviceName, exe, config, args...) + if err != nil { + return fmt.Errorf("failed to create %q service: %w", serviceName, err) + } + defer service.Close() + + err = service.SetRecoveryActions([]mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 1 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 2 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 4 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 8 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 16 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 32 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 64 * time.Second}, + }, 60) + if err != nil { + return fmt.Errorf("failed to set service recovery actions: %w", err) + } + if err := service.Start(); err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + + slog.Info("Windows service installed successfully") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling Windows service..") + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %w", err) + } + defer m.Disconnect() + + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("failed to open %q service: %w", serviceName, err) + } + + status, err := service.Query() + if err != nil { + service.Close() + return fmt.Errorf("failed to query service state: %w", err) + } + if status.State != svc.Stopped { + service.Control(svc.Stop) + } + err = service.Delete() + service.Close() + if err != nil { + return fmt.Errorf("failed to delete service: %w", err) + } + + slog.Info("Waiting for service to be removed...") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for service to be removed") + case <-time.After(100 * time.Millisecond): + if service, err = m.OpenService(serviceName); err != nil { + slog.Info("Windows service uninstalled successfully") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil + } + service.Close() + } + } +} + +func maybePlatformService() bool { + if !isWindowsService { + return false + } + if err := startWindowsService(); err != nil { + log.Fatalf("Failed to start Windows service: %v\n", err) + } + return true +} + +type service struct{} + +func startWindowsService() error { + return svc.Run(serviceName, &service{}) +} + +func (s *service) Execute(args []string, r <-chan svc.ChangeRequest, status chan<- svc.Status) (bool, uint32) { + status <- svc.Status{State: svc.StartPending} + + // The Execute args from the SCM dispatcher only contain runtime start parameters + // (typically just [serviceName]). The actual configured arguments are baked into + // os.Args via the service ImagePath. Parse from os.Args to get the real values, + // falling back to defaults if not present. + dataPath, logPath, logLevel := parseServiceArgs(os.Args[1:]) + + // Run the daemon as a child process so we can clean up network state if it crashes, + // regardless of whether the SCM is configured to restart the service. + childArgs := []string{"run", "--data-path", dataPath, "--log-path", logPath, "--log-level", logLevel} + child, err := spawnChild(childArgs, dataPath, logPath, logLevel) + if err != nil { + slog.Error("Failed to start daemon", "error", err) + return true, 1 + } + + status <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} + child.logger.Info("Running as Windows service") + + for { + select { + case err := <-child.Done(): + if err != nil { + child.HandleCrash(err) + } + return true, 1 + case change := <-r: + switch change.Cmd { + case svc.Stop, svc.Shutdown: + status <- svc.Status{State: svc.StopPending} + child.logger.Info("Service stop requested") + child.RequestShutdown() + child.WaitOrKill(15 * time.Second) + return false, windows.NO_ERROR + case svc.Interrogate: + status <- change.CurrentStatus + case svc.SessionChange: + status <- change.CurrentStatus + } + } + } +} + +func parseServiceArgs(args []string) (dataPath, logPath, logLevel string) { + dataPath = os.ExpandEnv(defaultDataPath) + logPath = os.ExpandEnv(defaultLogPath) + logLevel = "info" + for i := 0; i < len(args); i++ { + switch args[i] { + case "--data-path": + if i+1 < len(args) { + dataPath = os.ExpandEnv(args[i+1]) + i++ + } + case "--log-path": + if i+1 < len(args) { + logPath = os.ExpandEnv(args[i+1]) + i++ + } + case "--log-level": + if i+1 < len(args) { + logLevel = args[i+1] + i++ + } + } + } + return +} diff --git a/common/constants.go b/common/constants.go index b86adf79..bb1d46ae 100644 --- a/common/constants.go +++ b/common/constants.go @@ -7,15 +7,11 @@ import ( // Version is the application version, injected at build time via ldflags: // // -X 'github.com/getlantern/radiance/common.Version=x.y.z' -var Version = "dev" +var Version = "9.0.20" const ( Name = "lantern" - // filenames - LogFileName = "lantern.log" - ConfigFileName = "config.json" - ServersFileName = "servers.json" DefaultHTTPTimeout = (60 * time.Second) // API URLs @@ -25,7 +21,6 @@ const ( StageBaseURL = "https://api.staging.iantem.io/v1" ) - // GetProServerURL returns the pro server URL based on the current environment. func GetProServerURL() string { if Stage() { diff --git a/common/env/env.go b/common/env/env.go index de1b9b0b..8a429877 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -10,41 +10,30 @@ import ( "strconv" "strings" "testing" - - "github.com/getlantern/radiance/internal" ) -type Key = string +type _key string -const ( - LogLevel Key = "RADIANCE_LOG_LEVEL" - LogPath Key = "RADIANCE_LOG_PATH" - DataPath Key = "RADIANCE_DATA_PATH" - DisableFetch Key = "RADIANCE_DISABLE_FETCH_CONFIG" - PrintCurl Key = "RADIANCE_PRINT_CURL" - DisableStdout Key = "RADIANCE_DISABLE_STDOUT_LOG" - ENV Key = "RADIANCE_ENV" - UseSocks Key = "RADIANCE_USE_SOCKS_PROXY" - SocksAddress Key = "RADIANCE_SOCKS_ADDRESS" +var ( + LogLevel _key = "RADIANCE_LOG_LEVEL" + LogPath _key = "RADIANCE_LOG_PATH" + DataPath _key = "RADIANCE_DATA_PATH" + DisableFetch _key = "RADIANCE_DISABLE_FETCH_CONFIG" + PrintCurl _key = "RADIANCE_PRINT_CURL" + DisableStdout _key = "RADIANCE_DISABLE_STDOUT_LOG" + ENV _key = "RADIANCE_ENV" + UseSocks _key = "RADIANCE_USE_SOCKS_PROXY" + SocksAddress _key = "RADIANCE_SOCKS_ADDRESS" + Country _key = "RADIANCE_COUNTRY" + FeatureOverrides _key = "RADIANCE_FEATURE_OVERRIDES" - Testing Key = "RADIANCE_TESTING" -) + Testing _key = "RADIANCE_TESTING" -var ( - keys = []Key{ - LogLevel, - LogPath, - DataPath, - DisableFetch, - PrintCurl, - DisableStdout, - SocksAddress, - UseSocks, - ENV, - } - envVars = map[string]any{} + dotenv = map[string]string{} ) +func (k _key) String() string { return string(k) } + func init() { buf, err := os.ReadFile(".env") if err != nil && !errors.Is(err, fs.ErrNotExist) { @@ -61,55 +50,51 @@ func init() { if len(parts) == 2 { key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) - parseAndSet(key, value) + dotenv[key] = value } } } - - // Check for environment variables and populate envVars, overriding any values from the .env file - for _, key := range keys { - if value, exists := os.LookupEnv(key); exists { - parseAndSet(key, value) - } - } if testing.Testing() { - envVars[Testing] = true - envVars[LogLevel] = "DISABLE" - slog.SetLogLoggerLevel(internal.Disable) + dotenv[Testing.String()] = "true" + dotenv[LogLevel.String()] = "disable" } } -// Get retrieves the value associated with the given key and attempts to cast it to type T. If the -// key does not exist or the type does not match, it returns the zero value of T and false. -func Get[T any](key Key) (T, bool) { - if value, exists := envVars[key]; exists { - if v, ok := value.(T); ok { - return v, true - } +func Get(key _key) (string, bool) { + if value, exists := dotenv[key.String()]; exists { + return value, true } - var zero T - return zero, false + if value, exists := os.LookupEnv(key.String()); exists { + return value, true + } + return "", false } -// SetStagingEnv sets the environment to staging if it has not already been set. -// This is used for testing that need to interact with staging services, -func SetStagingEnv() { - slog.Info("setting environment to staging for testing") - envVars[ENV] = "staging" - envVars[PrintCurl] = true +func GetString(key _key) string { + value, _ := Get(key) + return value } -func parseAndSet(key, value string) { - // Attempt to parse as a boolean - if b, err := strconv.ParseBool(value); err == nil { - envVars[key] = b - return +func GetBool(key _key) bool { + value, exists := Get(key) + if !exists { + return false } - // Attempt to parse as an integer - if i, err := strconv.Atoi(value); err == nil { - envVars[key] = i - return + v, _ := strconv.ParseBool(value) + return v +} + +func GetInt(key _key) int { + value, exists := Get(key) + if !exists { + return 0 } - // Otherwise, store as a string - envVars[key] = value + v, _ := strconv.Atoi(value) + return v +} + +func SetStagingEnv() { + slog.Info("setting environment to staging for testing") + dotenv[ENV.String()] = "staging" + dotenv[PrintCurl.String()] = "true" } diff --git a/backend/headers.go b/common/headers.go similarity index 72% rename from backend/headers.go rename to common/headers.go index 537c59ee..b56fb5cf 100644 --- a/backend/headers.go +++ b/common/headers.go @@ -1,4 +1,4 @@ -package backend +package common import ( "context" @@ -6,12 +6,10 @@ import ( "io" "math/big" "net/http" - "strconv" "time" "github.com/getlantern/timezone" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" ) @@ -43,11 +41,11 @@ func NewRequestWithHeaders(ctx context.Context, method, url string, body io.Read // based on consistent packet lengths. req.Header.Add(RandomNoiseHeader, randomizedString()) - req.Header.Set(AppVersionHeader, common.Version) - req.Header.Set(VersionHeader, common.Version) - req.Header.Set(UserIDHeader, strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - req.Header.Set(PlatformHeader, common.Platform) - req.Header.Set(AppNameHeader, common.Name) + req.Header.Set(AppVersionHeader, Version) + req.Header.Set(VersionHeader, Version) + req.Header.Set(UserIDHeader, settings.GetString(settings.UserIDKey)) + req.Header.Set(PlatformHeader, Platform) + req.Header.Set(AppNameHeader, Name) req.Header.Set(DeviceIDHeader, settings.GetString(settings.DeviceIDKey)) if tz, err := timezone.IANANameForTime(time.Now()); err == nil { req.Header.Set(TimeZoneHeader, tz) @@ -55,21 +53,6 @@ func NewRequestWithHeaders(ctx context.Context, method, url string, body io.Read return req, nil } -// NewIssueRequest creates a new HTTP request with the required headers for issue reporting. -func NewIssueRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { - req, err := NewRequestWithHeaders(ctx, method, url, body) - if err != nil { - return nil, err - } - - req.Header.Set("content-type", "application/x-protobuf") - - // data caps - req.Header.Set(SupportedDataCapsHeader, "monthly,weekly,daily") - - return req, nil -} - // randomizedString returns a random string to avoid consistent packet lengths censors // may use to detect Lantern. func randomizedString() string { diff --git a/common/init.go b/common/init.go index e378df11..5a6118de 100644 --- a/common/init.go +++ b/common/init.go @@ -3,25 +3,22 @@ package common import ( "fmt" - "io" "log/slog" "os" "path/filepath" - "runtime" "runtime/debug" "strings" "sync/atomic" - "time" "unicode" "unicode/utf8" "github.com/getlantern/appdir" - "gopkg.in/natefinch/lumberjack.v2" "github.com/getlantern/radiance/common/env" "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) var ( @@ -29,50 +26,40 @@ var ( ) func Env() string { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) - return e + return strings.ToLower(env.GetString(env.ENV)) } // Prod returns true if the application is running in production environment. // Treating ENV == "" as production is intentional: if RADIANCE_ENV is unset, // we default to production mode to ensure the application runs with safe, non-debug settings. func Prod() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "production" || e == "prod" || e == "" } // Dev returns true if the application is running in development environment. func Dev() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "development" || e == "dev" } // Stage returns true if the application is running in staging environment. func Stage() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "stage" || e == "staging" } +func init() { + if env.GetBool(env.Testing) { + slog.SetDefault(log.NoOpLogger()) + slog.SetLogLoggerLevel(log.Disable) + } +} + // Init initializes the common components of the application. This includes setting up the directories // for data and logs, initializing the logger, and setting up reporting. func Init(dataDir, logDir, logLevel string) error { slog.Info("Initializing common package") - return initialize(dataDir, logDir, logLevel, false) -} - -// InitReadOnly locates the settings file in provided directory and initializes the common components -// in read-only mode using the necessary settings from the settings file. This is used in contexts -// where settings should not be modified, such as in the IPC server or other auxiliary processes. -func InitReadOnly(dataDir, logDir, logLevel string) error { - slog.Info("Initializing in read-only") - return initialize(dataDir, logDir, logLevel, true) -} - -func initialize(dataDir, logDir, logLevel string, readonly bool) error { if initialized.Swap(true) { return nil } @@ -82,32 +69,22 @@ func initialize(dataDir, logDir, logLevel string, readonly bool) error { if err != nil { return fmt.Errorf("failed to setup directories: %w", err) } - if readonly { - // in read-only mode, favor settings from the settings file if given parameters are empty - if logDir == "" && settings.GetString(settings.LogPathKey) != "" { - logs = settings.GetString(settings.LogPathKey) - } - if settings.GetString(settings.LogLevelKey) != "" { - logLevel = settings.GetString(settings.LogLevelKey) - } - } - err = initLogger(filepath.Join(logs, LogFileName), logLevel) - if err != nil { - slog.Error("Error initializing logger", "error", err) - return fmt.Errorf("initialize log: %w", err) - } - if readonly { - settings.SetReadOnly(true) - if err := settings.StartWatching(); err != nil { - return fmt.Errorf("start watching settings file: %w", err) - } - } else { - settings.Set(settings.DataPathKey, data) - settings.Set(settings.LogPathKey, logs) - settings.Set(settings.LogLevelKey, logLevel) + if err = settings.InitSettings(data); err != nil { + return fmt.Errorf("failed to initialize settings: %w", err) } + settings.Set(settings.DataPathKey, data) + settings.Set(settings.LogPathKey, logs) + settings.Set(settings.LogLevelKey, logLevel) + + logger := log.NewLogger(log.Config{ + LogPath: filepath.Join(logs, internal.LogFileName), + Level: logLevel, + Prod: Prod(), + }) + slog.SetDefault(logger) + slog.Info("Using data and log directories", "dataDir", data, "logDir", logs) createCrashReporter() if Dev() { @@ -133,7 +110,7 @@ func logModuleInfo() { } func createCrashReporter() { - crashFilePath := filepath.Join(settings.GetString(settings.LogPathKey), "lantern_crash.log") + crashFilePath := filepath.Join(settings.GetString(settings.LogPathKey), internal.CrashLogFileName) f, err := os.OpenFile(crashFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { slog.Error("Failed to open crash log file", "error", err) @@ -144,147 +121,16 @@ func createCrashReporter() { } } -// initLogger reconfigures the default slog.Logger to write to a file and stdout and sets the log level. -// The log level is determined, first by the environment variable if set and valid, then by the provided level. -// If both are invalid and/or not set, it defaults to "info". -func initLogger(logPath, level string) error { - if elevel, hasLevel := env.Get[string](env.LogLevel); hasLevel { - level = elevel - } - var lvl slog.Level - if level != "" { - var err error - lvl, err = internal.ParseLogLevel(level) - if err != nil { - slog.Warn("Failed to parse log level", "error", err) - } else { - slog.SetLogLoggerLevel(lvl) - } - } - if lvl == internal.Disable { - return nil - } - - // lumberjack will create the log file if it does not exist with permissions 0600 otherwise it - // carries over the existing permissions. So we create it here with 0644 so we don't need root/admin - // privileges or chown/chmod to read it. - f, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - slog.Warn("Failed to pre-create log file", "error", err, "path", logPath) - } else { - f.Close() - } - - logRotator := &lumberjack.Logger{ - Filename: logPath, // Log file path - MaxSize: 25, // Rotate log when it reaches 25 MB - MaxBackups: 2, // Keep up to 2 rotated log files - MaxAge: 30, // Retain old log files for up to 30 days - Compress: Prod(), // Compress rotated log files - } - - loggingToStdOut := true - var logWriter io.Writer - if noStdout, _ := env.Get[bool](env.DisableStdout); noStdout { - logWriter = logRotator - loggingToStdOut = false - } else if isWindowsProd() { - // For some reason, logging to both stdout and a file on Windows - // causes issues with some Windows services where the logs - // do not get written to the file. So in prod mode on Windows, - // we log to file only. See: - // https://www.reddit.com/r/golang/comments/1fpo3cg/golang_windows_service_cannot_write_log_files/ - logWriter = logRotator - loggingToStdOut = false - } else { - logWriter = io.MultiWriter(os.Stdout, logRotator) - } - runtime.AddCleanup(&logWriter, func(f *os.File) { - f.Close() - }, f) - logger := slog.New(slog.NewTextHandler(logWriter, &slog.HandlerOptions{ - AddSource: true, - Level: lvl, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - switch a.Key { - case slog.TimeKey: - if t, ok := a.Value.Any().(time.Time); ok { - a.Value = slog.StringValue(t.UTC().Format("2006-01-02 15:04:05.000 UTC")) - } - return a - case slog.SourceKey: - source, ok := a.Value.Any().(*slog.Source) - if !ok { - return a - } - // remove github.com/ to get pkg name - var service, fn string - fields := strings.SplitN(source.Function, "/", 4) - switch len(fields) { - case 0, 1, 2: - file := filepath.Base(source.File) - a.Value = slog.StringValue(fmt.Sprintf("%s:%d", file, source.Line)) - return a - case 3: - pf := strings.SplitN(fields[2], ".", 2) - service, fn = pf[0], pf[1] - default: - service = fields[2] - fn = strings.SplitN(fields[3], ".", 2)[1] - } - - _, file, fnd := strings.Cut(source.File, service+"/") - if !fnd { - file = filepath.Base(source.File) - } - src := slog.GroupValue( - slog.String("func", fn), - slog.String("file", fmt.Sprintf("%s:%d", file, source.Line)), - ) - a.Value = slog.GroupValue( - slog.String("service", service), - slog.Any("source", src), - ) - a.Key = "" - case slog.LevelKey: - // format the log level to account for the custom levels defined in internal/util.go, i.e. trace - // otherwise, slog will print as "DEBUG-4" (trace) or similar - level := a.Value.Any().(slog.Level) - a.Value = slog.StringValue(internal.FormatLogLevel(level)) - } - return a - }, - })) - if !loggingToStdOut { - if IsWindows() { - fmt.Printf("Logging to file only on Windows prod -- run with RADIANCE_ENV=dev to enable stdout path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } else { - fmt.Printf("Logging to file only -- RADIANCE_DISABLE_STDOUT_LOG is set path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } - } else { - fmt.Printf("Logging to file and stdout path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } - slog.SetDefault(logger) - return nil -} - -func isWindowsProd() bool { - if !IsWindows() { - return false - } - return !Dev() -} - // setupDirectories creates the data and logs directories, and needed subdirectories if they do // not exist. If data or logs are the empty string, it will use the user's config directory retrieved // from the OS. func setupDirectories(data, logs string) (dataDir, logDir string, err error) { - if d, ok := env.Get[string](env.DataPath); ok { + if d, ok := env.Get(env.DataPath); ok { data = d } else if data == "" { data = outDir("data") } - if l, ok := env.Get[string](env.LogPath); ok { + if l, ok := env.Get(env.LogPath); ok { logs = l } else if logs == "" { logs = outDir("logs") @@ -296,9 +142,6 @@ func setupDirectories(data, logs string) (dataDir, logDir string, err error) { return data, logs, fmt.Errorf("failed to create directory %s: %w", path, err) } } - if err := settings.InitSettings(data); err != nil { - return data, logs, fmt.Errorf("failed to initialize settings: %w", err) - } return data, logs, nil } diff --git a/common/settings/settings.go b/common/settings/settings.go index 75cc682b..ff481a86 100644 --- a/common/settings/settings.go +++ b/common/settings/settings.go @@ -1,3 +1,4 @@ +// Package settings provides a simple interface for storing and retrieving user settings. package settings import ( @@ -9,7 +10,6 @@ import ( "path/filepath" "strings" "sync" - "sync/atomic" "time" "github.com/knadh/koanf/parsers/json" @@ -17,37 +17,54 @@ import ( "github.com/knadh/koanf/v2" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/internal" ) -// Keys for various settings. +type _key string + const ( - CountryCodeKey = "country_code" - LocaleKey = "locale" - DeviceIDKey = "device_id" - DataPathKey = "data_path" - LogPathKey = "log_path" - EmailKey = "email" - UserLevelKey = "user_level" - TokenKey = "token" - JwtTokenKey = "jwt_token" - UserIDKey = "user_id" - DevicesKey = "devices" - LogLevelKey = "log_level" - LoginResponseKey = "login_response" - SmartRoutingKey = "smart_routing" - AdBlockKey = "ad_block" - filePathKey = "file_path" - - settingsFileName = "local.json" + // Keys for various settings. + // General settings keys. + DataPathKey _key = "data_path" // string + LogPathKey _key = "log_path" // string + LogLevelKey _key = "log_level" // string + CountryCodeKey _key = "country_code" // string + LocaleKey _key = "locale" // string + DeviceIDKey _key = "device_id" // string/int + + // Application behavior related keys. + TelemetryKey _key = "telemetry_enabled" // bool + ConfigFetchDisabledKey _key = "config_fetch_disabled" // bool + + // User account related keys. + EmailKey _key = "email" // string + UserIDKey _key = "user_id" // string + UserLevelKey _key = "user_level" // string + TokenKey _key = "token" // string + JwtTokenKey _key = "jwt_token" // string + DevicesKey _key = "devices" // []Device + UserDataKey _key = "user_data" // [account.UserData] + OAuthLoginKey _key = "oauth_login" // bool + + // VPN related keys. + SmartRoutingKey _key = "smart_routing" // bool + SplitTunnelKey _key = "split_tunnel" // bool + AdBlockKey _key = "ad_block" // bool + AutoConnectKey _key = "auto_connect" // bool + SelectedServerKey _key = "selected_server" // [servers.Server] Server.Options is not stored + + PreferredLocationKey _key = "preferred_location" // [common.PreferredLocation] + + settingsFileName = "settings.json" ) +var ErrNotExist = errors.New("key does not exist") + +func (k _key) String() string { return string(k) } + type settings struct { k *koanf.Koanf - readOnly atomic.Bool initialized bool - watcher *internal.FileWatcher + filePath string mu sync.Mutex } @@ -55,60 +72,38 @@ var k = &settings{ k: koanf.New("."), } -var ErrReadOnly = errors.New("read-only") +func init() { + // set default values. + k.k.Set(LocaleKey.String(), "fa-IR") + k.k.Set(UserLevelKey.String(), "free") +} -// InitSettings initializes the config for user settings, which can be used by both the tunnel process and -// the main application process to read user preferences like locale. +// InitSettings initializes the config for user settings. func InitSettings(fileDir string) error { k.mu.Lock() defer k.mu.Unlock() if k.initialized { return nil } - if err := initialize(fileDir); err != nil { - return fmt.Errorf("initializing settings: %w", err) - } - k.initialized = true - return nil -} - -func initialize(fileDir string) error { - k.k = koanf.New(".") if err := os.MkdirAll(fileDir, 0755); err != nil { return fmt.Errorf("failed to create data directory: %v", err) } - filePath := filepath.Join(fileDir, settingsFileName) - switch err := loadSettings(filePath); { + k.filePath = filepath.Join(fileDir, settingsFileName) + switch err := loadSettings(k.filePath); { case errors.Is(err, fs.ErrNotExist): - slog.Warn("settings file not found", "path", filePath) // file may not have been created yet - if err := setDefaults(filePath); err != nil { - return fmt.Errorf("setting default settings: %w", err) - } + slog.Warn("settings file not found", "path", k.filePath) // file may not have been created yet return save() case err != nil: return fmt.Errorf("loading settings: %w", err) } - return nil -} - -func setDefaults(filePath string) error { - // We need to set the file path first because the save function reads it as soon as we set any key. - if err := k.k.Set(filePathKey, filePath); err != nil { - return fmt.Errorf("failed to set file path: %w", err) - } - if err := k.k.Set(LocaleKey, "fa-IR"); err != nil { - return fmt.Errorf("failed to set default locale: %w", err) - } - if err := k.k.Set(UserLevelKey, "free"); err != nil { - return fmt.Errorf("failed to set default user level: %w", err) - } + k.initialized = true return nil } func loadSettings(path string) error { contents, err := atomicfile.ReadFile(path) if err != nil { - return fmt.Errorf("loading settings (read-only): %w", err) + return fmt.Errorf("loading settings: %w", err) } kk := koanf.New(".") if err := kk.Load(rawbytes.Provider(contents), json.Parser()); err != nil { @@ -118,107 +113,106 @@ func loadSettings(path string) error { return nil } -func SetReadOnly(readOnly bool) { - k.readOnly.Store(readOnly) +func Get(key _key) any { + return k.k.Get(key.String()) } -func StartWatching() error { - k.mu.Lock() - defer k.mu.Unlock() - if !k.initialized { - return errors.New("settings not initialized") - } - if k.watcher != nil { - return errors.New("settings file watcher already started") - } +func GetString(key _key) string { + return k.k.String(key.String()) +} - path := k.k.String(filePathKey) - watcher := internal.NewFileWatcher(path, func() { - if err := loadSettings(path); err != nil { - slog.Error("reloading settings file", "error", err) - } - }) - if err := watcher.Start(); err != nil { - return fmt.Errorf("starting settings file watcher: %w", err) - } - k.watcher = watcher - // reload settings once at start in case there were changes before we started watching - if err := loadSettings(path); err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - return nil +func GetBool(key _key) bool { + return k.k.Bool(key.String()) } -// StopWatching stops watching the settings file for changes. This is only relevant in read-only mode. -func StopWatching() { - k.mu.Lock() - defer k.mu.Unlock() - if k.watcher != nil { - k.watcher.Close() - k.watcher = nil - } +func GetInt(key _key) int { + return k.k.Int(key.String()) } -func Get(key string) any { - return k.k.Get(key) +func GetInt64(key _key) int64 { + return k.k.Int64(key.String()) } -func GetString(key string) string { - return k.k.String(key) +func GetFloat64(key _key) float64 { + return k.k.Float64(key.String()) } -func GetBool(key string) bool { - return k.k.Bool(key) +func GetStringSlice(key _key) []string { + return k.k.Strings(key.String()) } -func GetInt(key string) int { - return k.k.Int(key) +func GetDuration(key _key) time.Duration { + return k.k.Duration(key.String()) } -func GetInt64(key string) int64 { - return k.k.Int64(key) +func GetStruct(key _key, out any) error { + return k.k.Unmarshal(key.String(), out) } -func GetFloat64(key string) float64 { - return k.k.Float64(key) +func Exists(key _key) bool { + return k.k.Exists(key.String()) } -func GetStringSlice(key string) []string { - return k.k.Strings(key) +func Set(key _key, value any) error { + err := k.k.Set(key.String(), value) + if err != nil { + return fmt.Errorf("could not set key %s: %w", key, err) + } + return save() } -func GetDuration(key string) time.Duration { - return k.k.Duration(key) +func Clear(key _key) { + k.k.Delete(key.String()) } -func GetStruct(key string, out any) error { - return k.k.Unmarshal(key, out) +type Settings map[_key]any + +func (s Settings) Diff(s2 Settings) Settings { + diff := make(Settings) + for k, v1 := range s { + if v2, ok := s2[k]; !ok || v1 != v2 { + diff[k] = v1 + } + } + return diff } -func Set(key string, value any) error { - if k.readOnly.Load() { - return ErrReadOnly +func GetAll() Settings { + s := make(Settings) + for key, value := range k.k.All() { + s[_key(key)] = value } - err := k.k.Set(key, value) - if err != nil { - return fmt.Errorf("could not set key %s: %w", key, err) + return s +} + +func GetAllFor(keys ..._key) Settings { + if len(keys) == 0 { + return GetAll() + } + s := make(Settings) + for _, key := range keys { + s[key] = k.k.Get(key.String()) + } + return s +} + +// Patch takes a map of settings to update and applies them all at once. +func Patch(updates Settings) error { + for key, value := range updates { + if err := k.k.Set(_key(key).String(), value); err != nil { + return fmt.Errorf("could not set key %s: %w", key, err) + } } return save() } func save() error { - if k.readOnly.Load() { - return ErrReadOnly - } - if GetString(filePathKey) == "" { - return errors.New("settings file path is not set") - } out, err := k.k.Marshal(json.Parser()) if err != nil { return fmt.Errorf("could not marshal koanf file: %w", err) } - err = atomicfile.WriteFile(GetString(filePathKey), out, 0644) + err = atomicfile.WriteFile(k.filePath, out, 0644) if err != nil { return fmt.Errorf("could not write koanf file: %w", err) } @@ -229,14 +223,8 @@ func save() error { func Reset() { k.mu.Lock() defer k.mu.Unlock() - if !k.readOnly.Load() { - if k.watcher != nil { - k.watcher.Close() - k.watcher = nil - } - k.k = koanf.New(".") - k.initialized = false - } + k.k = koanf.New(".") + k.initialized = false } func IsPro() bool { @@ -254,7 +242,3 @@ func Devices() ([]Device, error) { err := GetStruct(DevicesKey, &devices) return devices, err } - -type UserChangeEvent struct { - events.Event -} diff --git a/common/settings/settings_test.go b/common/settings/settings_test.go index 21f16bd2..585205c2 100644 --- a/common/settings/settings_test.go +++ b/common/settings/settings_test.go @@ -5,190 +5,28 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "github.com/getlantern/radiance/common/env" ) func TestInitSettings(t *testing.T) { - t.Run("first run - no config file exists", func(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - // Verify default locale was set - locale := Get(LocaleKey) - if locale != "fa-IR" { - t.Errorf("expected default locale 'fa-IR', got %s", locale) - } - }) - t.Run("existing valid config file", func(t *testing.T) { - // Create a temporary directory tempDir := t.TempDir() + path := filepath.Join(tempDir, settingsFileName) + content := []byte(`{"locale": "en-US", "country_code": "US"}`) + require.NoError(t, os.WriteFile(path, content, 0644), "failed to create test config file") - // Create a valid config file - configPath := filepath.Join(tempDir, "local.json") - configContent := []byte(`{"locale": "en-US", "country_code": "US"}`) - if err := os.WriteFile(configPath, configContent, 0644); err != nil { - t.Fatalf("failed to create test config file: %v", err) - } - - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - // Verify config was loaded - locale := Get(LocaleKey) - if locale != "en-US" { - t.Errorf("expected locale 'en-US', got %s", locale) - } - - countryCode := Get(CountryCodeKey) - if countryCode != "US" { - t.Errorf("expected country_code 'US', got %s", countryCode) - } + require.NoError(t, InitSettings(tempDir), "failed to initialize settings") + assert.Equal(t, "en-US", Get(LocaleKey)) + assert.Equal(t, "US", Get(CountryCodeKey)) }) t.Run("invalid config file", func(t *testing.T) { - // Create a temporary directory - tempDir := t.TempDir() - - // Create an invalid config file - configPath := filepath.Join(tempDir, "local.json") - configContent := []byte(`{invalid json}`) - if err := os.WriteFile(configPath, configContent, 0644); err != nil { - t.Fatalf("failed to create test config file: %v", err) - } - - err := initialize(tempDir) - if err == nil { - t.Fatal("expected error for invalid config file, got nil") - } - }) - - t.Run("non-existent directory", func(t *testing.T) { - // Use a non-existent directory - nonExistentDir := filepath.Join(os.TempDir(), "non-existent-dir-123456789") - - err := initialize(nonExistentDir) - if err != nil { - t.Fatalf("expected no error for non-existent directory (first run), got %v", err) - } - }) -} - -func TestSetStruct(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error initializing settings, got %v", err) - } - - err = Set("testStruct", struct { - Field1 string - Field2 int - }{ - Field1: "value1", - Field2: 42, + path := filepath.Join(t.TempDir(), settingsFileName) + content := []byte(`{invalid json}`) + require.NoError(t, os.WriteFile(path, content, 0644), "failed to create test config file") + require.Error(t, loadSettings(path), "expected error for invalid config file") }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - var result struct { - Field1 string - Field2 int - } - err = GetStruct("testStruct", &result) - if err != nil { - t.Fatalf("expected no error retrieving struct, got %v", err) - } - - if result.Field1 != "value1" || result.Field2 != 42 { - t.Errorf("expected struct {Field1: 'value1', Field2: 42}, got %+v", result) - } - - // Reset koanf state - Reset() - result.Field1 = "" - result.Field2 = 0 - - // At first, the struct should not be present. - err = GetStruct("testStruct", &result) - if err != nil { - t.Fatalf("expected no error retrieving struct, got %v", err) - } - - if result.Field1 != "" || result.Field2 != 0 { - t.Errorf("expected struct {Field1: '', Field2: 0}, got %+v", result) - } - - err = initialize(tempDir) - if err != nil { - t.Fatalf("expected no error re-initializing settings, got %v", err) - } - - var result2 struct { - Field1 string - Field2 int - } - err = GetStruct("testStruct", &result2) - if err != nil { - t.Fatalf("expected no error retrieving struct after re-init, got %v", err) - } - - if result2.Field1 != "value1" || result2.Field2 != 42 { - t.Errorf("expected struct {Field1: 'value1', Field2: 42} after re-init, got %+v", result2) - } -} - -func TestStructSlicePersistence(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error initializing settings, got %v", err) - } - - type Item struct { - Name string - Value int - } - - items := []Item{ - {Name: "item1", Value: 1}, - {Name: "item2", Value: 2}, - } - - err = Set("itemList", items) - if err != nil { - t.Fatalf("expected no error setting struct slice, got %v", err) - } - - var retrievedItems []Item - err = GetStruct("itemList", &retrievedItems) - if err != nil { - t.Fatalf("expected no error retrieving struct slice, got %v", err) - } - - if len(retrievedItems) != 2 || retrievedItems[0].Name != "item1" || retrievedItems[1].Value != 2 { - t.Errorf("retrieved struct slice does not match expected values: %+v", retrievedItems) - } - - retrievedItems = nil - err = initialize(tempDir) - if err != nil { - t.Fatalf("expected no error re-initializing settings, got %v", err) - } - - var retrievedItems2 []Item - err = GetStruct("itemList", &retrievedItems2) - if err != nil { - t.Fatalf("expected no error retrieving struct slice after re-init, got %v", err) - } - - if len(retrievedItems2) != 2 || retrievedItems2[0].Name != "item1" || retrievedItems2[1].Value != 2 { - t.Errorf("retrieved struct slice after re-init does not match expected values: %+v", retrievedItems2) - } } diff --git a/common/types.go b/common/types.go new file mode 100644 index 00000000..fc1db8fb --- /dev/null +++ b/common/types.go @@ -0,0 +1,7 @@ +package common + +import ( + C "github.com/getlantern/common" +) + +type PreferredLocation = C.ServerLocation diff --git a/config/config.go b/config/config.go index 18473d7d..4b7880c5 100644 --- a/config/config.go +++ b/config/config.go @@ -10,6 +10,7 @@ import ( "fmt" "io/fs" "log/slog" + "net/http" "os" "path/filepath" "reflect" @@ -27,21 +28,18 @@ import ( box "github.com/getlantern/lantern-box" lbO "github.com/getlantern/lantern-box/option" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/api" + "github.com/getlantern/radiance/account" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/internal" "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/traces" ) const ( - maxRetryDelay = 2 * time.Minute + maxRetryDelay = 2 * time.Minute + defaultPollInterval = 10 * time.Minute ) var ( @@ -50,84 +48,91 @@ var ( ErrFetchingConfig = errors.New("failed to fetch config") ) -// Config includes all configuration data from the Lantern API as well as any stored local preferences. -type Config struct { - ConfigResponse C.ConfigResponse - PreferredLocation C.ServerLocation -} +// Config includes all configuration data from the Lantern API +type Config = C.ConfigResponse type ServerManager interface { - SetServers(serverGroup string, opts servers.Options) error + SetServers(serverGroup servers.ServerGroup, opts servers.Options) error } -// ListenerFunc is a function that is called when the configuration changes. -type ListenerFunc func(oldConfig, newConfig *Config) error - type Options struct { - PollInterval time.Duration - SvrManager ServerManager - DataDir string - Locale string - APIHandler *api.APIClient + PollInterval time.Duration + DataPath string + Locale string + AccountClient *account.Client + Logger *slog.Logger + HTTPClient *http.Client } // ConfigHandler handles fetching the proxy configuration from the proxy server. It provides access // to the most recent configuration. type ConfigHandler struct { // config holds a configResult. - config atomic.Pointer[Config] - ftr Fetcher - svrManager ServerManager - - ctx context.Context - cancel context.CancelFunc - fetchDisabled bool - configPath string - wgKeyPath string - preferredLocation atomic.Pointer[C.ServerLocation] - configMu sync.RWMutex + config atomic.Pointer[Config] + ftr Fetcher + logger *slog.Logger + options Options + + ctx context.Context + cancel context.CancelFunc + fetchDisabled bool + pollInterval time.Duration + configPath string + wgKeyPath string + configMu sync.RWMutex + startOnce sync.Once } // NewConfigHandler creates a new ConfigHandler that fetches the proxy configuration every pollInterval. -func NewConfigHandler(options Options) *ConfigHandler { - configPath := filepath.Join(options.DataDir, common.ConfigFileName) - ctx, cancel := context.WithCancel(context.Background()) +func NewConfigHandler(ctx context.Context, options Options) *ConfigHandler { + ctx, cancel := context.WithCancel(ctx) + pollInterval := options.PollInterval + if pollInterval == 0 { + pollInterval = defaultPollInterval + } + logger := options.Logger + if logger == nil { + logger = slog.Default() + } + dir := options.DataPath ch := &ConfigHandler{ - fetchDisabled: options.PollInterval <= 0, + fetchDisabled: pollInterval < 0, ctx: ctx, cancel: cancel, - configPath: configPath, - wgKeyPath: filepath.Join(options.DataDir, "wg.key"), - svrManager: options.SvrManager, + pollInterval: pollInterval, + configPath: filepath.Join(dir, internal.ConfigFileName), + wgKeyPath: filepath.Join(dir, "wg.key"), + logger: logger, + options: options, } - // Set the preferred location to an empty struct to define the underlying type. - ch.preferredLocation.Store(&C.ServerLocation{}) - - if err := os.MkdirAll(filepath.Dir(options.DataDir), 0o755); err != nil { - slog.Error("creating config directory", "error", err) + if err := os.MkdirAll(dir, 0o755); err != nil { + ch.logger.Error("creating config directory", "error", err) } - if err := ch.loadConfig(); err != nil { - slog.Error("failed to load config", "error", err) - } - - if !ch.fetchDisabled { - ch.ftr = newFetcher(options.Locale, options.APIHandler) - go ch.fetchLoop(options.PollInterval) - events.Subscribe(func(evt settings.UserChangeEvent) { - slog.Debug("User change detected that requires config refetch") - if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config", "error", err) - } - }) + ch.logger.Error("failed to load config", "error", err) } return ch } +func (ch *ConfigHandler) Start() { + ch.startOnce.Do(func() { + if !ch.fetchDisabled { + ch.ftr = newFetcher(ch.options.Locale, ch.options.AccountClient, ch.options.HTTPClient) + go ch.fetchLoop(ch.pollInterval) + events.Subscribe(func(evt account.UserChangeEvent) { + ch.logger.Debug("User change detected that requires config refetch") + if err := ch.fetchConfig(); err != nil { + ch.logger.Error("Failed to fetch config", "error", err) + } + }) + } + }) +} + var ErrNoWGKey = errors.New("no wg key") func (ch *ConfigHandler) loadWGKey() (wgtypes.Key, error) { - buf, err := os.ReadFile(ch.wgKeyPath) + buf, err := atomicfile.ReadFile(ch.wgKeyPath) if os.IsNotExist(err) { return wgtypes.Key{}, ErrNoWGKey } @@ -141,25 +146,6 @@ func (ch *ConfigHandler) loadWGKey() (wgtypes.Key, error) { return key, nil } -// SetPreferredServerLocation sets the preferred server location to connect to -func (ch *ConfigHandler) SetPreferredServerLocation(country, city string) { - preferred := &C.ServerLocation{ - Country: country, - City: city, - } - // We store the preferred location in memory in case we haven't fetched a config yet. - ch.preferredLocation.Store(preferred) - ch.modifyConfig(func(cfg *Config) { - cfg.PreferredLocation = *preferred - }) - // fetch the config with the new preferred location on a separate goroutine - go func() { - if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config: %v", "error", err) - } - }() -} - func (ch *ConfigHandler) fetchConfig() error { if ch.fetchDisabled { return fmt.Errorf("fetching config is disabled") @@ -167,17 +153,6 @@ func (ch *ConfigHandler) fetchConfig() error { if ch.isClosed() { return fmt.Errorf("config handler is closed") } - var preferred C.ServerLocation - oldConfig, err := ch.GetConfig() - if err != nil { - slog.Info("No stored config yet -- using in-memory server location", "error", err) - storedLocation := ch.preferredLocation.Load() - if storedLocation != nil { - preferred = *storedLocation - } - } else { - preferred = oldConfig.PreferredLocation - } privateKey, err := ch.loadWGKey() if err != nil && !errors.Is(err, ErrNoWGKey) { @@ -190,25 +165,30 @@ func (ch *ConfigHandler) fetchConfig() error { return fmt.Errorf("failed to generate wg keys: %w", keyErr) } - if writeErr := os.WriteFile(ch.wgKeyPath, []byte(privateKey.String()), 0o600); writeErr != nil { + if writeErr := atomicfile.WriteFile(ch.wgKeyPath, []byte(privateKey.String()), 0o600); writeErr != nil { return fmt.Errorf("writing wg key file: %w", writeErr) } } - slog.Info("Fetching config") + ch.logger.Info("Fetching config") + preferred := common.PreferredLocation{} + if err := settings.GetStruct(settings.PreferredLocationKey, &preferred); err != nil { + ch.logger.Error("failed to get preferred location from settings", "error", err) + } + resp, err := ch.ftr.fetchConfig(ch.ctx, preferred, privateKey.PublicKey().String()) if err != nil { return fmt.Errorf("%w: %w", ErrFetchingConfig, err) } if resp == nil { - slog.Info("no new config available") + ch.logger.Info("no new config available") return nil } - slog.Info("Config fetched from server") + ch.logger.Info("Config fetched from server") // Save the raw config for debugging - if writeErr := os.WriteFile(strings.TrimSuffix(ch.configPath, ".json")+"_raw.json", resp, 0o600); writeErr != nil { - slog.Error("writing raw config file", "error", writeErr) + if writeErr := atomicfile.WriteFile(strings.TrimSuffix(ch.configPath, ".json")+"_raw.json", resp, 0o600); writeErr != nil { + ch.logger.Error("writing raw config file", "error", writeErr) } // Otherwise, we keep the previous config and store any error that might have occurred. @@ -218,68 +198,21 @@ func (ch *ConfigHandler) fetchConfig() error { // On the other hand, if we have a new config, we want to overwrite any previous error. confResp, err := singjson.UnmarshalExtendedContext[C.ConfigResponse](box.BaseContext(), resp) if err != nil { - slog.Error("failed to parse config", "error", err) + ch.logger.Error("failed to parse config", "error", err) return fmt.Errorf("parsing config: %w", err) } cleanTags(&confResp) if err = setWireGuardKeyInOptions(confResp.Options.Endpoints, privateKey); err != nil { - slog.Error("failed to replace private key", "error", err) + ch.logger.Error("failed to replace private key", "error", err) return fmt.Errorf("setting wireguard private key: %w", err) } setCustomProtocolOptions(confResp.Options.Outbounds) - if err := ch.setConfig(&Config{ConfigResponse: confResp}); err == nil { - cfg := ch.config.Load().ConfigResponse - locs := make(map[string]C.ServerLocation, len(cfg.OutboundLocations)+len(cfg.Servers)) - // Track which cities are already covered by active outbounds. - coveredCities := make(map[string]bool, len(cfg.OutboundLocations)) - for k, v := range cfg.OutboundLocations { - if v == nil { - slog.Warn("Server location is nil, skipping", "tag", k) - continue - } - locs[k] = *v - coveredCities[v.City+"|"+v.CountryCode] = true - } - // Include available server locations not already covered by active - // outbounds so the client's location picker shows every location. - for _, sl := range cfg.Servers { - if coveredCities[sl.City+"|"+sl.CountryCode] { - continue - } - key := strings.ToLower(strings.ReplaceAll(sl.City, " ", "-") + "-" + sl.CountryCode) - locs[key] = sl - } - opts := servers.Options{ - Outbounds: cfg.Options.Outbounds, - Endpoints: cfg.Options.Endpoints, - Locations: locs, - URLOverrides: cfg.BanditURLOverrides, - } - if len(cfg.BanditURLOverrides) > 0 { - slog.Info("Config includes bandit URL overrides", - "override_count", len(cfg.BanditURLOverrides), - "outbound_count", len(cfg.Options.Outbounds), - "endpoint_count", len(cfg.Options.Endpoints), - ) - // Create a marker span linked to the API's bandit trace so the - // config fetch appears in the same distributed trace as the callback. - if ctx, ok := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides); ok { - _, span := otel.Tracer(tracerName).Start(ctx, "radiance.config_received", - trace.WithAttributes( - attribute.Int("bandit.override_count", len(cfg.BanditURLOverrides)), - attribute.Int("bandit.outbound_count", len(cfg.Options.Outbounds)), - ), - ) - span.End() // point-in-time marker — config was received at this timestamp - } - } - if err := ch.svrManager.SetServers(servers.SGLantern, opts); err != nil { - slog.Error("setting servers in manager", "error", err) - } + if err := ch.setConfig(&confResp); err != nil { + ch.logger.Error("failed to set config", "error", err) + return fmt.Errorf("setting config: %w", err) } - - slog.Info("Config fetched") + ch.logger.Info("Config fetched") return nil } @@ -296,7 +229,6 @@ func setCustomProtocolOptions(outbounds []option.Outbound) { } } -// TODO: move this to lantern-cloud func cleanTags(cfg *C.ConfigResponse) { opts := cfg.Options locs := cfg.OutboundLocations @@ -339,7 +271,7 @@ func (ch *ConfigHandler) fetchLoop(pollInterval time.Duration) { backoff := common.NewBackoff(maxRetryDelay) for { if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config. Retrying", "error", err) + ch.logger.Error("Failed to fetch config. Retrying", "error", err) backoff.Wait(ch.ctx) if ch.ctx.Err() != nil { return @@ -372,7 +304,7 @@ func (ch *ConfigHandler) isClosed() bool { // loadConfig loads the config file from the disk. If the config file is not found, it returns // nil. func (ch *ConfigHandler) loadConfig() error { - slog.Debug("reading config file") + ch.logger.Debug("reading config file") cfg, err := Load(ch.configPath) if err != nil { return fmt.Errorf("reading config file: %w", err) @@ -392,14 +324,22 @@ func Load(path string) (*Config, error) { if err != nil { return nil, fmt.Errorf("reading config file: %w", err) } - cfg, err := unmarshalConfig(buf) + ctx := box.BaseContext() + cfg, err := singjson.UnmarshalExtendedContext[*Config](ctx, buf) + if err != nil { + // try to migrate from old format if parsing fails + // TODO(3/06, garmr-ulfr): remove this migration code after a few releases + if cfg, err = migrateToNewFmt(buf); err == nil { + saveConfig(cfg, path) + } + } if err != nil { return nil, fmt.Errorf("parsing config: %w", err) } return cfg, nil } -func unmarshalConfig(data []byte) (*Config, error) { +func migrateToNewFmt(data []byte) (*Config, error) { type T struct { ConfigResponse json.RawMessage PreferredLocation C.ServerLocation @@ -412,10 +352,8 @@ func unmarshalConfig(data []byte) (*Config, error) { if err != nil { return nil, err } - return &Config{ - ConfigResponse: opts, - PreferredLocation: tmp.PreferredLocation, - }, nil + settings.Set(settings.PreferredLocationKey, &tmp.PreferredLocation) + return &opts, nil } // saveConfig saves the config to the disk. It creates the config file if it doesn't exist. @@ -440,27 +378,20 @@ func (ch *ConfigHandler) GetConfig() (*Config, error) { } func (ch *ConfigHandler) setConfig(cfg *Config) error { - slog.Info("Setting config") + ch.logger.Info("Setting config") if cfg == nil { - slog.Warn("Config is nil, not setting") + ch.logger.Warn("Config is nil, not setting") return nil } oldConfig, _ := ch.GetConfig() - if cfg.PreferredLocation == (C.ServerLocation{}) { - storedLocation := ch.preferredLocation.Load() - if storedLocation != nil { - cfg.PreferredLocation = *storedLocation - } - } - ch.config.Store(cfg) - slog.Debug("Saving config", "path", ch.configPath) + ch.logger.Debug("Saving config", "path", ch.configPath) if err := saveConfig(cfg, ch.configPath); err != nil { - slog.Error("saving config", "error", err) + ch.logger.Error("saving config", "error", err) return fmt.Errorf("saving config: %w", err) } - slog.Info("saved new config") - slog.Info("Config set") + ch.logger.Info("saved new config") + ch.logger.Info("Config set") if !ch.isClosed() { emit(oldConfig, cfg) } @@ -479,21 +410,3 @@ func emit(old, new *Config) { events.Emit(NewConfigEvent{Old: old, New: new}) } } - -// modifyConfig saves the config to the disk with the given config. It creates the config file -// if it doesn't exist. -func (ch *ConfigHandler) modifyConfig(fn func(cfg *Config)) { - ch.configMu.Lock() - cfg, err := ch.GetConfig() - if err != nil { - // This could happen if we haven't successfully fetched the config yet. - slog.Error("getting config", "error", err) - ch.configMu.Unlock() - return - } - // Call the function with the config - // and save the config to the disk. - fn(cfg) - ch.configMu.Unlock() - ch.setConfig(cfg) -} diff --git a/config/config_test.go b/config/config_test.go index 2282d8cd..c66568ca 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -6,30 +6,28 @@ import ( "errors" "os" "path/filepath" - "sync/atomic" "testing" C "github.com/getlantern/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" ) func TestSaveConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Create a sample config to save expectedConfig := Config{ - ConfigResponse: C.ConfigResponse{ - // Populate with sample data - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, + // Populate with sample data + Servers: []C.ServerLocation{ + {Country: "US", City: "New York"}, + {Country: "UK", City: "London"}, }, } // Save the config @@ -50,7 +48,7 @@ func TestSaveConfig(t *testing.T) { func TestGetConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Create a ConfigHandler with the mock parser ch := &ConfigHandler{ @@ -67,11 +65,9 @@ func TestGetConfig(t *testing.T) { // Test case: Valid config set t.Run("ValidConfigSet", func(t *testing.T) { expectedConfig := &Config{ - ConfigResponse: C.ConfigResponse{ - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, + Servers: []C.ServerLocation{ + {Country: "US", City: "New York"}, + {Country: "UK", City: "London"}, }, } @@ -84,53 +80,10 @@ func TestGetConfig(t *testing.T) { }) } -func TestSetPreferredServerLocation(t *testing.T) { - // Setup temporary directory for testing - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) - - // Create a ConfigHandler with the mock parser - ctx, cancel := context.WithCancel(context.Background()) - ch := &ConfigHandler{ - configPath: configPath, - ftr: newFetcher("en-US", nil), - ctx: ctx, - cancel: cancel, - } - - ch.config.Store(&Config{ - ConfigResponse: C.ConfigResponse{ - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, - }, - PreferredLocation: C.ServerLocation{ - Country: "US", - City: "New York", - }, - }) - - // Test case: Set preferred server location - t.Run("SetPreferredServerLocation", func(t *testing.T) { - country := "US" - city := "Los Angeles" - - // Call SetPreferredServerLocation - ch.SetPreferredServerLocation(country, city) - - // Verify the preferred location is updated - actualConfig, err := ch.GetConfig() - require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, country, actualConfig.PreferredLocation.Country, "Preferred country should match") - assert.Equal(t, city, actualConfig.PreferredLocation.City, "Preferred city should match") - }) -} - func TestHandlerFetchConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Mock fetcher mockFetcher := &MockFetcher{} @@ -138,13 +91,12 @@ func TestHandlerFetchConfig(t *testing.T) { // Create a ConfigHandler with the mock parser and fetcher ctx, cancel := context.WithCancel(context.Background()) ch := &ConfigHandler{ - configPath: configPath, - preferredLocation: atomic.Pointer[C.ServerLocation]{}, - ftr: mockFetcher, - wgKeyPath: filepath.Join(tempDir, "wg.key"), - svrManager: &mockSrvManager{}, - ctx: ctx, - cancel: cancel, + configPath: configPath, + ftr: mockFetcher, + wgKeyPath: filepath.Join(tempDir, "wg.key"), + ctx: ctx, + cancel: cancel, + logger: log.NoOpLogger(), } // Test case: No server location set @@ -160,8 +112,8 @@ func TestHandlerFetchConfig(t *testing.T) { require.NoError(t, err, "Should not return an error when no server location is set") actualConfig, err := ch.GetConfig() require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, "US", actualConfig.ConfigResponse.Servers[0].Country, "First server country should match") - assert.Equal(t, "New York", actualConfig.ConfigResponse.Servers[0].City, "First server city should match") + assert.Equal(t, "US", actualConfig.Servers[0].Country, "First server country should match") + assert.Equal(t, "New York", actualConfig.Servers[0].City, "First server city should match") }) // Test case: No stored config, fetch succeeds @@ -174,15 +126,13 @@ func TestHandlerFetchConfig(t *testing.T) { }`) mockFetcher.err = nil - ch.preferredLocation.Store(&C.ServerLocation{Country: "US", City: "New York"}) - err := ch.fetchConfig() require.NoError(t, err, "Should not return an error when fetch succeeds") actualConfig, err := ch.GetConfig() require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, "US", actualConfig.ConfigResponse.Servers[0].Country, "First server country should match") - assert.Equal(t, "New York", actualConfig.ConfigResponse.Servers[0].City, "First server city should match") + assert.Equal(t, "US", actualConfig.Servers[0].Country, "First server country should match") + assert.Equal(t, "New York", actualConfig.Servers[0].City, "First server city should match") }) // Test case: Fetch fails diff --git a/config/fetcher.go b/config/fetcher.go index b9419b75..2b97c233 100644 --- a/config/fetcher.go +++ b/config/fetcher.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "errors" - "os" "fmt" "io" @@ -23,12 +22,11 @@ import ( "github.com/getlantern/lantern-box/protocol" - "github.com/getlantern/radiance/api" - "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/account" "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/env" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/kindling" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" ) @@ -40,7 +38,7 @@ type Fetcher interface { // preferred is used to select the server location. // If preferred is empty, the server will select the best location. // The lastModified time is used to check if the configuration has changed since the last request. - fetchConfig(ctx context.Context, preferred C.ServerLocation, wgPublicKey string) ([]byte, error) + fetchConfig(ctx context.Context, preferred common.PreferredLocation, wgPublicKey string) ([]byte, error) } // fetcher is responsible for fetching the configuration from the server. @@ -48,20 +46,27 @@ type fetcher struct { lastModified time.Time locale string etag string - apiClient *api.APIClient + baseURL string + apiClient *account.Client + httpClient *http.Client } // newFetcher creates a new fetcher with the given http client. -func newFetcher(locale string, apiClient *api.APIClient) Fetcher { +func newFetcher(locale string, apiClient *account.Client, httpClient *http.Client) Fetcher { + if httpClient == nil { + httpClient = &http.Client{} + } return &fetcher{ lastModified: time.Time{}, locale: locale, + baseURL: common.GetBaseURL(), apiClient: apiClient, + httpClient: httpClient, } } // fetchConfig fetches the configuration from the server. Nil is returned if no new config is available. -func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, wgPublicKey string) ([]byte, error) { +func (f *fetcher) fetchConfig(ctx context.Context, preferred common.PreferredLocation, wgPublicKey string) ([]byte, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "config_fetcher.fetchConfig") defer span.End() // If we don't have a user ID or token, create a new user. @@ -73,7 +78,7 @@ func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, w Platform: common.Platform, AppName: common.Name, DeviceID: settings.GetString(settings.DeviceIDKey), - UserID: fmt.Sprintf("%d", settings.GetInt64(settings.UserIDKey)), + UserID: settings.GetString(settings.UserIDKey), ProToken: settings.GetString(settings.TokenKey), WGPublicKey: wgPublicKey, Backend: C.SINGBOX, @@ -98,7 +103,7 @@ func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, w if buf == nil { // no new config available return nil, nil } - slog.Log(nil, internal.LevelTrace, "received config", "config", string(buf)) + slog.Log(nil, log.LevelTrace, "received config", "config", string(buf)) f.lastModified = time.Now() return buf, nil @@ -142,18 +147,18 @@ func (f *fetcher) ensureUser(ctx context.Context) error { func (f *fetcher) send(ctx context.Context, body io.Reader) ([]byte, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "config_fetcher.send") defer span.End() - req, err := backend.NewRequestWithHeaders(ctx, http.MethodPost, common.GetBaseURL()+"/config-new", body) + req, err := common.NewRequestWithHeaders(ctx, http.MethodPost, f.baseURL+"/config-new", body) if err != nil { return nil, fmt.Errorf("could not create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Cache-Control", "no-cache") - if val, exists := os.LookupEnv("RADIANCE_COUNTRY"); exists { + if val := env.GetString(env.Country); val != "" { slog.Info("Setting x-lantern-client-country header", "country", val) req.Header.Set("x-lantern-client-country", val) } - if val, exists := os.LookupEnv("RADIANCE_FEATURE_OVERRIDE"); exists && val != "" { + if val := env.GetString(env.FeatureOverrides); val != "" { slog.Info("Setting X-Lantern-Feature-Override header", "features", val) req.Header.Set("X-Lantern-Feature-Override", val) } @@ -165,7 +170,7 @@ func (f *fetcher) send(ctx context.Context, body io.Reader) ([]byte, error) { req.Header.Set("If-None-Match", f.etag) } - resp, err := kindling.HTTPClient().Do(req) + resp, err := f.httpClient.Do(req) if err != nil { return nil, traces.RecordError(ctx, fmt.Errorf("could not send request: %w", err)) } diff --git a/config/fetcher_test.go b/config/fetcher_test.go index f561c13e..4645ef0d 100644 --- a/config/fetcher_test.go +++ b/config/fetcher_test.go @@ -1,81 +1,21 @@ package config import ( - "bytes" - "context" "encoding/json" "io" "net/http" - "path/filepath" + "net/http/httptest" "testing" C "github.com/getlantern/common" - "github.com/getlantern/kindling" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/getlantern/radiance/api" "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" - rkindling "github.com/getlantern/radiance/kindling" - "github.com/getlantern/radiance/kindling/fronted" ) -func TestDomainFrontingFetchConfig(t *testing.T) { - // Disable this test for now since it depends on external service. - t.Skip("Skipping TestDomainFrontingFetchConfig since it depends on external service.") - dataDir := t.TempDir() - f, err := fronted.NewFronted(context.Background(), reporting.PanicListener, filepath.Join(dataDir, "fronted_cache.json"), io.Discard) - require.NoError(t, err) - k := kindling.NewKindling( - "radiance-df-test", - kindling.WithDomainFronting(f), - ) - rkindling.SetKindling(k) - fetcher := newFetcher("en-US", &api.APIClient{}) - - privateKey, err := wgtypes.GenerateKey() - require.NoError(t, err) - - _, err = fetcher.fetchConfig(context.Background(), C.ServerLocation{Country: "US"}, privateKey.PublicKey().String()) - // We expect a 500 error since the user does not have any matching tracks. - require.Error(t, err) - assert.Contains(t, err.Error(), "no lantern-cloud tracks") -} - -func TestProxylessFetchConfig(t *testing.T) { - // Disable this test for now since it depends on external service. - t.Skip("Skipping TestProxylessFetchConfig since it depends on external service.") - k := kindling.NewKindling( - "radiance-df-test", - kindling.WithProxyless("df.iantem.io"), - ) - rkindling.SetKindling(k) - fetcher := newFetcher("en-US", &api.APIClient{}) - - privateKey, err := wgtypes.GenerateKey() - require.NoError(t, err) - - _, err = fetcher.fetchConfig(context.Background(), C.ServerLocation{Country: "US"}, privateKey.PublicKey().String()) - // We expect a 500 error since the user does not have any matching tracks. - require.Error(t, err) - assert.Contains(t, err.Error(), "no lantern-cloud tracks") - -} - -type mockRoundTripper struct { - req *http.Request - resp *http.Response - err error -} - -func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - m.req = req - return m.resp, m.err -} - func TestFetchConfig(t *testing.T) { settings.InitSettings(t.TempDir()) settings.Set(settings.DeviceIDKey, "mock-device-id") @@ -86,25 +26,20 @@ func TestFetchConfig(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - preferredServerLoc *C.ServerLocation - mockResponse *http.Response - mockError error - expectedConfig []byte - expectedErrorMessage string + name string + preferredServerLoc *C.ServerLocation + serverStatus int + serverBody string + expectedConfig []byte + expectError bool }{ { - name: "successful fetch with new config", + name: "successful fetch", preferredServerLoc: &C.ServerLocation{ Country: "US", }, - mockResponse: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(func() []byte { - data := []byte(`{"key":"value"}`) - return data - }())), - }, + serverStatus: http.StatusOK, + serverBody: `{"key":"value"}`, expectedConfig: []byte(`{"key":"value"}`), }, { @@ -112,81 +47,54 @@ func TestFetchConfig(t *testing.T) { preferredServerLoc: &C.ServerLocation{ Country: "US", }, - mockResponse: &http.Response{ - StatusCode: http.StatusNotModified, - Body: io.NopCloser(bytes.NewReader(nil)), - }, + serverStatus: http.StatusNotModified, expectedConfig: nil, }, - { - name: "error during request", - preferredServerLoc: &C.ServerLocation{ - Country: "US", - }, - mockError: context.DeadlineExceeded, - expectedErrorMessage: "context deadline exceeded", - }, } - apiClient := &api.APIClient{} - defer apiClient.Reset() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockRT := &mockRoundTripper{ - resp: tt.mockResponse, - err: tt.mockError, - } - rkindling.SetKindling(&mockKindling{ - &http.Client{ - Transport: mockRT, - }, - }) - fetcher := newFetcher("en-US", &api.APIClient{}) + var capturedReq *http.Request + var capturedBody []byte + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + capturedBody = body + capturedReq = r + w.WriteHeader(tt.serverStatus) + if tt.serverBody != "" { + w.Write([]byte(tt.serverBody)) + } + })) + defer srv.Close() - gotConfig, err := fetcher.fetchConfig(t.Context(), *tt.preferredServerLoc, privateKey.PublicKey().String()) + f := newFetcher("en-US", nil, srv.Client()).(*fetcher) + f.baseURL = srv.URL - if tt.expectedErrorMessage != "" { + gotConfig, err := f.fetchConfig(t.Context(), *tt.preferredServerLoc, privateKey.PublicKey().String()) + + if tt.expectError { require.Error(t, err) - assert.Contains(t, err.Error(), tt.expectedErrorMessage) } else { require.NoError(t, err) assert.Equal(t, tt.expectedConfig, gotConfig) } - if tt.mockResponse != nil { - require.NotNil(t, mockRT.req) - assert.Equal(t, "application/json", mockRT.req.Header.Get("Content-Type")) - assert.Equal(t, "no-cache", mockRT.req.Header.Get("Cache-Control")) + require.NotNil(t, capturedReq) + assert.Equal(t, "application/json", capturedReq.Header.Get("Content-Type")) + assert.Equal(t, "no-cache", capturedReq.Header.Get("Cache-Control")) - body, err := io.ReadAll(mockRT.req.Body) - require.NoError(t, err) + var confReq C.ConfigRequest + err = json.Unmarshal(capturedBody, &confReq) + require.NoError(t, err) - var confReq C.ConfigRequest - err = json.Unmarshal(body, &confReq) - require.NoError(t, err) - - assert.Equal(t, common.Platform, confReq.Platform) - assert.Equal(t, common.Name, confReq.AppName) - assert.Equal(t, settings.GetString(settings.DeviceIDKey), confReq.DeviceID) - assert.Equal(t, privateKey.PublicKey().String(), confReq.WGPublicKey) - if tt.preferredServerLoc != nil { - assert.Equal(t, tt.preferredServerLoc, confReq.PreferredLocation) - } + assert.Equal(t, common.Platform, confReq.Platform) + assert.Equal(t, common.Name, confReq.AppName) + assert.Equal(t, settings.GetString(settings.DeviceIDKey), confReq.DeviceID) + assert.Equal(t, privateKey.PublicKey().String(), confReq.WGPublicKey) + if tt.preferredServerLoc != nil { + assert.Equal(t, tt.preferredServerLoc, confReq.PreferredLocation) } }) } } - -type mockKindling struct { - c *http.Client -} - -// NewHTTPClient returns a new HTTP client that is configured to use kindling. -func (m *mockKindling) NewHTTPClient() *http.Client { - return m.c -} - -// ReplaceTransport replaces an existing transport RoundTripper generator with the provided one. -func (m *mockKindling) ReplaceTransport(name string, rt func(ctx context.Context, addr string) (http.RoundTripper, error)) error { - panic("not implemented") // TODO: Implement -} diff --git a/events/events.go b/events/events.go index 7889f1f1..a55e48e8 100644 --- a/events/events.go +++ b/events/events.go @@ -27,6 +27,7 @@ package events import ( + "reflect" "sync" ) @@ -36,7 +37,7 @@ type Event interface { } var ( - subscriptions = make(map[any]map[*Subscription[Event]]func(any)) + subscriptions = make(map[reflect.Type]map[*Subscription[Event]]func(any)) subscriptionsMu sync.RWMutex ) @@ -50,12 +51,12 @@ type Subscription[T Event] struct { func Subscribe[T Event](callback func(evt T)) *Subscription[T] { subscriptionsMu.Lock() defer subscriptionsMu.Unlock() - var evt T - if subscriptions[evt] == nil { - subscriptions[evt] = make(map[*Subscription[Event]]func(any)) + key := reflect.TypeFor[T]() + if subscriptions[key] == nil { + subscriptions[key] = make(map[*Subscription[Event]]func(any)) } sub := &Subscription[T]{} - subscriptions[evt][(*Subscription[Event])(sub)] = func(e any) { callback(e.(T)) } + subscriptions[key][(*Subscription[Event])(sub)] = func(e any) { callback(e.(T)) } return sub } @@ -77,11 +78,11 @@ func SubscribeOnce[T Event](callback func(evt T)) *Subscription[T] { func Unsubscribe[T Event](sub *Subscription[T]) { subscriptionsMu.Lock() defer subscriptionsMu.Unlock() - var evt T - if subs, ok := subscriptions[evt]; ok { + key := reflect.TypeFor[T]() + if subs, ok := subscriptions[key]; ok { delete(subs, (*Subscription[Event])(sub)) if len(subs) == 0 { - delete(subscriptions, evt) + delete(subscriptions, key) } } } @@ -95,8 +96,7 @@ func (e *Subscription[T]) Unsubscribe() { func Emit[T Event](evt T) { subscriptionsMu.RLock() defer subscriptionsMu.RUnlock() - var e T - if subs, ok := subscriptions[e]; ok { + if subs, ok := subscriptions[reflect.TypeFor[T]()]; ok { for _, cb := range subs { go cb(evt) } diff --git a/go.mod b/go.mod index ed8d6c69..7e6db629 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ replace github.com/refraction-networking/water => github.com/getlantern/water v0 require ( github.com/1Password/srp v0.2.0 github.com/Microsoft/go-winio v0.6.2 + github.com/alexflint/go-arg v1.6.1 github.com/alitto/pond v1.9.2 github.com/getlantern/amp v0.0.0-20260305201851-782bc8045e58 github.com/getlantern/appdir v0.0.0-20250324200952-507a0625eb01 @@ -36,7 +37,6 @@ require ( github.com/getlantern/lantern-box v0.0.51 github.com/getlantern/pluriconfig v0.0.0-20251126214241-8cc8bc561535 github.com/getlantern/timezone v0.0.0-20210901200113-3f9de9d360c9 - github.com/go-resty/resty/v2 v2.16.5 github.com/goccy/go-yaml v1.19.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 @@ -44,6 +44,7 @@ require ( github.com/knadh/koanf/parsers/json v1.0.0 github.com/knadh/koanf/providers/rawbytes v1.0.0 github.com/knadh/koanf/v2 v2.3.0 + github.com/r3labs/sse/v2 v2.10.0 github.com/sagernet/sing v0.7.18 github.com/sagernet/sing-box v1.12.22 github.com/stretchr/testify v1.11.1 @@ -52,7 +53,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.41.0 go.opentelemetry.io/otel/sdk v1.41.0 go.opentelemetry.io/otel/sdk/metric v1.41.0 - go.uber.org/mock v0.5.0 + golang.org/x/term v0.40.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -67,6 +68,7 @@ require ( github.com/akutz/memconn v0.1.0 // indirect github.com/alecthomas/atomic v0.1.0-alpha2 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect + github.com/alexflint/go-scalar v1.2.0 // indirect github.com/alitto/pond/v2 v2.1.5 // indirect github.com/anacrolix/btree v0.0.0-20251201064447-d86c3fa41bd8 // indirect github.com/anacrolix/chansync v0.7.0 // indirect @@ -114,6 +116,7 @@ require ( github.com/getlantern/algeneva v0.0.0-20250307163401-1824e7b54f52 // indirect github.com/getlantern/lantern-water v0.0.0-20260317143726-e0ee64a11d90 // indirect github.com/getlantern/samizdat v0.0.3-0.20260310125445-325cf1bd1b60 // indirect + github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 // indirect github.com/go-llsqlite/adapter v0.0.0-20230927005056-7f5ce7f0c916 // indirect github.com/go-llsqlite/crawshaw v0.5.6-0.20250312230104-194977a03421 // indirect @@ -205,15 +208,16 @@ require ( go.etcd.io/bbolt v1.3.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect golang.getoutline.org/sdk v0.0.21 // indirect golang.getoutline.org/sdk/x v0.1.0 // indirect - golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.34.0 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect + gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect modernc.org/libc v1.22.3 // indirect modernc.org/mathutil v1.5.0 // indirect modernc.org/memory v1.5.0 // indirect @@ -234,14 +238,13 @@ require ( github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534 // indirect github.com/getlantern/osversion v0.0.0-20240418205916-2e84a4a4e175 github.com/getsentry/sentry-go v0.31.1 - github.com/go-chi/chi/v5 v5.2.2 github.com/go-chi/render v1.0.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gofrs/uuid/v5 v5.3.2 + github.com/gofrs/uuid/v5 v5.3.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect diff --git a/go.sum b/go.sum index bfbccaa9..b0dac2f4 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,10 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alexflint/go-arg v1.6.1 h1:uZogJ6VDBjcuosydKgvYYRhh9sRCusjOvoOLZopBlnA= +github.com/alexflint/go-arg v1.6.1/go.mod h1:nQ0LFYftLJ6njcaee0sU+G0iS2+2XJQfA8I062D0LGc= +github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= +github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alitto/pond v1.9.2 h1:9Qb75z/scEZVCoSU+osVmQ0I0JOeLfdTDafrbcJ8CLs= github.com/alitto/pond v1.9.2/go.mod h1:xQn3P/sHTYcU/1BR3i86IGIrilcrGC2LiS+E2+CJWsI= github.com/alitto/pond/v2 v2.1.5 h1:2pp/KAPcb02NSpHsjjnxnrTDzogMLsq+vFf/L0DB84A= @@ -307,8 +311,6 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= -github.com/go-resty/resty/v2 v2.16.5 h1:hBKqmWrr7uRc3euHVqmh1HTHcKn99Smr7o5spptdhTM= -github.com/go-resty/resty/v2 v2.16.5/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= @@ -605,6 +607,8 @@ github.com/protolambda/ctxlock v0.1.0 h1:rCUY3+vRdcdZXqT07iXgyr744J2DU2LCBIXowYA github.com/protolambda/ctxlock v0.1.0/go.mod h1:vefhX6rIZH8rsg5ZpOJfEDYQOppZi19SfPiGOFrNnwM= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= +github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= @@ -874,6 +878,7 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= @@ -1061,6 +1066,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= +gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/constants.go b/internal/constants.go new file mode 100644 index 00000000..dd0cd197 --- /dev/null +++ b/internal/constants.go @@ -0,0 +1,10 @@ +package internal + +const ( + DebugBoxOptionsFileName = "debug-box-options.json" + ConfigFileName = "config.json" + ServersFileName = "servers.json" + SplitTunnelFileName = "split-tunnel.json" + LogFileName = "lantern.log" + CrashLogFileName = "lantern-crash.log" +) diff --git a/internal/log.go b/internal/log.go deleted file mode 100644 index 3f1c6195..00000000 --- a/internal/log.go +++ /dev/null @@ -1,73 +0,0 @@ -package internal - -import ( - "fmt" - "io" - "log/slog" - "strings" -) - -const ( - // slog does not define trace and fatal levels, so we define them here. - LevelTrace = slog.LevelDebug - 4 - LevelDebug = slog.LevelDebug - LevelInfo = slog.LevelInfo - LevelWarn = slog.LevelWarn - LevelError = slog.LevelError - LevelFatal = slog.LevelError + 4 - LevelPanic = slog.LevelError + 8 - - Disable = slog.LevelInfo + 1000 // A level that disables logging, used for testing or no-op logger. -) - -// ParseLogLevel parses a string representation of a log level and returns the corresponding slog.Level. -// If the level is not recognized, it returns LevelInfo. -func ParseLogLevel(level string) (slog.Level, error) { - switch strings.ToLower(level) { - case "trace": - return LevelTrace, nil - case "debug": - return LevelDebug, nil - case "info": - return LevelInfo, nil - case "warn", "warning": - return LevelWarn, nil - case "error": - return LevelError, nil - case "fatal": - return LevelFatal, nil - case "panic": - return LevelPanic, nil - case "disable", "none", "off": - return Disable, nil - default: - return LevelInfo, fmt.Errorf("unknown log level: %s", level) - } -} - -func FormatLogLevel(level slog.Level) string { - switch { - case level < LevelDebug: - return "TRACE" - case level < LevelInfo: - return "DEBUG" - case level < LevelWarn: - return "INFO" - case level < LevelError: - return "WARN" - case level < LevelFatal: - return "ERROR" - case level < LevelPanic: - return "FATAL" - default: - return "PANIC" - } -} - -// NoOpLogger returns a no-op logger that does not log anything. -func NoOpLogger() *slog.Logger { - // Create a no-op logger that does nothing. - return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ - Level: Disable, - })) -} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 44dbc0a8..8bdb3f05 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -2,7 +2,6 @@ package testutil import ( "testing" - _ "unsafe" // for go:linkname "github.com/getlantern/radiance/common/settings" ) @@ -15,8 +14,4 @@ func SetPathsForTesting(t *testing.T) { tmp := t.TempDir() settings.Set(settings.DataPathKey, tmp) settings.Set(settings.LogPathKey, tmp) - ipc_serverTestSetup(tmp + "/lantern.sock") } - -//go:linkname ipc_serverTestSetup -func ipc_serverTestSetup(path string) diff --git a/ipc/client.go b/ipc/client.go new file mode 100644 index 00000000..a24eb9cd --- /dev/null +++ b/ipc/client.go @@ -0,0 +1,592 @@ +package ipc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "os" + "syscall" + + box "github.com/getlantern/lantern-box" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/issue" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/vpn" + + sjson "github.com/sagernet/sing/common/json" +) + +func newClient() *Client { + return &Client{ + http: &http.Client{ + Transport: &http.Transport{ + DialContext: dialContext, + ForceAttemptHTTP2: true, + Protocols: &protocols, + }, + }, + } +} + +// doJSON executes an HTTP request and decodes the JSON response into dst. +func (c *Client) doJSON(ctx context.Context, method, endpoint string, body, dst any) error { + data, err := c.do(ctx, method, endpoint, body) + if err != nil { + return err + } + if dst == nil { + return nil + } + return json.Unmarshal(data, dst) +} + +// Error is returned by Client methods when the server responds with an error status. +type Error struct { + Status int + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("ipc: status %d: %s", e.Status, e.Message) +} + +// IsNotFound reports whether the error is a 404 response. +func IsNotFound(err error) bool { + var e *Error + return errors.As(err, &e) && e.Status == http.StatusNotFound +} + +///////////// +// VPN // +///////////// + +// VPNStatus returns the current VPN connection status. +func (c *Client) VPNStatus(ctx context.Context) (vpn.VPNStatus, error) { + var status vpn.VPNStatus + err := c.doJSON(ctx, http.MethodGet, vpnStatusEndpoint, nil, &status) + return status, err +} + +// ConnectVPN connects the VPN using the given server tag. +func (c *Client) ConnectVPN(ctx context.Context, tag string) error { + _, err := c.do(ctx, http.MethodPost, vpnConnectEndpoint, TagRequest{Tag: tag}) + return err +} + +// DisconnectVPN disconnects the VPN. +func (c *Client) DisconnectVPN(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnDisconnectEndpoint, nil) + return err +} + +// RestartVPN restarts the VPN connection. +func (c *Client) RestartVPN(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnRestartEndpoint, nil) + return err +} + +// VPNConnections returns all VPN connections (active and recently closed). +func (c *Client) VPNConnections(ctx context.Context) ([]vpn.Connection, error) { + var conns []vpn.Connection + err := c.doJSON(ctx, http.MethodGet, vpnConnectionsEndpoint, nil, &conns) + return conns, err +} + +// ActiveVPNConnections returns currently active VPN connections. +func (c *Client) ActiveVPNConnections(ctx context.Context) ([]vpn.Connection, error) { + var conns []vpn.Connection + err := c.doJSON(ctx, http.MethodGet, vpnConnectionsEndpoint+"?active=true", nil, &conns) + return conns, err +} + +// RunOfflineURLTests runs URL performance tests when offline (VPN disconnected) and caches the +// results. This enables autoconnect to select the best server for the initial connection. +func (c *Client) RunOfflineURLTests(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnOfflineTestsEndpoint, nil) + return err +} + +// VPNStatusEvents connects to the VPN status event stream. It calls handler for each event +// received until ctx is cancelled or the connection is closed. +func (c *Client) VPNStatusEvents(ctx context.Context, handler func(vpn.StatusUpdateEvent)) error { + return c.sseStream(ctx, vpnStatusEventsEndpoint, func(data []byte) { + var evt vpn.StatusUpdateEvent + if err := json.Unmarshal(data, &evt); err != nil { + return + } + handler(evt) + }) +} + +/////////////////////// +// Server selection // +/////////////////////// + +var boxCtx = box.BaseContext() + +// SelectServer selects the server with the given tag. +func (c *Client) SelectServer(ctx context.Context, tag string) error { + _, err := c.do(ctx, http.MethodPost, serverSelectedEndpoint, TagRequest{Tag: tag}) + return err +} + +// SelectedServer returns the currently selected server and whether it still exists. +func (c *Client) SelectedServer(ctx context.Context) (servers.Server, bool, error) { + data, err := c.do(ctx, http.MethodGet, serverSelectedEndpoint, nil) + if err != nil { + return servers.Server{}, false, err + } + resp, err := sjson.UnmarshalExtendedContext[SelectedServerResponse](boxCtx, data) + return resp.Server, resp.Exists, err +} + +// AutoSelected returns the server that's currently auto-selected. +func (c *Client) AutoSelected(ctx context.Context) (servers.Server, error) { + var selected servers.Server + err := c.doJSON(ctx, http.MethodGet, serverAutoSelectedEndpoint, nil, &selected) + return selected, err +} + +// AutoSelectedEvents connects to the auto-selected event stream. It calls handler for each +// event received until ctx is cancelled or the connection is closed. +func (c *Client) AutoSelectedEvents(ctx context.Context, handler func(vpn.AutoSelectedEvent)) error { + return c.sseStream(ctx, serverAutoSelectedEventsEndpoint, func(data []byte) { + var evt vpn.AutoSelectedEvent + if err := json.Unmarshal(data, &evt); err != nil { + return + } + handler(evt) + }) +} + +/////////////////////// +// Server management // +/////////////////////// + +// Servers returns all server groups. +func (c *Client) Servers(ctx context.Context) (servers.Servers, error) { + data, err := c.do(ctx, http.MethodGet, serversEndpoint, nil) + if err != nil { + return nil, err + } + return sjson.UnmarshalExtendedContext[servers.Servers](boxCtx, data) +} + +// GetServerByTag returns the server with the given tag. +func (c *Client) GetServerByTag(ctx context.Context, tag string) (servers.Server, bool, error) { + q := url.Values{"tag": {tag}} + data, err := c.do(ctx, http.MethodGet, serversEndpoint+"?"+q.Encode(), nil) + if err != nil { + if IsNotFound(err) { + return servers.Server{}, false, nil + } + return servers.Server{}, false, err + } + server, err := sjson.UnmarshalExtendedContext[servers.Server](boxCtx, data) + return server, true, nil +} + +// AddServers adds servers to the given group. +func (c *Client) AddServers(ctx context.Context, group servers.ServerGroup, options servers.Options) error { + req := AddServersRequest{Group: group, Options: options} + body, err := sjson.MarshalContext(boxCtx, req) + if err != nil { + return fmt.Errorf("marshal add servers request: %w", err) + } + _, err = c.do(ctx, http.MethodPost, serversAddEndpoint, body) + return err +} + +// RemoveServers removes servers by tag from the given group. +func (c *Client) RemoveServers(ctx context.Context, tags []string) error { + _, err := c.do(ctx, http.MethodPost, serversRemoveEndpoint, RemoveServersRequest{Tags: tags}) + return err +} + +// AddServersByJSON adds servers from a JSON configuration string. +func (c *Client) AddServersByJSON(ctx context.Context, config string) error { + _, err := c.do(ctx, http.MethodPost, serversFromJSONEndpoint, JSONConfigRequest{Config: config}) + return err +} + +// AddServersByURL adds servers from the given URLs. +func (c *Client) AddServersByURL(ctx context.Context, urls []string, skipCertVerification bool) error { + _, err := c.do(ctx, http.MethodPost, serversFromURLsEndpoint, URLsRequest{URLs: urls, SkipCertVerification: skipCertVerification}) + return err +} + +// AddPrivateServer adds a private server. +func (c *Client) AddPrivateServer(ctx context.Context, tag, ip string, port int, accessToken string) error { + _, err := c.do(ctx, http.MethodPost, serversPrivateEndpoint, PrivateServerRequest{Tag: tag, IP: ip, Port: port, AccessToken: accessToken}) + return err +} + +// InviteToPrivateServer creates an invite for a private server and returns the invite code. +func (c *Client) InviteToPrivateServer(ctx context.Context, ip string, port int, accessToken, inviteName string) (string, error) { + var resp CodeResponse + err := c.doJSON(ctx, http.MethodPost, serversPrivateInviteEndpoint, + PrivateServerInviteRequest{IP: ip, Port: port, AccessToken: accessToken, InviteName: inviteName}, &resp) + return resp.Code, err +} + +// RevokePrivateServerInvite revokes an invite for a private server. +func (c *Client) RevokePrivateServerInvite(ctx context.Context, ip string, port int, accessToken, inviteName string) error { + _, err := c.do(ctx, http.MethodDelete, serversPrivateInviteEndpoint, + PrivateServerInviteRequest{IP: ip, Port: port, AccessToken: accessToken, InviteName: inviteName}) + return err +} + +////////////// +// Settings // +////////////// + +// Features returns the feature flags from the current configuration. +func (c *Client) Features(ctx context.Context) (map[string]bool, error) { + var features map[string]bool + err := c.doJSON(ctx, http.MethodGet, featuresEndpoint, nil, &features) + return features, err +} + +// Settings returns the current settings as a map of key-value pairs. +func (c *Client) Settings(ctx context.Context) (settings.Settings, error) { + var s settings.Settings + err := c.doJSON(ctx, http.MethodGet, settingsEndpoint, nil, &s) + return s, err +} + +// PatchSettings updates settings with the given key-value pairs and returns the full updates settings. +func (c *Client) PatchSettings(ctx context.Context, updates settings.Settings) (settings.Settings, error) { + var s settings.Settings + err := c.doJSON(ctx, http.MethodPatch, settingsEndpoint, updates, &s) + return s, err +} + +func (c *Client) EnableTelemetry(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.TelemetryKey: enable}) + return err +} + +func (c *Client) EnableSplitTunneling(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.SplitTunnelKey: enable}) + return err +} + +func (c *Client) EnableSmartRouting(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.SmartRoutingKey: enable}) + return err +} + +func (c *Client) EnableAdBlocking(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.AdBlockKey: enable}) + return err +} + +////////////////// +// Split Tunnel // +///////////////// + +// SplitTunnelFilters returns the current split tunnel configuration. +func (c *Client) SplitTunnelFilters(ctx context.Context) (vpn.SplitTunnelFilter, error) { + var filter vpn.SplitTunnelFilter + err := c.doJSON(ctx, http.MethodGet, splitTunnelEndpoint, nil, &filter) + return filter, err +} + +// AddSplitTunnelItems adds items to the split tunnel filter. +func (c *Client) AddSplitTunnelItems(ctx context.Context, items vpn.SplitTunnelFilter) error { + _, err := c.do(ctx, http.MethodPost, splitTunnelEndpoint, items) + return err +} + +// RemoveSplitTunnelItems removes items from the split tunnel filter. +func (c *Client) RemoveSplitTunnelItems(ctx context.Context, items vpn.SplitTunnelFilter) error { + _, err := c.do(ctx, http.MethodDelete, splitTunnelEndpoint, items) + return err +} + +///////////// +// Account // +///////////// + +// NewUser creates a new anonymous user. +func (c *Client) NewUser(ctx context.Context) (*account.UserData, error) { + var userData account.UserData + if err := c.doJSON(ctx, http.MethodPost, accountNewUserEndpoint, nil, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// Login authenticates the user with email and password. +func (c *Client) Login(ctx context.Context, email, password string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodPost, accountLoginEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// Logout logs the user out. +func (c *Client) Logout(ctx context.Context, email string) (*account.UserData, error) { + var userData account.UserData + if err := c.doJSON(ctx, http.MethodPost, accountLogoutEndpoint, EmailRequest{Email: email}, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// FetchUserData fetches fresh user data from the remote server. +func (c *Client) FetchUserData(ctx context.Context) (*account.UserData, error) { + return c.userData(ctx, true) +} + +// UserData returns locally cached user data. +func (c *Client) UserData(ctx context.Context) (*account.UserData, error) { + return c.userData(ctx, false) +} + +func (c *Client) userData(ctx context.Context, fetch bool) (*account.UserData, error) { + var userData account.UserData + url := fmt.Sprintf("%s?fetch=%v", accountUserDataEndpoint, fetch) + if err := c.doJSON(ctx, http.MethodGet, url, nil, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// UserDevices returns the list of devices linked to the user's account. +func (c *Client) UserDevices(ctx context.Context) ([]settings.Device, error) { + var devices []settings.Device + err := c.doJSON(ctx, http.MethodGet, accountDevicesEndpoint, nil, &devices) + return devices, err +} + +// RemoveDevice removes a device from the user's account. +func (c *Client) RemoveDevice(ctx context.Context, deviceID string) (*account.LinkResponse, error) { + var resp account.LinkResponse + if err := c.doJSON(ctx, http.MethodDelete, accountDevicesEndpoint+url.PathEscape(deviceID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// SignUp creates a new account with the given email and password. +func (c *Client) SignUp(ctx context.Context, email, password string) ([]byte, *account.SignupResponse, error) { + var resp SignupResponse + err := c.doJSON( + ctx, http.MethodPost, accountSignupEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &resp, + ) + if err != nil { + return nil, nil, err + } + return resp.Salt, resp.Response, nil +} + +// SignupEmailConfirmation confirms the signup email with the given code. +func (c *Client) SignupEmailConfirmation(ctx context.Context, email, code string) error { + _, err := c.do(ctx, http.MethodPost, accountSignupEndpoint+"confirm", EmailCodeRequest{Email: email, Code: code}) + return err +} + +// SignupEmailResendCode requests a resend of the signup confirmation email. +func (c *Client) SignupEmailResendCode(ctx context.Context, email string) error { + _, err := c.do(ctx, http.MethodPost, accountSignupEndpoint+"resend", EmailRequest{Email: email}) + return err +} + +// StartChangeEmail initiates an email address change. +func (c *Client) StartChangeEmail(ctx context.Context, newEmail, password string) error { + _, err := c.do(ctx, http.MethodPost, accountEmailEndpoint+"/start", ChangeEmailStartRequest{NewEmail: newEmail, Password: password}) + return err +} + +// CompleteChangeEmail completes an email address change. +func (c *Client) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + _, err := c.do(ctx, http.MethodPost, accountEmailEndpoint+"/complete", + ChangeEmailCompleteRequest{NewEmail: newEmail, Password: password, Code: code}) + return err +} + +// StartRecoveryByEmail initiates account recovery by email. +func (c *Client) StartRecoveryByEmail(ctx context.Context, email string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/start", EmailRequest{Email: email}) + return err +} + +// CompleteRecoveryByEmail completes account recovery with a new password and code. +func (c *Client) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/complete", + RecoveryCompleteRequest{Email: email, NewPassword: newPassword, Code: code}) + return err +} + +// ValidateEmailRecoveryCode validates the recovery code without completing the recovery. +func (c *Client) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/validate", EmailCodeRequest{Email: email, Code: code}) + return err +} + +// DeleteAccount deletes the user's account. +func (c *Client) DeleteAccount(ctx context.Context, email, password string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodDelete, accountDeleteEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// OAuthLoginUrl returns the OAuth login URL for the given provider. +func (c *Client) OAuthLoginUrl(ctx context.Context, provider string) (string, error) { + var resp URLResponse + q := url.Values{"provider": {provider}} + err := c.doJSON(ctx, http.MethodGet, accountOAuthEndpoint+"?"+q.Encode(), nil, &resp) + return resp.URL, err +} + +// OAuthLoginCallback exchanges an OAuth token for user data. +func (c *Client) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodPost, accountOAuthEndpoint, + OAuthTokenRequest{OAuthToken: oAuthToken}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// DataCapInfo returns the current data cap information as a JSON string. +func (c *Client) DataCapInfo(ctx context.Context) (*account.DataCapInfo, error) { + var resp account.DataCapInfo + err := c.doJSON(ctx, http.MethodGet, accountDataCapEndpoint, nil, &resp) + return &resp, err +} + +// DataCapStream connects to the data cap event stream. It calls handler for each event +// received until ctx is cancelled or the connection is closed. +func (c *Client) DataCapStream(ctx context.Context, handler func(account.DataCapInfo)) error { + return c.sseStream(ctx, accountDataCapStreamEndpoint, func(data []byte) { + var info account.DataCapInfo + if err := json.Unmarshal(data, &info); err != nil { + return + } + handler(info) + }) +} + +/////////////////// +// Subscriptions // +/////////////////// + +// ActivationCode purchases a subscription using a reseller code. +func (c *Client) ActivationCode(ctx context.Context, email, resellerCode string) (*account.PurchaseResponse, error) { + var resp account.PurchaseResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionActivationEndpoint, + ActivationRequest{Email: email, ResellerCode: resellerCode}, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +// NewStripeSubscription creates a new Stripe subscription and returns the client secret. +func (c *Client) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { + var resp ClientSecretResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionStripeEndpoint, + StripeSubscriptionRequest{Email: email, PlanID: planID}, &resp) + return resp.ClientSecret, err +} + +// PaymentRedirect returns a payment redirect URL. +func (c *Client) PaymentRedirect(ctx context.Context, data account.PaymentRedirectData) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionPaymentRedirectEndpoint, data, &resp) + return resp.URL, err +} + +// ReferralAttach attaches a referral code to the current user. +func (c *Client) ReferralAttach(ctx context.Context, code string) (bool, error) { + var resp SuccessResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionReferralEndpoint, CodeRequest{Code: code}, &resp) + return resp.Success, err +} + +// StripeBillingPortalURL returns the Stripe billing portal URL. +func (c *Client) StripeBillingPortalURL(ctx context.Context) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodGet, subscriptionBillingPortalEndpoint, nil, &resp) + return resp.URL, err +} + +// SubscriptionPaymentRedirectURL returns a subscription payment redirect URL. +func (c *Client) SubscriptionPaymentRedirectURL(ctx context.Context, data account.PaymentRedirectData) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionPaymentRedirectURLEndpoint, data, &resp) + return resp.URL, err +} + +// SubscriptionPlans returns available subscription plans for the given channel. +func (c *Client) SubscriptionPlans(ctx context.Context, channel string) (string, error) { + var resp PlansResponse + q := url.Values{"channel": {channel}} + err := c.doJSON(ctx, http.MethodGet, subscriptionPlansEndpoint+"?"+q.Encode(), nil, &resp) + return resp.Plans, err +} + +// VerifySubscription verifies a subscription purchase. +func (c *Client) VerifySubscription(ctx context.Context, service account.SubscriptionService, data map[string]string) (string, error) { + var resp ResultResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionVerifyEndpoint, + VerifySubscriptionRequest{Service: service, Data: data}, &resp) + return resp.Result, err +} + +/////////// +// Issue // +/////////// + +// ReportIssue submits an issue report. additionalAttachments is a list of file paths for additional +// files to include. Logs, diagnostics, and the config response are included automatically and do +// not need to be specified. +func (c *Client) ReportIssue(ctx context.Context, issueType issue.IssueType, description, email string, additionalAttachments []string) error { + _, err := c.do(ctx, http.MethodPost, issueEndpoint, + IssueReportRequest{IssueType: issueType, Description: description, Email: email, AdditionalAttachments: additionalAttachments}) + return err +} + +///////////// +// helpers // +///////////// + +// isConnectionError reports whether err indicates that the IPC socket is unreachable +// (e.g. connection refused or socket file not found). +func isConnectionError(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) { + // connection refused (server not listening) + if errors.Is(opErr.Err, syscall.ECONNREFUSED) { + return true + } + // socket file does not exist (server never started / was cleaned up) + if errors.Is(opErr.Err, syscall.ENOENT) { + return true + } + // check wrapped syscall errors + var sysErr *os.SyscallError + if errors.As(opErr.Err, &sysErr) { + return errors.Is(sysErr.Err, syscall.ECONNREFUSED) || errors.Is(sysErr.Err, syscall.ENOENT) + } + } + // Also check the unwrapped error directly for cases where the wrapping differs by platform + return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) +} diff --git a/ipc/client_mobile.go b/ipc/client_mobile.go new file mode 100644 index 00000000..6d6a9fc6 --- /dev/null +++ b/ipc/client_mobile.go @@ -0,0 +1,233 @@ +//go:build android || ios || (darwin && !standalone) + +package ipc + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "time" + + "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/common/settings" + rlog "github.com/getlantern/radiance/log" +) + +type Client struct { + http *http.Client + localapi *localapi + mu sync.RWMutex +} + +func NewClient(ctx context.Context, opts backend.Options) (*Client, error) { + b, err := backend.NewLocalBackend(ctx, opts) + if err != nil { + return nil, fmt.Errorf("create local backend: %w", err) + } + b.Start() + c := newClient() + c.localapi = newLocalAPI(b, false) + return c, nil +} + +// Close releases resources held by the client, including any local backend. +func (c *Client) Close() { + c.stopLocal() + c.http.CloseIdleConnections() +} + +func (c *Client) stopLocal() { + c.mu.Lock() + defer c.mu.Unlock() + if be := c.localapi.setBackend(nil); be != nil { + be.Close() + } +} + +// do executes an HTTP request with an optional JSON body and returns the raw response body. If +// body needs to be marshaled using sing/json, it should be pre-marshaled to []byte before passing +// to do. do returns an error if the response status is >= 400. +func (c *Client) do(ctx context.Context, method, endpoint string, body any) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + switch body := body.(type) { + case []byte: + bodyReader = bytes.NewReader(body) + default: + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + bodyReader = bytes.NewReader(data) + } + } + + req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.http.Do(req) + if err != nil { + if isConnectionError(err) { + c.mu.Lock() + defer c.mu.Unlock() + if be := c.localapi.be.Load(); be == nil { + opts := backend.Options{ + DataDir: settings.GetString(settings.DataPathKey), + LogDir: settings.GetString(settings.LogPathKey), + Locale: settings.GetString(settings.LocaleKey), + DeviceID: settings.GetString(settings.DeviceIDKey), + LogLevel: settings.GetString(settings.LogLevelKey), + TelemetryConsent: settings.GetBool(settings.TelemetryKey), + } + be, err = backend.NewLocalBackend(ctx, opts) + if err != nil { + return nil, fmt.Errorf("create local backend: %w", err) + } + c.localapi.setBackend(be) + } + if br, ok := bodyReader.(*bytes.Reader); ok { + br.Seek(0, io.SeekStart) + } + req, _ = http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + return c.doLocal(req) + } + return nil, fmt.Errorf("ipc request %s %s: %w", method, endpoint, err) + } + c.stopLocal() // IPC is reachable; shut down local backend if still running + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, &Error{ + Status: resp.StatusCode, + Message: strings.TrimSpace(string(respBody)), + } + } + return respBody, nil +} + +// doLocal serves the request through the given in-process handler. +func (c *Client) doLocal(req *http.Request) ([]byte, error) { + rec := httptest.NewRecorder() + c.localapi.ServeHTTP(rec, req) + + body := rec.Body.Bytes() + if rec.Code >= 400 { + return nil, &Error{ + Status: rec.Code, + Message: strings.TrimSpace(string(body)), + } + } + return body, nil +} + +// TailLogs connects to the log stream endpoint and calls handler for each log +// entry received until ctx is cancelled or the connection is closed. +func (c *Client) TailLogs(ctx context.Context, handler func(rlog.LogEntry)) error { + merged := make(chan rlog.LogEntry, 64) + + // Always tail local logs. + localCh, unsub := rlog.Subscribe() + defer unsub() + go func() { + for { + select { + case entry := <-localCh: + select { + case merged <- entry: + default: + } + case <-ctx.Done(): + return + } + } + }() + + // Tail server logs whenever the IPC server is reachable. + go func() { + for ctx.Err() == nil { + c.sseStream(ctx, logsStreamEndpoint, func(data []byte) { + var entry rlog.LogEntry + if json.Unmarshal(data, &entry) == nil { + select { + case merged <- entry: + default: + } + } + }) + // Server unavailable or disconnected; wait before retrying. + select { + case <-time.After(500 * time.Millisecond): + case <-ctx.Done(): + return + } + } + }() + + for { + select { + case entry := <-merged: + handler(entry) + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// sseStream connects to an SSE endpoint and calls handler for each event data line. +// Blocks until ctx is cancelled or the connection is closed. +func (c *Client) sseStream(ctx context.Context, endpoint string, handler func([]byte)) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL+endpoint, nil) + if err != nil { + return fmt.Errorf("create SSE request: %w", err) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := c.http.Do(req) + if err != nil { + c.mu.RLock() + hasFallback := c.localapi != nil + c.mu.RUnlock() + if hasFallback && isConnectionError(err) { + return ErrIPCNotRunning + } + return fmt.Errorf("SSE connect %s: %w", endpoint, err) + } + c.stopLocal() // IPC is reachable; shut down local backend if still running + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &Error{Status: resp.StatusCode, Message: strings.TrimSpace(string(body))} + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if data, ok := strings.CutPrefix(line, "data: "); ok { + handler([]byte(data)) + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + return fmt.Errorf("SSE %s: read: %w", endpoint, err) + } + return nil +} diff --git a/ipc/client_nonmobile.go b/ipc/client_nonmobile.go new file mode 100644 index 00000000..6506ed1c --- /dev/null +++ b/ipc/client_nonmobile.go @@ -0,0 +1,123 @@ +//go:build (!android && !ios && !darwin) || (darwin && standalone) + +package ipc + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + rlog "github.com/getlantern/radiance/log" +) + +// Client communicates with the IPC server over a local socket. +type Client struct { + http *http.Client +} + +// NewClient creates a new IPC client that communicates exclusively through the IPC server. +func NewClient() *Client { + return newClient() +} + +// Close releases resources held by the client, including any local backend. +func (c *Client) Close() { + c.http.CloseIdleConnections() +} + +// do executes an HTTP request with an optional JSON body and returns the raw response body. If +// body needs to be marshaled using sing/json, it should be pre-marshaled to []byte before passing +// to do. do returns an error if the response status is >= 400. +func (c *Client) do(ctx context.Context, method, endpoint string, body any) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + switch body := body.(type) { + case []byte: + bodyReader = bytes.NewReader(body) + default: + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + bodyReader = bytes.NewReader(data) + } + } + + req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("ipc request %s %s: %w", method, endpoint, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, &Error{ + Status: resp.StatusCode, + Message: strings.TrimSpace(string(respBody)), + } + } + return respBody, nil +} + +// TailLogs connects to the log stream endpoint and calls handler for each log +// entry received until ctx is cancelled or the connection is closed. +func (c *Client) TailLogs(ctx context.Context, handler func(rlog.LogEntry)) error { + return c.sseStream(ctx, logsStreamEndpoint, func(data []byte) { + var entry rlog.LogEntry + if json.Unmarshal(data, &entry) == nil { + handler(entry) + } + }) +} + +// sseStream connects to an SSE endpoint and calls handler for each event data line. +// Blocks until ctx is cancelled or the connection is closed. +func (c *Client) sseStream(ctx context.Context, endpoint string, handler func([]byte)) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL+endpoint, nil) + if err != nil { + return fmt.Errorf("create SSE request: %w", err) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := c.http.Do(req) + if err != nil { + if isConnectionError(err) { + return ErrIPCNotRunning + } + return fmt.Errorf("SSE connect %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &Error{Status: resp.StatusCode, Message: strings.TrimSpace(string(body))} + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if data, ok := strings.CutPrefix(line, "data: "); ok { + handler([]byte(data)) + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + return fmt.Errorf("SSE %s: read: %w", endpoint, err) + } + return nil +} diff --git a/vpn/ipc/conn_nonwindows.go b/ipc/conn_nonwindows.go similarity index 91% rename from vpn/ipc/conn_nonwindows.go rename to ipc/conn_nonwindows.go index 76266fd8..6aee1c41 100644 --- a/vpn/ipc/conn_nonwindows.go +++ b/ipc/conn_nonwindows.go @@ -15,11 +15,9 @@ import ( const apiURL = "http://lantern" -func dialContext(_ context.Context, _, _ string) (net.Conn, error) { - return net.DialUnix("unix", nil, &net.UnixAddr{ - Name: socketPath(), - Net: "unix", - }) +func dialContext(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath()) } type sockListener struct { diff --git a/vpn/ipc/conn_windows.go b/ipc/conn_windows.go similarity index 100% rename from vpn/ipc/conn_windows.go rename to ipc/conn_windows.go diff --git a/vpn/ipc/middlewares.go b/ipc/middlewares.go similarity index 58% rename from vpn/ipc/middlewares.go rename to ipc/middlewares.go index 56716242..1563c3ac 100644 --- a/vpn/ipc/middlewares.go +++ b/ipc/middlewares.go @@ -6,17 +6,16 @@ import ( "log/slog" "net/http" - "github.com/go-chi/chi/v5/middleware" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" "go.opentelemetry.io/otel/trace" - "github.com/getlantern/radiance/internal" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" ) -func log(h http.Handler) http.Handler { +func logger(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Pull the trace ID from the request, if it exists. ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) @@ -24,7 +23,7 @@ func log(h http.Handler) http.Handler { span := trace.SpanFromContext(r.Context()) span.SetAttributes(semconv.HTTPRouteKey.String(r.URL.Path)) - slog.Log(r.Context(), internal.LevelTrace, "IPC request", "method", r.Method, "path", r.URL.Path) + slog.Log(r.Context(), rlog.LevelTrace, "IPC request", "method", r.Method, "path", r.URL.Path) h.ServeHTTP(w, r) }) } @@ -36,15 +35,41 @@ func tracer(next http.Handler) http.Handler { r = r.WithContext(ctx) var buf bytes.Buffer - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - ww.Tee(&buf) + ww := &statusRecorder{ResponseWriter: w, body: &buf} next.ServeHTTP(ww, r) - if ww.Status() >= 400 { - traces.RecordError(ctx, fmt.Errorf("status %d: %s", ww.Status(), buf.String())) + if ww.status >= 400 { + traces.RecordError(ctx, fmt.Errorf("status %d: %s", ww.status, buf.String())) } }) } +// statusRecorder wraps http.ResponseWriter to capture the status code and response body. +type statusRecorder struct { + http.ResponseWriter + status int + body *bytes.Buffer +} + +func (r *statusRecorder) WriteHeader(code int) { + r.status = code + r.ResponseWriter.WriteHeader(code) +} + +func (r *statusRecorder) Write(b []byte) (int, error) { + if r.status == 0 { + r.status = http.StatusOK + } + r.body.Write(b) + return r.ResponseWriter.Write(b) +} + +// Flush implements http.Flusher if the underlying ResponseWriter supports it. +func (r *statusRecorder) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + func authPeer(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { peer := usrFromContext(r.Context()) diff --git a/ipc/server.go b/ipc/server.go new file mode 100644 index 00000000..62986398 --- /dev/null +++ b/ipc/server.go @@ -0,0 +1,1034 @@ +// Package ipc implements the IPC server for communicating between the client and the VPN service. +// It provides HTTP endpoints for retrieving statistics, managing groups, selecting outbounds, +// changing modes, and closing connections. +package ipc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/events" + rlog "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/vpn" + + sjson "github.com/sagernet/sing/common/json" +) + +const ( + tracerName = "github.com/getlantern/radiance/ipc" + + // VPN endpoints + vpnStatusEndpoint = "/vpn/status" + vpnConnectEndpoint = "/vpn/connect" + vpnDisconnectEndpoint = "/vpn/disconnect" + vpnRestartEndpoint = "/vpn/restart" + vpnConnectionsEndpoint = "/vpn/connections" + vpnOfflineTestsEndpoint = "/vpn/offline-tests" + vpnStatusEventsEndpoint = "/vpn/status/events" + + // Server selection endpoints + serverSelectedEndpoint = "/server/selected" + serverAutoSelectedEndpoint = "/server/auto-selected" + serverAutoSelectedEventsEndpoint = "/server/auto-selected/events" + + // Server management endpoints + serversEndpoint = "/servers" + serversAddEndpoint = "/servers/add" + serversRemoveEndpoint = "/servers/remove" + serversFromJSONEndpoint = "/servers/json" + serversFromURLsEndpoint = "/servers/urls" + serversPrivateEndpoint = "/servers/private" + serversPrivateInviteEndpoint = "/servers/private/invite" + + // Settings endpoints + featuresEndpoint = "/settings/features" + settingsEndpoint = "/settings" + + // Split tunnel endpoint + splitTunnelEndpoint = "/split-tunnel" + + // Account endpoints + accountNewUserEndpoint = "/account/new-user" + accountLoginEndpoint = "/account/login" + accountLogoutEndpoint = "/account/logout" + accountUserDataEndpoint = "/account/user" + accountDevicesEndpoint = "/account/devices/" + accountSignupEndpoint = "/account/signup/" + accountEmailEndpoint = "/account/email" + accountRecoveryEndpoint = "/account/recovery" + accountDeleteEndpoint = "/account/delete" + accountOAuthEndpoint = "/account/oauth" + accountDataCapEndpoint = "/account/datacap" + accountDataCapStreamEndpoint = "/account/datacap/stream" + + // Subscription endpoints + subscriptionActivationEndpoint = "/subscription/activation" + subscriptionStripeEndpoint = "/subscription/stripe" + subscriptionPaymentRedirectEndpoint = "/subscription/payment-redirect" + subscriptionReferralEndpoint = "/subscription/referral" + subscriptionBillingPortalEndpoint = "/subscription/billing-portal" + subscriptionPaymentRedirectURLEndpoint = "/subscription/payment-redirect-url" + subscriptionPlansEndpoint = "/subscription/plans" + subscriptionVerifyEndpoint = "/subscription/verify" + + // Issue endpoint + issueEndpoint = "/issue" + + // Logs endpoint + logsStreamEndpoint = "/logs/stream" +) + +var ( + protocols = func() http.Protocols { + var p http.Protocols + p.SetUnencryptedHTTP2(true) + return p + }() + + ErrServiceIsNotReady = errors.New("service is not ready") + ErrIPCNotRunning = errors.New("IPC not running") +) + +// Server represents the IPC server that communicates over a Unix domain socket for Unix-like +// systems, and a named pipe for Windows. +type Server struct { + svr *http.Server + closed atomic.Bool +} + +// NewServer creates a new Server instance with the provided Backend. +func NewServer(b *backend.LocalBackend, withAuth bool) *Server { + // Only add auth middleware if not running on mobile, since mobile platforms have their own + // sandboxing and permission models. + svr := &http.Server{ + Handler: newLocalAPI(b, withAuth), + ReadTimeout: 5 * time.Second, + Protocols: &protocols, + } + if withAuth { + svr.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + peer, err := getConnPeer(c) + if err != nil { + slog.Error("Failed to get peer credentials", "error", err) + } + return contextWithUsr(ctx, peer) + } + } + return &Server{svr: svr} +} + +// Start begins listening for incoming IPC requests. +func (s *Server) Start() error { + if s.closed.Load() { + return errors.New("IPC server is closed") + } + l, err := listen() + if err != nil { + return fmt.Errorf("IPC server: listen: %w", err) + } + go func() { + slog.Info("IPC server started", "address", l.Addr().String()) + if err := s.svr.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Error("IPC server error", "error", err) + } + s.closed.Store(true) + }() + return nil +} + +// Close shuts down the IPC server. +func (s *Server) Close() error { + if s.closed.Swap(true) { + return nil + } + slog.Info("Closing IPC server") + return s.svr.Close() +} + +type backendKey struct{} + +type localapi struct { + be atomic.Pointer[backend.LocalBackend] + handler http.Handler +} + +// backend returns the LocalBackend snapshotted at the start of the request. +func (s *localapi) backend(ctx context.Context) *backend.LocalBackend { + return ctx.Value(backendKey{}).(*backend.LocalBackend) +} + +func newLocalAPI(b *backend.LocalBackend, withAuth bool) *localapi { + s := &localapi{} + s.be.Store(b) + + mux := http.NewServeMux() + + // traced wraps a handler with the tracer middleware. + traced := func(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + tracer(http.HandlerFunc(h)).ServeHTTP(w, r) + } + } + + // VPN + mux.HandleFunc("GET "+vpnStatusEndpoint, traced(s.vpnStatusHandler)) + mux.HandleFunc("POST "+vpnConnectEndpoint, traced(s.vpnConnectHandler)) + mux.HandleFunc("POST "+vpnDisconnectEndpoint, traced(s.vpnDisconnectHandler)) + mux.HandleFunc("POST "+vpnRestartEndpoint, traced(s.vpnRestartHandler)) + mux.HandleFunc("GET "+vpnConnectionsEndpoint, traced(s.vpnConnectionsHandler)) + mux.HandleFunc("POST "+vpnOfflineTestsEndpoint, traced(s.vpnOfflineTestsHandler)) + + // SSE routes skip the tracer middleware since it buffers the entire response body. + mux.HandleFunc("GET "+vpnStatusEventsEndpoint, s.vpnStatusEventsHandler) + + // Server selection + mux.HandleFunc(serverSelectedEndpoint, traced(s.serverSelectedHandler)) + mux.HandleFunc("GET "+serverAutoSelectedEndpoint, traced(s.serverAutoSelectedHandler)) + mux.HandleFunc("GET "+serverAutoSelectedEventsEndpoint, s.serverAutoSelectedEventsHandler) + + // Server management + mux.HandleFunc("GET "+serversEndpoint, traced(s.serversHandler)) + mux.HandleFunc("POST "+serversAddEndpoint, traced(s.serversAddHandler)) + mux.HandleFunc("POST "+serversRemoveEndpoint, traced(s.serversRemoveHandler)) + mux.HandleFunc("POST "+serversFromJSONEndpoint, traced(s.serversFromJSONHandler)) + mux.HandleFunc("POST "+serversFromURLsEndpoint, traced(s.serversFromURLsHandler)) + mux.HandleFunc("POST "+serversPrivateEndpoint, traced(s.serversPrivateAddHandler)) + mux.HandleFunc(serversPrivateInviteEndpoint, traced(s.serversPrivateInviteHandler)) + + // Settings + mux.HandleFunc("GET "+featuresEndpoint, traced(s.featuresHandler)) + mux.HandleFunc(settingsEndpoint, traced(s.settingsHandler)) + + // Split tunnel + mux.HandleFunc(splitTunnelEndpoint, traced(s.splitTunnelHandler)) + + // Account + mux.HandleFunc("POST "+accountNewUserEndpoint, traced(s.accountNewUserHandler)) + mux.HandleFunc("POST "+accountLoginEndpoint, traced(s.accountLoginHandler)) + mux.HandleFunc("POST "+accountLogoutEndpoint, traced(s.accountLogoutHandler)) + mux.HandleFunc("GET "+accountUserDataEndpoint, traced(s.accountUserDataHandler)) + mux.HandleFunc(accountDevicesEndpoint+"{deviceID...}", traced(s.accountDevicesHandler)) + mux.HandleFunc("POST "+accountSignupEndpoint+"{action...}", traced(s.accountSignupHandler)) + mux.HandleFunc("POST "+accountEmailEndpoint+"/{action}", traced(s.accountEmailHandler)) + mux.HandleFunc("POST "+accountRecoveryEndpoint+"/{action}", traced(s.accountRecoveryHandler)) + mux.HandleFunc("DELETE "+accountDeleteEndpoint, traced(s.accountDeleteHandler)) + mux.HandleFunc(accountOAuthEndpoint, traced(s.accountOAuthHandler)) + mux.HandleFunc("GET "+accountDataCapEndpoint, traced(s.accountDataCapHandler)) + + // SSE routes skip the tracer middleware since it buffers the entire response body. + mux.HandleFunc("GET "+accountDataCapStreamEndpoint, s.accountDataCapStreamHandler) + + // Subscriptions + mux.HandleFunc("POST "+subscriptionActivationEndpoint, traced(s.subscriptionActivationHandler)) + mux.HandleFunc("POST "+subscriptionStripeEndpoint, traced(s.subscriptionStripeHandler)) + mux.HandleFunc("POST "+subscriptionPaymentRedirectEndpoint, traced(s.subscriptionPaymentRedirectHandler)) + mux.HandleFunc("POST "+subscriptionReferralEndpoint, traced(s.subscriptionReferralHandler)) + mux.HandleFunc("GET "+subscriptionBillingPortalEndpoint, traced(s.subscriptionBillingPortalHandler)) + mux.HandleFunc("POST "+subscriptionPaymentRedirectURLEndpoint, traced(s.subscriptionPaymentRedirectURLHandler)) + mux.HandleFunc("GET "+subscriptionPlansEndpoint, traced(s.subscriptionPlansHandler)) + mux.HandleFunc("POST "+subscriptionVerifyEndpoint, traced(s.subscriptionVerifyHandler)) + + // Issue + mux.HandleFunc("POST "+issueEndpoint, traced(s.issueReportHandler)) + + // Logs (SSE, skip tracer) + mux.HandleFunc("GET "+logsStreamEndpoint, s.logsStreamHandler) + + // Build the middleware chain: log -> (optional auth) -> mux + var handler http.Handler = mux + if withAuth { + handler = authPeer(handler) + } + handler = logger(handler) + s.handler = handler + + return s +} + +func (s *localapi) setBackend(b *backend.LocalBackend) *backend.LocalBackend { + return s.be.Swap(b) +} + +func (s *localapi) ServeHTTP(w http.ResponseWriter, r *http.Request) { + b := s.be.Load() + if b == nil { + http.Error(w, "service is not ready", http.StatusServiceUnavailable) + return + } + ctx := context.WithValue(r.Context(), backendKey{}, b) + s.handler.ServeHTTP(w, r.WithContext(ctx)) +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(v); err != nil { + slog.Error("IPC: failed to write JSON response", "error", err) + } +} + +func decodeJSON(r *http.Request, v any) error { + return json.NewDecoder(r.Body).Decode(v) +} + +func writeSingJSON[T any](w http.ResponseWriter, status int, v T) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := sjson.NewEncoderContext(boxCtx, w).Encode(v); err != nil { + slog.Error("IPC: failed to write JSON response", "error", err) + } +} + +func decodeSingJSON(r *http.Request, v any) error { + return sjson.NewDecoderContext(boxCtx, r.Body).Decode(v) +} + +// sseWriter sets headers for a Server-Sent Events response and returns the flusher. +// Returns nil if the ResponseWriter does not support flushing. +func sseWriter(w http.ResponseWriter) http.Flusher { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return nil + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + return flusher +} + +///////////// +// VPN // +///////////// + +func (s *localapi) vpnStatusHandler(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, s.backend(r.Context()).VPNStatus()) +} + +func (s *localapi) vpnConnectHandler(w http.ResponseWriter, r *http.Request) { + var req TagRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).ConnectVPN(req.Tag); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnDisconnectHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).DisconnectVPN(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnRestartHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).RestartVPN(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// vpnConnectionsHandler handles GET /vpn/connections/ (all) and GET /vpn/connections/active. +func (s *localapi) vpnConnectionsHandler(w http.ResponseWriter, r *http.Request) { + var ( + conns []vpn.Connection + err error + ) + if r.URL.Query().Get("active") == "true" { + conns, err = s.backend(r.Context()).ActiveVPNConnections() + } else { + conns, err = s.backend(r.Context()).VPNConnections() + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, conns) +} + +func (s *localapi) vpnOfflineTestsHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).RunOfflineURLTests(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnStatusEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := make(chan []byte, 16) + sub := events.Subscribe(func(evt vpn.StatusUpdateEvent) { + data, err := json.Marshal(evt) + if err != nil { + return + } + select { + case ch <- data: + default: + } + }) + defer sub.Unsubscribe() + for { + select { + case data := <-ch: + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +/////////////////////// +// Server selection // +/////////////////////// + +// serverSelectedHandler handles GET /server/selected (read) and POST /server/selected (set). +func (s *localapi) serverSelectedHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var req TagRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SelectServer(req.Tag); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + return + } + server, exists, err := s.backend(r.Context()).SelectedServer() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeSingJSON(w, http.StatusOK, SelectedServerResponse{Server: server, Exists: exists}) +} + +func (s *localapi) serverAutoSelectedHandler(w http.ResponseWriter, r *http.Request) { + selected, err := s.backend(r.Context()).CurrentAutoSelectedServer() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, selected) +} + +func (s *localapi) serverAutoSelectedEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := make(chan []byte, 16) + sub := events.Subscribe(func(evt vpn.AutoSelectedEvent) { + data, err := json.Marshal(evt) + if err != nil { + return + } + select { + case ch <- data: + default: + } + }) + defer sub.Unsubscribe() + for { + select { + case data := <-ch: + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +/////////////////////// +// Server management // +/////////////////////// + +// serversHandler handles GET /servers +func (s *localapi) serversHandler(w http.ResponseWriter, r *http.Request) { + if tag := r.URL.Query().Get("tag"); tag != "" { + server, found := s.backend(r.Context()).GetServerByTag(tag) + if !found { + http.Error(w, "server not found", http.StatusNotFound) + return + } + writeSingJSON(w, http.StatusOK, server) + return + } + writeSingJSON(w, http.StatusOK, s.backend(r.Context()).Servers()) +} + +func (s *localapi) serversAddHandler(w http.ResponseWriter, r *http.Request) { + var req AddServersRequest + if err := decodeSingJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).AddServers(req.Group, req.Options); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversRemoveHandler(w http.ResponseWriter, r *http.Request) { + var req RemoveServersRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).RemoveServers(req.Tags); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversFromJSONHandler(w http.ResponseWriter, r *http.Request) { + var req JSONConfigRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).AddServersByJSON(req.Config); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversFromURLsHandler(w http.ResponseWriter, r *http.Request) { + var req URLsRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).AddServersByURL(req.URLs, req.SkipCertVerification); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversPrivateAddHandler(w http.ResponseWriter, r *http.Request) { + var req PrivateServerRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err := s.backend(r.Context()).AddPrivateServer(req.Tag, req.IP, req.Port, req.AccessToken, req.Location, req.Joined) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// serversPrivateInviteHandler handles POST (create) and DELETE (revoke) on /servers/private/invite. +func (s *localapi) serversPrivateInviteHandler(w http.ResponseWriter, r *http.Request) { + var req PrivateServerInviteRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.Method == http.MethodDelete { + if err := s.backend(r.Context()).RevokePrivateServerInvite(req.IP, req.Port, req.AccessToken, req.InviteName); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + return + } + code, err := s.backend(r.Context()).InviteToPrivateServer(req.IP, req.Port, req.AccessToken, req.InviteName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, CodeResponse{Code: code}) +} + +////////////// +// Settings // +////////////// + +func (s *localapi) featuresHandler(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, s.backend(r.Context()).Features()) +} + +func (s *localapi) settingsHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPatch: + var updates settings.Settings + if err := decodeJSON(r, &updates); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).PatchSettings(updates); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fallthrough + case http.MethodGet: + writeJSON(w, http.StatusOK, settings.GetAll()) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +///////////////// +// Split Tunnel // +///////////////// + +// splitTunnelHandler handles GET (read), POST (add), and DELETE (remove) on /split-tunnel. +func (s *localapi) splitTunnelHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + writeJSON(w, http.StatusOK, s.backend(r.Context()).SplitTunnelFilters()) + return + } + var items vpn.SplitTunnelFilter + if err := decodeJSON(r, &items); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var err error + switch r.Method { + case http.MethodPost: + err = s.backend(r.Context()).AddSplitTunnelItems(items) + case http.MethodDelete: + err = s.backend(r.Context()).RemoveSplitTunnelItems(items) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +///////////// +// Account // +///////////// + +func (s *localapi) accountNewUserHandler(w http.ResponseWriter, r *http.Request) { + userData, err := s.backend(r.Context()).NewUser(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountLoginHandler(w http.ResponseWriter, r *http.Request) { + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).Login(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountLogoutHandler(w http.ResponseWriter, r *http.Request) { + var req EmailRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).Logout(r.Context(), req.Email) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountUserDataHandler(w http.ResponseWriter, r *http.Request) { + var userData *account.UserData + var err error + if r.URL.Query().Get("fetch") == "true" { + userData, err = s.backend(r.Context()).FetchUserData(r.Context()) + } else { + userData, err = s.backend(r.Context()).UserData() + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +// accountDevicesHandler handles GET /account/devices (list) and DELETE /account/devices/{deviceID} (remove). +func (s *localapi) accountDevicesHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodDelete { + resp, err := s.backend(r.Context()).RemoveDevice(r.Context(), r.PathValue("deviceID")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, resp) + return + } + devices, err := s.backend(r.Context()).UserDevices() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, devices) +} + +// accountSignupHandler handles POST /account/signup, /account/signup/confirm, and /account/signup/resend. +func (s *localapi) accountSignupHandler(w http.ResponseWriter, r *http.Request) { + switch r.PathValue("action") { + case "confirm": + var req EmailCodeRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SignupEmailConfirmation(r.Context(), req.Email, req.Code); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + case "resend": + var req EmailRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SignupEmailResendCode(r.Context(), req.Email); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + default: + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + salt, resp, err := s.backend(r.Context()).SignUp(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, SignupResponse{Salt: salt, Response: resp}) + } +} + +// accountEmailHandler handles POST /account/email/{action} for start and complete. +func (s *localapi) accountEmailHandler(w http.ResponseWriter, r *http.Request) { + var err error + switch r.PathValue("action") { + case "start": + var req ChangeEmailStartRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).StartChangeEmail(r.Context(), req.NewEmail, req.Password) + case "complete": + var req ChangeEmailCompleteRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).CompleteChangeEmail(r.Context(), req.NewEmail, req.Password, req.Code) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// accountRecoveryHandler handles POST /account/recovery/{action} for start, complete, and validate. +func (s *localapi) accountRecoveryHandler(w http.ResponseWriter, r *http.Request) { + var err error + switch r.PathValue("action") { + case "start": + var req EmailRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).StartRecoveryByEmail(r.Context(), req.Email) + case "complete": + var req RecoveryCompleteRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).CompleteRecoveryByEmail(r.Context(), req.Email, req.NewPassword, req.Code) + case "validate": + var req EmailCodeRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).ValidateEmailRecoveryCode(r.Context(), req.Email, req.Code) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) accountDeleteHandler(w http.ResponseWriter, r *http.Request) { + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).DeleteAccount(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +// accountOAuthHandler handles GET /account/oauth (login URL) and POST /account/oauth (callback). +func (s *localapi) accountOAuthHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var req OAuthTokenRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).OAuthLoginCallback(r.Context(), req.OAuthToken) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) + return + } + provider := r.URL.Query().Get("provider") + if provider == "" { + http.Error(w, "provider is required", http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).OAuthLoginUrl(r.Context(), provider) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) accountDataCapHandler(w http.ResponseWriter, r *http.Request) { + info, err := s.backend(r.Context()).DataCapInfo(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, info) +} + +func (s *localapi) accountDataCapStreamHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + poll := func() { + info, err := s.backend(r.Context()).DataCapInfo(r.Context()) + if err != nil { + slog.Error("datacap poll error", "error", err) + return + } + data, err := json.Marshal(info) + if err != nil { + return + } + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + // Send initial data immediately + poll() + for { + select { + case <-ticker.C: + poll() + case <-r.Context().Done(): + return + } + } +} + +/////////////////// +// Subscriptions // +/////////////////// + +func (s *localapi) subscriptionActivationHandler(w http.ResponseWriter, r *http.Request) { + var req ActivationRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp, err := s.backend(r.Context()).ActivationCode(r.Context(), req.Email, req.ResellerCode) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, resp) +} + +func (s *localapi) subscriptionStripeHandler(w http.ResponseWriter, r *http.Request) { + var req StripeSubscriptionRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + clientSecret, err := s.backend(r.Context()).NewStripeSubscription(r.Context(), req.Email, req.PlanID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, ClientSecretResponse{ClientSecret: clientSecret}) +} + +func (s *localapi) subscriptionPaymentRedirectHandler(w http.ResponseWriter, r *http.Request) { + var req account.PaymentRedirectData + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).PaymentRedirect(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionReferralHandler(w http.ResponseWriter, r *http.Request) { + var req CodeRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ok, err := s.backend(r.Context()).ReferralAttach(r.Context(), req.Code) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, SuccessResponse{Success: ok}) +} + +func (s *localapi) subscriptionBillingPortalHandler(w http.ResponseWriter, r *http.Request) { + u, err := s.backend(r.Context()).StripeBillingPortalURL(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionPaymentRedirectURLHandler(w http.ResponseWriter, r *http.Request) { + var req account.PaymentRedirectData + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).SubscriptionPaymentRedirectURL(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionPlansHandler(w http.ResponseWriter, r *http.Request) { + plans, err := s.backend(r.Context()).SubscriptionPlans(r.Context(), r.URL.Query().Get("channel")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, PlansResponse{Plans: plans}) +} + +func (s *localapi) subscriptionVerifyHandler(w http.ResponseWriter, r *http.Request) { + var req VerifySubscriptionRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + result, err := s.backend(r.Context()).VerifySubscription(r.Context(), req.Service, req.Data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, ResultResponse{Result: result}) +} + +/////////// +// Issue // +/////////// + +func (s *localapi) issueReportHandler(w http.ResponseWriter, r *http.Request) { + var req IssueReportRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).ReportIssue(req.IssueType, req.Description, req.Email, req.AdditionalAttachments); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +/////////// +// Logs // +/////////// + +func (s *localapi) logsStreamHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch, unsub := rlog.Subscribe() + defer unsub() + for { + select { + case entry := <-ch: + data, err := json.Marshal(entry) + if err != nil { + continue + } + fmt.Fprintf(w, "data: %s\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} diff --git a/vpn/ipc/socket.go b/ipc/socket.go similarity index 100% rename from vpn/ipc/socket.go rename to ipc/socket.go diff --git a/vpn/ipc/socket_mobile.go b/ipc/socket_mobile.go similarity index 74% rename from vpn/ipc/socket_mobile.go rename to ipc/socket_mobile.go index 6383a570..c7289f1e 100644 --- a/vpn/ipc/socket_mobile.go +++ b/ipc/socket_mobile.go @@ -11,7 +11,7 @@ import ( "syscall" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) // this is a no-op on mobile @@ -41,7 +41,7 @@ func getNonRootOwner(path string) (uid, gid int) { return uid, gid } - slog.Log(context.Background(), internal.LevelTrace, "searching for non-root owner of", "path", path) + slog.Log(context.Background(), log.LevelTrace, "searching for non-root owner of", "path", path) for { parentDir := filepath.Dir(path) if parentDir == path || parentDir == "/" { @@ -51,7 +51,7 @@ func getNonRootOwner(path string) (uid, gid int) { fInfo, err := os.Stat(path) if err != nil { - slog.Log(context.Background(), internal.LevelTrace, "stat error", "path", path, "error", err) + slog.Log(context.Background(), log.LevelTrace, "stat error", "path", path, "error", err) continue } stat, ok := fInfo.Sys().(*syscall.Stat_t) @@ -59,11 +59,11 @@ func getNonRootOwner(path string) (uid, gid int) { continue } if int(stat.Uid) != 0 { - slog.Log(context.Background(), internal.LevelTrace, "found non-root owner", "path", path, "uid", stat.Uid, "gid", stat.Gid) + slog.Log(context.Background(), log.LevelTrace, "found non-root owner", "path", path, "uid", stat.Uid, "gid", stat.Gid) return int(stat.Uid), int(stat.Gid) } } - if slog.Default().Enabled(context.Background(), internal.LevelTrace) { + if slog.Default().Enabled(context.Background(), log.LevelTrace) { slog.Warn("falling back to root owner for", "path", path) } return uid, gid diff --git a/vpn/ipc/testsetup.go b/ipc/testsetup.go similarity index 100% rename from vpn/ipc/testsetup.go rename to ipc/testsetup.go diff --git a/ipc/types.go b/ipc/types.go new file mode 100644 index 00000000..d012e489 --- /dev/null +++ b/ipc/types.go @@ -0,0 +1,146 @@ +package ipc + +import ( + "github.com/getlantern/common" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/issue" + "github.com/getlantern/radiance/servers" +) + +// Shared request types used by both client and server. + +type TagRequest struct { + Tag string `json:"tag"` +} + +type EmailRequest struct { + Email string `json:"email"` +} + +type EmailPasswordRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type EmailCodeRequest struct { + Email string `json:"email"` + Code string `json:"code"` +} + +type OAuthTokenRequest struct { + OAuthToken string `json:"oAuthToken"` +} + +type CodeRequest struct { + Code string `json:"code"` +} + +type JSONConfigRequest struct { + Config string `json:"config"` +} + +type AddServersRequest struct { + Group servers.ServerGroup `json:"group"` + Options servers.Options `json:"options"` +} + +type RemoveServersRequest struct { + Tags []string `json:"tags"` +} + +type URLsRequest struct { + URLs []string `json:"urls"` + SkipCertVerification bool `json:"skipCertVerification"` +} + +type PrivateServerRequest struct { + Tag string `json:"tag"` + IP string `json:"ip"` + Port int `json:"port"` + AccessToken string `json:"accessToken"` + Location common.ServerLocation `json:"location"` + Joined bool `json:"joined"` +} + +type PrivateServerInviteRequest struct { + IP string `json:"ip"` + Port int `json:"port"` + AccessToken string `json:"accessToken"` + InviteName string `json:"inviteName"` +} + +type ChangeEmailStartRequest struct { + NewEmail string `json:"newEmail"` + Password string `json:"password"` +} + +type ChangeEmailCompleteRequest struct { + NewEmail string `json:"newEmail"` + Password string `json:"password"` + Code string `json:"code"` +} + +type RecoveryCompleteRequest struct { + Email string `json:"email"` + NewPassword string `json:"newPassword"` + Code string `json:"code"` +} + +type ActivationRequest struct { + Email string `json:"email"` + ResellerCode string `json:"resellerCode"` +} + +type StripeSubscriptionRequest struct { + Email string `json:"email"` + PlanID string `json:"planID"` +} + +type VerifySubscriptionRequest struct { + Service account.SubscriptionService `json:"service"` + Data map[string]string `json:"data"` +} + +type IssueReportRequest struct { + IssueType issue.IssueType `json:"issueType"` + Description string `json:"description"` + Email string `json:"email"` + AdditionalAttachments []string `json:"additionalAttachments"` +} + +// Shared response types used by both client and server. + +type SelectedServerResponse struct { + Server servers.Server `json:"server"` + Exists bool `json:"exists"` +} + +type SignupResponse struct { + Salt []byte `json:"salt"` + Response *account.SignupResponse `json:"response"` +} + +type URLResponse struct { + URL string `json:"url"` +} + +type CodeResponse struct { + Code string `json:"code"` +} + +type ClientSecretResponse struct { + ClientSecret string `json:"clientSecret"` +} + +type SuccessResponse struct { + Success bool `json:"success"` +} + +type PlansResponse struct { + Plans string `json:"plans"` +} + +type ResultResponse struct { + Result string `json:"result"` +} diff --git a/vpn/ipc/usr.go b/ipc/usr.go similarity index 100% rename from vpn/ipc/usr.go rename to ipc/usr.go diff --git a/vpn/ipc/usr_darwin.go b/ipc/usr_darwin.go similarity index 100% rename from vpn/ipc/usr_darwin.go rename to ipc/usr_darwin.go diff --git a/vpn/ipc/usr_linux.go b/ipc/usr_linux.go similarity index 100% rename from vpn/ipc/usr_linux.go rename to ipc/usr_linux.go diff --git a/vpn/ipc/usr_windows.go b/ipc/usr_windows.go similarity index 100% rename from vpn/ipc/usr_windows.go rename to ipc/usr_windows.go diff --git a/vpn/ipc/zsyscall_windows.go b/ipc/zsyscall_windows.go similarity index 100% rename from vpn/ipc/zsyscall_windows.go rename to ipc/zsyscall_windows.go diff --git a/issue/archive.go b/issue/archive.go new file mode 100644 index 00000000..6a675943 --- /dev/null +++ b/issue/archive.go @@ -0,0 +1,212 @@ +package issue + +import ( + "archive/zip" + "bytes" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" +) + +// buildIssueArchive creates a zip archive containing the log file and additional +// attachment files. The total compressed archive size will not exceed maxSize bytes. +// +// Additional files are included only if space permits after the log. +func buildIssueArchive(logPath string, additionalFiles []string, maxSize int64) ([]byte, error) { + logData, err := snapshotLogFile(logPath, maxSize) + if err != nil { + slog.Warn("unable to snapshot log file, trying additional files only", "path", logPath, "error", err) + } + + extras := readExtraFiles(additionalFiles) + + return fitArchive(logData, extras, maxSize) +} + +// snapshotLogFile opens the log file, records its current size, and reads the tail +// up to a reasonable cap. +func snapshotLogFile(logPath string, maxCompressed int64) ([]byte, error) { + f, err := os.Open(logPath) + if err != nil { + return nil, err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return nil, err + } + + size := fi.Size() + if size == 0 { + return nil, nil + } + + // Cap the amount we read: even with poor compression, we'd never need more + // than maxCompressed * 20 bytes of uncompressed log to fill the archive. + maxRead := maxCompressed * 20 + readSize := size + if readSize > maxRead { + readSize = maxRead + } + + // Seek to read only the tail (most recent logs). + if size > readSize { + if _, err := f.Seek(size-readSize, io.SeekStart); err != nil { + return nil, err + } + } + + data := make([]byte, readSize) + n, err := io.ReadFull(f, data) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, fmt.Errorf("reading log file: %w", err) + } + return data[:n], nil +} + +type extraFile struct { + name string + data []byte +} + +func readExtraFiles(paths []string) []extraFile { + var files []extraFile + for _, p := range paths { + data, err := os.ReadFile(p) + if err != nil { + slog.Warn("unable to read additional file", "path", p, "error", err) + continue + } + files = append(files, extraFile{ + name: filepath.Base(p), + data: data, + }) + } + return files +} + +// fitArchive builds a zip archive that fits within maxSize, prioritizing log data. +func fitArchive(logData []byte, extras []extraFile, maxSize int64) ([]byte, error) { + if len(logData) == 0 && len(extras) == 0 { + return nil, nil + } + + // Try everything. + buf, err := writeArchive(logData, extras) + if err != nil { + return nil, err + } + if int64(buf.Len()) <= maxSize { + return buf.Bytes(), nil + } + + // Try full log, no extras. + if len(logData) > 0 { + buf, err = writeArchive(logData, nil) + if err != nil { + return nil, err + } + if int64(buf.Len()) <= maxSize { + // Full log fits — greedily add extras that still fit. + return addExtrasGreedily(logData, extras, maxSize) + } + + // Full log doesn't fit — binary search for the maximum tail. + tailSize := searchMaxLogTail(logData, maxSize) + tail := logData[len(logData)-tailSize:] + return addExtrasGreedily(tail, extras, maxSize) + } + + // No log data — try extras only. + return addExtrasGreedily(nil, extras, maxSize) +} + +const logArchiveName = "lantern.log" + +func writeArchive(logData []byte, extras []extraFile) (*bytes.Buffer, error) { + buf := new(bytes.Buffer) + w := zip.NewWriter(buf) + + if len(logData) > 0 { + fw, err := w.Create(logArchiveName) + if err != nil { + return nil, err + } + if _, err := fw.Write(logData); err != nil { + return nil, err + } + } + + for _, f := range extras { + fw, err := w.Create("attachments/" + f.name) + if err != nil { + return nil, err + } + if _, err := fw.Write(f.data); err != nil { + return nil, err + } + } + + if err := w.Close(); err != nil { + return nil, err + } + return buf, nil +} + +// searchMaxLogTail binary-searches for the largest tail of logData (in 256KB chunks) +// that compresses into a zip archive not exceeding maxSize. +func searchMaxLogTail(logData []byte, maxSize int64) int { + const chunkSize = 256 * 1024 + n := len(logData) + lo, hi := 1, (n+chunkSize-1)/chunkSize + best := 0 + + for lo <= hi { + mid := lo + (hi-lo)/2 + tailBytes := mid * chunkSize + if tailBytes > n { + tailBytes = n + } + + buf, err := writeArchive(logData[n-tailBytes:], nil) + if err != nil { + hi = mid - 1 + continue + } + if int64(buf.Len()) <= maxSize { + best = tailBytes + lo = mid + 1 + } else { + hi = mid - 1 + } + } + return best +} + +// addExtrasGreedily starts from the given log data and adds extra files one by one, +// keeping each only if the archive still fits within maxSize. +func addExtrasGreedily(logData []byte, extras []extraFile, maxSize int64) ([]byte, error) { + var included []extraFile + buf, err := writeArchive(logData, nil) + if err != nil { + return nil, err + } + lastGood := buf.Bytes() + + for _, f := range extras { + // Safe append that won't modify the existing slice's backing array. + trial := append(included[:len(included):len(included)], f) + buf, err := writeArchive(logData, trial) + if err != nil { + continue + } + if int64(buf.Len()) <= maxSize { + included = trial + lastGood = buf.Bytes() + } + } + return lastGood, nil +} diff --git a/issue/archive_test.go b/issue/archive_test.go new file mode 100644 index 00000000..46efbb07 --- /dev/null +++ b/issue/archive_test.go @@ -0,0 +1,411 @@ +package issue + +import ( + "archive/zip" + "bytes" + "crypto/rand" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSnapshotLogFile(t *testing.T) { + t.Run("reads full file when small", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + content := "line1\nline2\nline3\n" + require.NoError(t, os.WriteFile(logPath, []byte(content), 0644)) + + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + assert.Equal(t, content, string(data)) + }) + + t.Run("reads only tail when file exceeds cap", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + + // maxCompressed=100 → maxRead = 100*20 = 2000 + // Write 5000 bytes so the file exceeds the cap. + full := strings.Repeat("X", 5000) + require.NoError(t, os.WriteFile(logPath, []byte(full), 0644)) + + data, err := snapshotLogFile(logPath, 100) + require.NoError(t, err) + assert.Equal(t, 2000, len(data)) + // Should be the tail of the file. + assert.Equal(t, full[3000:], string(data)) + }) + + t.Run("returns nil for empty file", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "empty.log") + require.NoError(t, os.WriteFile(logPath, nil, 0644)) + + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + assert.Nil(t, data) + }) + + t.Run("returns error for missing file", func(t *testing.T) { + _, err := snapshotLogFile("/nonexistent/path.log", 1024*1024) + assert.Error(t, err) + }) + + t.Run("snapshot is stable after file rotation", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + original := "original log content\n" + require.NoError(t, os.WriteFile(logPath, []byte(original), 0644)) + + // Open and snapshot size (simulating what snapshotLogFile does internally). + f, err := os.Open(logPath) + require.NoError(t, err) + defer f.Close() + + fi, err := f.Stat() + require.NoError(t, err) + size := fi.Size() + + // Simulate rotation: rename the file and create a new one. + require.NoError(t, os.Rename(logPath, logPath+".1")) + require.NoError(t, os.WriteFile(logPath, []byte("new log content\n"), 0644)) + + // The original fd should still read the original data. + data := make([]byte, size) + n, err := f.Read(data) + require.NoError(t, err) + assert.Equal(t, original, string(data[:n])) + }) +} + +func TestReadExtraFiles(t *testing.T) { + t.Run("reads existing files", func(t *testing.T) { + dir := t.TempDir() + f1 := filepath.Join(dir, "a.txt") + f2 := filepath.Join(dir, "b.txt") + require.NoError(t, os.WriteFile(f1, []byte("aaa"), 0644)) + require.NoError(t, os.WriteFile(f2, []byte("bbb"), 0644)) + + files := readExtraFiles([]string{f1, f2}) + require.Len(t, files, 2) + assert.Equal(t, "a.txt", files[0].name) + assert.Equal(t, "aaa", string(files[0].data)) + assert.Equal(t, "b.txt", files[1].name) + assert.Equal(t, "bbb", string(files[1].data)) + }) + + t.Run("skips missing files", func(t *testing.T) { + dir := t.TempDir() + existing := filepath.Join(dir, "exists.txt") + require.NoError(t, os.WriteFile(existing, []byte("data"), 0644)) + + files := readExtraFiles([]string{"/no/such/file", existing}) + require.Len(t, files, 1) + assert.Equal(t, "exists.txt", files[0].name) + }) + + t.Run("nil input returns nil", func(t *testing.T) { + files := readExtraFiles(nil) + assert.Nil(t, files) + }) +} + +func TestWriteArchive(t *testing.T) { + t.Run("log only", func(t *testing.T) { + logData := []byte("some log content") + buf, err := writeArchive(logData, nil) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "some log content", entries[0].content) + }) + + t.Run("log with extras", func(t *testing.T) { + logData := []byte("log line") + extras := []extraFile{ + {name: "config.json", data: []byte(`{"key":"val"}`)}, + {name: "screenshot.png", data: []byte("fake png")}, + } + buf, err := writeArchive(logData, extras) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 3) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "attachments/config.json", entries[1].name) + assert.Equal(t, "attachments/screenshot.png", entries[2].name) + }) + + t.Run("extras only", func(t *testing.T) { + extras := []extraFile{{name: "file.txt", data: []byte("hello")}} + buf, err := writeArchive(nil, extras) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 1) + assert.Equal(t, "attachments/file.txt", entries[0].name) + }) + + t.Run("empty inputs", func(t *testing.T) { + buf, err := writeArchive(nil, nil) + require.NoError(t, err) + // Should produce a valid but empty zip. + entries := readZipEntries(t, buf.Bytes()) + assert.Empty(t, entries) + }) +} + +func TestFitArchive(t *testing.T) { + t.Run("everything fits", func(t *testing.T) { + logData := []byte("small log") + extras := []extraFile{{name: "a.txt", data: []byte("small")}} + result, err := fitArchive(logData, extras, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 2) + }) + + t.Run("nil log and nil extras returns nil", func(t *testing.T) { + result, err := fitArchive(nil, nil, 1024*1024) + require.NoError(t, err) + assert.Nil(t, result) + }) + + t.Run("extras dropped when too large", func(t *testing.T) { + logData := []byte("log data") + // Make an extra that's big enough to push past a small maxSize. + bigExtra := extraFile{name: "big.bin", data: bytes.Repeat([]byte{0xFF}, 50*1024)} + + // Find the compressed size of just the log. + logOnly, err := writeArchive(logData, nil) + require.NoError(t, err) + maxSize := int64(logOnly.Len()) + 100 // just barely enough for log, not the extra + + result, err := fitArchive(logData, []extraFile{bigExtra}, maxSize) + require.NoError(t, err) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "log data", entries[0].content) + }) + + t.Run("log truncated to tail when too large", func(t *testing.T) { + // Use incompressible random data (2MB) with a budget that fits ~1-2 + // chunks (256KB each) but not the full log. + logData := make([]byte, 2*1024*1024) // 2MB + _, err := rand.Read(logData) + require.NoError(t, err) + + maxSize := int64(512 * 1024) // 512KB + + result, err := fitArchive(logData, nil, maxSize) + require.NoError(t, err) + assert.LessOrEqual(t, int64(len(result)), maxSize) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + + // The included content should be a tail of the original. + content := entries[0].content + assert.True(t, len(content) < len(logData), "log should be truncated") + assert.Equal(t, string(logData[len(logData)-len(content):]), content, + "included content should be the tail of the original log") + }) + + t.Run("extras only when no log", func(t *testing.T) { + extras := []extraFile{ + {name: "a.txt", data: []byte("aaa")}, + {name: "b.txt", data: []byte("bbb")}, + } + result, err := fitArchive(nil, extras, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 2) + }) +} + +func TestSearchMaxLogTail(t *testing.T) { + t.Run("all fits", func(t *testing.T) { + logData := []byte("small log data") + tailSize := searchMaxLogTail(logData, 1024*1024) + assert.Equal(t, len(logData), tailSize) + }) + + t.Run("truncates incompressible data", func(t *testing.T) { + logData := make([]byte, 1024*1024) // 1MB random + _, err := rand.Read(logData) + require.NoError(t, err) + + maxSize := int64(300 * 1024) // 300KB + tailSize := searchMaxLogTail(logData, maxSize) + assert.Greater(t, tailSize, 0) + assert.Less(t, tailSize, len(logData)) + + // Verify the result actually fits. + buf, err := writeArchive(logData[len(logData)-tailSize:], nil) + require.NoError(t, err) + assert.LessOrEqual(t, int64(buf.Len()), maxSize) + }) +} + +func TestAddExtrasGreedily(t *testing.T) { + t.Run("adds all when they fit", func(t *testing.T) { + logData := []byte("log") + extras := []extraFile{ + {name: "a.txt", data: []byte("aaa")}, + {name: "b.txt", data: []byte("bbb")}, + } + result, err := addExtrasGreedily(logData, extras, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 3) + }) + + t.Run("skips extras that would exceed limit", func(t *testing.T) { + logData := []byte("log") + small := extraFile{name: "small.txt", data: []byte("s")} + big := extraFile{name: "big.bin", data: bytes.Repeat([]byte{0xFF}, 50*1024)} + + // Budget enough for log + small, but not big. + bufWithSmall, err := writeArchive(logData, []extraFile{small}) + require.NoError(t, err) + maxSize := int64(bufWithSmall.Len()) + 50 // tight budget + + result, err := addExtrasGreedily(logData, []extraFile{small, big}, maxSize) + require.NoError(t, err) + + entries := readZipEntries(t, result) + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.name + } + assert.Contains(t, names, logArchiveName) + assert.Contains(t, names, "attachments/small.txt") + assert.NotContains(t, names, "attachments/big.bin") + }) + + t.Run("no extras returns log only", func(t *testing.T) { + logData := []byte("log content") + result, err := addExtrasGreedily(logData, nil, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + }) +} + +func TestBuildIssueArchive(t *testing.T) { + t.Run("end to end with log and extras", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "lantern.log") + require.NoError(t, os.WriteFile(logPath, []byte("log line 1\nlog line 2\n"), 0644)) + + extra := filepath.Join(dir, "extra.txt") + require.NoError(t, os.WriteFile(extra, []byte("extra content"), 0644)) + + result, err := buildIssueArchive(logPath, []string{extra}, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + require.Len(t, entries, 2) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "log line 1\nlog line 2\n", entries[0].content) + assert.Equal(t, "attachments/extra.txt", entries[1].name) + }) + + t.Run("missing log file still includes extras", func(t *testing.T) { + dir := t.TempDir() + extra := filepath.Join(dir, "extra.txt") + require.NoError(t, os.WriteFile(extra, []byte("data"), 0644)) + + result, err := buildIssueArchive(filepath.Join(dir, "nonexistent.log"), []string{extra}, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, "attachments/extra.txt", entries[0].name) + }) + + t.Run("archive respects maxSize", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "lantern.log") + // Write incompressible data (2MB). + logContent := make([]byte, 2*1024*1024) + _, err := rand.Read(logContent) + require.NoError(t, err) + require.NoError(t, os.WriteFile(logPath, logContent, 0644)) + + maxSize := int64(512 * 1024) + result, err := buildIssueArchive(logPath, nil, maxSize) + require.NoError(t, err) + assert.LessOrEqual(t, int64(len(result)), maxSize) + + // Verify it contains the tail. + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + content := entries[0].content + assert.Equal(t, string(logContent[len(logContent)-len(content):]), content) + }) + + t.Run("snapshot excludes data written after call", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "lantern.log") + original := "before snapshot\n" + require.NoError(t, os.WriteFile(logPath, []byte(original), 0644)) + + // Snapshot the file. + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + + // Append after snapshot. + f, err := os.OpenFile(logPath, os.O_APPEND|os.O_WRONLY, 0644) + require.NoError(t, err) + _, err = f.WriteString("after snapshot\n") + require.NoError(t, err) + f.Close() + + // Snapshot should only contain original content. + assert.Equal(t, original, string(data)) + }) +} + +// --- test helpers --- + +type zipEntry struct { + name string + content string +} + +func readZipEntries(t *testing.T, data []byte) []zipEntry { + t.Helper() + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + require.NoError(t, err) + + var entries []zipEntry + for _, f := range r.File { + rc, err := f.Open() + require.NoError(t, err) + body, err := io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + entries = append(entries, zipEntry{name: f.Name, content: string(body)}) + } + return entries +} diff --git a/issue/issue.go b/issue/issue.go index 711f61f5..ed6a7647 100644 --- a/issue/issue.go +++ b/issue/issue.go @@ -4,140 +4,131 @@ import ( "bytes" "context" "fmt" + "io" "log/slog" "math/rand" "net/http" "net/http/httputil" - "strconv" + "path/filepath" + "runtime" "time" "github.com/getlantern/osversion" + "github.com/getlantern/timezone" "go.opentelemetry.io/otel" - "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" + "github.com/getlantern/radiance/internal" "github.com/getlantern/radiance/traces" "google.golang.org/protobuf/proto" ) const ( - maxUncompressedLogSize = 50 * 1024 * 1024 // 50 MB - tracerName = "github.com/getlantern/radiance/issue" + maxCompressedSize = 20 * 1024 * 1024 // 20 MB + tracerName = "github.com/getlantern/radiance/issue" ) -// IssueReporter is used to send issue reports to backend -type IssueReporter struct{} +// IssueReporter is used to send issue reports to backend. +type IssueReporter struct { + httpClient *http.Client +} // NewIssueReporter creates a new IssueReporter that can be used to send issue reports // to the backend. -func NewIssueReporter() *IssueReporter { - return &IssueReporter{} +func NewIssueReporter(httpClient *http.Client) *IssueReporter { + return &IssueReporter{httpClient: httpClient} } -func randStr(n int) string { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - var hexStr string - for i := 0; i < n; i++ { - hexStr += fmt.Sprintf("%x", r.Intn(16)) - } - return hexStr -} +type IssueType int -// Attachment is a file attachment -type Attachment struct { - Name string - Data []byte -} +const ( + CannotCompletePurchase IssueType = iota + CannotSignIn + SpinnerLoadsEndlessly + CannotAccessBlockedSites + Slow + CannotLinkDevice + ApplicationCrashes + Other IssueType = iota + 2 + UpdateFails +) + +// // issue text to type mapping +// var issueTypeMap = map[string]IssueType{ +// "Cannot complete purchase": CannotCompletePurchase, +// "Cannot sign in": CannotSignIn, +// "Spinner loads endlessly": SpinnerLoadsEndlessly, +// "Cannot access blocked sites": CannotAccessBlockedSites, +// "Slow": Slow, +// "Cannot link device": CannotLinkDevice, +// "Application crashes": ApplicationCrashes, +// "Other": Other, +// "Update fails": UpdateFails, +// } type IssueReport struct { // Type is one of the predefined issue type strings - Type string - // Issue description + Type IssueType Description string - // Attachment is a list of issue attachments - Attachments []*Attachment + Email string + CountryCode string // device common name - Device string + Device string + DeviceID string + UserID string + SubscriptionLevel string + Locale string // device alphanumeric name Model string -} - -// issue text to type mapping -var issueTypeMap = map[string]int{ - "Cannot complete purchase": 0, - "Cannot sign in": 1, - "Spinner loads endlessly": 2, - "Cannot access blocked sites": 3, - "Slow": 4, - "Cannot link device": 5, - "Application crashes": 6, - "Other": 9, - "Update fails": 10, + // AdditionalAttachments is a list of additional files to be attached. The log file will be + // automatically included. + AdditionalAttachments []string } // Report sends an issue report to lantern-cloud/issue, which is then forwarded to ticket system via API -func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEmail, country string) error { +func (ir *IssueReporter) Report(ctx context.Context, report IssueReport) error { ctx, span := otel.Tracer(tracerName).Start(ctx, "Report") defer span.End() // set a random email if it's empty - if userEmail == "" { - userEmail = "support+" + randStr(8) + "@getlantern.org" + if report.Email == "" { + report.Email = "support+" + randStr(8) + "@getlantern.org" } - userStatus := settings.GetString(settings.UserLevelKey) + // userStatus := settings.GetString(settings.UserLevelKey) osVersion, err := osversion.GetHumanReadable() if err != nil { slog.Error("Unable to get OS version", "error", err) + osVersion = runtime.GOOS + " " + runtime.GOARCH } - // get issue type as integer - iType, ok := issueTypeMap[report.Type] - if !ok { - slog.Error("Unknown issue type, setting to 'Other'", "type", report.Type) - iType = 9 - } - r := &ReportIssueRequest{ - Type: ReportIssueRequest_ISSUE_TYPE(iType), - CountryCode: country, + Type: ReportIssueRequest_ISSUE_TYPE(report.Type), AppVersion: common.Version, - SubscriptionLevel: userStatus, Platform: common.Platform, + CountryCode: report.CountryCode, + SubscriptionLevel: report.SubscriptionLevel, Description: report.Description, - UserEmail: userEmail, - DeviceId: settings.GetString(settings.DeviceIDKey), - UserId: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), + UserEmail: report.Email, + DeviceId: report.DeviceID, + UserId: report.UserID, Device: report.Device, Model: report.Model, + Language: report.Locale, OsVersion: osVersion, - Language: settings.GetString(settings.LocaleKey), } - for _, attachment := range report.Attachments { - r.Attachments = append(r.Attachments, &ReportIssueRequest_Attachment{ - Type: "application/zip", - Name: attachment.Name, - Content: attachment.Data, - }) + logPath := filepath.Join(settings.GetString(settings.LogPathKey), internal.LogFileName) + archive, err := buildIssueArchive(logPath, report.AdditionalAttachments, maxCompressedSize) + if err != nil { + slog.Error("failed to build issue archive", "error", err) } - - // Zip logs - slog.Debug("zipping log files for issue report") - buf := &bytes.Buffer{} - // zip * under folder common.LogDir - logDir := settings.GetString(settings.LogPathKey) - slog.Debug("zipping log files", "logDir", logDir, "maxSize", maxUncompressedLogSize) - if _, zipErr := zipLogFiles(buf, logDir, maxUncompressedLogSize, int64(maxUncompressedLogSize)); zipErr == nil { - r.Attachments = append(r.Attachments, &ReportIssueRequest_Attachment{ + if len(archive) > 0 { + r.Attachments = []*ReportIssueRequest_Attachment{{ Type: "application/zip", Name: "logs.zip", - Content: buf.Bytes(), - }) - slog.Debug("log files zipped for issue report", "size", len(buf.Bytes())) - } else { - slog.Error("unable to zip log files", "error", err, "logDir", logDir, "maxSize", maxUncompressedLogSize) + Content: archive, + }} } // send message to lantern-cloud @@ -148,7 +139,7 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma } issueURL := common.GetBaseURL() + "/issue" - req, err := backend.NewIssueRequest( + req, err := newIssueRequest( ctx, http.MethodPost, issueURL, @@ -159,7 +150,7 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma return traces.RecordError(ctx, err) } - resp, err := kindling.HTTPClient().Do(req) + resp, err := ir.httpClient.Do(req) if err != nil { slog.Error("failed to send issue report", "error", err, "requestURL", issueURL) return traces.RecordError(ctx, err) @@ -178,3 +169,28 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma slog.Debug("issue report sent") return nil } + +// newIssueRequest creates a new HTTP request with the required headers for issue reporting. +func newIssueRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + req, err := common.NewRequestWithHeaders(ctx, method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("content-type", "application/x-protobuf") + req.Header.Set(common.SupportedDataCapsHeader, "monthly,weekly,daily") + if tz, err := timezone.IANANameForTime(time.Now()); err == nil { + req.Header.Set(common.TimeZoneHeader, tz) + } + + return req, nil +} + +func randStr(n int) string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var hexStr string + for range n { + hexStr += fmt.Sprintf("%x", r.Intn(16)) + } + return hexStr +} diff --git a/issue/issue_test.go b/issue/issue_test.go index 7e6b4634..58609fa3 100644 --- a/issue/issue_test.go +++ b/issue/issue_test.go @@ -1,12 +1,15 @@ package issue import ( + "archive/zip" + "bytes" "context" "io" "net/http" "net/http/httptest" "net/url" - "strconv" + "os" + "path/filepath" "testing" "github.com/getlantern/osversion" @@ -16,7 +19,6 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" ) func TestSendReport(t *testing.T) { @@ -26,7 +28,13 @@ func TestSendReport(t *testing.T) { osVer, err := osversion.GetHumanReadable() require.NoError(t, err) - // Build expected report + // Create a temp file to use as an additional attachment + tmpDir := t.TempDir() + attachPath := filepath.Join(tmpDir, "Hello.txt") + err = os.WriteFile(attachPath, []byte("Hello World"), 0644) + require.NoError(t, err) + + // Build expected report (without attachments — we verify those separately) want := &ReportIssueRequest{ Type: ReportIssueRequest_NO_ACCESS, CountryCode: "US", @@ -36,53 +44,40 @@ func TestSendReport(t *testing.T) { Description: "Description placeholder-test only", UserEmail: "radiancetest@getlantern.org", DeviceId: settings.GetString(settings.DeviceIDKey), - UserId: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), + UserId: settings.GetString(settings.UserIDKey), Device: "Samsung Galaxy S10", Model: "SM-G973F", OsVersion: osVer, Language: settings.GetString(settings.LocaleKey), - Attachments: []*ReportIssueRequest_Attachment{ - { - Type: "application/zip", - Name: "Hello.txt", - Content: []byte("Hello World"), - }, - }, } srv := newTestServer(t, want) defer srv.Close() - reporter := &IssueReporter{} - kindling.SetKindling(&mockKindling{newTestClient(t, srv.URL)}) - report := IssueReport{ - Type: "Cannot access blocked sites", - Description: "Description placeholder-test only", - Attachments: []*Attachment{ - { - Name: "Hello.txt", - Data: []byte("Hello World"), - }, - }, - Device: "Samsung Galaxy S10", - Model: "SM-G973F", - } - - err = reporter.Report(context.Background(), report, "radiancetest@getlantern.org", "US") - require.NoError(t, err) -} - -func newTestClient(t *testing.T, testURL string) *http.Client { - return &http.Client{ + reporter := NewIssueReporter(&http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - parsedURL, err := url.Parse(testURL) - if err != nil { - t.Fatalf("failed to parse testURL: %v", err) - } + parsedURL, err := url.Parse(srv.URL) + require.NoError(t, err, "failed to parse test server URL") req.URL = parsedURL return http.DefaultTransport.RoundTrip(req) }), + }) + report := IssueReport{ + Type: CannotAccessBlockedSites, + Description: "Description placeholder-test only", + Email: "radiancetest@getlantern.org", + CountryCode: "US", + SubscriptionLevel: "free", + DeviceID: settings.GetString(settings.DeviceIDKey), + UserID: settings.GetString(settings.UserIDKey), + Locale: settings.GetString(settings.LocaleKey), + Device: "Samsung Galaxy S10", + Model: "SM-G973F", + AdditionalAttachments: []string{attachPath}, } + + err = reporter.Report(context.Background(), report) + require.NoError(t, err) } // roundTripperFunc allows using a function as http.RoundTripper @@ -109,18 +104,29 @@ func newTestServer(t *testing.T, want *ReportIssueRequest) *testServer { err = proto.Unmarshal(body, &got) require.NoError(t, err, "should unmarshal protobuf request") - // Filter got.Attachments to only include the ones we're testing - // (exclude logs.zip and other dynamic attachments) - filteredAttachments := make([]*ReportIssueRequest_Attachment, 0) - for _, gotAtt := range got.Attachments { - for _, wantAtt := range ts.want.Attachments { - if gotAtt.Name == wantAtt.Name { - filteredAttachments = append(filteredAttachments, gotAtt) - break + // Verify logs.zip attachment contains the additional file + var foundHello bool + for _, att := range got.Attachments { + if att.Name == "logs.zip" { + zr, err := zip.NewReader(bytes.NewReader(att.Content), int64(len(att.Content))) + require.NoError(t, err, "should open logs.zip") + for _, f := range zr.File { + if f.Name == "attachments/Hello.txt" { + rc, err := f.Open() + require.NoError(t, err) + data, err := io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + assert.Equal(t, "Hello World", string(data)) + foundHello = true + } } } } - got.Attachments = filteredAttachments + assert.True(t, foundHello, "logs.zip should contain attachments/Hello.txt") + + // Clear attachments for field-level comparison + got.Attachments = nil // Compare received report with expected report using proto.Equal if assert.True(t, proto.Equal(ts.want, &got), "received report should match expected report") { @@ -131,17 +137,3 @@ func newTestServer(t *testing.T, want *ReportIssueRequest) *testServer { })) return ts } - -type mockKindling struct { - c *http.Client -} - -// NewHTTPClient returns a new HTTP client that is configured to use kindling. -func (m *mockKindling) NewHTTPClient() *http.Client { - return m.c -} - -// ReplaceTransport replaces an existing transport RoundTripper generator with the provided one. -func (m *mockKindling) ReplaceTransport(name string, rt func(ctx context.Context, addr string) (http.RoundTripper, error)) error { - panic("not implemented") // TODO: Implement -} diff --git a/issue/logzipper.go b/issue/logzipper.go deleted file mode 100644 index 693a1a87..00000000 --- a/issue/logzipper.go +++ /dev/null @@ -1,111 +0,0 @@ -package issue - -// copied from flashlight/logging/logging.go - -import ( - "io" - "log/slog" - "os" - "path/filepath" - "sort" -) - -type fileInfo struct { - file string - size int64 - modTime int64 -} -type byDate []*fileInfo - -func (a byDate) Len() int { return len(a) } -func (a byDate) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a byDate) Less(i, j int) bool { return a[i].modTime > a[j].modTime } - -// zipLogFiles zips the Lantern log files to the writer. All files will be -// placed under the folder in the archieve. It will stop and return if the -// newly added file would make the extracted files exceed maxBytes in total. -// -// It also returns up to maxTextBytes of plain text from the end of the most recent log file. -func zipLogFiles(w io.Writer, logDir string, maxBytes int64, maxTextBytes int64) (string, error) { - return zipLogFilesFrom(w, maxBytes, maxTextBytes, map[string]string{"logs": logDir}) -} - -// zipLogFilesFrom zips the log files from the given dirs to the writer. It will -// stop and return if the newly added file would make the extracted files exceed -// maxBytes in total. -// -// It also returns up to maxTextBytes of plain text from the end of the most recent log file. -func zipLogFilesFrom(w io.Writer, maxBytes int64, maxTextBytes int64, dirs map[string]string) (string, error) { - globs := make(map[string]string, len(dirs)) - for baseDir, dir := range dirs { - globs[baseDir] = filepath.Join(dir, "*") - } - err := zipFiles(w, zipOptions{ - Globs: globs, - MaxBytes: maxBytes, - }) - if err != nil { - return "", err - } - - if maxTextBytes <= 0 { - return "", nil - } - - // Get info for all log files - allFiles := make(byDate, 0) - for _, glob := range globs { - matched, err := filepath.Glob(glob) - if err != nil { - slog.Error("Unable to glob log files", "glob", glob, "error", err) - continue - } - for _, file := range matched { - fi, err := os.Stat(file) - if err != nil { - slog.Error("Unable to stat log file", "file", file, "error", err) - continue - } - allFiles = append(allFiles, &fileInfo{ - file: file, - size: fi.Size(), - modTime: fi.ModTime().Unix(), - }) - } - } - - if len(allFiles) > 0 { - // Sort by recency - sort.Sort(allFiles) - - mostRecent := allFiles[0] - slog.Debug("Grabbing log tail", "file", mostRecent.file) - - mostRecentFile, err := os.Open(mostRecent.file) - if err != nil { - slog.Error("Unable to open most recent log file", "file", mostRecent.file, "error", err) - return "", nil - } - defer mostRecentFile.Close() - - seekTo := mostRecent.size - maxTextBytes - if seekTo > 0 { - slog.Debug("Seeking to tail of log file", "file", mostRecent.file, "seekTo", seekTo) - _, err = mostRecentFile.Seek(seekTo, io.SeekCurrent) - if err != nil { - slog.Error("Unable to seek to tail of log file", "file", mostRecent.file, "error", err) - return "", nil - } - } - tail, err := io.ReadAll(mostRecentFile) - if err != nil { - slog.Error("Unable to read tail of log file", "file", mostRecent.file, "error", err) - return "", nil - } - - slog.Debug("Returning log tail", "file", mostRecent.file, "tailSize", len(tail)) - return string(tail), nil - } - - return "", nil -} diff --git a/issue/zip.go b/issue/zip.go deleted file mode 100644 index 28731eb0..00000000 --- a/issue/zip.go +++ /dev/null @@ -1,118 +0,0 @@ -package issue - -import ( - "archive/zip" - "fmt" - "io" - "math" - "os" - "path/filepath" -) - -// zipOptions is a set of options for zipFiles. -type zipOptions struct { - // The search patterns for the files / directories to be zipped, keyed to the - // directory prefix used for storing the associated files in the ZIP, - // The search pattern is described at the comments of path/filepath.Match. - // As a special note, "**/*" doesn't match files not under a subdirectory. - Globs map[string]string - // The limit of total bytes of all the files in the archive. - // All remaining files will be ignored if the limit would be hit. - MaxBytes int64 -} - -// zipFiles creates a zip archive per the options and writes to the writer. -func zipFiles(writer io.Writer, opts zipOptions) (err error) { - w := zip.NewWriter(writer) - defer func() { - if e := w.Close(); e != nil { - err = e - } - }() - - maxBytes := opts.MaxBytes - if maxBytes == 0 { - maxBytes = math.MaxInt64 - } - - var totalBytes int64 - for baseDir, glob := range opts.Globs { - matched, e := filepath.Glob(glob) - if e != nil { - return e - } - for _, source := range matched { - nextTotal, e := zipFile(w, baseDir, source, maxBytes, totalBytes) - if e != nil || nextTotal > maxBytes { - return e - } - totalBytes = nextTotal - } - } - return -} - -func zipFile(w *zip.Writer, baseDir string, source string, limit int64, prevBytes int64) (newBytes int64, err error) { - _, e := os.Stat(source) - if e != nil { - return prevBytes, fmt.Errorf("%s: stat: %v", source, e) - } - - walkErr := filepath.Walk(source, func(fpath string, info os.FileInfo, err error) error { - if err != nil { - return fmt.Errorf("walking to %s: %v", fpath, err) - } - - newBytes = prevBytes + info.Size() - if newBytes > limit { - return filepath.SkipDir - } - header, err := zip.FileInfoHeader(info) - if err != nil { - return fmt.Errorf("%s: getting header: %v", fpath, err) - } - - dir, filename := filepath.Split(fpath) - if baseDir != "" { - dir = baseDir - } else { - dir = dir[:len(dir)-1] // strip trailing slash - } - if info.IsDir() { - header.Name = fmt.Sprintf("%v/", dir) - header.Method = zip.Store - } else { - header.Name = fmt.Sprintf("%v/%v", dir, filename) - header.Method = zip.Deflate - } - - writer, err := w.CreateHeader(header) - if err != nil { - return fmt.Errorf("%s: making header: %v", fpath, err) - } - - if info.IsDir() { - return nil - } - - if !header.Mode().IsRegular() { - return nil - } - file, err := os.Open(fpath) - if err != nil { - return fmt.Errorf("%s: opening: %v", fpath, err) - } - defer file.Close() - - _, err = io.Copy(writer, file) - if err != nil && err != io.EOF { - return fmt.Errorf("%s: copying contents: %v", fpath, err) - } - return nil - }) - - if walkErr != filepath.SkipDir { - return newBytes, walkErr - } - return newBytes, nil -} diff --git a/issue/zip_test.go b/issue/zip_test.go deleted file mode 100644 index 76a21238..00000000 --- a/issue/zip_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package issue - -import ( - "archive/zip" - "bytes" - "io" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestZipFilesWithoutPath(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, zipOptions{Globs: map[string]string{"": "**/*.txt*"}}) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "test_data/hello.txt", - "test_data/hello.txt.1", - "test_data/large.txt", - "test_data/zzzz.txt.2", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func TestZipFilesWithMaxBytes(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, - zipOptions{ - Globs: map[string]string{"": "test_data/*.txt*"}, - MaxBytes: 1024, // 1KB - }, - ) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "test_data/hello.txt", - "test_data/hello.txt.1", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func TestZipFilesWithNewRoot(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, zipOptions{Globs: map[string]string{"new_root": "**/*.txt*"}}) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "new_root/hello.txt", - "new_root/hello.txt.1", - "new_root/large.txt", - "new_root/zzzz.txt.2", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func testZipFiles(t *testing.T, zipped []byte, expectedFiles []string) { - reader, eread := zip.NewReader(bytes.NewReader(zipped), int64(len(zipped))) - if !assert.NoError(t, eread) { - return - } - if !assert.Equal(t, len(expectedFiles), len(reader.File), "should not include extra files and files that would exceed MaxBytes") { - return - } - for idx, file := range reader.File { - t.Log(file.Name) - assert.Equal(t, expectedFiles[idx], file.Name) - if !strings.Contains(file.Name, "hello.txt") { - continue - } - fileReader, err := file.Open() - if !assert.NoError(t, err) { - return - } - defer fileReader.Close() - actual, _ := io.ReadAll(fileReader) - assert.Equal(t, []byte("world\n"), actual) - } -} diff --git a/kindling/client.go b/kindling/client.go index 071c4861..66a6c10c 100644 --- a/kindling/client.go +++ b/kindling/client.go @@ -8,15 +8,16 @@ import ( "sync" "github.com/getlantern/kindling" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/kindling/dnstt" "github.com/getlantern/radiance/kindling/fronted" "github.com/getlantern/radiance/traces" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) var ( @@ -36,7 +37,7 @@ var ( // HTTPClient returns a http client with kindling transport func HTTPClient() *http.Client { if k == nil { - SetKindling(NewKindling()) + SetKindling(NewKindling(settings.GetString(settings.DataPathKey))) } httpClient := k.NewHTTPClient() httpClient.Timeout = common.DefaultHTTPTimeout @@ -45,7 +46,7 @@ func HTTPClient() *http.Client { } // Close stop all concurrent config fetches that can be happening in background -func Close(_ context.Context) error { +func Close() error { if stopUpdater != nil { stopUpdater() } @@ -70,8 +71,7 @@ func SetKindling(a kindling.Kindling) { const tracerName = "github.com/getlantern/radiance/kindling" // NewKindling build a kindling client and bootstrap this package -func NewKindling() kindling.Kindling { - dataDir := settings.GetString(settings.DataPathKey) +func NewKindling(dataDir string) kindling.Kindling { logger := &slogWriter{Logger: slog.Default()} ctx, span := otel.Tracer(tracerName).Start( diff --git a/kindling/client_test.go b/kindling/client_test.go index 675a18a6..c57c3cda 100644 --- a/kindling/client_test.go +++ b/kindling/client_test.go @@ -1,27 +1,18 @@ package kindling import ( - "context" - "log/slog" "net/http" - "os" "testing" - "github.com/getlantern/radiance/common/settings" "github.com/stretchr/testify/assert" ) func TestNewClient(t *testing.T) { - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - }))) - settings.Set(settings.DataPathKey, t.TempDir()) - k = NewKindling() + k = NewKindling(t.TempDir()) SetKindling(k) t.Cleanup(func() { - Close(context.Background()) + Close() k = nil }) diff --git a/kindling/dnstt/parser_test.go b/kindling/dnstt/parser_test.go index efa0f1e4..99d1dc0a 100644 --- a/kindling/dnstt/parser_test.go +++ b/kindling/dnstt/parser_test.go @@ -5,16 +5,16 @@ import ( "compress/gzip" "context" "io" - "log/slog" "net/http" "os" "path/filepath" "testing" "time" - "github.com/getlantern/radiance/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/getlantern/radiance/events" ) type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -125,10 +125,6 @@ dnsttConfigs: func TestDNSTTOptions(t *testing.T) { logger := bytes.NewBuffer(nil) - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - }))) waitFor = 15 * time.Second t.Run("embedded config only", func(t *testing.T) { dnst, err := DNSTTOptions(context.Background(), "", logger) diff --git a/log/log.go b/log/log.go new file mode 100644 index 00000000..931eff65 --- /dev/null +++ b/log/log.go @@ -0,0 +1,231 @@ +package log + +import ( + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "gopkg.in/natefinch/lumberjack.v2" + + "github.com/getlantern/radiance/common/env" +) + +const ( + // slog does not define trace and fatal levels, so we define them here. + LevelTrace = slog.LevelDebug - 4 + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelWarn = slog.LevelWarn + LevelError = slog.LevelError + LevelFatal = slog.LevelError + 4 + LevelPanic = slog.LevelError + 8 + + Disable = slog.LevelInfo + 1000 // A level that disables logging, used for testing or no-op logger. +) + +// Config holds the configuration for creating a new logger. +type Config struct { + // LogPath is the full path to the log file. + LogPath string + // Level is the log level string (e.g., "info", "debug"). + Level string + // Prod indicates whether the application is running in production mode. + Prod bool + // DisablePublisher indicates whether to disable the log publisher which is used for real-time + // log streaming. + DisablePublisher bool +} + +// NewLogger creates and returns a configured *slog.Logger that writes to a rotating log file +// and optionally to stdout. +// Returns noop logger if log level is set to disable. +func NewLogger(cfg Config) *slog.Logger { + level := env.GetString(env.LogLevel) + if level == "" && cfg.Level != "" { + level = cfg.Level + } + slevel, err := ParseLogLevel(level) + if err != nil { + slog.Warn("Failed to parse log level", "error", err) + } + slog.SetLogLoggerLevel(slevel) + if slevel == Disable { + return NoOpLogger() + } + + // lumberjack will create the log file if it does not exist with permissions 0600 otherwise it + // carries over the existing permissions. So we create it here with 0644 so we don't need root/admin + // privileges or chown/chmod to read it. + f, err := os.OpenFile(cfg.LogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + slog.Warn("Failed to pre-create log file", "error", err, "path", cfg.LogPath) + } else { + f.Close() + } + + logRotator := &lumberjack.Logger{ + Filename: cfg.LogPath, // Log file path + MaxSize: 25, // Rotate log when it reaches 25 MB + MaxBackups: 2, // Keep up to 2 rotated log files + MaxAge: 30, // Retain old log files for up to 30 days + Compress: cfg.Prod, // Compress rotated log files + } + + isWindows := runtime.GOOS == "windows" + isWindowsProd := isWindows && cfg.Prod + + loggingToStdOut := true + var logWriter io.Writer + if env.GetBool(env.DisableStdout) { + logWriter = logRotator + loggingToStdOut = false + } else if isWindowsProd { + // For some reason, logging to both stdout and a file on Windows + // causes issues with some Windows services where the logs + // do not get written to the file. So in prod mode on Windows, + // we log to file only. See: + // https://www.reddit.com/r/golang/comments/1fpo3cg/golang_windows_service_cannot_write_log_files/ + logWriter = logRotator + loggingToStdOut = false + } else { + logWriter = io.MultiWriter(os.Stdout, logRotator) + } + runtime.AddCleanup(&logWriter, func(f *os.File) { + f.Close() + }, f) + var handler slog.Handler = slog.NewTextHandler(logWriter, &slog.HandlerOptions{ + AddSource: true, + Level: slevel, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + switch a.Key { + case slog.TimeKey: + if t, ok := a.Value.Any().(time.Time); ok { + a.Value = slog.StringValue(t.UTC().Format("2006-01-02 15:04:05.000 UTC")) + } + return a + case slog.SourceKey: + source, ok := a.Value.Any().(*slog.Source) + if !ok { + return a + } + // remove github.com/ to get pkg name + var pkg, fn string + fields := strings.SplitN(source.Function, "/", 4) + switch len(fields) { + case 0, 1, 2: + file := filepath.Base(source.File) + a.Value = slog.StringValue(fmt.Sprintf("%s:%d", file, source.Line)) + return a + case 3: + pf := strings.SplitN(fields[2], ".", 2) + pkg, fn = pf[0], pf[1] + default: + pkg = fields[2] + fn = strings.SplitN(fields[3], ".", 2)[1] + } + + _, file, fnd := strings.Cut(source.File, pkg+"/") + if !fnd { + file = filepath.Base(source.File) + } + src := slog.GroupValue( + slog.String("func", fn), + slog.String("file", fmt.Sprintf("%s:%d", file, source.Line)), + ) + a.Value = slog.GroupValue( + slog.String("pkg", pkg), + slog.Any("source", src), + ) + a.Key = "" + case slog.LevelKey: + // format the log level to account for the custom levels defined in internal/util.go, i.e. trace + // otherwise, slog will print as "DEBUG-4" (trace) or similar + level := a.Value.Any().(slog.Level) + a.Value = slog.StringValue(FormatLogLevel(level)) + } + return a + }, + }) + handler = &Handler{Handler: handler, w: logWriter} + if !cfg.DisablePublisher { + pub := newPublisher(200) + handler = &PublishHandler{inner: handler, publisher: pub} + } + logger := slog.New(handler) + if !loggingToStdOut { + if isWindows { + fmt.Printf("Logging to file only on Windows prod -- run with RADIANCE_ENV=dev to enable stdout path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } else { + fmt.Printf("Logging to file only -- RADIANCE_DISABLE_STDOUT_LOG is set path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } + } else { + fmt.Printf("Logging to file and stdout path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } + return logger +} + +type Handler struct { + slog.Handler + w io.Writer +} + +func (h *Handler) Writer() io.Writer { + return h.w +} + +// ParseLogLevel parses a string representation of a log level and returns the corresponding slog.Level. +// If the level is not recognized, it returns LevelInfo. +func ParseLogLevel(level string) (slog.Level, error) { + switch strings.ToLower(level) { + case "trace": + return LevelTrace, nil + case "debug": + return LevelDebug, nil + case "info": + return LevelInfo, nil + case "warn", "warning": + return LevelWarn, nil + case "error": + return LevelError, nil + case "fatal": + return LevelFatal, nil + case "panic": + return LevelPanic, nil + case "disable", "none", "off": + return Disable, nil + default: + return LevelInfo, fmt.Errorf("unknown log level: %s", level) + } +} + +func FormatLogLevel(level slog.Level) string { + switch { + case level < LevelDebug: + return "TRACE" + case level < LevelInfo: + return "DEBUG" + case level < LevelWarn: + return "INFO" + case level < LevelError: + return "WARN" + case level < LevelFatal: + return "ERROR" + case level < LevelPanic: + return "FATAL" + default: + return "PANIC" + } +} + +// NoOpLogger returns a no-op logger that does not log anything. +func NoOpLogger() *slog.Logger { + // Create a no-op logger that does nothing. + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: Disable, + })) +} diff --git a/log/publish_handler.go b/log/publish_handler.go new file mode 100644 index 00000000..386ff625 --- /dev/null +++ b/log/publish_handler.go @@ -0,0 +1,129 @@ +package log + +import ( + "context" + "log/slog" + "sync" +) + +// Subscribe returns a channel that receives log entries from the default logger +// and an unsubscribe function. Recent entries from the ring buffer are sent +// immediately. +func Subscribe() (chan LogEntry, func()) { + h, ok := slog.Default().Handler().(*PublishHandler) + if ok { + return h.Subscribe() + } + ph := &PublishHandler{inner: h, publisher: newPublisher(200)} + slog.SetDefault(slog.New(ph)) + return ph.Subscribe() +} + +// LogEntry is a structured log entry streamed to clients. +type LogEntry struct { + Time string `json:"time"` + Level string `json:"level"` + Message string `json:"msg"` + Source string `json:"source,omitempty"` + Attrs map[string]any `json:"attrs,omitempty"` +} + +// PublishHandler wraps an slog.Handler and broadcasts each record to an observer. +type PublishHandler struct { + inner slog.Handler + publisher *publisher +} + +func (h *PublishHandler) Inner() slog.Handler { + return h.inner +} + +func (h *PublishHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.inner.Enabled(ctx, level) +} + +func (h *PublishHandler) Handle(ctx context.Context, record slog.Record) error { + entry := LogEntry{ + Time: record.Time.UTC().Format("2006-01-02 15:04:05.000 UTC"), + Level: record.Level.String(), + Message: record.Message, + } + if record.NumAttrs() > 0 { + entry.Attrs = make(map[string]any, record.NumAttrs()) + record.Attrs(func(a slog.Attr) bool { + entry.Attrs[a.Key] = a.Value.String() + return true + }) + } + h.publisher.publish(entry) + return h.inner.Handle(ctx, record) +} + +func (h *PublishHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &PublishHandler{inner: h.inner.WithAttrs(attrs), publisher: h.publisher} +} + +func (h *PublishHandler) WithGroup(name string) slog.Handler { + return &PublishHandler{inner: h.inner.WithGroup(name), publisher: h.publisher} +} + +// Subscribe returns a channel that receives log entries and an unsubscribe function. +// Recent entries from the ring buffer are sent immediately. +func (h *PublishHandler) Subscribe() (chan LogEntry, func()) { + return h.publisher.subscribe() +} + +// publisher fans out log entries to connected SSE clients. It maintains a ring buffer +// of recent entries so new subscribers get immediate context. +type publisher struct { + clients map[chan LogEntry]struct{} + ring []LogEntry + ringSize int + ringIdx int + mu sync.RWMutex +} + +func newPublisher(ringSize int) *publisher { + return &publisher{ + clients: make(map[chan LogEntry]struct{}), + ring: make([]LogEntry, ringSize), + ringSize: ringSize, + } +} + +func (lb *publisher) publish(entry LogEntry) { + lb.mu.Lock() + lb.ring[lb.ringIdx%lb.ringSize] = entry + lb.ringIdx++ + lb.mu.Unlock() + + lb.mu.RLock() + defer lb.mu.RUnlock() + for ch := range lb.clients { + select { + case ch <- entry: + default: // drop if client is slow + } + } +} + +func (lb *publisher) subscribe() (chan LogEntry, func()) { + ch := make(chan LogEntry, lb.ringSize) + lb.mu.Lock() + start := max(0, lb.ringIdx-lb.ringSize) + for i := start; i < lb.ringIdx; i++ { + entry := lb.ring[i%lb.ringSize] + if entry.Time != "" { + ch <- entry + } + } + lb.clients[ch] = struct{}{} + lb.mu.Unlock() + + unsub := func() { + lb.mu.Lock() + delete(lb.clients, ch) + lb.mu.Unlock() + } + return ch, unsub +} diff --git a/log/publish_test.go b/log/publish_test.go new file mode 100644 index 00000000..c44a075b --- /dev/null +++ b/log/publish_test.go @@ -0,0 +1,120 @@ +package log + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPushlisher(t *testing.T) { + p := newPublisher(10) + + ch, unsub := p.subscribe() + defer unsub() + + entry := LogEntry{Time: "2025-01-01 00:00:00.000 UTC", Level: "INFO", Message: "hello"} + p.publish(entry) + + select { + case got := <-ch: + assert.Equal(t, entry, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } +} + +func TestMultipleSubscribers(t *testing.T) { + p := newPublisher(10) + + ch1, unsub1 := p.subscribe() + defer unsub1() + ch2, unsub2 := p.subscribe() + defer unsub2() + + entry := LogEntry{Time: "2025-01-01 00:00:00.000 UTC", Level: "DEBUG", Message: "multi"} + p.publish(entry) + + for _, ch := range []chan LogEntry{ch1, ch2} { + select { + case got := <-ch: + assert.Equal(t, entry, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } + } +} + +func TestUnsubscribe(t *testing.T) { + p := newPublisher(10) + + ch, unsub := p.subscribe() + unsub() + + p.publish(LogEntry{Time: "2025-01-01 00:00:00.000 UTC", Level: "INFO", Message: "after unsub"}) + + select { + case <-ch: + t.Fatal("should not receive after unsubscribe") + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestRingBuffer(t *testing.T) { + p := newPublisher(3) + + // Fill the ring buffer with 5 entries, so only the last 3 should be available. + for i := range 5 { + p.publish(LogEntry{ + Time: "t", + Level: "INFO", + Message: string(rune('a' + i)), + }) + } + + ch, unsub := p.subscribe() + defer unsub() + + // New subscriber should get the 3 ring buffer entries. + var msgs []string + for range 3 { + select { + case e := <-ch: + msgs = append(msgs, e.Message) + case <-time.After(time.Second): + t.Fatal("timed out reading ring buffer entries") + } + } + assert.Equal(t, []string{"c", "d", "e"}, msgs) +} + +func TestConcurrentBroadcast(t *testing.T) { + p := newPublisher(100) + ch, unsub := p.subscribe() + defer unsub() + + var wg sync.WaitGroup + n := 50 + wg.Add(n) + for i := range n { + go func(i int) { + defer wg.Done() + p.publish(LogEntry{Time: "t", Level: "INFO", Message: "msg"}) + }(i) + } + wg.Wait() + + received := 0 + for { + select { + case <-ch: + received++ + default: + require.Equal(t, n, received) + return + } + } +} diff --git a/option/algeneva.go b/option/algeneva.go deleted file mode 100644 index dd638ee7..00000000 --- a/option/algeneva.go +++ /dev/null @@ -1,12 +0,0 @@ -package option - -import "github.com/sagernet/sing-box/option" - -type ALGenevaInboundOptions struct { - option.HTTPMixedInboundOptions -} - -type ALGenevaOutboundOptions struct { - option.HTTPOutboundOptions - Strategy string `json:"strategy,omitempty"` -} diff --git a/option/amnezia.go b/option/amnezia.go deleted file mode 100644 index 2cc4c95f..00000000 --- a/option/amnezia.go +++ /dev/null @@ -1,85 +0,0 @@ -package option - -import ( - O "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common/json/badoption" - "net/netip" -) - -/************* ADDED FOR AMNEZIA *************/ -/* -WireGuardAdvancedSecurityOptions provides advanced security options for WireGuard required to activate AmneziaWG. - -In AmneziaWG, random bytes are appended to every auth packet to alter their size. - -Thus, "init and response handshake packets" have added "junk" at the beginning of their data, the size of which -is determined by the values S1 and S2. - -By default, the initiating handshake packet has a fixed size (148 bytes). After adding the junk, its size becomes 148 bytes + S1. -AmneziaWG also incorporates another trick for more reliable masking. Before initiating a session, Amnezia sends a - -certain number of "junk" packets to thoroughly confuse DPI systems. The number of these packets and their -minimum and maximum byte sizes can also be adjusted in the settings, using parameters Jc, Jmin, and Jmax. - -*/ - -type WireGuardAdvancedSecurityOptions struct { - JunkPacketCount int `json:"junk_packet_count,omitempty"` // jc - JunkPacketMinSize int `json:"junk_packet_min_size,omitempty"` // jmin - JunkPacketMaxSize int `json:"junk_packet_max_size,omitempty"` // jmax - InitPacketJunkSize int `json:"init_packet_junk_size,omitempty"` // s1 - ResponsePacketJunkSize int `json:"response_packet_junk_size,omitempty"` // s2 - InitPacketMagicHeader uint32 `json:"init_packet_magic_header,omitempty"` // h1 - ResponsePacketMagicHeader uint32 `json:"response_packet_magic_header,omitempty"` // h2 - UnderloadPacketMagicHeader uint32 `json:"underload_packet_magic_header,omitempty"` // h3 - TransportPacketMagicHeader uint32 `json:"transport_packet_magic_header,omitempty"` // h4 -} -/******************** END ********************/ -type WireGuardEndpointOptions struct { - System bool `json:"system,omitempty"` - Name string `json:"name,omitempty"` - MTU uint32 `json:"mtu,omitempty"` - Address badoption.Listable[netip.Prefix] `json:"address"` - PrivateKey string `json:"private_key"` - ListenPort uint16 `json:"listen_port,omitempty"` - Peers []WireGuardPeer `json:"peers,omitempty"` - UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` - Workers int `json:"workers,omitempty"` - WireGuardAdvancedSecurityOptions /** ADDED FOR AMNEZIA **/ - O.DialerOptions -} - -type WireGuardPeer struct { - Address string `json:"address,omitempty"` - Port uint16 `json:"port,omitempty"` - PublicKey string `json:"public_key,omitempty"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` - PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` -} - -type LegacyWireGuardOutboundOptions struct { - O.DialerOptions - SystemInterface bool `json:"system_interface,omitempty"` - GSO bool `json:"gso,omitempty"` - InterfaceName string `json:"interface_name,omitempty"` - LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"` - PrivateKey string `json:"private_key"` - Peers []LegacyWireGuardPeer `json:"peers,omitempty"` - O.ServerOptions - PeerPublicKey string `json:"peer_public_key"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` - Workers int `json:"workers,omitempty"` - MTU uint32 `json:"mtu,omitempty"` - Network O.NetworkList `json:"network,omitempty"` -} - -type LegacyWireGuardPeer struct { - O.ServerOptions - PublicKey string `json:"public_key,omitempty"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` -} diff --git a/option/outline.go b/option/outline.go deleted file mode 100644 index a5fb93aa..00000000 --- a/option/outline.go +++ /dev/null @@ -1,55 +0,0 @@ -package option - -import O "github.com/sagernet/sing-box/option" - -// OutboundOutlineOptions set the outbound options used by the outline-sdk -// smart dialer. You can find more details about the parameters by looking -// through the implementation: https://github.com/Jigsaw-Code/outline-sdk/blob/v0.0.18/x/smart/stream_dialer.go#L65-L100 -// Or check the documentation README: https://github.com/Jigsaw-Code/outline-sdk/tree/v0.0.18/x/smart -type OutboundOutlineOptions struct { - O.DialerOptions - DNSResolvers []DNSEntryConfig `json:"dns,omitempty" yaml:"dns,omitempty"` - TLS []string `json:"tls,omitempty" yaml:"tls,omitempty"` - TestTimeout string `json:"test_timeout" yaml:"-"` - Domains []string `json:"domains" yaml:"-"` -} - -// DNSEntryConfig specifies a list of resolvers to test and they can be one of -// the attributes (system, https, tls, udp or tcp) -type DNSEntryConfig struct { - // System is used for using the system as a resolver, if you want to use it - // provide an empty object. - System *struct{} `json:"system,omitempty"` - // HTTPS use an encrypted DNS over HTTPS (DoH) resolver. - HTTPS *HTTPSEntryConfig `json:"https,omitempty"` - // TLS use an encrypted DNS over TLS (DoT) resolver. - TLS *TLSEntryConfig `json:"tls,omitempty"` - // UDP use a UDP resolver - UDP *UDPEntryConfig `json:"udp,omitempty"` - // TCP use a TCP resolver - TCP *TCPEntryConfig `json:"tcp,omitempty"` -} - -type HTTPSEntryConfig struct { - // Domain name of the host. - Name string `json:"name,omitempty"` - // Host:port. Defaults to Name:443. - Address string `json:"address,omitempty"` -} - -type TLSEntryConfig struct { - // Domain name of the host. - Name string `json:"name,omitempty"` - // Host:port. Defaults to Name:853. - Address string `json:"address,omitempty"` -} - -type UDPEntryConfig struct { - // Host:port. - Address string `json:"address,omitempty"` -} - -type TCPEntryConfig struct { - // Host:port. - Address string `json:"address,omitempty"` -} diff --git a/radiance.go b/radiance.go deleted file mode 100644 index 6c32dadf..00000000 --- a/radiance.go +++ /dev/null @@ -1,333 +0,0 @@ -// Package radiance provides a local server that proxies all requests to a remote proxy server using different -// protocols meant to circumvent censorship. Radiance uses a [transport.StreamDialer] to dial the target server -// over the desired protocol. The [config.Config] is used to configure the dialer for a proxy server. -package radiance - -import ( - "context" - "fmt" - "log/slog" - "sync" - "sync/atomic" - "time" - - "github.com/Xuanwo/go-locale" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/metric/noop" - "go.opentelemetry.io/otel/trace" - traceNoop "go.opentelemetry.io/otel/trace/noop" - - lcommon "github.com/getlantern/common" - - "github.com/getlantern/radiance/api" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/deviceid" - "github.com/getlantern/radiance/common/env" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/config" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/issue" - "github.com/getlantern/radiance/kindling" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/telemetry" - "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn" -) - -const configPollInterval = 10 * time.Minute -const tracerName = "github.com/getlantern/radiance" - -//go:generate mockgen -destination=radiance_mock_test.go -package=radiance github.com/getlantern/radiance configHandler - -// configHandler is an interface that abstracts the config.ConfigHandler struct for easier testing. -type configHandler interface { - // Stop stops the config handler from fetching new configurations. - Stop() - // SetPreferredServerLocation sets the preferred server location. If not set - it's auto selected by the API - SetPreferredServerLocation(country, city string) - // GetConfig returns the current configuration. - // It returns an error if the configuration is not yet available. - GetConfig() (*config.Config, error) -} - -type issueReporter interface { - Report(ctx context.Context, report issue.IssueReport, userEmail, country string) error -} - -// Radiance is a local server that proxies all requests to a remote proxy server over a transport.StreamDialer. -type Radiance struct { - confHandler configHandler - issueReporter issueReporter - apiHandler *api.APIClient - srvManager *servers.Manager - shutdownFuncs []func(context.Context) error - closeOnce sync.Once - stopChan chan struct{} - telemetryConsent atomic.Bool -} - -type Options struct { - DataDir string - LogDir string - Locale string - DeviceID string - LogLevel string - // User choice for telemetry consent - TelemetryConsent bool -} - -// NewRadiance creates a new Radiance VPN client. opts includes the platform interface used to -// interact with the underlying platform on iOS, Android, and MacOS. On other platforms, it is -// ignored and can be nil. -func NewRadiance(opts Options) (*Radiance, error) { - if opts.Locale == "" { - // It is preferable to use the locale from the frontend, as locale is a requirement for lots - // of frontend code and therefore is more reliably supported there. - // However, if the frontend locale is not available, we can use the system locale as a fallback. - if tag, err := locale.Detect(); err != nil { - opts.Locale = "en-US" - } else { - opts.Locale = tag.String() - } - } - - var platformDeviceID string - switch common.Platform { - case "ios", "android": - platformDeviceID = opts.DeviceID - default: - platformDeviceID = deviceid.Get() - } - - shutdownFuncs := []func(context.Context) error{} - if err := common.Init(opts.DataDir, opts.LogDir, opts.LogLevel); err != nil { - return nil, fmt.Errorf("failed to initialize: %w", err) - } - settings.Set(settings.LocaleKey, opts.Locale) - - dataDir := settings.GetString(settings.DataPathKey) - kindling.SetKindling(kindling.NewKindling()) - setUserConfig(platformDeviceID, dataDir, opts.Locale) - apiHandler := api.NewAPIClient(dataDir) - issueReporter := issue.NewIssueReporter() - - svrMgr, err := servers.NewManager(dataDir) - if err != nil { - return nil, fmt.Errorf("failed to create server manager: %w", err) - } - cOpts := config.Options{ - PollInterval: configPollInterval, - SvrManager: svrMgr, - DataDir: dataDir, - Locale: opts.Locale, - APIHandler: apiHandler, - } - if disableFetch, ok := env.Get[bool](env.DisableFetch); ok && disableFetch { - cOpts.PollInterval = -1 - slog.Info("Disabling config fetch") - } - r := &Radiance{ - issueReporter: issueReporter, - apiHandler: apiHandler, - srvManager: svrMgr, - shutdownFuncs: shutdownFuncs, - stopChan: make(chan struct{}), - closeOnce: sync.Once{}, - } - r.telemetryConsent.Store(opts.TelemetryConsent) - events.Subscribe(func(evt config.NewConfigEvent) { - if r.telemetryConsent.Load() { - slog.Info("Telemetry consent given; handling new config for telemetry") - if err := telemetry.OnNewConfig(evt.Old, evt.New, platformDeviceID); err != nil { - slog.Error("Failed to handle new config for telemetry", "error", err) - } - } else { - slog.Info("Telemetry consent not given; skipping telemetry initialization") - } - }) - r.confHandler = config.NewConfigHandler(cOpts) - // Register AFTER NewConfigHandler so the disk-load event is already - // consumed. Runs whenever a new config is applied to provide continuous - // bandit callback data even when the VPN tunnel is not active. - sub := events.Subscribe(func(evt config.NewConfigEvent) { - vpn.RunURLTests(dataDir) - }) - r.addShutdownFunc(telemetry.Close, kindling.Close, func(_ context.Context) error { - sub.Unsubscribe() - return nil - }) - return r, nil -} - -// addShutdownFunc adds a shutdown function(s) to the Radiance instance. -// This function is called when the Radiance instance is closed to ensure that all -// resources are cleaned up properly. -func (r *Radiance) addShutdownFunc(fns ...func(context.Context) error) { - for _, fn := range fns { - if fn != nil { - r.shutdownFuncs = append(r.shutdownFuncs, fn) - } - } -} - -func (r *Radiance) Close() { - r.closeOnce.Do(func() { - slog.Debug("Closing Radiance") - r.confHandler.Stop() - close(r.stopChan) - for _, shutdown := range r.shutdownFuncs { - if err := shutdown(context.Background()); err != nil { - slog.Error("Failed to shutdown", "error", err) - } - } - }) - <-r.stopChan -} - -// APIHandler returns the API handler for the Radiance client. -func (r *Radiance) APIHandler() *api.APIClient { - return r.apiHandler -} - -// SetPreferredServer sets the preferred server location for the VPN connection. -// pass empty strings to auto select the server location -func (r *Radiance) SetPreferredServer(ctx context.Context, country, city string) { - r.confHandler.SetPreferredServerLocation(country, city) -} - -// ServerManager returns the server manager for the Radiance client. -func (r *Radiance) ServerManager() *servers.Manager { - return r.srvManager -} - -type IssueReport = issue.IssueReport - -// ReportIssue submits an issue report to the back-end with an optional user email -func (r *Radiance) ReportIssue(email string, report IssueReport) error { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "report_issue") - defer span.End() - if report.Type == "" && report.Description == "" { - return fmt.Errorf("issue report should contain at least type or description") - } - var country string - // get country from the config returned by the backend - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Warn("Failed to get config", "error", err) - } else { - country = cfg.ConfigResponse.Country - } - - err = r.issueReporter.Report(ctx, report, email, country) - if err != nil { - slog.Error("Failed to report issue", "error", err) - return traces.RecordError(ctx, fmt.Errorf("failed to report issue: %w", err)) - } - slog.Info("Issue reported successfully") - return nil -} - -// Features returns the features available in the current configuration, returned from the server in the -// config response. -func (r *Radiance) Features() map[string]bool { - _, span := otel.Tracer(tracerName).Start(context.Background(), "features") - defer span.End() - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Info("Failed to get config for features", "error", err) - return map[string]bool{} - } - if cfg == nil { - slog.Info("No config available for features, returning empty map") - return map[string]bool{} - } - slog.Debug("Returning features from config", "features", cfg.ConfigResponse.Features) - // Return the features from the config - if cfg.ConfigResponse.Features == nil { - slog.Info("No features available in config, returning empty map") - return map[string]bool{} - } - return cfg.ConfigResponse.Features -} - -// EnableTelemetry enable OpenTelemetry instrumentation for the Radiance client. -// After enabling it, it should initialize telemetry again once a new config -// is available -func (r *Radiance) EnableTelemetry() { - slog.Info("Enabling telemetry") - r.telemetryConsent.Store(true) - // If a config is already available, initialize telemetry immediately instead of - // waiting for the next config change event. - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Warn("Failed to get config while enabling telemetry; telemetry will be initialized on next config update", "error", err) - return - } - if cfg == nil { - slog.Info("No config available while enabling telemetry; telemetry will be initialized on next config update") - return - } - cErr := telemetry.OnNewConfig(nil, cfg, settings.GetString(settings.DeviceIDKey)) - if cErr != nil { - slog.Warn("Failed to initialize telemetry on enabling", "error", cErr) - } -} - -// DisableTelemetry disables OpenTelemetry instrumentation for the Radiance client. -func (r *Radiance) DisableTelemetry() { - slog.Info("Disabling telemetry") - r.telemetryConsent.Store(false) - otel.SetTracerProvider(traceNoop.NewTracerProvider()) - otel.SetMeterProvider(noop.NewMeterProvider()) -} - -// ServerLocations returns the list of server locations where the user can connect to proxies. -func (r *Radiance) ServerLocations() ([]lcommon.ServerLocation, error) { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "server_locations") - defer span.End() - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Error("Failed to get config for server locations", "error", err) - traces.RecordError(ctx, err, trace.WithStackTrace(true)) - return nil, fmt.Errorf("failed to get config: %w", err) - } - if cfg == nil { - slog.Info("No config available for server locations, returning error") - traces.RecordError(ctx, err, trace.WithStackTrace(true)) - return nil, fmt.Errorf("no config available") - } - slog.Debug("Returning server locations from config", "locations", cfg.ConfigResponse.Servers) - return cfg.ConfigResponse.Servers, nil -} - -type slogWriter struct { - *slog.Logger -} - -func (w *slogWriter) Write(p []byte) (n int, err error) { - // Convert the byte slice to a string and log it - w.Info(string(p)) - return len(p), nil -} - -// setUserConfig creates a new UserInfo object -func setUserConfig(deviceID, dataDir, locale string) { - if err := settings.Set(settings.DeviceIDKey, deviceID); err != nil { - slog.Error("failed to set device ID in settings", "error", err) - } - if err := settings.Set(settings.DataPathKey, dataDir); err != nil { - slog.Error("failed to set data path in settings", "error", err) - } - if err := settings.Set(settings.LocaleKey, locale); err != nil { - slog.Error("failed to set locale in settings", "error", err) - } - - events.SubscribeOnce(func(evt config.NewConfigEvent) { - if evt.New != nil && evt.New.ConfigResponse.Country != "" { - if err := settings.Set(settings.CountryCodeKey, evt.New.ConfigResponse.Country); err != nil { - slog.Error("failed to set country code in settings", "error", err) - } - slog.Info("Set country code from config response", "country_code", evt.New.ConfigResponse.Country) - } - }) -} diff --git a/radiance_mock_test.go b/radiance_mock_test.go deleted file mode 100644 index 1f455bae..00000000 --- a/radiance_mock_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/getlantern/radiance (interfaces: configHandler) -// -// Generated by this command: -// -// mockgen -destination=radiance_mock_test.go -package=radiance github.com/getlantern/radiance configHandler -// - -// Package radiance is a generated GoMock package. -package radiance - -import ( - context "context" - reflect "reflect" - - config "github.com/getlantern/common" - gomock "go.uber.org/mock/gomock" -) - -// MockconfigHandler is a mock of configHandler interface. -type MockconfigHandler struct { - ctrl *gomock.Controller - recorder *MockconfigHandlerMockRecorder - isgomock struct{} -} - -// MockconfigHandlerMockRecorder is the mock recorder for MockconfigHandler. -type MockconfigHandlerMockRecorder struct { - mock *MockconfigHandler -} - -// NewMockconfigHandler creates a new mock instance. -func NewMockconfigHandler(ctrl *gomock.Controller) *MockconfigHandler { - mock := &MockconfigHandler{ctrl: ctrl} - mock.recorder = &MockconfigHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockconfigHandler) EXPECT() *MockconfigHandlerMockRecorder { - return m.recorder -} - -// GetConfig mocks base method. -func (m *MockconfigHandler) GetConfig(ctx context.Context) (*config.ConfigResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetConfig", ctx) - ret0, _ := ret[0].(*config.ConfigResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetConfig indicates an expected call of GetConfig. -func (mr *MockconfigHandlerMockRecorder) GetConfig(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockconfigHandler)(nil).GetConfig), ctx) -} - -// ListAvailableServers mocks base method. -func (m *MockconfigHandler) ListAvailableServers(ctx context.Context) ([]config.ServerLocation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAvailableServers", ctx) - ret0, _ := ret[0].([]config.ServerLocation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListAvailableServers indicates an expected call of ListAvailableServers. -func (mr *MockconfigHandlerMockRecorder) ListAvailableServers(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAvailableServers", reflect.TypeOf((*MockconfigHandler)(nil).ListAvailableServers), ctx) -} - -// SetPreferredServerLocation mocks base method. -func (m *MockconfigHandler) SetPreferredServerLocation(country, city string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetPreferredServerLocation", country, city) -} - -// SetPreferredServerLocation indicates an expected call of SetPreferredServerLocation. -func (mr *MockconfigHandlerMockRecorder) SetPreferredServerLocation(country, city any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPreferredServerLocation", reflect.TypeOf((*MockconfigHandler)(nil).SetPreferredServerLocation), country, city) -} - -// Stop mocks base method. -func (m *MockconfigHandler) Stop() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Stop") -} - -// Stop indicates an expected call of Stop. -func (mr *MockconfigHandlerMockRecorder) Stop() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockconfigHandler)(nil).Stop)) -} diff --git a/radiance_test.go b/radiance_test.go deleted file mode 100644 index a153e9f3..00000000 --- a/radiance_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package radiance - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/getlantern/radiance/config" -) - -func TestNewRadiance(t *testing.T) { - t.Run("it should create a new Radiance instance successfully", func(t *testing.T) { - dir := t.TempDir() - r, err := NewRadiance(Options{ - DataDir: dir, - Locale: "en-US", - }) - assert.NoError(t, err) - r.Close() - - assert.NotNil(t, r) - assert.NotNil(t, r.confHandler) - assert.NotNil(t, r.stopChan) - assert.NotNil(t, r.issueReporter) - }) -} - -func TestReportIssue(t *testing.T) { - var tests = []struct { - name string - email string - report IssueReport - assert func(*testing.T, error) - }{ - { - name: "return error when missing type and description", - email: "", - report: IssueReport{}, - assert: func(t *testing.T, err error) { - assert.Error(t, err) - }, - }, - { - name: "return nil when issue report is valid", - email: "radiancetest@getlantern.org", - report: IssueReport{ - Type: "Application crashes", - Description: "internal test only", - Device: "test device", - Model: "a123", - }, - assert: func(t *testing.T, err error) { - assert.NoError(t, err) - }, - }, - { - name: "return nil when issue report is valid with empty email", - email: "", - report: IssueReport{ - Type: "Cannot sign in", - Description: "internal test only", - Device: "test device 2", - Model: "b456", - }, - assert: func(t *testing.T, err error) { - assert.NoError(t, err) - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &Radiance{ - issueReporter: &mockIssueReporter{}, - confHandler: &mockConfigHandler{}, - } - err := r.ReportIssue(tt.email, tt.report) - tt.assert(t, err) - }) - } -} - -type mockIssueReporter struct{} - -func (m *mockIssueReporter) Report(_ context.Context, _ IssueReport, _, _ string) error { return nil } - -type mockConfigHandler struct{} - -func (m *mockConfigHandler) Stop() {} - -func (m *mockConfigHandler) SetPreferredServerLocation(country string, city string) {} - -func (m *mockConfigHandler) GetConfig() (*config.Config, error) { - return &config.Config{}, nil -} - -func (m *mockConfigHandler) AddConfigListener(listener config.ListenerFunc) { - listener(&config.Config{}, &config.Config{}) -} diff --git a/servers/manager.go b/servers/manager.go index 16690ae4..b7565eb8 100644 --- a/servers/manager.go +++ b/servers/manager.go @@ -17,6 +17,7 @@ import ( "path/filepath" "slices" "strconv" + "strings" "sync" "time" @@ -27,10 +28,9 @@ import ( C "github.com/getlantern/common" "github.com/getlantern/radiance/bypass" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/events" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" "github.com/getlantern/pluriconfig" @@ -59,20 +59,13 @@ type ServerCredentials struct { } type Options struct { - Outbounds []option.Outbound `json:"outbounds,omitempty"` - Endpoints []option.Endpoint `json:"endpoints,omitempty"` - Locations map[string]C.ServerLocation `json:"locations,omitempty"` - URLOverrides map[string]string `json:"url_overrides,omitempty"` + Outbounds []option.Outbound `json:"outbounds,omitempty"` + Endpoints []option.Endpoint `json:"endpoints,omitempty"` + Locations map[string]C.ServerLocation `json:"locations,omitempty"` + URLOverrides map[string]string `json:"url_overrides,omitempty"` Credentials map[string]ServerCredentials `json:"credentials,omitempty"` } -// MarshalJSON encodes Options using the sing-box context so that type-specific outbound/endpoint -// options (server, port, password, etc.) are included in the output. -func (o Options) MarshalJSON() ([]byte, error) { - type Alias Options - return json.MarshalContext(box.BaseContext(), Alias(o)) -} - // AllTags returns a slice of all tags from both endpoints and outbounds in the Options. func (o Options) AllTags() []string { tags := make([]string, 0, len(o.Outbounds)+len(o.Endpoints)) @@ -87,18 +80,36 @@ func (o Options) AllTags() []string { type Servers map[ServerGroup]Options +type Server struct { + // Group indicates which group the server belongs to. + Group ServerGroup + // Tag is the tag/name of the server + Tag string + // Type is the type of the server, e.g. "http", "shadowsocks", etc. + Type string + Options any // will be either [option.Endpoint] or [option.Outbound] + Location C.ServerLocation +} + +type optsMap map[string]Server + +func (m optsMap) add(group, tag, typ string, options any, loc C.ServerLocation) { + m[tag] = Server{group, tag, typ, options, loc} +} + // Manager manages server configurations, including endpoints and outbounds. type Manager struct { - access sync.RWMutex - servers Servers - optsMaps map[ServerGroup]map[string]any // map of tag to option for quick access + access sync.RWMutex + servers Servers + optsMap optsMap // map of tag to option for quick access + logger *slog.Logger serversFile string httpClient *http.Client } // NewManager creates a new Manager instance, loading server options from disk. -func NewManager(dataPath string) (*Manager, error) { +func NewManager(dataPath string, logger *slog.Logger) (*Manager, error) { mgr := &Manager{ servers: Servers{ SGLantern: Options{ @@ -114,28 +125,24 @@ func NewManager(dataPath string) (*Manager, error) { Credentials: make(map[string]ServerCredentials), }, }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - access: sync.RWMutex{}, - + optsMap: map[string]Server{}, + serversFile: filepath.Join(dataPath, internal.ServersFileName), + logger: logger, // Use the bypass proxy dialer to route requests outside the VPN tunnel. // This client is only used to access private servers the user has created. - httpClient: retryableHTTPClient().StandardClient(), + httpClient: retryableHTTPClient(logger).StandardClient(), } - slog.Debug("Loading servers", "file", mgr.serversFile) + mgr.logger.Debug("Loading servers", "file", mgr.serversFile) if err := mgr.loadServers(); err != nil { - slog.Error("Failed to load servers", "file", mgr.serversFile, "error", err) + mgr.logger.Error("Failed to load servers", "file", mgr.serversFile, "error", err) return nil, fmt.Errorf("failed to load servers from file: %w", err) } - slog.Log(nil, internal.LevelTrace, "Loaded servers", "servers", mgr.servers) + mgr.logger.Log(nil, log.LevelTrace, "Loaded servers", "servers", mgr.servers) return mgr, nil } -func retryableHTTPClient() *retryablehttp.Client { +func retryableHTTPClient(logger *slog.Logger) *retryablehttp.Client { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: bypass.DialContext, @@ -154,6 +161,7 @@ func retryableHTTPClient() *retryablehttp.Client { client.RetryMax = 10 client.RetryWaitMin = 1 * time.Second client.RetryWaitMax = 10 * time.Second + client.Logger = logger return client } @@ -175,14 +183,6 @@ func (m *Manager) Servers() Servers { return result } -type Server struct { - Group ServerGroup - Tag string - Type string - Options any // will be either [option.Endpoint] or [option.Outbound] - Location C.ServerLocation -} - // GetServerByTag returns the server configuration for a given tag and a boolean indicating whether // the server was found. The returned Server contains pointer-rich sing-box types in its Options // field, so callers on a CGo callback stack should use [GetServerByTagJSON] instead. This method @@ -191,118 +191,38 @@ type Server struct { func (m *Manager) GetServerByTag(tag string) (Server, bool) { m.access.RLock() defer m.access.RUnlock() - return m.getServerByTagLocked(tag) -} - -// getServerByTagLocked performs the tag lookup. Caller must hold access.RLock. -func (m *Manager) getServerByTagLocked(tag string) (Server, bool) { - group := SGLantern - opts, ok := m.optsMaps[SGLantern][tag] - if !ok { - if opts, ok = m.optsMaps[SGUser][tag]; !ok { - return Server{}, false - } - group = SGUser - } - s := Server{ - Group: group, - Tag: tag, - Options: opts, - Location: m.servers[group].Locations[tag], - } - switch v := opts.(type) { - case option.Endpoint: - s.Type = v.Type - case option.Outbound: - s.Type = v.Type - } - return s, true -} - -// ServersJSON returns the current server configurations as pre-marshalled JSON. -// Safe to call from CGo callback stacks: the work runs on a dedicated Go goroutine -// (via [common.RunOffCgoStack]) so pointer-rich sing-box types never touch the C stack. -func (m *Manager) ServersJSON() ([]byte, error) { - return common.RunOffCgoStack(func() ([]byte, error) { - m.access.RLock() - defer m.access.RUnlock() - return json.MarshalContext(box.BaseContext(), m.servers) - }) -} - -// GetServerByTagJSON returns the server configuration for a given tag as pre-marshalled JSON. -// Like [ServersJSON], safe to call from CGo callback stacks. -func (m *Manager) GetServerByTagJSON(tag string) ([]byte, bool, error) { - type result struct { - data []byte - ok bool - } - r, err := common.RunOffCgoStack(func() (result, error) { - m.access.RLock() - defer m.access.RUnlock() - - s, ok := m.getServerByTagLocked(tag) - if !ok { - return result{}, nil - } - b, err := json.MarshalContext(box.BaseContext(), s) - if err != nil { - return result{}, fmt.Errorf("marshal server %q: %w", tag, err) - } - return result{data: b, ok: true}, nil - }) - return r.data, r.ok, err -} - -type ServersUpdatedEvent struct { - events.Event - Group ServerGroup - Options *Options -} - -type ServersAddedEvent struct { - events.Event - Group ServerGroup - Options *Options -} - -type ServersRemovedEvent struct { - events.Event - Group ServerGroup - Tag string + s, exists := m.optsMap[tag] + return s, exists } // SetServers sets the server options for a specific group. // Important: this will overwrite any existing servers for that group. To add new servers without // overwriting existing ones, use [AddServers] instead. func (m *Manager) SetServers(group ServerGroup, options Options) error { - if err := m.setServers(group, options); err != nil { - return fmt.Errorf("set servers: %w", err) + switch group { + case SGLantern, SGUser: + default: + return fmt.Errorf("invalid server group: %s", group) } m.access.Lock() defer m.access.Unlock() + if err := m.setServers(group, options); err != nil { + return fmt.Errorf("set servers: %w", err) + } + if err := m.saveServers(); err != nil { return fmt.Errorf("failed to save servers: %w", err) } - events.Emit(ServersUpdatedEvent{ - Group: group, - Options: &options, - }) + servers := make([]Server, 0, len(options.Outbounds)+len(options.Endpoints)) + for _, tag := range options.AllTags() { + servers = append(servers, m.optsMap[tag]) + } return nil } func (m *Manager) setServers(group ServerGroup, options Options) error { - switch group { - case SGLantern, SGUser: - default: - return fmt.Errorf("invalid server group: %s", group) - } - - m.access.Lock() - defer m.access.Unlock() - - slog.Log(nil, internal.LevelTrace, "Setting servers", "group", group, "options", options) + m.logger.Log(nil, log.LevelTrace, "Setting servers", "group", group, "options", options) opts := Options{ Outbounds: append([]option.Outbound{}, options.Outbounds...), Endpoints: append([]option.Endpoint{}, options.Endpoints...), @@ -312,80 +232,67 @@ func (m *Manager) setServers(group ServerGroup, options Options) error { } maps.Copy(opts.Locations, options.Locations) maps.Copy(opts.Credentials, options.Credentials) - - m.servers[group] = opts - oMap := make(map[string]any, len(options.Endpoints)+len(options.Outbounds)) - for _, ep := range options.Endpoints { - oMap[ep.Tag] = ep + for _, ep := range opts.Endpoints { + m.optsMap.add(group, ep.Tag, ep.Type, ep, options.Locations[ep.Tag]) } - for _, out := range options.Outbounds { - oMap[out.Tag] = out + for _, out := range opts.Outbounds { + m.optsMap.add(group, out.Tag, out.Type, out, options.Locations[out.Tag]) } - m.optsMaps[group] = oMap + m.servers[group] = opts return nil } -// AddServers adds new servers to the specified group. If a server with the same tag already exists, -// it will be skipped. -func (m *Manager) AddServers(group ServerGroup, opts Options) error { +// AddServers adds new servers to the specified group. If force is true, it will overwrite any +// existing servers with the same tags. +func (m *Manager) AddServers(group ServerGroup, options Options, force bool) error { switch group { case SGLantern, SGUser: default: return fmt.Errorf("invalid server group: %s", group) } + if len(options.Endpoints) == 0 && len(options.Outbounds) == 0 { + return nil + } m.access.Lock() defer m.access.Unlock() - slog.Log(nil, internal.LevelTrace, "Adding servers", "group", group, "options", opts) - existingTags := m.merge(group, opts) - if len(existingTags) > 0 { - slog.Warn("Some servers were not added because they already exist", "tags", existingTags) - } + m.logger.Log(nil, log.LevelTrace, "Adding servers", "group", group, "options", options) + added := m.merge(group, options, force) if err := m.saveServers(); err != nil { return fmt.Errorf("failed to save servers: %w", err) } - if len(existingTags) > 0 { - slog.Warn("Tried to add some servers that already exist", "tags", existingTags) - return fmt.Errorf("some servers were not added because they already exist: %v", existingTags) - } - slog.Debug("Server configs added", "group", group, "newCount", len(opts.AllTags())) - events.Emit(ServersAddedEvent{ - Group: group, - Options: &opts, - }) + m.logger.Info("Server configs added", "group", group, "newCount", len(added)) return nil } -// merge adds new endpoints and outbounds to the specified group, skipping any that already exist. -// It returns the tags that were skipped. -func (m *Manager) merge(group ServerGroup, options Options) []string { - if len(options.Endpoints) == 0 && len(options.Outbounds) == 0 { - return nil - } - var existingTags []string - opts := m.optsMaps[group] +func (m *Manager) merge(group ServerGroup, options Options, force bool) []Server { + var added []Server servers := m.servers[group] for _, ep := range options.Endpoints { - if _, exists := opts[ep.Tag]; exists { - existingTags = append(existingTags, ep.Tag) - continue + if !force { + if _, exists := m.optsMap[ep.Tag]; exists { + continue + } } - opts[ep.Tag] = ep servers.Endpoints = append(servers.Endpoints, ep) servers.Locations[ep.Tag] = options.Locations[ep.Tag] + m.optsMap.add(group, ep.Tag, ep.Type, ep, options.Locations[ep.Tag]) + added = append(added, m.optsMap[ep.Tag]) if creds, ok := options.Credentials[ep.Tag]; ok { servers.Credentials[ep.Tag] = creds } } for _, out := range options.Outbounds { - if _, exists := opts[out.Tag]; exists { - existingTags = append(existingTags, out.Tag) - continue + if !force { + if _, exists := m.optsMap[out.Tag]; exists { + continue + } } - opts[out.Tag] = out servers.Outbounds = append(servers.Outbounds, out) servers.Locations[out.Tag] = options.Locations[out.Tag] + m.optsMap.add(group, out.Tag, out.Type, out, options.Locations[out.Tag]) + added = append(added, m.optsMap[out.Tag]) if creds, ok := options.Credentials[out.Tag]; ok { servers.Credentials[out.Tag] = creds } @@ -396,59 +303,74 @@ func (m *Manager) merge(group ServerGroup, options Options) []string { } servers.URLOverrides[k] = v } + if force { + servers.Endpoints = slices.CompactFunc(servers.Endpoints, func(ep1, ep2 option.Endpoint) bool { + return ep1.Tag == ep2.Tag + }) + servers.Outbounds = slices.CompactFunc(servers.Outbounds, func(ob1, ob2 option.Outbound) bool { + return ob1.Tag == ob2.Tag + }) + } m.servers[group] = servers - return existingTags + return added } // RemoveServer removes a server config by its tag. func (m *Manager) RemoveServer(tag string) error { + _, err := m.removeServers([]string{tag}) + return err +} + +// RemoveServers removes multiple server configs by their tags and returns the removed servers. +func (m *Manager) RemoveServers(tags []string) ([]Server, error) { + return m.removeServers(tags) +} + +func (m *Manager) removeServers(tags []string) ([]Server, error) { m.access.Lock() defer m.access.Unlock() - slog.Log(nil, internal.LevelTrace, "Removing server", "tag", tag) - // check which group the server belongs to so we can get the correct optsMaps and servers - group := SGLantern - if _, exists := m.optsMaps[group][tag]; !exists { - group = SGUser - if _, exists := m.optsMaps[group][tag]; !exists { - slog.Warn("Tried to remove non-existent server", "tag", tag) - return fmt.Errorf("server with tag %q not found", tag) + removed := make([]Server, 0, len(tags)) + remove := func(it any) bool { + var tag string + switch v := it.(type) { + case option.Endpoint: + tag = v.Tag + case option.Outbound: + tag = v.Tag + } + server, exists := m.optsMap[tag] + if exists { + removed = append(removed, server) + } + return exists + } + for group, options := range m.servers { + removed := removed[len(removed):] + options.Outbounds = slices.DeleteFunc(options.Outbounds, func(out option.Outbound) bool { + return remove(out) + }) + options.Endpoints = slices.DeleteFunc(options.Endpoints, func(ep option.Endpoint) bool { + return remove(ep) + }) + for _, server := range removed { + delete(options.Locations, server.Tag) + delete(m.optsMap, server.Tag) + } + m.servers[group] = options + if len(removed) > 0 { + m.logger.Info("Server configs removed", "group", group, "tags", removed) } } - // remove the server from the optsMaps and servers - servers := m.servers[group] - switch v := m.optsMaps[group][tag].(type) { - case option.Endpoint: - servers.Endpoints = remove(servers.Endpoints, v) - case option.Outbound: - servers.Outbounds = remove(servers.Outbounds, v) - } - delete(m.optsMaps[group], tag) - delete(servers.Locations, tag) - delete(servers.Credentials, tag) - m.servers[group] = servers - if err := m.saveServers(); err != nil { - return fmt.Errorf("failed to save servers after removing %q: %w", tag, err) - } - slog.Debug("Server config removed", "group", group, "tag", tag) - events.Emit(ServersRemovedEvent{ - Group: group, - Tag: tag, - }) - return nil -} -func remove[T comparable](slice []T, item T) []T { - i := slices.Index(slice, item) - if i == -1 { - return slice + if err := m.saveServers(); err != nil { + return nil, fmt.Errorf("failed to save servers: %w", err) } - slice[i] = slice[len(slice)-1] - return slice[:len(slice)-1] + return removed, nil } func (m *Manager) saveServers() error { - slog.Log(nil, internal.LevelTrace, "Saving server configs to file", "file", m.serversFile, "servers", m.servers) + m.logger.Log(nil, log.LevelTrace, "Saving server configs to file", "file", m.serversFile, "servers", m.servers) ctx := box.BaseContext() buf, err := json.MarshalContext(ctx, m.servers) if err != nil { @@ -477,7 +399,7 @@ func (m *Manager) loadServers() error { // Lantern Server Manager Integration // AddPrivateServer fetches VPN connection info from a remote server manager and adds it as a server. -func (m *Manager) AddPrivateServer(tag string, ip string, port int, accessToken string, serverLocation *C.ServerLocation, isJoined bool) error { +func (m *Manager) AddPrivateServer(tag, ip string, port int, accessToken string, loc C.ServerLocation, joined bool) error { u := &url.URL{ Scheme: "https", Host: net.JoinHostPort(ip, strconv.Itoa(port)), @@ -508,19 +430,20 @@ func (m *Manager) AddPrivateServer(tag string, ip string, port int, accessToken return fmt.Errorf("no endpoints or outbounds in response") } + // TODO: update when we support endpoints servers.Outbounds[0].Tag = tag // If the server location is provided, set it for the server's tag. - if serverLocation != nil { + if loc != (C.ServerLocation{}) { servers.Locations = map[string]C.ServerLocation{ - tag: *serverLocation, + tag: loc, } } // Store the credentials for the server's tag. servers.Credentials = map[string]ServerCredentials{ - tag: {AccessToken: accessToken, Port: port, IsJoined: isJoined}, + tag: {AccessToken: accessToken, Port: port, IsJoined: joined}, } - slog.Info("Adding private server from remote manager", "tag", tag, "ip", ip, "port", port, "location", serverLocation, "is_joined", isJoined) - return m.AddServers(SGUser, servers) + slog.Info("Adding private server from remote manager", "tag", tag, "ip", ip, "port", port, "location", loc, "is_joined", joined) + return m.AddServers(SGUser, servers, true) } // InviteToPrivateServer invites another user to the server manager instance and returns a connection @@ -564,36 +487,32 @@ func (m *Manager) RevokePrivateServerInvite(ip string, port int, accessToken str return nil } -// AddServerWithSingboxJSON parse a value that can be a JSON sing-box config. -// It parses the config into a sing-box config and add it to the user managed group. -func (m *Manager) AddServerWithSingboxJSON(ctx context.Context, value []byte) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerWithSingboxJSON") +// AddServersByJSON adds any outbounds and endpoints defined in the provided sing-box JSON config. +func (m *Manager) AddServersByJSON(ctx context.Context, config []byte) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerBySingboxJSON") defer span.End() - var opts Options - if err := json.UnmarshalContext(box.BaseContext(), value, &opts); err != nil { + opts, err := json.UnmarshalExtendedContext[Options](box.BaseContext(), config) + if err != nil { return traces.RecordError(ctx, fmt.Errorf("failed to parse config: %w", err)) } if len(opts.Endpoints) == 0 && len(opts.Outbounds) == 0 { return traces.RecordError(ctx, fmt.Errorf("no endpoints or outbounds found in the provided configuration")) } - if err := m.AddServers(SGUser, opts); err != nil { + if err := m.AddServers(SGUser, opts, true); err != nil { return traces.RecordError(ctx, fmt.Errorf("failed to add servers: %w", err)) } return nil } -// AddServerBasedOnURLs adds a server(s) based on the provided URL string. -// The URL can be a comma-separated list of URLs, URLs separated by new lines, or a single URL. -// Note that the UI allows the user to specify a server name. If there is only one URL, the server name overrides -// the tag typically included in the URL. If there are multiple URLs, the server name is ignored. -func (m *Manager) AddServerBasedOnURLs(ctx context.Context, urls string, skipCertVerification bool, serverName string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerBasedOnURLs") +// AddServersByURL adds a server(s) by downloading and parsing the config from a list of URLs. +func (m *Manager) AddServersByURL(ctx context.Context, urls []string, skipCertVerification bool) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerByURLs") defer span.End() urlProvider, loaded := pluriconfig.GetProvider(string(model.ProviderURL)) if !loaded { return traces.RecordError(ctx, fmt.Errorf("URL config provider not loaded")) } - cfg, err := urlProvider.Parse(ctx, []byte(urls)) + cfg, err := urlProvider.Parse(ctx, []byte(strings.Join(urls, "\n"))) if err != nil { return traces.RecordError(ctx, fmt.Errorf("failed to parse URLs: %w", err)) } @@ -602,17 +521,6 @@ func (m *Manager) AddServerBasedOnURLs(ctx context.Context, urls string, skipCer return traces.RecordError(ctx, fmt.Errorf("no valid URLs found in the provided configuration")) } - // If we only have a single URL, and the server name is specified, use that - // to override the tag specified in the anchor hash fragment. - if len(cfgURLs) == 1 && serverName != "" { - // override the tag, which is specified in the anchor hash fragment or - // in the tag query parameter. - q := cfgURLs[0].Query() - q.Del("tag") - cfgURLs[0].Fragment = serverName - cfgURLs[0].RawQuery = q.Encode() - cfg.Options = cfgURLs - } if skipCertVerification { urlsWithCustomOptions := make([]url.URL, 0, len(cfgURLs)) for _, v := range cfgURLs { @@ -632,6 +540,6 @@ func (m *Manager) AddServerBasedOnURLs(ctx context.Context, urls string, skipCer if err != nil { return traces.RecordError(ctx, fmt.Errorf("failed to serialize sing-box config: %w", err)) } - slog.Info("Adding servers based on URLs", "serverCount", len(cfgURLs), "skipCertVerification", skipCertVerification, "serverName", serverName) - return m.AddServerWithSingboxJSON(ctx, singBoxCfg) + m.logger.Info("Added servers based on URLs", "serverCount", len(cfgURLs), "skipCertVerification", skipCertVerification) + return m.AddServersByJSON(ctx, singBoxCfg) } diff --git a/servers/manager_test.go b/servers/manager_test.go index 0df4965c..5280397e 100644 --- a/servers/manager_test.go +++ b/servers/manager_test.go @@ -1,9 +1,7 @@ package servers import ( - "context" "crypto/tls" - "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -14,136 +12,24 @@ import ( "testing" C "github.com/getlantern/common" + box "github.com/getlantern/lantern-box" + + _ "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/common" ) -func newTestManager(t *testing.T) *Manager { - t.Helper() - dataPath := t.TempDir() - mgr := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: []option.Outbound{ - {Tag: "ss-denver", Type: "shadowsocks", Options: &option.ShadowsocksOutboundOptions{ - ServerOptions: option.ServerOptions{ - Server: "1.2.3.4", - ServerPort: 1080, - }, - Method: "chacha20-ietf-poly1305", - Password: "testpass", - }}, - }, - Endpoints: make([]option.Endpoint, 0), - Locations: map[string]C.ServerLocation{ - "ss-denver": {Country: "US", City: "Denver", CountryCode: "US"}, - }, - Credentials: make(map[string]ServerCredentials), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: {"ss-denver": option.Outbound{Tag: "ss-denver", Type: "shadowsocks", Options: &option.ShadowsocksOutboundOptions{ - ServerOptions: option.ServerOptions{Server: "1.2.3.4", ServerPort: 1080}, - Method: "chacha20-ietf-poly1305", - Password: "testpass", - }}}, - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - return mgr -} - -func TestServersJSON(t *testing.T) { - mgr := newTestManager(t) - - b, err := mgr.ServersJSON() - require.NoError(t, err) - require.NotEmpty(t, b) - - // Must be valid JSON - var raw map[string]json.RawMessage - require.NoError(t, json.Unmarshal(b, &raw), "ServersJSON must return valid JSON") - assert.Contains(t, raw, "lantern") - assert.Contains(t, raw, "user") - - // Lantern group must include the sing-box type-specific fields - lanternJSON := string(raw["lantern"]) - assert.Contains(t, lanternJSON, "shadowsocks", "should contain outbound type") - assert.Contains(t, lanternJSON, "1.2.3.4", "should contain server address") - assert.Contains(t, lanternJSON, "1080", "should contain server port") - assert.Contains(t, lanternJSON, "chacha20-ietf-poly1305", "should contain method") -} - -func TestGetServerByTagJSON(t *testing.T) { - mgr := newTestManager(t) - - t.Run("existing tag", func(t *testing.T) { - b, ok, err := mgr.GetServerByTagJSON("ss-denver") - require.NoError(t, err) - require.True(t, ok) - require.NotEmpty(t, b) - - // Must be valid JSON - var raw map[string]json.RawMessage - require.NoError(t, json.Unmarshal(b, &raw), "GetServerByTagJSON must return valid JSON") - assert.Contains(t, raw, "Tag") - assert.Contains(t, raw, "Type") - assert.Contains(t, raw, "Options") - assert.Contains(t, raw, "Location") - - // Verify the correct tag and type - fullJSON := string(b) - assert.Contains(t, fullJSON, "ss-denver") - assert.Contains(t, fullJSON, "shadowsocks") - assert.Contains(t, fullJSON, "Denver") - }) - - t.Run("missing tag", func(t *testing.T) { - b, ok, err := mgr.GetServerByTagJSON("nonexistent") - assert.NoError(t, err) - assert.False(t, ok) - assert.Nil(t, b) - }) -} - func TestPrivateServerIntegration(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - httpClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, + manager := testManager(t) + manager.httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, }, }, } @@ -154,24 +40,24 @@ func TestPrivateServerIntegration(t *testing.T) { port, _ := strconv.Atoi(parsedURL.Port()) t.Run("convert a token into a custom server", func(t *testing.T) { - require.NoError(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, "rootToken", nil, false)) - require.Contains(t, manager.optsMaps[SGUser], "s1", "server should be added to the manager") + require.NoError(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, "rootToken", C.ServerLocation{}, false)) + require.Contains(t, manager.optsMap, "s1", "server should be added to the manager") }) - t.Run("invite a user", func(t *testing.T) { + t.Run("invite user", func(t *testing.T) { inviteToken, err := manager.InviteToPrivateServer(parsedURL.Hostname(), port, "rootToken", "invite1") assert.NoError(t, err) assert.NotEmpty(t, inviteToken) - require.NoError(t, manager.AddPrivateServer("s2", parsedURL.Hostname(), port, inviteToken, nil, true)) - require.Contains(t, manager.optsMaps[SGUser], "s2", "server should be added for the invited user") + require.NoError(t, manager.AddPrivateServer("s2", parsedURL.Hostname(), port, inviteToken, C.ServerLocation{}, true)) + require.Contains(t, manager.optsMap, "s2", "server should be added for the invited user") t.Run("revoke user access", func(t *testing.T) { - delete(manager.optsMaps[SGUser], "s1") + delete(manager.optsMap, "s1") require.NoError(t, manager.RevokePrivateServerInvite(parsedURL.Hostname(), port, "rootToken", "invite1")) // trying to access again with the same token should fail - assert.Error(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, inviteToken, nil, true)) - assert.NotContains(t, manager.optsMaps[SGUser], "s1", "server should not be added after revoking invite") + assert.Error(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, inviteToken, C.ServerLocation{}, true)) + assert.NotContains(t, manager.optsMap, "s1", "server should not be added after revoking invite") }) }) @@ -185,8 +71,6 @@ type lanternServerManagerMock struct { func newLanternServerManagerMock() *httptest.Server { testConfig := ` { - "inbounds": [ - ], "outbounds": [ { "tag": "testing-out", @@ -243,231 +127,101 @@ func (s *lanternServerManagerMock) ServeHTTP(w http.ResponseWriter, r *http.Requ w.WriteHeader(http.StatusNotFound) } -func TestAddServerWithSingBoxJSON(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - - ctx := context.Background() - jsonConfig := ` - { - "outbounds": [ - { - "type": "shadowsocks", - "tag": "ss-out", - "server": "127.0.0.1", - "server_port": 8388, - "method": "chacha20-ietf-poly1305", - "password": "randompasswordwith24char", - "network": "tcp" - } - ] - }` - - t.Run("adding server with a sing-box json config should work", func(t *testing.T) { - require.NoError(t, manager.AddServerWithSingboxJSON(ctx, []byte(jsonConfig))) - }) - t.Run("using a empty config should return an error", func(t *testing.T) { - require.Error(t, manager.AddServerWithSingboxJSON(ctx, []byte{})) +func TestAddServersByJSON(t *testing.T) { + t.Run("valid config", func(t *testing.T) { + testConfig := []byte(` +{ + "outbounds": [ + { + "tag": "out", + "type": "shadowsocks", + "server": "127.0.0.1", + "server_port": 1080, + "method": "chacha20-ietf-poly1305", + "password": "", + } + ] +}`) + options, err := json.UnmarshalExtendedContext[Options](box.BaseContext(), testConfig) + require.NoError(t, err, "failed to unmarshal test config") + want := Server{ + Group: SGUser, + Tag: "out", + Type: "shadowsocks", + Options: options.Outbounds[0], + } + m := testManager(t) + require.NoError(t, m.AddServersByJSON(t.Context(), testConfig)) + got, exists := m.GetServerByTag("out") + assert.True(t, exists, "server was not added") + assert.Equal(t, want, got, "added server does not match expected configuration") }) - t.Run("providing a json that doesn't have any endpoints or outbounds should return a error", func(t *testing.T) { - require.Error(t, manager.AddServerWithSingboxJSON(ctx, json.RawMessage("{}"))) + t.Run("empty config", func(t *testing.T) { + m := testManager(t) + assert.Error(t, m.AddServersByJSON(t.Context(), []byte("{}"))) + assert.Empty(t, m.optsMap, "no servers should have been added") }) } -func TestAddServerBasedOnURLs(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - ctx := context.Background() - after := func() { - manager.RemoveServer("VLESS+over+WS+with+TLS") - manager.RemoveServer("Trojan+with+TLS") - manager.RemoveServer("SpecialName") - } - - urls := strings.Join([]string{ +func TestAddServersByURL(t *testing.T) { + urls := []string{ "vless://uuid@host:443?encryption=none&security=tls&type=ws&host=example.com&path=/vless#VLESS+over+WS+with+TLS", "trojan://password@host:443?security=tls&sni=example.com#Trojan+with+TLS", - }, "\n") - t.Run("adding server based on URLs should work", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, urls, false, "")) - assert.Contains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - assert.Contains(t, manager.optsMaps[SGUser], "Trojan+with+TLS") - after() - }) - - t.Run("using empty URLs should return an error", func(t *testing.T) { - require.Error(t, manager.AddServerBasedOnURLs(ctx, "", false, "")) - }) - - t.Run("skip certificate verification option works", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, urls, true, "")) - opts, isOutbound := manager.optsMaps[SGUser]["Trojan+with+TLS"].(option.Outbound) - require.True(t, isOutbound) - trojanSettings, ok := opts.Options.(*option.TrojanOutboundOptions) - require.True(t, ok) - require.NotNil(t, trojanSettings) - require.NotNil(t, trojanSettings.TLS) - assert.True(t, trojanSettings.OutboundTLSOptionsContainer.TLS.Insecure, trojanSettings.OutboundTLSOptionsContainer.TLS) - after() - }) - - url := "vless://uuid@host:443?encryption=none&security=tls&type=ws&host=example.com&path=/vless#VLESS+over+WS+with+TLS" - t.Run("adding single URL should work", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, url, false, "SpecialName")) - assert.Contains(t, manager.optsMaps[SGUser], "SpecialName") - assert.NotContains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - - require.NoError(t, manager.AddServerBasedOnURLs(ctx, url, false, "")) - assert.Contains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - assert.Contains(t, manager.optsMaps[SGUser], "SpecialName") - after() - }) -} -func TestServers(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: []option.Outbound{ - {Tag: "lantern-out", Type: "shadowsocks"}, - }, - Endpoints: []option.Endpoint{ - {Tag: "lantern-ep", Type: "shadowsocks"}, - }, - Locations: map[string]C.ServerLocation{ - "lantern-out": {City: "New York", Country: "US"}, - }, - }, - SGUser: Options{ - Outbounds: []option.Outbound{ - {Tag: "user-out", Type: "trojan"}, - }, - Endpoints: []option.Endpoint{ - {Tag: "user-ep", Type: "vless"}, - }, - Locations: map[string]C.ServerLocation{ - "user-out": {City: "London", Country: "GB"}, - }, - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: { - "lantern-out": option.Outbound{Tag: "lantern-out", Type: "shadowsocks"}, - "lantern-ep": option.Endpoint{Tag: "lantern-ep", Type: "shadowsocks"}, - }, - SGUser: { - "user-out": option.Outbound{Tag: "user-out", Type: "trojan"}, - "user-ep": option.Endpoint{Tag: "user-ep", Type: "vless"}, - }, - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), } - - t.Run("returns copy of servers", func(t *testing.T) { - servers := manager.Servers() - - require.NotNil(t, servers) - require.Contains(t, servers, SGLantern) - require.Contains(t, servers, SGUser) - - assert.Len(t, servers[SGLantern].Outbounds, 1) - assert.Len(t, servers[SGLantern].Endpoints, 1) - assert.Equal(t, "lantern-out", servers[SGLantern].Outbounds[0].Tag) - assert.Equal(t, "lantern-ep", servers[SGLantern].Endpoints[0].Tag) - - assert.Len(t, servers[SGUser].Outbounds, 1) - assert.Len(t, servers[SGUser].Endpoints, 1) - assert.Equal(t, "user-out", servers[SGUser].Outbounds[0].Tag) - assert.Equal(t, "user-ep", servers[SGUser].Endpoints[0].Tag) - - assert.Equal(t, "New York", servers[SGLantern].Locations["lantern-out"].City) - assert.Equal(t, "London", servers[SGUser].Locations["user-out"].City) + t.Run("valid urls", func(t *testing.T) { + m := testManager(t) + require.NoError(t, m.AddServersByURL(t.Context(), urls, false)) + _, exists := m.GetServerByTag("VLESS+over+WS+with+TLS") + assert.True(t, exists, "VLESS server should be added") + _, exists = m.GetServerByTag("Trojan+with+TLS") + assert.True(t, exists, "Trojan server should be added") }) - - t.Run("modifications to returned copy don't affect original", func(t *testing.T) { - servers := manager.Servers() - assert.Len(t, servers[SGLantern].Outbounds, 1) - assert.Len(t, servers[SGUser].Endpoints, 1) - - // Modify the copy - servers[SGLantern].Outbounds[0].Tag = "modified-out" - - // Original should remain unchanged - originalServers := manager.Servers() - assert.NotEqual(t, originalServers[SGLantern].Outbounds[0].Tag, "modified-out") + t.Run("skip certificate", func(t *testing.T) { + m := testManager(t) + require.NoError(t, m.AddServersByURL(t.Context(), urls, true)) + server, exists := m.GetServerByTag("Trojan+with+TLS") + require.True(t, exists, "Trojan server should be added") + + options := server.Options.(option.Outbound).Options + require.IsType(t, &option.TrojanOutboundOptions{}, options) + trojanOpts := options.(*option.TrojanOutboundOptions) + require.NotNil(t, trojanOpts.TLS) + assert.True(t, trojanOpts.TLS.Insecure, "TLS.Insecure should be true") }) - - t.Run("handles empty servers", func(t *testing.T) { - emptyManager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - - servers := emptyManager.Servers() - require.NotNil(t, servers) - assert.Len(t, servers[SGLantern].Outbounds, 0) - assert.Len(t, servers[SGLantern].Endpoints, 0) - assert.Len(t, servers[SGUser].Outbounds, 0) - assert.Len(t, servers[SGUser].Endpoints, 0) + t.Run("empty urls", func(t *testing.T) { + m := testManager(t) + assert.Error(t, m.AddServersByURL(t.Context(), []string{}, false)) + assert.Empty(t, m.optsMap, "no servers should have been added") }) } func TestRetryableHTTPClient(t *testing.T) { - cli := retryableHTTPClient().StandardClient() + cli := retryableHTTPClient(log.NoOpLogger()).StandardClient() request, err := http.NewRequest(http.MethodGet, "https://www.gstatic.com/generate_204", http.NoBody) require.NoError(t, err) resp, err := cli.Do(request) require.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) } + +func testManager(t *testing.T) *Manager { + return &Manager{ + servers: Servers{ + SGLantern: Options{ + Outbounds: make([]option.Outbound, 0), + Endpoints: make([]option.Endpoint, 0), + Locations: make(map[string]C.ServerLocation), + Credentials: make(map[string]ServerCredentials), + }, + SGUser: Options{ + Outbounds: make([]option.Outbound, 0), + Endpoints: make([]option.Endpoint, 0), + Locations: make(map[string]C.ServerLocation), + Credentials: make(map[string]ServerCredentials), + }, + }, + optsMap: map[string]Server{}, + serversFile: filepath.Join(t.TempDir(), internal.ServersFileName), + logger: log.NoOpLogger(), + } +} diff --git a/telemetry/connections.go b/telemetry/connections.go index acc55d1e..5f131ea9 100644 --- a/telemetry/connections.go +++ b/telemetry/connections.go @@ -9,13 +9,21 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" - "github.com/getlantern/radiance/vpn/ipc" + "github.com/getlantern/radiance/vpn" ) -// harvestConnectionMetrics periodically polls the number of active connections and their total +// ConnectionSource provides access to the current VPN connections for metrics collection. +type ConnectionSource interface { + Connections() ([]vpn.Connection, error) +} + +// StartConnectionMetrics periodically polls the number of active connections and their total // upload and download bytes, setting the corresponding OpenTelemetry metrics. It returns a function // that can be called to stop the polling. -func harvestConnectionMetrics(pollInterval time.Duration) func() { +// +// The caller is responsible for only calling this when the VPN is connected and telemetry is +// enabled, and for calling the returned stop function when either condition changes. +func StartConnectionMetrics(ctx context.Context, src ConnectionSource, pollInterval time.Duration) func() { ticker := time.NewTicker(pollInterval) meter := otel.Meter("github.com/getlantern/radiance/metrics") currentActiveConnections, err := meter.Int64Counter("current_active_connections", metric.WithDescription("Current number of active connections")) @@ -34,7 +42,7 @@ func harvestConnectionMetrics(pollInterval time.Duration) func() { if err != nil { slog.Warn("failed to create uplink_bytes metric", slog.Any("error", err)) } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) go func() { seenConnections := make(map[string]bool) for { @@ -44,16 +52,9 @@ func harvestConnectionMetrics(pollInterval time.Duration) func() { return case <-ticker.C: slog.Debug("polling connections for metrics", slog.Int("seen_connections", len(seenConnections)), slog.Duration("poll_interval", pollInterval)) - vpnStatus, err := ipc.GetStatus(ctx) - if err != nil { - slog.Warn("failed to get service status", "error", err) - } - if vpnStatus != ipc.Connected { - continue - } - conns, err := ipc.GetConnections(ctx) + conns, err := src.Connections() if err != nil { - slog.Warn("failed to retrieve connections", slog.Any("error", err)) + slog.Debug("failed to retrieve connections for metrics", slog.Any("error", err)) continue } diff --git a/telemetry/otel.go b/telemetry/otel.go index 3875b1c8..13eb1621 100644 --- a/telemetry/otel.go +++ b/telemetry/otel.go @@ -18,6 +18,7 @@ import ( "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/metric/noop" "go.opentelemetry.io/otel/propagation" sdkmetric "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/resource" @@ -31,10 +32,8 @@ import ( ) var ( - initMutex sync.Mutex - shutdownOTEL func(context.Context) error - harvestConnections sync.Once - harvestConnectionTickerStop func() + initMutex sync.Mutex + shutdownOTEL func(context.Context) error ) type Attributes struct { @@ -56,18 +55,18 @@ type Attributes struct { // OnNewConfig handles OpenTelemetry re-initialization when the configuration changes. func OnNewConfig(oldConfig, newConfig *config.Config, deviceID string) error { // Check if the old OTEL configuration is the same as the new one. - if oldConfig != nil && reflect.DeepEqual(oldConfig.ConfigResponse.OTEL, newConfig.ConfigResponse.OTEL) { + if oldConfig != nil && reflect.DeepEqual(oldConfig.OTEL, newConfig.OTEL) { slog.Debug("OpenTelemetry configuration has not changed, skipping initialization") return nil } - if err := initialize(deviceID, newConfig.ConfigResponse, settings.IsPro()); err != nil { + if err := Initialize(deviceID, *newConfig, settings.IsPro()); err != nil { slog.Error("Failed to initialize OpenTelemetry", "error", err) return fmt.Errorf("Failed to initialize OpenTelemetry: %w", err) } return nil } -func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) error { +func Initialize(deviceID string, configResponse config.Config, pro bool) error { initMutex.Lock() defer initMutex.Unlock() @@ -107,24 +106,19 @@ func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) } shutdownOTEL = shutdown - - harvestConnections.Do(func() { - harvestConnectionTickerStop = harvestConnectionMetrics(1 * time.Minute) - }) return nil } -func Close(ctx context.Context) error { +func Close() error { + return CloseContext(context.Background()) +} + +func CloseContext(ctx context.Context) error { initMutex.Lock() defer initMutex.Unlock() var errs error - // stop collecting connection metrics - if harvestConnectionTickerStop != nil { - harvestConnectionTickerStop() - } - if shutdownOTEL != nil { slog.Info("Shutting down existing OpenTelemetry SDK") if err := shutdownOTEL(ctx); err != nil { @@ -133,6 +127,8 @@ func Close(ctx context.Context) error { } shutdownOTEL = nil } + // otel.SetTracerProvider(traceNoop.NewTracerProvider()) + otel.SetMeterProvider(noop.NewMeterProvider()) return errs } @@ -157,7 +153,7 @@ func buildResources(serviceName string, a Attributes) []attribute.KeyValue { // setupOTelSDK bootstraps the OpenTelemetry pipeline. // If it does not return an error, make sure to call shutdown for proper cleanup. -func setupOTelSDK(ctx context.Context, attributes Attributes, cfg common.ConfigResponse) (func(context.Context) error, error) { +func setupOTelSDK(ctx context.Context, attributes Attributes, cfg config.Config) (func(context.Context) error, error) { if cfg.Features == nil { cfg.Features = make(map[string]bool) } diff --git a/tester/main.go b/tester/main.go index 10120f8c..7a0d7470 100644 --- a/tester/main.go +++ b/tester/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log/slog" "os" @@ -8,10 +9,11 @@ import ( "strconv" "time" - "github.com/getlantern/radiance" + "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/config" "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/ipc" "github.com/getlantern/radiance/vpn" ) @@ -20,7 +22,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i os.RemoveAll(dataDir) } os.MkdirAll(dataDir, 0o755) - r, err := radiance.NewRadiance(radiance.Options{ + be, err := backend.NewLocalBackend(context.Background(), backend.Options{ DataDir: dataDir, LogDir: dataDir, Locale: "en-US", @@ -28,7 +30,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i if err != nil { return fmt.Errorf("failed to create radiance instance: %w", err) } - defer r.Close() + defer be.Close() settings.Set(settings.UserIDKey, userId) settings.Set(settings.TokenKey, token) settings.Set(settings.UserLevelKey, "") @@ -40,14 +42,16 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i }, }) - ipcServer, err := vpn.InitIPC(dataDir, "", "trace", nil) + be.Start() + + ipcServer := ipc.NewServer(be, false) + err = ipcServer.Start() if err != nil { return fmt.Errorf("failed to initialize IPC server: %w", err) } exit := func() { - status, _ := vpn.GetStatus() - if status.TunnelOpen { - vpn.Disconnect() + if be.VPNStatus() != vpn.Disconnected { + be.DisconnectVPN() } ipcServer.Close() } @@ -70,7 +74,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i } } t1 := time.Now() - if err = vpn.QuickConnect("all", nil); err != nil { + if err = be.ConnectVPN(vpn.AutoSelectTag); err != nil { return fmt.Errorf("quick connect failed: %w", err) } fmt.Println("Quick connect successful") @@ -79,7 +83,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i proxyAddr := os.Getenv("RADIANCE_SOCKS_ADDRESS") if proxyAddr == "" { - proxyAddr = "127.0.0.1:6666" + proxyAddr = "127.0.0.1:6666" } cmd := exec.Command("curl", "-v", "-x", proxyAddr, "-s", urlToHit) diff --git a/traces/errors.go b/traces/errors.go index cccef67e..6ed6c319 100644 --- a/traces/errors.go +++ b/traces/errors.go @@ -7,6 +7,7 @@ import ( "go.opentelemetry.io/otel/trace" ) +// RecordError records the given error in the current span. If error is nil, it is noop. func RecordError(ctx context.Context, err error, options ...trace.EventOption) error { if err == nil { return nil diff --git a/vpn/boxoptions.go b/vpn/boxoptions.go index ff0e1ed1..521a134b 100644 --- a/vpn/boxoptions.go +++ b/vpn/boxoptions.go @@ -9,13 +9,14 @@ import ( "log/slog" "net/netip" "path/filepath" + "slices" "time" - lcommon "github.com/getlantern/common" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + lcommon "github.com/getlantern/common" box "github.com/getlantern/lantern-box" lbC "github.com/getlantern/lantern-box/constant" lbO "github.com/getlantern/lantern-box/option" @@ -28,17 +29,14 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" "github.com/getlantern/radiance/common/env" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/config" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" ) const ( - autoAllTag = "auto" - - autoLanternTag = "auto-lantern" - autoUserTag = "auto-user" + AutoSelectTag = "auto" + ManualSelectTag = "manual" urlTestInterval = 3 * time.Minute // must be less than urlTestIdleTimeout urlTestIdleTimeout = 15 * time.Minute @@ -47,11 +45,35 @@ const ( cacheFileName = "lantern.cache" ) +var reservedTags = []string{AutoSelectTag, ManualSelectTag, "direct", "block"} + +func ReservedTags() []string { + return slices.Clone(reservedTags) +} + +type BoxOptions struct { + BasePath string `json:"base_path,omitempty"` + // Options contains the main options that are merged into the base options with the exception of + // DNS, which overrides the base DNS options entirely instead of being merged. Options should + // contain all servers (both lantern and user). + Options O.Options `json:"options,omitempty"` + // SmartRouting contains smart routing rules to merge into the final options. + SmartRouting lcommon.SmartRoutingRules `json:"smart_routing,omitempty"` + // AdBlock contains ad block rules to merge into the final options. + AdBlock lcommon.AdBlockRules `json:"ad_block,omitempty"` + // BanditURLOverrides maps outbound tags to per-proxy callback URLs for + // the bandit Thompson sampling system. When set, these override the + // default MutableURLTest URL for each specific outbound, allowing the + // server to detect which proxies successfully connected. + BanditURLOverrides map[string]string `json:"bandit_url_overrides,omitempty"` + BanditThroughputURL string `json:"bandit_throughput_url,omitempty"` +} + // this is the base options that is need for everything to work correctly. this should not be // changed unless you know what you're doing. func baseOpts(basePath string) O.Options { splitTunnelPath := filepath.Join(basePath, splitTunnelFile) - + cacheFile := filepath.Join(basePath, cacheFileName) loopbackAddr := badoption.Addr(netip.MustParseAddr("127.0.0.1")) return O.Options{ Log: &O.LogOptions{ @@ -121,13 +143,13 @@ func baseOpts(basePath string) O.Options { }, Experimental: &O.ExperimentalOptions{ ClashAPI: &O.ClashAPIOptions{ - DefaultMode: autoAllTag, - ModeList: []string{servers.SGLantern, servers.SGUser, autoAllTag}, + DefaultMode: AutoSelectTag, + ModeList: []string{servers.SGLantern, servers.SGUser, AutoSelectTag}, ExternalController: "", // intentionally left empty }, CacheFile: &O.CacheFileOptions{ Enabled: true, - Path: cacheFileName, + Path: cacheFile, CacheID: cacheID, }, }, @@ -228,25 +250,21 @@ func baseRoutingRules() []O.Rule { } // buildOptions builds the box options using the config options and user servers. -func buildOptions(ctx context.Context, path string) (O.Options, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "buildOptions") +func buildOptions(bOptions BoxOptions) (O.Options, error) { + _, span := otel.Tracer(tracerName).Start(context.Background(), "buildOptions") defer span.End() - slog.Log(nil, internal.LevelTrace, "Starting buildOptions", "path", path) - - opts := baseOpts(path) - slog.Debug("Base options initialized") + if len(bOptions.Options.Outbounds) == 0 && len(bOptions.Options.Endpoints) == 0 { + return O.Options{}, errors.New("no outbounds or endpoints found in config or user servers") + } - // update default options and paths - opts.Experimental.CacheFile.Path = filepath.Join(path, cacheFileName) + slog.Log(nil, log.LevelTrace, "Starting buildOptions", "path", bOptions.BasePath) - slog.Log(nil, internal.LevelTrace, "Updated default options and paths", - "cacheFilePath", opts.Experimental.CacheFile.Path, - "clashAPIDefaultMode", opts.Experimental.ClashAPI.DefaultMode, - ) + opts := baseOpts(bOptions.BasePath) + slog.Debug("Base options initialized") - if _, useSocks := env.Get[bool](env.UseSocks); useSocks { - socksAddr, _ := env.Get[string](env.SocksAddress) + if env.GetBool(env.UseSocks) { + socksAddr, _ := env.Get(env.SocksAddress) slog.Info("Using SOCKS proxy for inbound as per environment variable", "socksAddr", socksAddr) addrPort, err := netip.ParseAddrPort(socksAddr) if err != nil { @@ -276,81 +294,39 @@ func buildOptions(ctx context.Context, path string) (O.Options, error) { } } - // Load config file - confPath := filepath.Join(path, common.ConfigFileName) - slog.Debug("Loading config file", "confPath", confPath) - cfg, err := loadConfig(confPath) - if err != nil { - slog.Error("Failed to load config options", "error", err) - return O.Options{}, err - } - // add smart routing and ad block rules - if settings.GetBool(settings.SmartRoutingKey) && len(cfg.SmartRouting) > 0 { + if len(bOptions.SmartRouting) > 0 { slog.Debug("Adding smart-routing rules") - outbounds, rules, rulesets := cfg.SmartRouting.ToOptions(urlTestInterval, urlTestIdleTimeout) + outbounds, rules, rulesets := bOptions.SmartRouting.ToOptions(urlTestInterval, urlTestIdleTimeout) opts.Outbounds = append(opts.Outbounds, outbounds...) opts.Route.Rules = append(opts.Route.Rules, rules...) opts.Route.RuleSet = append(opts.Route.RuleSet, rulesets...) } - if settings.GetBool(settings.AdBlockKey) && len(cfg.AdBlock) > 0 { + if len(bOptions.AdBlock) > 0 { slog.Debug("Adding ad-block rules") - rule, rulesets := cfg.AdBlock.ToOptions() + rule, rulesets := bOptions.AdBlock.ToOptions() opts.Route.Rules = append(opts.Route.Rules, rule) opts.Route.RuleSet = append(opts.Route.RuleSet, rulesets...) } - var lanternTags []string - configOpts := cfg.Options - if len(configOpts.Outbounds) == 0 && len(configOpts.Endpoints) == 0 { - slog.Warn("Config loaded but no outbounds or endpoints found") - } - lanternTags = mergeAndCollectTags(&opts, &configOpts) - slog.Debug("Merged config options", "tags", lanternTags) - - appendGroupOutbounds(&opts, servers.SGLantern, autoLanternTag, lanternTags, cfg.BanditURLOverrides) + tags := mergeAndCollectTags(&opts, &bOptions.Options) - // Load user servers - slog.Debug("Loading user servers") - userOpts, err := loadUserOptions(path) - if err != nil { - slog.Error("Failed to load user servers", "error", err) - return O.Options{}, err - } - var userTags []string - if len(userOpts.Outbounds) == 0 && len(userOpts.Endpoints) == 0 { - slog.Info("No user servers found") - } else { - userTags = mergeAndCollectTags(&opts, &userOpts) - slog.Debug("Merged user server options", "tags", userTags) - } - appendGroupOutbounds(&opts, servers.SGUser, autoUserTag, userTags, nil) - - if len(lanternTags) == 0 && len(userTags) == 0 { - return O.Options{}, errors.New("no outbounds or endpoints found in config or user servers") - } - - // Add auto all outbound - opts.Outbounds = append(opts.Outbounds, urlTestOutbound(autoAllTag, []string{autoLanternTag, autoUserTag}, nil)) - - // Add routing rules for the groups - opts.Route.Rules = append(opts.Route.Rules, groupRule(autoAllTag)) - opts.Route.Rules = append(opts.Route.Rules, groupRule(servers.SGLantern)) - opts.Route.Rules = append(opts.Route.Rules, groupRule(servers.SGUser)) + // add mode selector outbounds and rules + opts.Outbounds = append(opts.Outbounds, urlTestOutbound(AutoSelectTag, tags, bOptions.BanditURLOverrides)) + opts.Outbounds = append(opts.Outbounds, selectorOutbound(ManualSelectTag, tags)) + opts.Route.Rules = append(opts.Route.Rules, selectModeRule(AutoSelectTag)) + opts.Route.Rules = append(opts.Route.Rules, selectModeRule(ManualSelectTag)) // catch-all rule to ensure no fallthrough opts.Route.Rules = append(opts.Route.Rules, catchAllBlockerRule()) - slog.Debug("Finished building options", slog.String("env", common.Env())) + slog.Debug("Finished building options", "env", common.Env()) span.AddEvent("finished building options", trace.WithAttributes( - attribute.String("options", string(writeBoxOptions(path, opts))), - attribute.String("env", common.Env()), + attribute.String("options", string(writeBoxOptions(bOptions.BasePath, opts))), )) return opts, nil } -const debugLanternBoxOptionsFilename = "debug-lantern-box-options.json" - // writeBoxOptions marshals the options as JSON and stores them in a file so we can debug them // we can ignore the errors here since the tunnel will error out anyway if something is wrong func writeBoxOptions(path string, opts O.Options) []byte { @@ -365,37 +341,17 @@ func writeBoxOptions(path string, opts O.Options) []byte { slog.Warn("failed to indent marshaled options while writing debug box options", slog.Any("error", err)) return buf } - if err := atomicfile.WriteFile(filepath.Join(path, debugLanternBoxOptionsFilename), b.Bytes(), 0644); err != nil { + if err := atomicfile.WriteFile(filepath.Join(path, internal.DebugBoxOptionsFileName), b.Bytes(), 0644); err != nil { slog.Warn("failed to write options file", slog.Any("error", err)) return buf } return b.Bytes() } -/////////////////////// +////////////////////// // Helper functions // ////////////////////// -func loadConfig(path string) (lcommon.ConfigResponse, error) { - cfg, err := config.Load(path) - if err != nil { - return lcommon.ConfigResponse{}, fmt.Errorf("load config: %w", err) - } - if cfg == nil { - return lcommon.ConfigResponse{}, nil - } - return cfg.ConfigResponse, nil -} - -func loadUserOptions(path string) (O.Options, error) { - mgr, err := servers.NewManager(path) - if err != nil { - return O.Options{}, fmt.Errorf("server manager: %w", err) - } - u := mgr.Servers()[servers.SGUser] - return O.Options{Outbounds: u.Outbounds, Endpoints: u.Endpoints}, nil -} - // mergeAndCollectTags merges src into dst and returns all outbound/endpoint tags from src. func mergeAndCollectTags(dst, src *O.Options) []string { dst.Outbounds = append(dst.Outbounds, src.Outbounds...) @@ -429,30 +385,6 @@ func useIfNotZero[T comparable](newVal, oldVal T) T { return oldVal } -func appendGroupOutbounds(opts *O.Options, serverGroup, autoTag string, tags []string, urlOverrides map[string]string) { - opts.Outbounds = append(opts.Outbounds, urlTestOutbound(autoTag, tags, urlOverrides)) - opts.Outbounds = append(opts.Outbounds, selectorOutbound(serverGroup, append([]string{autoTag}, tags...))) - slog.Log( - nil, internal.LevelTrace, "Added group outbounds", - "serverGroup", serverGroup, - "tags", tags, - "outbounds", opts.Outbounds[len(opts.Outbounds)-2:], - ) -} - -func groupAutoTag(group string) string { - switch group { - case servers.SGLantern: - return autoLanternTag - case servers.SGUser: - return autoUserTag - case "all", "": - return autoAllTag - default: - return "" - } -} - func urlTestOutbound(tag string, outbounds []string, urlOverrides map[string]string) O.Outbound { return O.Outbound{ Type: lbC.TypeMutableURLTest, @@ -467,27 +399,27 @@ func urlTestOutbound(tag string, outbounds []string, urlOverrides map[string]str } } -func selectorOutbound(group string, outbounds []string) O.Outbound { +func selectorOutbound(tag string, outbounds []string) O.Outbound { return O.Outbound{ Type: lbC.TypeMutableSelector, - Tag: group, + Tag: tag, Options: &lbO.MutableSelectorOutboundOptions{ Outbounds: outbounds, }, } } -func groupRule(group string) O.Rule { +func selectModeRule(mode string) O.Rule { return O.Rule{ Type: C.RuleTypeDefault, DefaultOptions: O.DefaultRule{ RawDefaultRule: O.RawDefaultRule{ - ClashMode: group, + ClashMode: mode, }, RuleAction: O.RuleAction{ Action: C.RuleActionTypeRoute, RouteOptions: O.RouteActionOptions{ - Outbound: group, + Outbound: mode, }, }, }, @@ -506,7 +438,6 @@ func catchAllBlockerRule() O.Rule { } } - func newDNSServerOptions(typ, tag, server, domainResolver string) O.DNSServerOptions { var serverOpts any remoteOpts := O.RemoteDNSServerOptions{ diff --git a/vpn/boxoptions_test.go b/vpn/boxoptions_test.go index 4009ad62..18225bb4 100644 --- a/vpn/boxoptions_test.go +++ b/vpn/boxoptions_test.go @@ -1,118 +1,56 @@ package vpn import ( - "context" - "fmt" "os" - "path/filepath" "slices" "testing" - "github.com/sagernet/sing-box/constant" O "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - LC "github.com/getlantern/common" box "github.com/getlantern/lantern-box" lbO "github.com/getlantern/lantern-box/option" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/config" - "github.com/getlantern/radiance/servers" ) func TestBuildOptions(t *testing.T) { - testOpts, _, err := testBoxOptions("") - require.NoError(t, err, "get test box options") - lanternTags, lanternOuts := filterOutbounds(*testOpts, constant.TypeHTTP) - userTags, userOuts := filterOutbounds(*testOpts, constant.TypeSOCKS) - cfg := config.Config{ - ConfigResponse: LC.ConfigResponse{ - Options: O.Options{ - Outbounds: lanternOuts, - }, - }, - } - svrs := servers.Servers{ - servers.SGUser: servers.Options{ - Outbounds: userOuts, - }, - } + options, tags := testBoxOptions(t) tests := []struct { name string - lanternTags []string - userTags []string + boxOptions BoxOptions shouldError bool }{ { - name: "config without user servers", - lanternTags: lanternTags, - }, - { - name: "user servers without config", - userTags: userTags, - }, - { - name: "config and user servers", - lanternTags: lanternTags, - userTags: userTags, + name: "success", + boxOptions: BoxOptions{ + BasePath: t.TempDir(), + Options: options, + }, }, { - name: "neither config nor user servers", + name: "no servers available", + boxOptions: BoxOptions{ + BasePath: t.TempDir(), + }, shouldError: true, }, } - hasGroupWithTags := func(t *testing.T, outs []O.Outbound, group string, tags []string) { - out := findOutbound(outs, group) - if !assert.NotNilf(t, out, "group %s not found", group) { - return - } - switch opts := out.Options.(type) { - case *lbO.MutableSelectorOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *O.SelectorOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *lbO.MutableURLTestOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *O.URLTestOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - default: - assert.Failf(t, fmt.Sprintf("%s[%T] is not a group outbound", group, opts), "") - } - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - path := t.TempDir() - if len(tt.lanternTags) > 0 { - testOptsToFile(t, cfg, filepath.Join(path, common.ConfigFileName)) - } - if len(tt.userTags) > 0 { - testOptsToFile(t, svrs, filepath.Join(path, common.ServersFileName)) - } - opts, err := buildOptions(context.Background(), path) + opts, err := buildOptions(tt.boxOptions) if tt.shouldError { require.Error(t, err, "expected error but got none") return } require.NoError(t, err) - gotOutbounds := opts.Outbounds - require.NotEmpty(t, gotOutbounds, "no outbounds in built options") - - assert.NotNil(t, findOutbound(gotOutbounds, constant.TypeDirect), "direct outbound not found") - assert.NotNil(t, findOutbound(gotOutbounds, constant.TypeBlock), "block outbound not found") - - hasGroupWithTags(t, gotOutbounds, servers.SGLantern, append(tt.lanternTags, autoLanternTag)) - hasGroupWithTags(t, gotOutbounds, servers.SGUser, append(tt.userTags, autoUserTag)) - - hasGroupWithTags(t, gotOutbounds, autoLanternTag, tt.lanternTags) - hasGroupWithTags(t, gotOutbounds, autoUserTag, tt.userTags) - hasGroupWithTags(t, gotOutbounds, autoAllTag, []string{autoLanternTag, autoUserTag}) - - assert.FileExists(t, filepath.Join(path, debugLanternBoxOptionsFilename), "debug option file must be written") + urlTest := urlTestOutbound(AutoSelectTag, tags, nil) + assert.Contains(t, opts.Outbounds, urlTest, "options should contain auto-select URL test outbound") + selector := selectorOutbound(ManualSelectTag, tags) + assert.Contains(t, opts.Outbounds, selector, "options should contain manual-select selector outbound") }) } } @@ -185,18 +123,14 @@ func TestBuildOptions_Rulesets(t *testing.T) { wantAdBlockOpts, err := json.UnmarshalExtendedContext[O.Options](box.BaseContext(), []byte(adBlockJSON)) require.NoError(t, err) - buf, err := os.ReadFile("testdata/config.json") - require.NoError(t, err, "read test config file") - + cfg := testConfig(t) + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + } t.Run("with smart routing", func(t *testing.T) { - tmp := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(tmp, common.ConfigFileName), buf, 0644), "write test config file to temp dir") - - require.NoError(t, settings.InitSettings(tmp)) - t.Cleanup(settings.Reset) - - settings.Set(settings.SmartRoutingKey, true) - options, err := buildOptions(context.Background(), tmp) + boxOptions.SmartRouting = cfg.SmartRouting + options, err := buildOptions(boxOptions) require.NoError(t, err) // check rules, rulesets, and outbounds are correctly built into options assert.True(t, contains(t, options.Route.Rules, wantSmartRoutingOpts.Route.Rules[0]), "missing smart routing rule") @@ -204,14 +138,8 @@ func TestBuildOptions_Rulesets(t *testing.T) { assert.True(t, contains(t, options.Outbounds, wantSmartRoutingOpts.Outbounds[0]), "missing smart routing outbound") }) t.Run("with ad block", func(t *testing.T) { - tmp := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(tmp, common.ConfigFileName), buf, 0644), "write test config file to temp dir") - - require.NoError(t, settings.InitSettings(tmp)) - t.Cleanup(settings.Reset) - - settings.Set(settings.AdBlockKey, true) - options, err := buildOptions(context.Background(), tmp) + boxOptions.AdBlock = cfg.AdBlock + options, err := buildOptions(boxOptions) require.NoError(t, err) // check reject rule and rulesets are correctly built into options for _, rs := range wantAdBlockOpts.Route.RuleSet { @@ -224,32 +152,23 @@ func TestBuildOptions_Rulesets(t *testing.T) { } func TestBuildOptions_BanditURLOverrides(t *testing.T) { - testOpts, _, err := testBoxOptions("") - require.NoError(t, err) - lanternTags, lanternOuts := filterOutbounds(*testOpts, constant.TypeHTTP) - require.NotEmpty(t, lanternTags, "need at least one HTTP outbound for test") - + cfg := testConfig(t) overrides := map[string]string{ - lanternTags[0]: "https://example.com/callback?token=abc", + cfg.Options.Outbounds[0].Tag: "https://example.com/callback?token=abc", } - cfg := config.Config{ - ConfigResponse: LC.ConfigResponse{ - Options: O.Options{Outbounds: lanternOuts}, - BanditURLOverrides: overrides, - }, + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + BanditURLOverrides: overrides, } - - path := t.TempDir() - testOptsToFile(t, cfg, filepath.Join(path, common.ConfigFileName)) - - opts, err := buildOptions(context.Background(), path) + opts, err := buildOptions(boxOptions) require.NoError(t, err) - out := findOutbound(opts.Outbounds, autoLanternTag) - require.NotNil(t, out, "auto-lantern outbound not found") + out := findOutbound(opts.Outbounds, AutoSelectTag) + require.NotNil(t, out, "missing auto-select outbound") - mutOpts, ok := out.Options.(*lbO.MutableURLTestOutboundOptions) - require.True(t, ok, "auto-lantern outbound should be MutableURLTestOutboundOptions") + require.IsType(t, &lbO.MutableURLTestOutboundOptions{}, out.Options, "auto-select outbound options should be MutableURLTestOutboundOptions") + mutOpts := out.Options.(*lbO.MutableURLTestOutboundOptions) assert.Equal(t, overrides, mutOpts.URLOverrides, "URLOverrides should be wired from config") } @@ -292,24 +211,23 @@ func findOutbound(outs []O.Outbound, tag string) *O.Outbound { return &outs[idx] } -func testOptsToFile[T any](t *testing.T, opts T, path string) { - buf, err := json.Marshal(opts) - require.NoError(t, err, "marshal options") - require.NoError(t, os.WriteFile(path, buf, 0644), "write options to file") +func testConfig(t *testing.T) config.Config { + buf, err := os.ReadFile("testdata/config.json") + require.NoError(t, err, "read test config file") + + cfg, err := json.UnmarshalExtendedContext[config.Config](box.BaseContext(), buf) + require.NoError(t, err, "unmarshal test config") + return cfg } -func testBoxOptions(tmpPath string) (*O.Options, string, error) { - content, err := os.ReadFile("testdata/boxopts.json") - if err != nil { - return nil, "", err +func testBoxOptions(t *testing.T) (O.Options, []string) { + cfg := testConfig(t) + var tags []string + for _, o := range cfg.Options.Outbounds { + tags = append(tags, o.Tag) } - opts, err := json.UnmarshalExtendedContext[O.Options](box.BaseContext(), content) - if err != nil { - return nil, "", err + for _, ep := range cfg.Options.Endpoints { + tags = append(tags, ep.Tag) } - - opts.Experimental.CacheFile.Path = filepath.Join(tmpPath, cacheFileName) - opts.Experimental.CacheFile.CacheID = cacheID - buf, _ := json.Marshal(opts) - return &opts, string(buf), nil + return cfg.Options, tags } diff --git a/vpn/dnsoptions_test.go b/vpn/dnsoptions_test.go index 06b49f1b..9f5866b8 100644 --- a/vpn/dnsoptions_test.go +++ b/vpn/dnsoptions_test.go @@ -3,6 +3,8 @@ package vpn import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/getlantern/radiance/common/settings" ) @@ -62,9 +64,7 @@ func TestNormalizeLocale(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := normalizeLocale(tt.locale) - if result != tt.expected { - t.Errorf("normalizeLocale(%q) = %q, expected %q", tt.locale, result, tt.expected) - } + assert.Equalf(t, tt.expected, result, "normalizeLocale(%q) should return %q", tt.locale, tt.expected) }) } } @@ -138,43 +138,7 @@ func TestLocalDNSIP(t *testing.T) { settings.Set(settings.LocaleKey, tt.locale) result := localDNSIP() - if result != tt.expected { - t.Errorf("localDNSIP() with locale %q = %q, expected %q", tt.locale, result, tt.expected) - } + assert.Equalf(t, tt.expected, result, "localDNSIP() with locale %q should return %q", tt.locale, tt.expected) }) } } -func TestBuildDNSRules(t *testing.T) { - rules := buildDNSRules() - - if len(rules) != 1 { - t.Fatalf("expected 1 DNS rule, got %d", len(rules)) - } - - rule := rules[0] - - if rule.Type != "default" { - t.Errorf("expected rule type 'default', got %q", rule.Type) - } - - if rule.DefaultOptions.DNSRuleAction.Action != "route" { - t.Errorf("expected action 'route', got %q", rule.DefaultOptions.DNSRuleAction.Action) - } - - if rule.DefaultOptions.DNSRuleAction.RouteOptions.Server != "dns_fakeip" { - t.Errorf("expected server 'dns_fakeip', got %q", rule.DefaultOptions.DNSRuleAction.RouteOptions.Server) - } - - queryTypes := rule.DefaultOptions.RawDefaultDNSRule.QueryType - if len(queryTypes) != 2 { - t.Fatalf("expected 2 query types, got %d", len(queryTypes)) - } - - if queryTypes[0] != 1 { // dns.TypeA - t.Errorf("expected first query type to be TypeA (1), got %d", queryTypes[0]) - } - - if queryTypes[1] != 28 { // dns.TypeAAAA - t.Errorf("expected second query type to be TypeAAAA (28), got %d", queryTypes[1]) - } -} diff --git a/vpn/ipc.go b/vpn/ipc.go deleted file mode 100644 index 795cbd88..00000000 --- a/vpn/ipc.go +++ /dev/null @@ -1,45 +0,0 @@ -package vpn - -import ( - "context" - "fmt" - "log/slog" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn/ipc" - "github.com/getlantern/radiance/vpn/rvpn" -) - -// InitIPC initializes and returns a started IPC server. -func InitIPC(dataPath, logPath, logLevel string, platformIfce rvpn.PlatformInterface) (*ipc.Server, error) { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "initIPC", - trace.WithAttributes(attribute.String("dataPath", dataPath)), - ) - defer span.End() - - span.AddEvent("initializing IPC server") - - if err := common.InitReadOnly(dataPath, logPath, logLevel); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("init common ro: %w", err)) - } - if path := settings.GetString(settings.DataPathKey); path != "" && path != dataPath { - dataPath = path - } - - server := ipc.NewServer(NewTunnelService(dataPath, slog.Default().With("service", "ipc"), platformIfce)) - slog.Debug("starting IPC server") - if err := server.Start(); err != nil { - slog.Error("failed to start IPC server", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("start IPC server: %w", err)) - } - - return server, nil -} diff --git a/vpn/ipc/clash_mode.go b/vpn/ipc/clash_mode.go deleted file mode 100644 index ec0f9e97..00000000 --- a/vpn/ipc/clash_mode.go +++ /dev/null @@ -1,64 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "log/slog" - "net/http" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/internal" -) - -type m struct { - Mode string `json:"mode"` -} - -// GetClashMode retrieves the current mode from the Clash server. -func GetClashMode(ctx context.Context) (string, error) { - res, err := sendRequest[m](ctx, "GET", clashModeEndpoint, nil) - if err != nil { - return "", err - } - return res.Mode, nil -} - -// SetClashMode sets the mode of the Clash server. -func SetClashMode(ctx context.Context, mode string) error { - _, err := sendRequest[empty](ctx, "POST", clashModeEndpoint, m{Mode: mode}) - return err -} - -// clashModeHandler handles HTTP requests for getting or setting the Clash server mode. -func (s *Server) clashModeHandler(w http.ResponseWriter, req *http.Request) { - span := trace.SpanFromContext(req.Context()) - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - switch req.Method { - case "GET": - mode := cs.Mode() - span.SetAttributes(attribute.String("mode", mode)) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(m{Mode: mode}); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case "POST": - var mode m - if err := json.NewDecoder(req.Body).Decode(&mode); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - span.SetAttributes(attribute.String("mode", mode.Mode)) - slog.Log(nil, internal.LevelTrace, "Setting clash mode", "mode", mode.Mode) - cs.SetMode(mode.Mode) - w.WriteHeader(http.StatusOK) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} diff --git a/vpn/ipc/connections.go b/vpn/ipc/connections.go deleted file mode 100644 index 125c8017..00000000 --- a/vpn/ipc/connections.go +++ /dev/null @@ -1,126 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "net/http" - runtimeDebug "runtime/debug" - "time" - - "github.com/gofrs/uuid/v5" - "github.com/sagernet/sing-box/common/conntrack" - "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" -) - -// CloseConnections closes connections by their IDs. If connIDs is empty, all connections will be closed. -func CloseConnections(ctx context.Context, connIDs []string) error { - _, err := sendRequest[empty](ctx, "POST", closeConnectionsEndpoint, connIDs) - return err -} - -func (s *Server) closeConnectionHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var cids []string - err := json.NewDecoder(r.Body).Decode(&cids) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if len(cids) > 0 { - tm := s.service.ClashServer().TrafficManager() - for _, cid := range cids { - targetConn := tm.Connection(uuid.FromStringOrNil(cid)) - if targetConn == nil { - continue - } - targetConn.Close() - } - } else { - conntrack.Close() - } - go func() { - time.Sleep(time.Second) - runtimeDebug.FreeOSMemory() - }() - w.WriteHeader(http.StatusOK) -} - -// GetConnections retrieves the list of current and recently closed connections. -func GetConnections(ctx context.Context) ([]Connection, error) { - return sendRequest[[]Connection](ctx, "GET", connectionsEndpoint, nil) -} - -func (s *Server) connectionsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - w.Header().Set("Content-Type", "application/json") - tm := s.service.ClashServer().TrafficManager() - activeConns := tm.Connections() - closedConns := tm.ClosedConnections() - connections := make([]Connection, 0, len(activeConns)+len(closedConns)) - for _, connection := range activeConns { - connections = append(connections, newConnection(connection)) - } - for _, connection := range closedConns { - connections = append(connections, newConnection(connection)) - } - if err := json.NewEncoder(w).Encode(connections); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// Connection represents a network connection with relevant metadata. -type Connection struct { - ID string - Inbound string - IPVersion int - Network string - Source string - Destination string - Domain string - Protocol string - FromOutbound string - CreatedAt int64 - ClosedAt int64 - Uplink int64 - Downlink int64 - Rule string - Outbound string - ChainList []string -} - -func newConnection(metadata trafficontrol.TrackerMetadata) Connection { - var rule string - if metadata.Rule != nil { - rule = metadata.Rule.String() + " => " + metadata.Rule.Action().String() - } - var closedAt int64 - if !metadata.ClosedAt.IsZero() { - closedAt = metadata.ClosedAt.UnixMilli() - } - md := metadata.Metadata - return Connection{ - ID: metadata.ID.String(), - Inbound: md.InboundType + "/" + md.Inbound, - IPVersion: int(md.IPVersion), - Network: md.Network, - Source: md.Source.String(), - Destination: md.Destination.String(), - Domain: md.Domain, - Protocol: md.Protocol, - FromOutbound: md.Outbound, - CreatedAt: metadata.CreatedAt.UnixMilli(), - ClosedAt: closedAt, - Uplink: metadata.Upload.Load(), - Downlink: metadata.Download.Load(), - Rule: rule, - Outbound: metadata.OutboundType + "/" + metadata.Outbound, - ChainList: metadata.Chain, - } -} diff --git a/vpn/ipc/endpoints.go b/vpn/ipc/endpoints.go deleted file mode 100644 index b55c43d2..00000000 --- a/vpn/ipc/endpoints.go +++ /dev/null @@ -1,19 +0,0 @@ -package ipc - -const ( - statusEndpoint = "/status" - metricsEndpoint = "/metrics" - startServiceEndpoint = "/service/start" - stopServiceEndpoint = "/service/stop" - restartServiceEndpoint = "/service/restart" - groupsEndpoint = "/groups" - selectEndpoint = "/outbound/select" - activeEndpoint = "/outbound/active" - updateOutboundsEndpoint = "/outbound/update" - addOutboundsEndpoint = "/outbound/add" - removeOutboundsEndpoint = "/outbound/remove" - clashModeEndpoint = "/clash/mode" - connectionsEndpoint = "/connections" - closeConnectionsEndpoint = "/connections/close" - setSettingsPathEndpoint = "/set" -) diff --git a/vpn/ipc/group.go b/vpn/ipc/group.go deleted file mode 100644 index 48ede66a..00000000 --- a/vpn/ipc/group.go +++ /dev/null @@ -1,83 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "net/http" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing/service" -) - -// GetGroups retrieves the list of group outbounds. -func GetGroups(ctx context.Context) ([]OutboundGroup, error) { - return sendRequest[[]OutboundGroup](ctx, "GET", groupsEndpoint, nil) -} - -func (s *Server) groupHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - groups, err := getGroups(s.service.Ctx()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(groups); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// OutboundGroup represents a group of outbounds. -type OutboundGroup struct { - Tag string - Type string - Selected string - Outbounds []Outbounds -} - -// Outbounds represents outbounds within a group. -type Outbounds struct { - Tag string - Type string -} - -func getGroups(ctx context.Context) ([]OutboundGroup, error) { - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) - if outboundMgr == nil { - return nil, errors.New("outbound manager not found") - } - outbounds := outboundMgr.Outbounds() - var iGroups []adapter.OutboundGroup - for _, it := range outbounds { - if group, isGroup := it.(adapter.OutboundGroup); isGroup { - iGroups = append(iGroups, group) - } - } - var groups []OutboundGroup - for _, iGroup := range iGroups { - group := OutboundGroup{ - Tag: iGroup.Tag(), - Type: iGroup.Type(), - Selected: iGroup.Now(), - } - for _, itemTag := range iGroup.All() { - itemOutbound, isLoaded := outboundMgr.Outbound(itemTag) - if !isLoaded { - continue - } - - item := Outbounds{ - Tag: itemTag, - Type: itemOutbound.Type(), - } - group.Outbounds = append(group.Outbounds, item) - } - groups = append(groups, group) - } - return groups, nil -} diff --git a/vpn/ipc/http.go b/vpn/ipc/http.go deleted file mode 100644 index f2af307d..00000000 --- a/vpn/ipc/http.go +++ /dev/null @@ -1,74 +0,0 @@ -package ipc - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/traces" -) - -const tracerName = "github.com/getlantern/radiance/vpn/ipc" - -// empty is a placeholder type for requests that do not expect a response body. -type empty struct{} - -// sendRequest sends an HTTP request to the specified endpoint with the given method and data. -func sendRequest[T any](ctx context.Context, method, endpoint string, data any) (T, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "vpn.ipc", - trace.WithAttributes(attribute.String("endpoint", endpoint)), - ) - defer span.End() - - buf, err := json.Marshal(data) - var res T - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("failed to marshal payload: %w", err)) - } - req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bytes.NewReader(buf)) - if err != nil { - return res, err - } - client := &http.Client{ - Transport: &http.Transport{ - DialContext: dialContext, - }, - } - resp, err := client.Do(req) - if errors.Is(err, os.ErrNotExist) { - err = ErrIPCNotRunning - } - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("request failed: %w", err)) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return res, traces.RecordError(ctx, readErrorResponse(resp)) - } - if _, ok := any(&res).(*empty); ok { - return res, nil - } - - err = json.NewDecoder(resp.Body).Decode(&res) - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("failed to decode response: %w", err)) - } - return res, nil -} - -func readErrorResponse(resp *http.Response) error { - buf, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read error response body: %w, status: %s", err, resp.Status) - } - return fmt.Errorf("%s: %s", resp.Status, buf) -} diff --git a/vpn/ipc/outbound.go b/vpn/ipc/outbound.go deleted file mode 100644 index 88981f12..00000000 --- a/vpn/ipc/outbound.go +++ /dev/null @@ -1,254 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - runtimeDebug "runtime/debug" - "time" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/conntrack" - "github.com/sagernet/sing/service" - - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/servers" -) - -type selection struct { - GroupTag string `json:"groupTag"` - OutboundTag string `json:"outboundTag"` -} - -// SelectOutbound selects an outbound within a group. -func SelectOutbound(ctx context.Context, groupTag, outboundTag string) error { - _, err := sendRequest[empty](ctx, "POST", selectEndpoint, selection{groupTag, outboundTag}) - return err -} - -func (s *Server) selectHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var p selection - err := json.NewDecoder(r.Body).Decode(&p) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - defer func() { - if r := recover(); r != nil { - http.Error(w, fmt.Sprint(r), http.StatusInternalServerError) - } - }() - slog.Log(nil, internal.LevelTrace, "selecting outbound", "group", p.GroupTag, "outbound", p.OutboundTag) - outbound, err := getGroupOutbound(s.service.Ctx(), p.GroupTag) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - selector, isSelector := outbound.(Selector) - if !isSelector { - http.Error(w, fmt.Sprintf("outbound %q is not a selector", p.GroupTag), http.StatusBadRequest) - return - } - slog.Log(nil, internal.LevelTrace, "setting outbound", "outbound", p.OutboundTag) - if !selector.SelectOutbound(p.OutboundTag) { - http.Error(w, fmt.Sprintf("outbound %q not found in group", p.OutboundTag), http.StatusBadRequest) - return - } - cs := s.service.ClashServer() - if mode := cs.Mode(); mode != p.GroupTag { - slog.Log(nil, internal.LevelDebug, "changing clash mode", "new", p.GroupTag, "old", mode) - s.service.ClashServer().SetMode(p.GroupTag) - conntrack.Close() - go func() { - time.Sleep(time.Second) - runtimeDebug.FreeOSMemory() - }() - } - w.WriteHeader(http.StatusOK) -} - -// Selector is helper interface to check if an outbound is a selector or wrapper of selector. -type Selector interface { - adapter.OutboundGroup - SelectOutbound(tag string) bool -} - -// GetSelected retrieves the currently selected outbound and its group. -func GetSelected(ctx context.Context) (group, tag string, err error) { - res, err := sendRequest[selection](ctx, "GET", selectEndpoint, nil) - if err != nil { - return "", "", err - } - return res.GroupTag, res.OutboundTag, nil -} - -func (s *Server) selectedHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - mode := cs.Mode() - selector, err := getGroupOutbound(s.service.Ctx(), mode) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - res := selection{ - GroupTag: mode, - OutboundTag: selector.Now(), - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(res); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// GetActiveOutbound retrieves the outbound that is actively being used, resolving nested groups -// if necessary. -func GetActiveOutbound(ctx context.Context) (group, tag string, err error) { - res, err := sendRequest[selection](ctx, "GET", activeEndpoint, nil) - if err != nil { - return "", "", err - } - return res.GroupTag, res.OutboundTag, nil -} - -func (s *Server) activeOutboundHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - mode := cs.Mode() - group, err := getGroupOutbound(s.service.Ctx(), mode) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - tag := group.Now() - // if the selected outbound is also a group, retrieve its selected outbound - // continue until we reach a non-group outbound - for { - group, err = getGroupOutbound(s.service.Ctx(), tag) - if err != nil { - break - } - tag = group.Now() - } - if tag == "" { - tag = "unavailable" - } - res := selection{ - GroupTag: mode, - OutboundTag: tag, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(res); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -func getGroupOutbound(ctx context.Context, tag string) (adapter.OutboundGroup, error) { - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) - if outboundMgr == nil { - return nil, errors.New("outbound manager not found") - } - - outbound, loaded := outboundMgr.Outbound(tag) - if !loaded { - return nil, fmt.Errorf("group not found: %s", tag) - } - group, isGroup := outbound.(adapter.OutboundGroup) - if !isGroup { - return nil, fmt.Errorf("outbound is not a group: %s", tag) - } - return group, nil -} - -func UpdateOutbounds(ctx context.Context, servers servers.Servers) error { - _, err := sendRequest[empty](ctx, "POST", updateOutboundsEndpoint, servers) - return err -} - -func (s *Server) updateOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var data servers.Servers - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - slog.Debug("Updating outbounds") - if err := s.service.UpdateOutbounds(data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) -} - -type newOutbounds struct { - Group string `json:"group"` - Servers servers.Options `json:"servers"` -} - -func AddOutbounds(ctx context.Context, group string, servers servers.Options) error { - _, err := sendRequest[empty](ctx, "POST", addOutboundsEndpoint, newOutbounds{Group: group, Servers: servers}) - return err -} - -func (s *Server) addOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var data newOutbounds - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - slog.Debug("Adding outbounds", "group", data.Group) - if err := s.service.AddOutbounds(data.Group, data.Servers); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) -} - -type outboundsToRemove struct { - Group string `json:"group"` - Tags []string `json:"tags"` -} - -func RemoveOutbounds(ctx context.Context, group string, tags []string) error { - _, err := sendRequest[empty](ctx, "POST", removeOutboundsEndpoint, outboundsToRemove{Group: group, Tags: tags}) - return err -} - -func (s *Server) removeOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var data outboundsToRemove - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := s.service.RemoveOutbounds(data.Group, data.Tags); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) -} diff --git a/vpn/ipc/server.go b/vpn/ipc/server.go deleted file mode 100644 index 187aa57d..00000000 --- a/vpn/ipc/server.go +++ /dev/null @@ -1,263 +0,0 @@ -// Package ipc implements the IPC server for communicating between the client and the VPN service. -// It provides HTTP endpoints for retrieving statistics, managing groups, selecting outbounds, -// changing modes, and closing connections. -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net" - "net/http" - "sync/atomic" - "time" - - "github.com/go-chi/chi/v5" - "github.com/sagernet/sing-box/experimental/clashapi" - "go.opentelemetry.io/otel" - - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/servers" -) - -var ( - ErrServiceIsNotReady = errors.New("service is not ready") - ErrIPCNotRunning = errors.New("IPC not running") -) - -// Service defines the interface that the IPC server uses to interact with the underlying VPN service. -type Service interface { - Ctx() context.Context - Status() VPNStatus - Start(ctx context.Context, options string) error - Restart(ctx context.Context, options string) error - Close() error - ClashServer() *clashapi.Server - UpdateOutbounds(options servers.Servers) error - AddOutbounds(group string, options servers.Options) error - RemoveOutbounds(group string, tags []string) error -} - -// Server represents the IPC server that communicates over a Unix domain socket for Unix-like -// systems, and a named pipe for Windows. -type Server struct { - svr *http.Server - service Service - router chi.Router - vpnStatus atomic.Value // string - closed atomic.Bool -} - -// StatusUpdateEvent is emitted when the VPN status changes. -type StatusUpdateEvent struct { - events.Event - Status VPNStatus - Error error -} - -type VPNStatus string - -// Possible VPN statuses -const ( - Connected VPNStatus = "connected" - Disconnected VPNStatus = "disconnected" - Connecting VPNStatus = "connecting" - Disconnecting VPNStatus = "disconnecting" - ErrorStatus VPNStatus = "error" -) - -func (vpn *VPNStatus) String() string { - return string(*vpn) -} - -// NewServer creates a new Server instance with the provided Service. -func NewServer(service Service) *Server { - s := &Server{ - service: service, - router: chi.NewMux(), - } - s.vpnStatus.Store(Disconnected) - s.router.Use(log, tracer) - - // Only add auth middleware if not running on mobile, since mobile platforms have their own - // sandboxing and permission models. - addAuth := !common.IsMobile() && !_testing - if addAuth { - s.router.Use(authPeer) - } - - s.router.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - s.router.Get(statusEndpoint, s.statusHandler) - s.router.Get(metricsEndpoint, s.metricsHandler) - s.router.Get(groupsEndpoint, s.groupHandler) - s.router.Get(connectionsEndpoint, s.connectionsHandler) - s.router.Get(selectEndpoint, s.selectedHandler) - s.router.Get(activeEndpoint, s.activeOutboundHandler) - s.router.Post(selectEndpoint, s.selectHandler) - s.router.Get(clashModeEndpoint, s.clashModeHandler) - s.router.Post(clashModeEndpoint, s.clashModeHandler) - s.router.Post(startServiceEndpoint, s.startServiceHandler) - s.router.Post(stopServiceEndpoint, s.stopServiceHandler) - s.router.Post(restartServiceEndpoint, s.restartServiceHandler) - s.router.Post(updateOutboundsEndpoint, s.updateOutboundsHandler) - s.router.Post(addOutboundsEndpoint, s.addOutboundsHandler) - s.router.Post(removeOutboundsEndpoint, s.removeOutboundsHandler) - s.router.Post(closeConnectionsEndpoint, s.closeConnectionHandler) - - svr := &http.Server{ - Handler: s.router, - ReadTimeout: time.Second * 5, - WriteTimeout: time.Second * 5, - } - if addAuth { - svr.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - peer, err := getConnPeer(c) - if err != nil { - slog.Error("Failed to get peer credentials", "error", err) - } - return contextWithUsr(ctx, peer) - } - } - s.svr = svr - return s -} - -// Start begins listening for incoming IPC requests. -func (s *Server) Start() error { - if s.closed.Load() { - return errors.New("IPC server is closed") - } - l, err := listen() - if err != nil { - return fmt.Errorf("IPC server: listen: %w", err) - } - go func() { - slog.Info("IPC server started", "address", l.Addr().String()) - err := s.svr.Serve(l) - if err != nil && err != http.ErrServerClosed { - slog.Error("IPC server", "error", err) - } - s.closed.Store(true) - if s.service.Status() != Disconnected { - slog.Warn("IPC server stopped unexpectedly, closing service") - s.service.Close() - s.setVPNStatus(ErrorStatus, errors.New("IPC server stopped unexpectedly")) - } - }() - - return nil -} - -// Close shuts down the IPC server. -func (s *Server) Close() error { - if s.closed.Swap(true) { - return nil - } - defer s.service.Close() - - slog.Info("Closing IPC server") - return s.svr.Close() -} - -func (s *Server) IsClosed() bool { - return s.closed.Load() -} - -type opts struct { - Options string `json:"options"` -} - -// StartService sends a request to start the service -func StartService(ctx context.Context, options string) error { - _, err := sendRequest[empty](ctx, "POST", startServiceEndpoint, opts{Options: options}) - return err -} - -func (s *Server) startServiceHandler(w http.ResponseWriter, r *http.Request) { - ctx, span := otel.Tracer(tracerName).Start(r.Context(), "ipc.Server.StartService") - defer span.End() - switch s.service.Status() { - case Disconnected: - // proceed to start - case Connected: - w.WriteHeader(http.StatusOK) - return - case Disconnecting: - http.Error(w, "service is disconnecting, please wait", http.StatusConflict) - return - default: - http.Error(w, "service is already starting", http.StatusConflict) - return - } - var p opts - if err := json.NewDecoder(r.Body).Decode(&p); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - s.setVPNStatus(Connecting, nil) - if err := s.service.Start(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - s.setVPNStatus(Connected, nil) - w.WriteHeader(http.StatusOK) -} - -// StopService sends a request to stop the service (IPC server stays up) -func StopService(ctx context.Context) error { - _, err := sendRequest[empty](ctx, "POST", stopServiceEndpoint, nil) - return err -} - -func (s *Server) stopServiceHandler(w http.ResponseWriter, r *http.Request) { - slog.Debug("Received request to stop service via IPC") - s.setVPNStatus(Disconnecting, nil) - if err := s.service.Close(); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.setVPNStatus(Disconnected, nil) - w.WriteHeader(http.StatusOK) -} - -func RestartService(ctx context.Context, options string) error { - _, err := sendRequest[empty](ctx, "POST", restartServiceEndpoint, opts{Options: options}) - return err -} - -func (s *Server) restartServiceHandler(w http.ResponseWriter, r *http.Request) { - ctx, span := otel.Tracer(tracerName).Start(r.Context(), "ipc.Server.restartServiceHandler") - defer span.End() - - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusInternalServerError) - return - } - var p opts - if err := json.NewDecoder(r.Body).Decode(&p); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - s.setVPNStatus(Disconnected, nil) - if err := s.service.Restart(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.setVPNStatus(Connected, nil) - w.WriteHeader(http.StatusOK) -} - -func (s *Server) setVPNStatus(status VPNStatus, err error) { - s.vpnStatus.Store(status) - events.Emit(StatusUpdateEvent{Status: status, Error: err}) -} diff --git a/vpn/ipc/status.go b/vpn/ipc/status.go deleted file mode 100644 index aef35029..00000000 --- a/vpn/ipc/status.go +++ /dev/null @@ -1,99 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "runtime" - - "github.com/sagernet/sing-box/common/conntrack" - "github.com/sagernet/sing/common/memory" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -// Metrics represents the runtime metrics of the service. -type Metrics struct { - Memory uint64 - Goroutines int - Connections int - - // UplinkTotal and DownlinkTotal are only available when the service is running and there are - // active connections. - // In bytes. - UplinkTotal int64 - // In bytes. - DownlinkTotal int64 -} - -// GetMetrics retrieves the current runtime metrics of the service. -func GetMetrics(ctx context.Context) (Metrics, error) { - return sendRequest[Metrics](ctx, "GET", metricsEndpoint, nil) -} - -func (s *Server) metricsHandler(w http.ResponseWriter, r *http.Request) { - _, span := otel.Tracer(tracerName).Start(r.Context(), "server.metricsHandler") - defer span.End() - stats := Metrics{ - Memory: memory.Inuse(), - Goroutines: runtime.NumGoroutine(), - Connections: conntrack.Count(), - } - if s.service.Status() == Connected { - up, down := s.service.ClashServer().TrafficManager().Total() - stats.UplinkTotal, stats.DownlinkTotal = up, down - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(stats); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -type state struct { - State VPNStatus `json:"state"` -} - -// GetStatus retrieves the current status of the service. -func GetStatus(ctx context.Context) (VPNStatus, error) { - // try to dial first to check if IPC server is even running and avoid waiting for timeout - if canDial, err := tryDial(ctx); !canDial { - return Disconnected, err - } - - res, err := sendRequest[state](ctx, "GET", statusEndpoint, nil) - if errors.Is(err, ErrIPCNotRunning) || errors.Is(err, ErrServiceIsNotReady) { - return Disconnected, nil - } - if err != nil { - return "", fmt.Errorf("error getting status: %w", err) - } - return res.State, nil -} - -func tryDial(ctx context.Context) (bool, error) { - conn, err := dialContext(ctx, "", "") - if err == nil { - conn.Close() - return true, nil - } - if errors.Is(err, os.ErrNotExist) { - return false, nil // IPC server is not running so don't treat it as an error - } - return false, err -} - -func (s *Server) statusHandler(w http.ResponseWriter, r *http.Request) { - span := trace.SpanFromContext(r.Context()) - status := s.service.Status() - span.SetAttributes(attribute.String("status", string(status))) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(state{status}); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} diff --git a/vpn/rvpn/platform.go b/vpn/rvpn/platform.go deleted file mode 100644 index 72275218..00000000 --- a/vpn/rvpn/platform.go +++ /dev/null @@ -1,9 +0,0 @@ -package rvpn - -import "github.com/sagernet/sing-box/experimental/libbox" - -type PlatformInterface interface { - libbox.PlatformInterface - RestartService() error - PostServiceClose() -} diff --git a/vpn/service.go b/vpn/service.go deleted file mode 100644 index 4df2f133..00000000 --- a/vpn/service.go +++ /dev/null @@ -1,223 +0,0 @@ -package vpn - -import ( - "context" - "errors" - "fmt" - "io" - "log/slog" - "os" - "path/filepath" - "runtime" - "sync" - - "github.com/sagernet/sing-box/experimental/clashapi" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" - "github.com/getlantern/radiance/vpn/rvpn" -) - -var _ ipc.Service = (*TunnelService)(nil) - -// TunnelService manages the lifecycle of the VPN tunnel. -type TunnelService struct { - tunnel *tunnel - - platformIfce rvpn.PlatformInterface - logger *slog.Logger - - mu sync.Mutex -} - -// NewTunnelService creates a new TunnelService instance with the provided configuration paths, log -// level, and platform interface. -func NewTunnelService(dataPath string, logger *slog.Logger, platformIfce rvpn.PlatformInterface) *TunnelService { - if logger == nil { - logger = slog.Default() - } - switch logger.Handler().(type) { - case *slog.TextHandler, *slog.JSONHandler: - default: - os.MkdirAll(dataPath, 0o755) - path := filepath.Join(dataPath, "radiance_vpn.log") - var writer io.Writer - f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - slog.Error("Failed to open log file", "error", err) - writer = os.Stdout - } else { - writer = f - } - logger = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{AddSource: true, Level: internal.LevelTrace})) - runtime.AddCleanup(logger, func(file *os.File) { - file.Close() - }, f) - } - return &TunnelService{ - platformIfce: platformIfce, - logger: logger, - } -} - -// Start initializes and starts the tunnel with the specified options. Returns an error if the -// tunnel is already running or initialization fails. -func (s *TunnelService) Start(ctx context.Context, options string) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel != nil { - s.logger.Warn("tunnel already started") - return errors.New("tunnel already started") - } - s.logger.Debug("Starting tunnel", "options", options) - if err := s.start(ctx, options); err != nil { - return err - } - return nil -} - -func (s *TunnelService) start(ctx context.Context, options string) error { - path := settings.GetString(settings.DataPathKey) - t := tunnel{ - dataPath: path, - } - if err := t.start(options, s.platformIfce); err != nil { - return fmt.Errorf("failed to start tunnel: %w", err) - } - s.tunnel = &t - return nil -} - -// Close shuts down the currently running tunnel, if any. Returns an error if closing the tunnel fails. -func (s *TunnelService) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - if err := s.close(); err != nil { - return err - } - if s.platformIfce != nil { - s.platformIfce.PostServiceClose() - } - return nil -} - -func (s *TunnelService) close() error { - t := s.tunnel - s.tunnel = nil - - s.logger.Info("Closing tunnel") - if err := t.close(); err != nil { - return err - } - s.logger.Debug("Tunnel closed") - runtime.GC() - return nil -} - -// Restart closes and restarts the tunnel if it is currently running. Returns an error if the tunnel -// is not running or restart fails. -func (s *TunnelService) Restart(ctx context.Context, options string) error { - s.mu.Lock() - if s.tunnel == nil { - s.mu.Unlock() - return errors.New("tunnel not started") - } - if s.tunnel.Status() != ipc.Connected { - s.mu.Unlock() - return errors.New("tunnel not running") - } - - s.logger.Info("Restarting tunnel") - if s.platformIfce != nil { - s.mu.Unlock() - if err := s.platformIfce.RestartService(); err != nil { - s.logger.Error("Failed to restart tunnel via platform interface", "error", err) - return fmt.Errorf("platform interface restart failed: %w", err) - } - return nil - } - - defer s.mu.Unlock() - if err := s.close(); err != nil { - return fmt.Errorf("closing tunnel: %w", err) - } - if err := s.start(ctx, options); err != nil { - s.logger.Error("starting tunnel", "error", err) - return fmt.Errorf("starting tunnel: %w", err) - } - s.logger.Info("Tunnel restarted successfully") - return nil -} - -// Status returns the current status of the tunnel (e.g., running, closed). -func (s *TunnelService) Status() ipc.VPNStatus { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return ipc.Disconnected - } - return s.tunnel.Status() -} - -// Ctx returns the context associated with the tunnel, or nil if no tunnel is running. -func (s *TunnelService) Ctx() context.Context { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - return s.tunnel.ctx -} - -// ClashServer returns the Clash server instance associated with the tunnel, or nil if no tunnel is -// running. -func (s *TunnelService) ClashServer() *clashapi.Server { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - return s.tunnel.clashServer -} - -var errTunnelNotStarted = errors.New("tunnel not started") - -// activeTunnel returns the running tunnel or errTunnelNotStarted. -func (s *TunnelService) activeTunnel() (*tunnel, error) { - s.mu.Lock() - t := s.tunnel - s.mu.Unlock() - if t == nil { - return nil, errTunnelNotStarted - } - return t, nil -} - -func (s *TunnelService) UpdateOutbounds(newOpts servers.Servers) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.updateOutbounds(newOpts) -} - -func (s *TunnelService) AddOutbounds(group string, options servers.Options) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.addOutbounds(group, options) -} - -func (s *TunnelService) RemoveOutbounds(group string, tags []string) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.removeOutbounds(group, tags) -} diff --git a/vpn/split_tunnel.go b/vpn/split_tunnel.go index 8550de5b..7d589949 100644 --- a/vpn/split_tunnel.go +++ b/vpn/split_tunnel.go @@ -2,7 +2,6 @@ package vpn import ( "context" - "encoding/json" "errors" "fmt" "io/fs" @@ -16,17 +15,16 @@ import ( C "github.com/sagernet/sing-box/constant" O "github.com/sagernet/sing-box/option" - singjson "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) const ( splitTunnelTag = "split-tunnel" - splitTunnelFile = splitTunnelTag + ".json" + splitTunnelFile = internal.SplitTunnelFileName TypeDomain = "domain" TypeDomainSuffix = "domainSuffix" @@ -47,17 +45,18 @@ type SplitTunnel struct { ruleMap map[string]*O.DefaultHeadlessRule enabled *atomic.Bool access sync.Mutex + logger *slog.Logger } -func NewSplitTunnelHandler() (*SplitTunnel, error) { - s := newSplitTunnel(settings.GetString(settings.DataPathKey)) +func NewSplitTunnelHandler(dataPath string, logger *slog.Logger) (*SplitTunnel, error) { + s := newSplitTunnel(dataPath, logger) if err := s.loadRule(); err != nil { return nil, fmt.Errorf("loading split tunnel rule file %s: %w", s.ruleFile, err) } return s, nil } -func newSplitTunnel(path string) *SplitTunnel { +func newSplitTunnel(path string, logger *slog.Logger) *SplitTunnel { rule := defaultRule() s := &SplitTunnel{ rule: rule, @@ -65,24 +64,17 @@ func newSplitTunnel(path string) *SplitTunnel { activeFilter: &(rule.Rules[1].LogicalOptions), ruleMap: make(map[string]*O.DefaultHeadlessRule), enabled: &atomic.Bool{}, + logger: logger, } s.initRuleMap() if _, err := os.Stat(s.ruleFile); errors.Is(err, fs.ErrNotExist) { - slog.Debug("Creating initial split tunnel rule file", "file", s.ruleFile) + logger.Debug("Creating initial split tunnel rule file", "file", s.ruleFile) s.saveToFile() } return s } -func (s *SplitTunnel) Enable() error { - return s.setEnabled(true) -} - -func (s *SplitTunnel) Disable() error { - return s.setEnabled(false) -} - -func (s *SplitTunnel) setEnabled(enabled bool) error { +func (s *SplitTunnel) SetEnabled(enabled bool) error { if s.enabled.Load() == enabled { return nil } @@ -97,7 +89,7 @@ func (s *SplitTunnel) setEnabled(enabled bool) error { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } s.enabled.Store(enabled) - slog.Log(context.Background(), internal.LevelTrace, "Updated split-tunneling", "enabled", enabled) + s.logger.Log(context.Background(), log.LevelTrace, "Updated split-tunneling", "enabled", enabled) return nil } @@ -105,10 +97,10 @@ func (s *SplitTunnel) IsEnabled() bool { return s.enabled.Load() } -func (s *SplitTunnel) Filters() Filter { +func (s *SplitTunnel) Filters() SplitTunnelFilter { s.access.Lock() defer s.access.Unlock() - return Filter{ + return SplitTunnelFilter{ Domain: slices.Clone(s.ruleMap[TypeDomain].Domain), DomainSuffix: slices.Clone(s.ruleMap[TypeDomainSuffix].DomainSuffix), DomainKeyword: slices.Clone(s.ruleMap[TypeDomainKeyword].DomainKeyword), @@ -120,101 +112,12 @@ func (s *SplitTunnel) Filters() Filter { } } -// ItemsJSON returns the items for the given filter type as a JSON-encoded []string. -// It is safe to call from CGo callback stacks (uses RunOffCgoStack internally). -func (s *SplitTunnel) ItemsJSON(filterType string) (string, error) { - return common.RunOffCgoStack(func() (string, error) { - items, err := s.Filters().Items(filterType) - if err != nil { - return "", err - } - if items == nil { - items = []string{} - } - b, err := json.Marshal(items) - if err != nil { - return "", err - } - return string(b), nil - }) -} - -// EnabledAppsJSON returns all enabled app/process identifiers from the split -// tunnel configuration as a JSON-encoded []string. It first extracts values -// from the parsed rule set (current sing-box format with snake_case keys), -// then falls back to scanning the raw file for legacy camelCase keys. -// It is safe to call from CGo callback stacks. -func (s *SplitTunnel) EnabledAppsJSON() (string, error) { - return common.RunOffCgoStack(func() (string, error) { - seen := map[string]struct{}{} - out := make([]string, 0, 16) - isWindows := common.IsWindows() - - addString := func(str string) { - str = strings.TrimSpace(str) - if str == "" { - return - } - key := str - if isWindows { - key = strings.ToLower(str) - } - if _, exists := seen[key]; exists { - return - } - seen[key] = struct{}{} - out = append(out, str) - } - - addSlice := func(items []string) { - for _, str := range items { - addString(str) - } - } - - // Extract from the parsed rule set (current format). - f := s.Filters() - addSlice(f.ProcessPath) - addSlice(f.ProcessPathRegex) - addSlice(f.ProcessName) - addSlice(f.PackageName) - - // Fall back to legacy camelCase top-level keys in the raw file. - b, err := atomicfile.ReadFile(s.ruleFile) - if err == nil && len(b) > 0 { - m, parseErr := singjson.UnmarshalExtended[map[string]any](b) - if parseErr == nil { - legacyKeys := []string{ - "processPathRegex", "processPath", "packageName", - } - for _, k := range legacyKeys { - arr, ok := m[k].([]any) - if !ok { - continue - } - for _, it := range arr { - if str, ok := it.(string); ok { - addString(str) - } - } - } - } - } - - encoded, err := json.Marshal(out) - if err != nil { - return "", err - } - return string(encoded), nil - }) -} - // AddItem adds a new item to the filter of the given type. func (s *SplitTunnel) AddItem(filterType, item string) error { if err := s.updateFilter(filterType, item, merge); err != nil { return err } - slog.Debug("added item to filter", "filterType", filterType, "item", item) + s.logger.Debug("added item to filter", "filterType", filterType, "item", item) if err := s.saveToFile(); err != nil { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } @@ -226,7 +129,7 @@ func (s *SplitTunnel) RemoveItem(filterType, item string) error { if err := s.updateFilter(filterType, item, remove); err != nil { return err } - slog.Debug("removed item from filter", "filterType", filterType, "item", item) + s.logger.Debug("removed item from filter", "filterType", filterType, "item", item) if err := s.saveToFile(); err != nil { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } @@ -234,20 +137,20 @@ func (s *SplitTunnel) RemoveItem(filterType, item string) error { } // AddItems adds multiple items to the filter. -func (s *SplitTunnel) AddItems(items Filter) error { +func (s *SplitTunnel) AddItems(items SplitTunnelFilter) error { s.updateFilters(items, merge) - slog.Debug("added items to filter", "items", items.String()) + s.logger.Debug("added items to filter", "items", items.String()) return s.saveToFile() } // RemoveItems removes multiple items from the filter. -func (s *SplitTunnel) RemoveItems(items Filter) error { +func (s *SplitTunnel) RemoveItems(items SplitTunnelFilter) error { s.updateFilters(items, remove) - slog.Debug("removed items from filter", "items", items.String()) + s.logger.Debug("removed items from filter", "items", items.String()) return s.saveToFile() } -type Filter struct { +type SplitTunnelFilter struct { Domain []string DomainSuffix []string DomainKeyword []string @@ -258,31 +161,7 @@ type Filter struct { PackageName []string } -// Items returns the items for the given filter type. -func (f Filter) Items(filterType string) ([]string, error) { - switch filterType { - case TypeDomain: - return f.Domain, nil - case TypeDomainSuffix: - return f.DomainSuffix, nil - case TypeDomainKeyword: - return f.DomainKeyword, nil - case TypeDomainRegex: - return f.DomainRegex, nil - case TypeProcessName: - return f.ProcessName, nil - case TypeProcessPath: - return f.ProcessPath, nil - case TypeProcessPathRegex: - return f.ProcessPathRegex, nil - case TypePackageName: - return f.PackageName, nil - default: - return nil, fmt.Errorf("unsupported filter type: %s", filterType) - } -} - -func (f Filter) String() string { +func (f SplitTunnelFilter) String() string { var str []string if len(f.Domain) > 0 { str = append(str, fmt.Sprintf("domain: %v", f.Domain)) @@ -343,7 +222,7 @@ func (s *SplitTunnel) updateFilter(filterType string, item string, fn actionFn) return nil } -func (s *SplitTunnel) updateFilters(diff Filter, fn actionFn) { +func (s *SplitTunnel) updateFilters(diff SplitTunnelFilter, fn actionFn) { s.access.Lock() defer s.access.Unlock() @@ -400,7 +279,7 @@ func (s *SplitTunnel) saveToFile() error { }, }, } - buf, err := singjson.Marshal(rs) + buf, err := json.Marshal(rs) if err != nil { return fmt.Errorf("marshalling rule set: %w", err) } @@ -424,13 +303,13 @@ func (s *SplitTunnel) loadRule() error { if err != nil { return fmt.Errorf("reading rule file %s: %w", s.ruleFile, err) } - ruleSet, err := singjson.UnmarshalExtended[O.PlainRuleSetCompat](content) + ruleSet, err := json.UnmarshalExtended[O.PlainRuleSetCompat](content) if err != nil { return fmt.Errorf("unmarshalling rule file %s: %w", s.ruleFile, err) } rules := ruleSet.Options.Rules if len(rules) == 0 { - slog.Warn("split tunnel rule file format is invalid, using empty rule") + s.logger.Warn("split tunnel rule file format is invalid, using empty rule") return nil } @@ -446,7 +325,7 @@ func (s *SplitTunnel) loadRule() error { } else if len(s.rule.Rules) > 1 && s.rule.Rules[1].Type == C.RuleTypeDefault { // Migrate legacy format: wrap DefaultOptions into LogicalOptions // TODO(2/10): remove in future commit - slog.Debug("Migrating legacy split tunnel rule format") + s.logger.Debug("Migrating legacy split tunnel rule format") legacyRule := s.rule.Rules[1].DefaultOptions s.rule.Rules[1] = O.HeadlessRule{ Type: C.RuleTypeLogical, @@ -506,7 +385,7 @@ func (s *SplitTunnel) loadRule() error { s.initRuleMap() s.enabled.Store(s.rule.Mode == C.LogicalTypeOr) - slog.Log(context.Background(), internal.LevelTrace, "loaded split tunnel rules", + s.logger.Log(context.Background(), log.LevelTrace, "loaded split tunnel rules", "file", s.ruleFile, "filters", s.Filters().String(), "enabled", s.IsEnabled(), ) return nil diff --git a/vpn/split_tunnel_test.go b/vpn/split_tunnel_test.go index ba347c57..5aafd44c 100644 --- a/vpn/split_tunnel_test.go +++ b/vpn/split_tunnel_test.go @@ -2,7 +2,6 @@ package vpn import ( "context" - stdjson "encoding/json" "testing" "time" @@ -17,29 +16,22 @@ import ( "github.com/stretchr/testify/require" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" + rlog "github.com/getlantern/radiance/log" ) -func setupTestSplitTunnel(t *testing.T) *SplitTunnel { - testutil.SetPathsForTesting(t) - s := newSplitTunnel(settings.GetString(settings.DataPathKey)) - return s -} - func TestEnableDisableIsEnabled(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - if assert.NoError(t, st.Disable()) { + if assert.NoError(t, st.SetEnabled(false)) { assert.False(t, st.IsEnabled(), "split tunnel should be disabled") } - if assert.NoError(t, st.Enable()) { + if assert.NoError(t, st.SetEnabled(true)) { assert.True(t, st.IsEnabled(), "split tunnel should be enabled") } } func TestAddRemoveItem(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) domain := "example.com" domain2 := "example2.com" @@ -72,18 +64,18 @@ func TestAddRemoveItem(t *testing.T) { } func TestRemoveItems(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - require.NoError(t, st.RemoveItems(Filter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}})) + require.NoError(t, st.RemoveItems(SplitTunnelFilter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}})) f := st.Filters() assert.Empty(t, f.Domain) assert.Empty(t, f.ProcessName) } func TestAddRemoveItems(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - items := Filter{ + items := SplitTunnelFilter{ Domain: []string{"a.com", "b.com"}, DomainSuffix: []string{"suffix"}, ProcessName: []string{"proc"}, @@ -97,7 +89,7 @@ func TestAddRemoveItems(t *testing.T) { assert.Equal(t, []string{"proc"}, f.ProcessName) assert.Equal(t, []string{"pkg"}, f.PackageName) - err = st.RemoveItems(Filter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}}) + err = st.RemoveItems(SplitTunnelFilter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}}) require.NoError(t, err) f = st.Filters() assert.Equal(t, []string{"b.com"}, f.Domain) @@ -105,20 +97,21 @@ func TestAddRemoveItems(t *testing.T) { } func TestFilterPersistence(t *testing.T) { - st := setupTestSplitTunnel(t) + tmpDir := t.TempDir() + st := newSplitTunnel(tmpDir, rlog.NoOpLogger()) require.NoError(t, st.AddItem("domain", "example.com")) f := st.Filters() assert.Equal(t, []string{"example.com"}, f.Domain) - st = newSplitTunnel(settings.GetString(settings.DataPathKey)) + st = newSplitTunnel(tmpDir, rlog.NoOpLogger()) assert.NoError(t, st.loadRule()) f = st.Filters() assert.Equal(t, []string{"example.com"}, f.Domain, "expected filters to persist after reloading from file") } func TestUpdateFilterUnsupportedType(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) err := st.AddItem("unsupported", "foo") assert.Error(t, err) } @@ -143,7 +136,7 @@ func TestRemoveEdgeCases(t *testing.T) { } func TestMatch(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) require.NoError(t, st.AddItem("domain", "example.com")) ruleOpts := O.Rule{ @@ -191,7 +184,7 @@ func TestMatch(t *testing.T) { metadata := &adapter.InboundContext{Domain: "example.com"} rsStr := ruleSet.String() - require.NoError(t, st.Enable()) + require.NoError(t, st.SetEnabled(true)) require.Eventually(t, func() bool { return ruleSet.String() != rsStr }, time.Second, 50*time.Millisecond, "timed out waiting for rule reload") @@ -199,7 +192,7 @@ func TestMatch(t *testing.T) { assert.True(t, rule.Match(metadata), "rule should match when split tunnel is enabled") rsStr = ruleSet.String() - require.NoError(t, st.Disable()) + require.NoError(t, st.SetEnabled(false)) require.Eventually(t, func() bool { return ruleSet.String() != rsStr }, time.Second, 50*time.Millisecond, "timed out waiting for rule reload") @@ -217,7 +210,7 @@ func (r *mockRouter) RuleSet(tag string) (adapter.RuleSet, bool) { } func TestMigration(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) // Create a legacy format rule file legacyRule := O.LogicalHeadlessRule{ @@ -322,99 +315,3 @@ func TestMigration(t *testing.T) { rule, _ := json.UnmarshalExtended[O.LogicalHeadlessRule]([]byte(want)) assert.Equal(t, rule, st.rule) } - -// unmarshalItems is a test helper that unmarshals a JSON string into []string. -func unmarshalItems(t *testing.T, jsonStr string) []string { - t.Helper() - var items []string - require.NoError(t, stdjson.Unmarshal([]byte(jsonStr), &items)) - return items -} - -func TestItemsJSON(t *testing.T) { - st := setupTestSplitTunnel(t) - - t.Run("returns items for valid filter type", func(t *testing.T) { - require.NoError(t, st.AddItem(TypeDomain, "example.com")) - require.NoError(t, st.AddItem(TypeDomain, "test.org")) - - result, err := st.ItemsJSON(TypeDomain) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Equal(t, []string{"example.com", "test.org"}, items) - }) - - t.Run("returns empty array when no items", func(t *testing.T) { - result, err := st.ItemsJSON(TypeDomainKeyword) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Empty(t, items) - }) - - t.Run("returns error for unsupported filter type", func(t *testing.T) { - _, err := st.ItemsJSON("unsupported") - assert.Error(t, err) - assert.Contains(t, err.Error(), "unsupported filter type") - }) - - t.Run("returns items for package names", func(t *testing.T) { - require.NoError(t, st.AddItem(TypePackageName, "com.example.app")) - result, err := st.ItemsJSON(TypePackageName) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Equal(t, []string{"com.example.app"}, items) - }) -} - -func TestEnabledAppsJSON(t *testing.T) { - st := setupTestSplitTunnel(t) - - t.Run("returns empty array when no apps configured", func(t *testing.T) { - result, err := st.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Empty(t, items) - }) - - t.Run("returns apps from current format", func(t *testing.T) { - require.NoError(t, st.AddItem(TypePackageName, "com.example.app")) - require.NoError(t, st.AddItem(TypeProcessPath, "/usr/bin/firefox")) - - result, err := st.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Contains(t, items, "com.example.app") - assert.Contains(t, items, "/usr/bin/firefox") - }) - - t.Run("picks up legacy camelCase keys from raw file", func(t *testing.T) { - st2 := setupTestSplitTunnel(t) - require.NoError(t, st2.AddItem(TypePackageName, "com.current.app")) - - // Patch the file with legacy camelCase keys alongside current format - b, err := atomicfile.ReadFile(st2.ruleFile) - require.NoError(t, err) - var raw map[string]any - require.NoError(t, stdjson.Unmarshal(b, &raw)) - raw["packageName"] = []string{"com.legacy.app"} - raw["processPath"] = []string{"/opt/legacy"} - patched, err := stdjson.Marshal(raw) - require.NoError(t, err) - require.NoError(t, atomicfile.WriteFile(st2.ruleFile, patched, 0644)) - - result, err := st2.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Contains(t, items, "com.current.app") - assert.Contains(t, items, "com.legacy.app") - assert.Contains(t, items, "/opt/legacy") - // Deduplication: com.current.app should appear exactly once - count := 0 - for _, app := range items { - if app == "com.current.app" { - count++ - } - } - assert.Equal(t, 1, count, "com.current.app should appear exactly once") - }) -} diff --git a/vpn/testdata/boxopts.json b/vpn/testdata/boxopts.json index fae1b2e1..cc593d72 100644 --- a/vpn/testdata/boxopts.json +++ b/vpn/testdata/boxopts.json @@ -14,90 +14,44 @@ "type": "direct", "tag": "direct" }, - { - "type": "block", - "tag": "block" - }, { "type": "http", - "tag": "http1-out", + "tag": "http-out", "server": "127.0.0.1", "server_port": 4080 }, - { - "type": "http", - "tag": "http2-out", - "server": "127.0.0.1", - "server_port": 4443 - }, { "type": "socks", - "tag": "socks1-out", + "tag": "socks-out", "server": "127.0.0.1", "server_port": 5080 }, - { - "type": "socks", - "tag": "socks2-out", - "server": "127.0.0.1", - "server_port": 5443 - }, { "type": "mutableurltest", - "tag": "auto-http", + "tag": "auto", "outbounds": [ - "http1-out", - "http2-out" - ] - }, - { - "type": "mutableurltest", - "tag": "auto-socks", - "outbounds": [ - "socks1-out", - "socks2-out" + "http-out", + "socks-out" ] }, { "type": "mutableselector", - "tag": "http", + "tag": "manual", "outbounds": [ - "auto-http", - "http1-out", - "http2-out" - ] - }, - { - "type": "mutableselector", - "tag": "socks", - "outbounds": [ - "auto-socks", - "socks1-out", - "socks2-out" - ] - }, - { - "type": "mutableurltest", - "tag": "auto-all", - "outbounds": [ - "auto-http", - "auto-socks" + "http-out", + "socks-out" ] } ], "route": { "rules": [ { - "clash_mode": "direct", - "outbound": "direct" - }, - { - "clash_mode": "http", - "outbound": "http" + "clash_mode": "auto", + "outbound": "auto" }, { - "clash_mode": "socks", - "outbound": "socks" + "clash_mode": "manual", + "outbound": "manual" } ] }, @@ -107,7 +61,7 @@ "cache_id": "test_cache" }, "clash_api": { - "default_mode": "Rule" + "default_mode": "auto" } } } diff --git a/vpn/testdata/config.json b/vpn/testdata/config.json index 8519af94..ef0b67b8 100644 --- a/vpn/testdata/config.json +++ b/vpn/testdata/config.json @@ -1,58 +1,55 @@ { - "ConfigResponse": { - "smart_routing": [ + "smart_routing": [ + { + "category": "openai", + "rule_sets": [ + { + "tag": "openai", + "url": "https://ruleset.com/openai.srs" + } + ], + "outbounds": [ + "http1-out", + "socks1-out" + ] + } + ], + "ad_block": [ + { + "tag": "adblock-1", + "url": "https://ruleset.com/adblock-1.srs" + }, + { + "tag": "adblock-2", + "url": "https://ruleset.com/adblock-2.srs" + } + ], + "options": { + "outbounds": [ { - "category": "openai", - "rule_sets": [ - { - "tag": "openai", - "url": "https://ruleset.com/openai.srs" - } - ], - "outbounds": [ - "http1-out", - "socks1-out" - ] - } - ], - "ad_block": [ + "type": "http", + "tag": "http1-out", + "server": "127.0.0.1", + "server_port": 4080 + }, + { + "type": "http", + "tag": "http2-out", + "server": "127.0.0.1", + "server_port": 4443 + }, { - "tag": "adblock-1", - "url": "https://ruleset.com/adblock-1.srs" + "type": "socks", + "tag": "socks1-out", + "server": "127.0.0.1", + "server_port": 5080 }, { - "tag": "adblock-2", - "url": "https://ruleset.com/adblock-2.srs" + "type": "socks", + "tag": "socks2-out", + "server": "127.0.0.1", + "server_port": 5443 } - ], - "options": { - "outbounds": [ - { - "type": "http", - "tag": "http1-out", - "server": "127.0.0.1", - "server_port": 4080 - }, - { - "type": "http", - "tag": "http2-out", - "server": "127.0.0.1", - "server_port": 4443 - }, - { - "type": "socks", - "tag": "socks1-out", - "server": "127.0.0.1", - "server_port": 5080 - }, - { - "type": "socks", - "tag": "socks2-out", - "server": "127.0.0.1", - "server_port": 5443 - } - ] - } - }, - "PreferredLocation": {} + ] + } } diff --git a/vpn/tunnel.go b/vpn/tunnel.go index ef9f3bc4..52936007 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "path/filepath" + runtimeDebug "runtime/debug" "slices" "sync/atomic" "time" @@ -23,12 +24,12 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/events" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/urltest" + "github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/experimental/clashapi" "github.com/sagernet/sing-box/experimental/libbox" sblog "github.com/sagernet/sing-box/log" @@ -52,14 +53,21 @@ type tunnel struct { clientContextTracker *clientcontext.ClientContextInjector - status atomic.Value + status atomic.Value // VPNStatus cancel context.CancelFunc closers []io.Closer } -func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) error { - t.status.Store(ipc.Connecting) +func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) (err error) { + if t.status.Load() != Restarting { + t.setStatus(Connecting, nil) + } t.ctx, t.cancel = context.WithCancel(box.BaseContext()) + defer func() { + if err != nil { + t.setStatus(ErrorStatus, err) + } + }() if err := t.init(options, platformIfce); err != nil { t.close() @@ -72,13 +80,13 @@ func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) er slog.Error("Failed to connect tunnel", "error", err) return fmt.Errorf("connecting tunnel: %w", err) } - t.status.Store(ipc.Connected) + t.setStatus(Connected, nil) t.optsMap = makeOutboundOptsMap(t.ctx, options) return nil } func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) error { - slog.Log(nil, internal.LevelTrace, "Initializing tunnel") + slog.Log(nil, rlog.LevelTrace, "Initializing tunnel") // setup libbox service dataPath := t.dataPath @@ -93,7 +101,7 @@ func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) err setupOpts.FixAndroidStack = true } - slog.Log(nil, internal.LevelTrace, "Setting up libbox", "setup_options", setupOpts) + slog.Log(nil, rlog.LevelTrace, "Setting up libbox", "setup_options", setupOpts) if err := libbox.Setup(setupOpts); err != nil { return fmt.Errorf("setup libbox: %w", err) } @@ -101,7 +109,7 @@ func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) err t.logFactory = lblog.NewFactory(slog.Default().Handler()) service.MustRegister[sblog.Factory](t.ctx, t.logFactory) - slog.Log(nil, internal.LevelTrace, "Creating libbox service") + slog.Log(nil, rlog.LevelTrace, "Creating libbox service") lb, err := libbox.NewServiceWithContext(t.ctx, options, platformIfce) if err != nil { return fmt.Errorf("create libbox service: %w", err) @@ -118,10 +126,10 @@ func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) err t.closers = append(t.closers, lb) t.lbService = lb - history := service.PtrFromContext[urltest.HistoryStorage](t.ctx) - if err := loadURLTestHistory(history, filepath.Join(dataPath, urlTestHistoryFileName)); err != nil { - return fmt.Errorf("load urltest history: %w", err) - } + // history := service.PtrFromContext[urltest.HistoryStorage](t.ctx) + // if err := loadURLTestHistory(history, filepath.Join(dataPath, urlTestHistoryFileName)); err != nil { + // return fmt.Errorf("load urltest history: %w", err) + // } // set memory limit for Android and iOS switch common.Platform { @@ -146,16 +154,12 @@ func newClientContextInjector(outboundMgr adapter.OutboundManager, dataPath stri Version: common.Version, } } + // Outbound match bounds start empty and are populated when lantern servers are added via + // addOutbounds. Only lantern servers support client context tracking. matchBounds := clientcontext.MatchBounds{ Inbound: []string{"any"}, Outbound: []string{}, } - if outbound, exists := outboundMgr.Outbound(servers.SGLantern); exists { - // Note: this should only contain lantern outbounds with servers that support client context - // tracking. otherwise, the connections will fail. - tags := outbound.(adapter.OutboundGroup).All() - matchBounds.Outbound = append(tags, servers.SGLantern, groupAutoTag(servers.SGLantern)) - } return clientcontext.NewClientContextInjector(infoFn, matchBounds) } @@ -180,7 +184,7 @@ func newMutableGroupManager( } func (t *tunnel) connect() (err error) { - slog.Log(nil, internal.LevelTrace, "Starting libbox service") + slog.Log(nil, rlog.LevelTrace, "Starting libbox service") defer func() { if r := recover(); r != nil { @@ -208,25 +212,40 @@ func (t *tunnel) connect() (err error) { return nil } -func (t *tunnel) selectOutbound(group, tag string) error { - if status := t.Status(); status != ipc.Connected { +func (t *tunnel) selectMode(mode string) error { + if status := t.Status(); status != Connected { return fmt.Errorf("tunnel not running: status %v", status) } - t.clashServer.SetMode(group) - if tag == "" { - return nil + if t.clashServer.Mode() != mode { + t.clashServer.SetMode(mode) + conntrack.Close() + go func() { + time.Sleep(time.Second) + runtimeDebug.FreeOSMemory() + }() + } + return nil +} + +func (t *tunnel) selectOutbound(tag string) error { + if err := t.selectMode(ManualSelectTag); err != nil { + return err } + outboundMgr := service.FromContext[adapter.OutboundManager](t.ctx) - outbound, loaded := outboundMgr.Outbound(group) + outbound, loaded := outboundMgr.Outbound(ManualSelectTag) if !loaded { - return fmt.Errorf("selector group not found: %s", group) + return fmt.Errorf("manual select group not found") } - outbound.(ipc.Selector).SelectOutbound(tag) + outbound.(Selector).SelectOutbound(tag) return nil } func (t *tunnel) close() error { + if t.status.Load() != Restarting { + t.setStatus(Disconnecting, nil) + } if t.cancel != nil { t.cancel() } @@ -235,7 +254,7 @@ func (t *tunnel) close() error { go func() { var errs []error for _, closer := range t.closers { - slog.Log(nil, internal.LevelTrace, "Closing tunnel resource", "type", fmt.Sprintf("%T", closer)) + slog.Log(nil, rlog.LevelTrace, "Closing tunnel resource", "type", fmt.Sprintf("%T", closer)) errs = append(errs, closer.Close()) } done <- errors.Join(errs...) @@ -249,25 +268,36 @@ func (t *tunnel) close() error { t.closers = nil t.lbService = nil - t.status.Store(ipc.Disconnected) + if t.status.Load() != Restarting { + t.setStatus(Disconnected, nil) + } return err } -func (t *tunnel) Status() ipc.VPNStatus { - return t.status.Load().(ipc.VPNStatus) +func (t *tunnel) Status() VPNStatus { + return t.status.Load().(VPNStatus) +} + +func (t *tunnel) setStatus(status VPNStatus, err error) { + t.status.Store(status) + evt := StatusUpdateEvent{Status: status} + if err != nil { + evt.Error = err.Error() + } + events.Emit(evt) } var errLibboxClosed = errors.New("libbox closed") func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) { if len(options.Outbounds) == 0 && len(options.Endpoints) == 0 { - slog.Debug("No outbounds or endpoints to add", "group", group) + slog.Debug("No outbounds or endpoints to add") return nil } - slog.Info("Adding servers to group", "group", group, "tags", options.AllTags()) + slog.Info("Adding servers to group", "tags", options.AllTags()) // remove duplicates from newOpts before adding to avoid unnecessary reloads - newOptions := removeDuplicates(t.ctx, t.optsMap, options, group) + newOptions := removeDuplicates(t.ctx, t.optsMap, options) ctx := t.ctx router := service.FromContext[adapter.Router](ctx) @@ -277,32 +307,37 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) // preemptively merge the new lantern tags into the clientContextInjector match bounds to // capture any new connections before finished adding the servers. if tags := options.AllTags(); len(tags) > 0 { - slog.Log(nil, internal.LevelTrace, "Temporarily merging new lantern tags into ClientContextInjector") + slog.Log(nil, rlog.LevelTrace, "Temporarily merging new lantern tags into ClientContextInjector") matchBounds := t.clientContextTracker.MatchBounds() matchBounds.Outbound = append(matchBounds.Outbound, tags...) t.clientContextTracker.SetBounds(matchBounds) } defer func() { if !errors.Is(err, errLibboxClosed) { - t.updateClientContextTracker() + // Rebuild bounds from the full set of lantern tags currently in the + // ManualSelectTag group, rather than just the tags from this call. + mb := t.clientContextTracker.MatchBounds() + mb.Outbound = append(mb.Outbound, options.AllTags()...) + // Deduplicate: the preemptive merge above may have already added these tags. + slices.Sort(mb.Outbound) + mb.Outbound = slices.Compact(mb.Outbound) + t.clientContextTracker.SetBounds(mb) } }() } var ( mutGrpMgr = t.mutGrpMgr - autoTag = groupAutoTag(group) added = 0 ) - // for each outbound/endpoint in new add to group for _, outbound := range newOptions.Outbounds { logger := t.logFactory.NewLogger("outbound/" + outbound.Tag + "[" + outbound.Type + "]") err := mutGrpMgr.CreateOutboundForGroup( - ctx, router, logger, group, outbound.Tag, outbound.Type, outbound.Options, + ctx, router, logger, ManualSelectTag, outbound.Tag, outbound.Type, outbound.Options, ) if err == nil { - // add to urltest - err = mutGrpMgr.AddToGroup(autoTag, outbound.Tag) + // add to autoselect + err = mutGrpMgr.AddToGroup(AutoSelectTag, outbound.Tag) } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed @@ -323,11 +358,11 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) for _, endpoint := range newOptions.Endpoints { logger := t.logFactory.NewLogger("endpoint/" + endpoint.Tag + "[" + endpoint.Type + "]") err := mutGrpMgr.CreateEndpointForGroup( - ctx, router, logger, group, endpoint.Tag, endpoint.Type, endpoint.Options, + ctx, router, logger, ManualSelectTag, endpoint.Tag, endpoint.Type, endpoint.Options, ) if err == nil { - // add to urltest - err = mutGrpMgr.AddToGroup(autoTag, endpoint.Tag) + // add to autoselect + err = mutGrpMgr.AddToGroup(AutoSelectTag, endpoint.Tag) } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed @@ -343,32 +378,30 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) if len(options.URLOverrides) > 0 { slog.Info("Applying bandit URL overrides to URL test group", - "group", autoTag, "override_count", len(options.URLOverrides), ) } - if err := t.mutGrpMgr.SetURLOverrides(autoTag, options.URLOverrides); err != nil { - slog.Warn("Failed to set URL overrides", "group", autoTag, "error", err) + if err := t.mutGrpMgr.SetURLOverrides(AutoSelectTag, options.URLOverrides); err != nil { + slog.Warn("Failed to set URL overrides", "error", err) } else if len(options.URLOverrides) > 0 { // Trigger an immediate URL test cycle when we have bandit overrides so // callback probes are hit within seconds of config receipt rather than // waiting for the next scheduled interval (3 min). - if err := t.mutGrpMgr.CheckOutbounds(autoTag); err != nil { - slog.Warn("Failed to trigger immediate URL test after bandit overrides", "group", autoTag, "error", err) + if err := t.mutGrpMgr.CheckOutbounds(AutoSelectTag); err != nil { + slog.Warn("Failed to trigger immediate URL test after bandit overrides", "error", err) } else { - slog.Info("Triggered immediate URL test for bandit callbacks", "group", autoTag) + slog.Info("Triggered immediate URL test for bandit callbacks") } } - slog.Debug("Added servers to group", "group", group, "added", added) + slog.Debug("Added servers", "added", added) return errors.Join(errs...) } func (t *tunnel) removeOutbounds(group string, tags []string) error { var ( mutGrpMgr = t.mutGrpMgr - autoTag = groupAutoTag(group) - removed = 0 + removed []string errs []error ) for _, tag := range tags { @@ -377,10 +410,10 @@ func (t *tunnel) removeOutbounds(group string, tags []string) error { continue // skip nested urltests } } - err := mutGrpMgr.RemoveFromGroup(group, tag) + err := mutGrpMgr.RemoveFromGroup(ManualSelectTag, tag) if err == nil { // remove from urltest - err = mutGrpMgr.RemoveFromGroup(autoTag, tag) + err = mutGrpMgr.RemoveFromGroup(AutoSelectTag, tag) } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed @@ -389,87 +422,66 @@ func (t *tunnel) removeOutbounds(group string, tags []string) error { errs = append(errs, err) } else { t.optsMap.Delete(tag) - removed++ + removed = append(removed, tag) } } - if t.clientContextTracker != nil { - t.updateClientContextTracker() - } - slog.Debug("Removed servers from group", "group", group, "removed", removed) + if t.clientContextTracker != nil && group == servers.SGLantern { + mb := t.clientContextTracker.MatchBounds() + mb.Outbound = slices.DeleteFunc(mb.Outbound, func(s string) bool { + return slices.Contains(removed, s) + }) + t.clientContextTracker.SetBounds(clientcontext.MatchBounds{ + Inbound: []string{"any"}, + Outbound: mb.Outbound, + }) + } + slog.Debug("Removed servers", "removed", len(removed)) return errors.Join(errs...) } -func (t *tunnel) updateClientContextTracker() { - outboundMgr := service.FromContext[adapter.OutboundManager](t.ctx) - outbound, exists := outboundMgr.Outbound(servers.SGLantern) - if !exists { - return - } - outGroup := outbound.(adapter.OutboundGroup) - slog.Debug("Setting updated lantern tags into ClientContextInjector") - t.clientContextTracker.SetBounds(clientcontext.MatchBounds{ - Inbound: []string{"any"}, - Outbound: append(outGroup.All(), servers.SGLantern, groupAutoTag(servers.SGLantern)), - }) -} - -func (t *tunnel) updateOutbounds(new servers.Servers) error { +func (t *tunnel) updateOutbounds(group string, newOpts servers.Options) error { var errs []error - for _, group := range []string{servers.SGLantern, servers.SGUser} { - newOpts := new[group] - if len(newOpts.Outbounds) == 0 && len(newOpts.Endpoints) == 0 && len(newOpts.URLOverrides) == 0 { - slog.Debug("No outbounds, endpoints, or URL overrides to update, skipping", "group", group) - continue - } - slog.Log(nil, internal.LevelTrace, "Updating servers", "group", group) - - autoTag := groupAutoTag(group) - selector, selectorExists := t.mutGrpMgr.OutboundGroup(group) - _, urltestExists := t.mutGrpMgr.OutboundGroup(autoTag) - if !selectorExists || !urltestExists { - // Yes, panic. And, yes, it's intentional. Both selector and URLtest should always exist - // if the tunnel is running, so this is a "world no longer makes sense" situation. This - // should be caught during testing and will not panic in release builds. - slog.Log( - nil, internal.LevelPanic, "selector or urltest group missing", "group", group, - "selector_exists", selectorExists, "urltest_exists", urltestExists, - ) - panic(fmt.Errorf( - "selector or urltest group missing for %q. selector_exists=%v, urltest_exists=%v", - group, selectorExists, urltestExists, - )) - } + if len(newOpts.Outbounds) == 0 && len(newOpts.Endpoints) == 0 && len(newOpts.URLOverrides) == 0 { + slog.Debug("No outbounds, endpoints, or bandit overrides to update, skipping") + return nil + } + slog.Log(nil, rlog.LevelTrace, "Updating servers") - if contextDone(t.ctx) { - return t.ctx.Err() - } + selector, selectorExists := t.mutGrpMgr.OutboundGroup(ManualSelectTag) + _, urltestExists := t.mutGrpMgr.OutboundGroup(AutoSelectTag) + if !selectorExists || !urltestExists { + slog.Error("Selector or URL test group not found when updating outbounds") + return errors.New("selector or url test group not found") + } - // collect tags present in the current group but absent from the new config - newTags := newOpts.AllTags() - var toRemove []string - for _, tag := range selector.All() { - if !slices.Contains(newTags, tag) { - toRemove = append(toRemove, tag) - } - } + if contextDone(t.ctx) { + return t.ctx.Err() + } - if err := t.removeOutbounds(group, toRemove); errors.Is(err, errLibboxClosed) { - return err - } else if err != nil { - errs = append(errs, err) - } - if err := t.addOutbounds(group, newOpts); errors.Is(err, errLibboxClosed) { - return err - } else if err != nil { - errs = append(errs, err) + // collect current tags that are not in the new options + newTags := newOpts.AllTags() + var toRemove []string + for _, tag := range selector.All() { + if !slices.Contains(newTags, tag) { + toRemove = append(toRemove, tag) } } + if err := t.removeOutbounds(group, toRemove); errors.Is(err, errLibboxClosed) { + return err + } else if err != nil { + errs = append(errs, err) + } + if err := t.addOutbounds(group, newOpts); errors.Is(err, errLibboxClosed) { + return err + } else if err != nil { + errs = append(errs, err) + } return errors.Join(errs...) } -func removeDuplicates(ctx context.Context, curr *lsync.TypedMap[string, []byte], new servers.Options, group string) servers.Options { - slog.Log(nil, internal.LevelTrace, "Removing duplicate outbounds/endpoints", "group", group) +func removeDuplicates(ctx context.Context, curr *lsync.TypedMap[string, []byte], new servers.Options) servers.Options { + slog.Log(nil, rlog.LevelTrace, "Removing duplicate outbounds/endpoints") deduped := servers.Options{ Outbounds: []O.Outbound{}, Endpoints: []O.Endpoint{}, @@ -499,7 +511,7 @@ func removeDuplicates(ctx context.Context, curr *lsync.TypedMap[string, []byte], deduped.Locations[ep.Tag] = new.Locations[ep.Tag] } if len(dropped) > 0 { - slog.Log(nil, internal.LevelDebug, "Dropped duplicate outbounds/endpoints", "group", group, "tags", dropped) + slog.Debug("Dropped duplicate outbounds/endpoints", "tags", dropped) } return deduped } diff --git a/vpn/tunnel_test.go b/vpn/tunnel_test.go index 03dd03b2..1ddce9c5 100644 --- a/vpn/tunnel_test.go +++ b/vpn/tunnel_test.go @@ -1,163 +1,156 @@ package vpn import ( - "path/filepath" + "context" "testing" - "time" - sbA "github.com/sagernet/sing-box/adapter" - sbC "github.com/sagernet/sing-box/constant" - sbO "github.com/sagernet/sing-box/option" + lcommon "github.com/getlantern/common" + lsync "github.com/getlantern/common/sync" + box "github.com/getlantern/lantern-box" + O "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" - "github.com/sagernet/sing/service" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/getlantern/lantern-box/adapter" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" ) -func TestConnection(t *testing.T) { - testutil.SetPathsForTesting(t) - opts, optsStr, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to get test box options") +func TestTunnelStatus(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Disconnected) + assert.Equal(t, Disconnected, tun.Status()) - tmp := settings.GetString(settings.DataPathKey) + tun.setStatus(Connecting, nil) + assert.Equal(t, Connecting, tun.Status()) - opts.Route.RuleSet = baseOpts(settings.GetString(settings.DataPathKey)).Route.RuleSet - opts.Route.RuleSet[0].LocalOptions.Path = filepath.Join(tmp, splitTunnelFile) - opts.Route.Rules = append([]sbO.Rule{baseOpts(settings.GetString(settings.DataPathKey)).Route.Rules[2]}, opts.Route.Rules...) - newSplitTunnel(tmp) + tun.setStatus(Connected, nil) + assert.Equal(t, Connected, tun.Status()) +} - tun := &tunnel{ - dataPath: tmp, - } +func TestTunnelSetStatus_WithError(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Disconnected) - require.NoError(t, tun.start(optsStr, nil), "failed to establish connection") - t.Cleanup(func() { - tun.close() - }) + testErr := assert.AnError + tun.setStatus(ErrorStatus, testErr) + assert.Equal(t, ErrorStatus, tun.Status()) +} - require.Equal(t, ipc.Connected, tun.Status(), "tunnel should be running") +func TestTunnelClose_NoResources(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Connected) + err := tun.close() + assert.NoError(t, err) + assert.Equal(t, Disconnected, tun.Status()) + assert.Nil(t, tun.closers) + assert.Nil(t, tun.lbService) +} - assert.NoError(t, tun.selectOutbound("http", "http1-out"), "failed to select http outbound") - assert.NoError(t, tun.close(), "failed to close lbService") - assert.Equal(t, ipc.Disconnected, tun.Status(), "tun should be closed") +func TestTunnelClose_PreservesRestartingStatus(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Restarting) + err := tun.close() + assert.NoError(t, err) + assert.Equal(t, Restarting, tun.Status(), "close should not override Restarting status") } -func TestUpdateServers(t *testing.T) { - testutil.SetPathsForTesting(t) - testOpts, _, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to get test box options") +func TestTunnelClose_WithCancel(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Connected) + ctx, cancel := context.WithCancel(context.Background()) + tun.cancel = cancel - baseOuts := baseOpts(settings.GetString(settings.DataPathKey)).Outbounds - allOutbounds := map[string]sbO.Outbound{ - "direct": baseOuts[0], - "block": baseOuts[1], - } - for _, out := range testOpts.Outbounds { - switch out.Type { - case sbC.TypeHTTP, sbC.TypeSOCKS: - allOutbounds[out.Tag] = out - default: - } - } + err := tun.close() + assert.NoError(t, err) + assert.Error(t, ctx.Err(), "context should be cancelled after close") +} - lanternTags := []string{"http1-out", "http2-out", "socks1-out"} - userTags := []string{} - outs := []sbO.Outbound{ - allOutbounds["direct"], allOutbounds["block"], - allOutbounds["http1-out"], allOutbounds["http2-out"], allOutbounds["socks1-out"], - urlTestOutbound(autoLanternTag, lanternTags, nil), urlTestOutbound(autoUserTag, userTags, nil), - selectorOutbound(servers.SGLantern, append(lanternTags, autoLanternTag)), - selectorOutbound(servers.SGUser, append(userTags, autoUserTag)), - urlTestOutbound(autoAllTag, []string{autoLanternTag, autoUserTag}, nil), - } +type errCloser struct{ err error } + +func (c errCloser) Close() error { return c.err } + +func TestTunnelClose_CloserErrors(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Connected) + tun.closers = append(tun.closers, errCloser{err: assert.AnError}) - testOpts.Outbounds = outs - tun := testConnection(t, *testOpts) - defer func() { - tun.close() - }() + err := tun.close() + assert.ErrorIs(t, err, assert.AnError) +} + +func TestSelectMode_NotConnected(t *testing.T) { + tun := &tunnel{} + tun.status.Store(Disconnected) + err := tun.selectMode(AutoSelectTag) + require.Error(t, err) + assert.Contains(t, err.Error(), "tunnel not running") +} - time.Sleep(500 * time.Millisecond) +func TestRemoveDuplicates(t *testing.T) { + ctx := box.BaseContext() - err = tun.removeOutbounds(servers.SGLantern, []string{"http2-out", "socks1-out"}) - require.NoError(t, err, "failed to remove servers from lantern") + out1 := O.Outbound{Type: "http", Tag: "http-1", Options: &O.HTTPOutboundOptions{}} + out2 := O.Outbound{Type: "http", Tag: "http-2", Options: &O.HTTPOutboundOptions{}} + ep1 := O.Endpoint{Type: "wireguard", Tag: "wg-1", Options: &O.WireGuardEndpointOptions{}} + + // Build a current map with out1 and ep1. + var curr lsync.TypedMap[string, []byte] + b1, _ := json.MarshalContext(ctx, out1) + curr.Store(out1.Tag, b1) + bEp1, _ := json.MarshalContext(ctx, ep1) + curr.Store(ep1.Tag, bEp1) newOpts := servers.Options{ - Outbounds: []sbO.Outbound{ - allOutbounds["http1-out"], allOutbounds["socks2-out"], + Outbounds: []O.Outbound{out1, out2}, + Endpoints: []O.Endpoint{ep1}, + Locations: map[string]lcommon.ServerLocation{ + out1.Tag: {}, + out2.Tag: {}, + ep1.Tag: {}, }, } - err = tun.addOutbounds(servers.SGLantern, newOpts) - require.NoError(t, err, "failed to update servers for lantern") - time.Sleep(250 * time.Millisecond) + result := removeDuplicates(ctx, &curr, newOpts) - outboundMgr := service.FromContext[sbA.OutboundManager](tun.ctx) - require.NotNil(t, outboundMgr, "outbound manager should not be nil") + // out1 and ep1 are duplicates, only out2 should remain. + assert.Len(t, result.Outbounds, 1) + assert.Equal(t, "http-2", result.Outbounds[0].Tag) + assert.Empty(t, result.Endpoints) +} - groups := tun.mutGrpMgr.OutboundGroups() +func TestRemoveDuplicates_AllNew(t *testing.T) { + ctx := box.BaseContext() + var curr lsync.TypedMap[string, []byte] - want := map[string][]string{ - autoAllTag: {autoLanternTag, autoUserTag}, - servers.SGLantern: {autoLanternTag, "http1-out", "socks2-out"}, - autoLanternTag: {"http1-out", "socks2-out"}, - servers.SGUser: {autoUserTag}, - autoUserTag: {}, - } - got := make(map[string][]string) - allTags := []string{"direct", "block", autoAllTag, autoLanternTag, autoUserTag, servers.SGLantern, servers.SGUser} - for _, g := range groups { - tags := g.All() - got[g.Tag()] = tags - allTags = append(allTags, tags...) - } - for _, tag := range allTags { - if _, found := outboundMgr.Outbound(tag); !found { - assert.Failf(t, "outbound missing from outbound manager", "outbound %s not found", tag) - } - } - for group, tags := range want { - assert.ElementsMatchf(t, tags, got[group], "group %s does not have correct outbounds", group) - } -} + out1 := O.Outbound{Type: "http", Tag: "http-1", Options: &O.HTTPOutboundOptions{}} + out2 := O.Outbound{Type: "socks", Tag: "socks-1", Options: &O.SOCKSOutboundOptions{}} -func getGroups(outboundMgr sbA.OutboundManager) []adapter.MutableOutboundGroup { - outbounds := outboundMgr.Outbounds() - var iGroups []adapter.MutableOutboundGroup - for _, it := range outbounds { - if group, isGroup := it.(adapter.MutableOutboundGroup); isGroup { - iGroups = append(iGroups, group) - } + newOpts := servers.Options{ + Outbounds: []O.Outbound{out1, out2}, + Locations: map[string]lcommon.ServerLocation{ + out1.Tag: {}, + out2.Tag: {}, + }, } - return iGroups -} -func testConnection(t *testing.T, opts sbO.Options) *tunnel { - tmp := settings.GetString(settings.DataPathKey) + result := removeDuplicates(ctx, &curr, newOpts) + assert.Len(t, result.Outbounds, 2) +} - opts.Route.RuleSet = baseOpts(settings.GetString(settings.DataPathKey)).Route.RuleSet - opts.Route.RuleSet[0].LocalOptions.Path = filepath.Join(tmp, splitTunnelFile) - opts.Route.Rules = append([]sbO.Rule{baseOpts(settings.GetString(settings.DataPathKey)).Route.Rules[2]}, opts.Route.Rules...) - newSplitTunnel(tmp) +func TestRemoveDuplicates_Empty(t *testing.T) { + ctx := box.BaseContext() + var curr lsync.TypedMap[string, []byte] - tun := &tunnel{ - dataPath: tmp, - } + result := removeDuplicates(ctx, &curr, servers.Options{}) + assert.Empty(t, result.Outbounds) + assert.Empty(t, result.Endpoints) +} - options, _ := json.Marshal(opts) - err := tun.start(string(options), nil) - require.NoError(t, err, "failed to establish connection") - t.Cleanup(func() { - tun.close() - }) +func TestContextDone(t *testing.T) { + ctx := context.Background() + assert.False(t, contextDone(ctx)) - assert.Equal(t, ipc.Connected, tun.Status(), "tunnel should be running") - return tun + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.True(t, contextDone(ctx)) } diff --git a/vpn/types.go b/vpn/types.go new file mode 100644 index 00000000..43d6f551 --- /dev/null +++ b/vpn/types.go @@ -0,0 +1,86 @@ +package vpn + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" + + "github.com/getlantern/radiance/events" +) + +// StatusUpdateEvent is emitted when the VPN status changes. +type StatusUpdateEvent struct { + events.Event + Status VPNStatus `json:"status"` + Error string `json:"error,omitempty"` +} + +// Selector is helper interface to check if an outbound is a selector or wrapper of selector. +type Selector interface { + adapter.OutboundGroup + SelectOutbound(tag string) bool +} + +// OutboundGroup represents a group of outbounds. +type OutboundGroup struct { + Tag string + Type string + Selected string + Outbounds []Outbounds +} + +// Outbounds represents outbounds within a group. +type Outbounds struct { + Tag string + Type string +} + +// Connection represents a network connection with relevant metadata. +type Connection struct { + ID string + Inbound string + IPVersion int + Network string + Source string + Destination string + Domain string + Protocol string + FromOutbound string + CreatedAt int64 + ClosedAt int64 + Uplink int64 + Downlink int64 + Rule string + Outbound string + ChainList []string +} + +// NewConnection creates a Connection from tracker metadata. +func newConnection(metadata trafficontrol.TrackerMetadata) Connection { + var rule string + if metadata.Rule != nil { + rule = metadata.Rule.String() + " => " + metadata.Rule.Action().String() + } + var closedAt int64 + if !metadata.ClosedAt.IsZero() { + closedAt = metadata.ClosedAt.UnixMilli() + } + md := metadata.Metadata + return Connection{ + ID: metadata.ID.String(), + Inbound: md.InboundType + "/" + md.Inbound, + IPVersion: int(md.IPVersion), + Network: md.Network, + Source: md.Source.String(), + Destination: md.Destination.String(), + Domain: md.Domain, + Protocol: md.Protocol, + FromOutbound: md.Outbound, + CreatedAt: metadata.CreatedAt.UnixMilli(), + ClosedAt: closedAt, + Uplink: metadata.Upload.Load(), + Downlink: metadata.Download.Load(), + Rule: rule, + Outbound: metadata.OutboundType + "/" + metadata.Outbound, + ChainList: metadata.Chain, + } +} diff --git a/vpn/vpn.go b/vpn/vpn.go index 52506faa..8eda5dd3 100644 --- a/vpn/vpn.go +++ b/vpn/vpn.go @@ -5,13 +5,11 @@ package vpn import ( "context" - "encoding/json" "errors" "fmt" "log/slog" - "os" "path/filepath" - "slices" + "runtime" "strings" "sync" "time" @@ -19,6 +17,7 @@ import ( sbox "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/urltest" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental/libbox" "github.com/sagernet/sing-box/option" sbjson "github.com/sagernet/sing/common/json" @@ -30,356 +29,337 @@ import ( box "github.com/getlantern/lantern-box" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn/ipc" ) const ( tracerName = "github.com/getlantern/radiance/vpn" ) -func init() { - forwardToTunnel := func(action func(ctx context.Context) error, desc string) { - ctx := context.Background() - status, err := ipc.GetStatus(ctx) - if err != nil { - slog.Warn("Event received but failed to get tunnel status", "event", desc, "error", err) - return - } - if status != ipc.Connected { - return - } - if err := action(ctx); err != nil { - slog.Error("Failed to forward event to tunnel", "event", desc, "error", err) - } - } +var ( + ErrTunnelNotConnected = errors.New("tunnel not connected") + ErrTunnelAlreadyConnected = errors.New("tunnel already connected") +) - events.Subscribe(func(e servers.ServersUpdatedEvent) { - forwardToTunnel(func(ctx context.Context) error { - svrs := map[string]servers.Options{e.Group: *e.Options} - return ipc.UpdateOutbounds(ctx, svrs) - }, "servers-updated") - }) - events.Subscribe(func(e servers.ServersAddedEvent) { - forwardToTunnel(func(ctx context.Context) error { - return ipc.AddOutbounds(ctx, e.Group, *e.Options) - }, "servers-added") - }) - events.Subscribe(func(e servers.ServersRemovedEvent) { - forwardToTunnel(func(ctx context.Context) error { - return ipc.RemoveOutbounds(ctx, e.Group, []string{e.Tag}) - }, "servers-removed") - }) +type VPNStatus string + +// Possible VPN statuses +const ( + Connecting VPNStatus = "connecting" + Connected VPNStatus = "connected" + Disconnecting VPNStatus = "disconnecting" + Disconnected VPNStatus = "disconnected" + Restarting VPNStatus = "restarting" + ErrorStatus VPNStatus = "error" +) + +func (s *VPNStatus) String() string { + return string(*s) } -// Deprecated: Use AutoConnect instead with the desired group. -func QuickConnect(group string, _ libbox.PlatformInterface) (err error) { - return AutoConnect(group) +// VPNClient manages the lifecycle of the VPN tunnel. +type VPNClient struct { + tunnel *tunnel + + platformIfce PlatformInterface + logger *slog.Logger + + preTestCancel context.CancelFunc + preTestDone chan struct{} + + mu sync.RWMutex } -// AutoConnect automatically connects to the best available server in the specified group. Valid -// groups are [servers.ServerGroupLantern], [servers.ServerGroupUser], "all", or the empty string. -// Using "all" or the empty string will connect to the best available server across all groups. -func AutoConnect(group string) error { +type PlatformInterface interface { + libbox.PlatformInterface + RestartService() error + PostServiceClose() +} + +// NewVPNClient creates a new VPNClient instance with the provided configuration paths, log +// level, and platform interface. +func NewVPNClient(dataPath string, logger *slog.Logger, platformIfce PlatformInterface) *VPNClient { + if logger == nil { + logger = slog.Default() + } + _ = newSplitTunnel(dataPath, logger) + done := make(chan struct{}) + close(done) + return &VPNClient{ + platformIfce: platformIfce, + logger: logger, + preTestCancel: func() {}, + preTestDone: done, + } +} + +func (c *VPNClient) Connect(boxOptions BoxOptions) error { ctx, span := otel.Tracer(tracerName).Start( context.Background(), - "quick_connect", - trace.WithAttributes(attribute.String("group", group))) + "connect", + ) defer span.End() - switch group { - case servers.SGLantern: - return traces.RecordError(ctx, ConnectToServer(servers.SGLantern, autoLanternTag, nil)) - case servers.SGUser: - return traces.RecordError(ctx, ConnectToServer(servers.SGUser, autoUserTag, nil)) - case autoAllTag, "all", "": - if isOpen(ctx) { - if err := ipc.SetClashMode(ctx, autoAllTag); err != nil { - return fmt.Errorf("failed to set auto mode: %w", err) - } - return nil + c.mu.Lock() + // Cancel any running pre-start tests and wait for them to finish. If no tests are running, + // preTestCancel is a no-op and preTestDone is already closed (returns immediately). + c.preTestCancel() + done := c.preTestDone + c.mu.Unlock() + <-done + + c.mu.Lock() + defer c.mu.Unlock() + if c.tunnel != nil { + switch status := c.tunnel.Status(); status { + case Connected: + return ErrTunnelAlreadyConnected + case Restarting, Connecting, Disconnecting: + return fmt.Errorf("tunnel is currently %s", status) + case Disconnected, ErrorStatus: + // Clean up the stale tunnel so we can reconnect. + c.tunnel = nil + default: + return fmt.Errorf("tunnel is in unexpected state: %s", status) } - return traces.RecordError(ctx, connect(autoAllTag, "")) - default: - return traces.RecordError(ctx, fmt.Errorf("invalid group: %s", group)) } -} -// Deprecated: Use Connect instead with the desired group and tag. -func ConnectToServer(group, tag string, _ libbox.PlatformInterface) error { - return Connect(group, tag) + options, err := buildOptions(boxOptions) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to build options: %w", err)) + } + opts, err := sbjson.Marshal(options) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to marshal options: %w", err)) + } + return traces.RecordError(ctx, c.start(boxOptions.BasePath, string(opts))) } -// Connect connects to a specific server identified by the group and tag. Valid groups are -// [servers.SGLantern] and [servers.SGUser]. -func Connect(group, tag string) error { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "connect_to_server", - trace.WithAttributes( - attribute.String("group", group), - attribute.String("tag", tag))) - defer span.End() - - switch group { - case servers.SGLantern, servers.SGUser: - default: - return traces.RecordError(ctx, fmt.Errorf("invalid group: %s", group)) +func (c *VPNClient) start(path, options string) error { + c.logger.Debug("Starting tunnel", "options", options) + t := tunnel{ + dataPath: path, } - if tag == "" { - return traces.RecordError(ctx, errors.New("tag must be specified")) + if err := t.start(options, c.platformIfce); err != nil { + return fmt.Errorf("failed to start tunnel: %w", err) } - return traces.RecordError(ctx, connect(group, tag)) + c.tunnel = &t + return nil } -func connect(group, tag string) error { - ctx := context.Background() - if isOpen(ctx) { - return SelectServer(ctx, group, tag) +// Close shuts down the currently running tunnel, if any. Returns an error if closing the tunnel fails. +func (c *VPNClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.tunnel == nil { + return nil } - dataPath := settings.GetString(settings.DataPathKey) - _ = newSplitTunnel(dataPath) - options, err := getOptions() - if err != nil { + if err := c.close(); err != nil { return err } - if err := ipc.StartService(ctx, options); err != nil { - return err + if c.platformIfce != nil { + c.platformIfce.PostServiceClose() } - return SelectServer(ctx, group, tag) + return nil } -// Restart restarts the tunnel by reconnecting to the currently selected server. -func Restart() error { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "restart") - defer span.End() +func (c *VPNClient) close() error { + t := c.tunnel + c.tunnel = nil - options, err := getOptions() - if err != nil { + c.logger.Info("Closing tunnel") + if err := t.close(); err != nil { return err } - return traces.RecordError(ctx, ipc.RestartService(ctx, options)) + c.logger.Debug("Tunnel closed") + runtime.GC() + return nil } -func getOptions() (string, error) { - dataPath := settings.GetString(settings.DataPathKey) - options, err := buildOptions(context.Background(), dataPath) +// Restart closes and restarts the tunnel if it is currently running. Returns an error if the tunnel +// is not running or restart fails. +func (c *VPNClient) Restart(boxOptions BoxOptions) error { + c.mu.Lock() + if c.tunnel == nil || c.tunnel.Status() != Connected { + c.mu.Unlock() + return ErrTunnelNotConnected + } + + t := c.tunnel + c.logger.Info("Restarting tunnel") + t.setStatus(Restarting, nil) + if c.platformIfce != nil { + c.mu.Unlock() + if err := c.platformIfce.RestartService(); err != nil { + c.logger.Error("Failed to restart tunnel via platform interface", "error", err) + err = fmt.Errorf("platform interface restart failed: %w", err) + t.setStatus(ErrorStatus, err) + return err + } + c.logger.Info("Tunnel restarted successfully") + return nil + } + + defer c.mu.Unlock() + if err := c.close(); err != nil { + return fmt.Errorf("closing tunnel: %w", err) + } + options, err := buildOptions(boxOptions) if err != nil { - return "", fmt.Errorf("failed to build options: %w", err) + return fmt.Errorf("failed to build options: %w", err) } opts, err := sbjson.Marshal(options) if err != nil { - return "", fmt.Errorf("failed to marshal options: %w", err) + return fmt.Errorf("failed to marshal options: %w", err) } - return string(opts), nil + if err := c.start(boxOptions.BasePath, string(opts)); err != nil { + c.logger.Error("starting tunnel", "error", err) + return fmt.Errorf("starting tunnel: %w", err) + } + c.logger.Info("Tunnel restarted successfully") + return nil +} + +// Status returns the current status of the tunnel (e.g., running, closed). +func (c *VPNClient) Status() VPNStatus { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return Disconnected + } + return c.tunnel.Status() } // isOpen returns true if the tunnel is open, false otherwise. // Note, this does not check if the tunnel can connect to a server. -func isOpen(ctx context.Context) bool { - state, err := ipc.GetStatus(ctx) - if err != nil { - slog.Error("Failed to get tunnel state", "error", err) - } - return state == ipc.Connected +func (c *VPNClient) isOpen() bool { + return c.Status() == Connected } // Disconnect closes the tunnel and all active connections. -func Disconnect() error { +func (c *VPNClient) Disconnect() error { ctx, span := otel.Tracer(tracerName).Start(context.Background(), "disconnect") defer span.End() - slog.Info("Disconnecting VPN") - return traces.RecordError(ctx, ipc.StopService(ctx)) + c.logger.Info("Disconnecting VPN") + return traces.RecordError(ctx, c.Close()) } -// SelectServer selects the specified server for the tunnel. The tunnel must already be open. -func SelectServer(ctx context.Context, group, tag string) error { - if !isOpen(ctx) { - return errors.New("tunnel is not open") +// SelectServer changes the currently selected server to the one specified by tag. If tag is AutoSelectTag, +// the tunnel will switch to auto-select mode and automatically choose the best server. +func (c *VPNClient) SelectServer(tag string) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil || c.tunnel.Status() != Connected { + return ErrTunnelNotConnected } - if group == autoAllTag { - slog.Info("Switching to auto mode", "group", group) - if err := ipc.SetClashMode(ctx, group); err != nil { - slog.Error("Failed to set auto mode", "group", group, "error", err) - return fmt.Errorf("failed to set auto mode: %w", err) - } - return nil + t := c.tunnel + if tag == AutoSelectTag { + return c.tunnel.selectMode(AutoSelectTag) } - slog.Info("Selecting server", "group", group, "tag", tag) - if err := ipc.SelectOutbound(ctx, group, tag); err != nil { - slog.Error("Failed to select server", "group", group, "tag", tag, "error", err) - return fmt.Errorf("failed to select server %s/%s: %w", group, tag, err) + + c.logger.Info("Selecting server", "tag", tag) + if err := t.selectOutbound(tag); err != nil { + c.logger.Error("Failed to select server", "tag", tag, "error", err) + return fmt.Errorf("failed to select server %s: %w", tag, err) } return nil } -// Status represents the current status of the tunnel, including whether it is open, the selected -// server, and the active server. Active is only set if the tunnel is open. -type Status struct { - TunnelOpen bool - // SelectedServer is the server that is currently selected for the tunnel. - SelectedServer string - // ActiveServer is the server that is currently active for the tunnel. This will differ from - // SelectedServer if using auto-select mode. - ActiveServer string -} - -func GetStatus() (Status, error) { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "get_status") - defer span.End() - slog.Debug("Retrieving tunnel status") - s := Status{ - TunnelOpen: isOpen(ctx), - } - if !s.TunnelOpen { - return s, nil - } - - slog.Log(nil, internal.LevelTrace, "Tunnel is open, retrieving selected and active servers") - group, tag, err := ipc.GetSelected(ctx) - if err != nil { - return s, fmt.Errorf("failed to get selected server: %w", err) - } - if group == autoAllTag { - s.SelectedServer = autoAllTag - } else { - s.SelectedServer = tag +func (c *VPNClient) UpdateOutbounds(group string, newOptions servers.Options) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - - _, active, err := ipc.GetActiveOutbound(ctx) - if err != nil { - return s, fmt.Errorf("failed to get active server: %w", err) - } - s.ActiveServer = active - slog.Log(nil, internal.LevelTrace, "retrieved tunnel status", "tunnelOpen", s.TunnelOpen, "selectedServer", s.SelectedServer, "activeServer", s.ActiveServer) - return s, nil + return c.tunnel.updateOutbounds(group, newOptions) } -func ActiveServer(ctx context.Context) (group, tag string, err error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "active_server") - defer span.End() - slog.Log(nil, internal.LevelTrace, "Retrieving active server") - group, tag, err = ipc.GetActiveOutbound(ctx) - if err != nil { - return "", "", fmt.Errorf("failed to get active server: %w", err) +func (c *VPNClient) AddOutbounds(group string, options servers.Options) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - return group, tag, nil + return c.tunnel.addOutbounds(group, options) } -// ActiveConnections returns a list of currently active connections, ordered from newest to oldest. -// A non-nil error is only returned if there was an error retrieving the connections, or if the -// tunnel is closed. If there are no active connections and the tunnel is open, an empty slice is -// returned without an error. -func ActiveConnections(ctx context.Context) ([]ipc.Connection, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "active_connections") - defer span.End() - connections, err := Connections(ctx) - if err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to get active connections: %w", err)) +func (c *VPNClient) RemoveOutbounds(group string, tags []string) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - - connections = slices.DeleteFunc(connections, func(c ipc.Connection) bool { - return c.ClosedAt != 0 - }) - slices.SortFunc(connections, func(a, b ipc.Connection) int { - return int(b.CreatedAt - a.CreatedAt) - }) - return connections, nil + return c.tunnel.removeOutbounds(group, tags) } // Connections returns a list of all connections, both active and recently closed. A non-nil error // is only returned if there was an error retrieving the connections, or if the tunnel is closed. // If there are no connections and the tunnel is open, an empty slice is returned without an error. -func Connections(ctx context.Context) ([]ipc.Connection, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "connections") +func (c *VPNClient) Connections() ([]Connection, error) { + _, span := otel.Tracer(tracerName).Start(context.Background(), "connections") defer span.End() - connections, err := ipc.GetConnections(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get connections: %w", err) + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return nil, fmt.Errorf("failed to get connections: %w", ErrTunnelNotConnected) + } + tm := c.tunnel.clashServer.TrafficManager() + activeConns := tm.Connections() + closedConns := tm.ClosedConnections() + connections := make([]Connection, 0, len(activeConns)+len(closedConns)) + for _, conn := range activeConns { + connections = append(connections, newConnection(conn)) + } + for _, conn := range closedConns { + connections = append(connections, newConnection(conn)) } return connections, nil } -// AutoSelections represents the currently active servers for each auto server group. -type AutoSelections struct { - Lantern string - User string - AutoAll string -} - -// AutoSelectionsEvent is emitted when server location changes for any auto server group. -type AutoSelectionsEvent struct { +// AutoSelectedEvent is emitted when the auto-selected server changes. +type AutoSelectedEvent struct { events.Event - Selections AutoSelections + Selected string `json:"selected"` } -// SelectionUnavailable is the sentinel value returned for an auto-selection -// group that has no active server (tunnel not running, group not found, etc.). -const SelectionUnavailable = "Unavailable" - -// AutoServerSelections returns the currently active server for each auto server group. If the group -// is not found or has no active server, SelectionUnavailable is returned for that group. -func AutoServerSelections() (AutoSelections, error) { - as := AutoSelections{ - Lantern: SelectionUnavailable, - User: SelectionUnavailable, - AutoAll: SelectionUnavailable, +// CurrentAutoSelectedServer returns the tag of the currently auto-selected server +func (c *VPNClient) CurrentAutoSelectedServer() (string, error) { + if !c.isOpen() { + c.logger.Log(nil, log.LevelTrace, "Tunnel not running, cannot get auto selections") + return "", nil } - ctx := context.Background() - if !isOpen(ctx) { - slog.Log(ctx, internal.LevelTrace, "Tunnel not running, cannot get auto selections") - return as, nil + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return "", ErrTunnelNotConnected } - groups, err := ipc.GetGroups(ctx) - if err != nil { - return as, fmt.Errorf("failed to get groups: %w", err) - } - slog.Log(ctx, internal.LevelTrace, "Retrieved groups", "groups", groups) - selected := func(tag string) string { - idx := slices.IndexFunc(groups, func(g ipc.OutboundGroup) bool { - return g.Tag == tag - }) - if idx < 0 || groups[idx].Selected == "" { - slog.Log(ctx, internal.LevelTrace, "Group not found or has no selection", "tag", tag) - return SelectionUnavailable - } - return groups[idx].Selected - } - auto := AutoSelections{ - Lantern: selected(autoLanternTag), - User: selected(autoUserTag), + outboundMgr := service.FromContext[adapter.OutboundManager](c.tunnel.ctx) + if outboundMgr == nil { + return "", errors.New("outbound manager not found") } - - switch all := selected(autoAllTag); all { - case autoLanternTag: - auto.AutoAll = auto.Lantern - case autoUserTag: - auto.AutoAll = auto.User - default: - auto.AutoAll = all + outbound, loaded := outboundMgr.Outbound(AutoSelectTag) + if !loaded { + return "", fmt.Errorf("auto select group not found") } - return auto, nil + return outbound.(adapter.OutboundGroup).Now(), nil } const ( - rapidPollInterval = 500 * time.Millisecond - rapidPollWindow = 15 * time.Second + rapidPollInterval = 500 * time.Millisecond + rapidPollWindow = 15 * time.Second steadyPollInterval = 10 * time.Second ) -// AutoSelectionsChangeListener polls for auto-selection changes and emits an -// AutoSelectionsEvent whenever the selection differs from the previous value. +// AutoSelectedChangeListener polls for auto-selection changes and emits an +// AutoSelectedEvent whenever the selection differs from the previous value. // It performs an initial rapid poll to catch the first selection soon after // tunnel connect, then settles into a slower steady-state interval. -func AutoSelectionsChangeListener(ctx context.Context) { +func (c *VPNClient) AutoSelectedChangeListener(ctx context.Context) { go func() { - var prev AutoSelections + var prev string // Rapid initial poll to emit the first selection promptly after connect. initialDeadline := time.NewTimer(rapidPollWindow) @@ -394,17 +374,15 @@ func AutoSelectionsChangeListener(ctx context.Context) { case <-initialDeadline.C: break initial case <-tick.C: - curr, err := AutoServerSelections() + curr, err := c.CurrentAutoSelectedServer() if err != nil { tick.Reset(rapidPollInterval) continue } if curr != prev { prev = curr - events.Emit(AutoSelectionsEvent{Selections: curr}) - if curr.Lantern != SelectionUnavailable || - curr.User != SelectionUnavailable || - curr.AutoAll != SelectionUnavailable { + events.Emit(AutoSelectedEvent{Selected: curr}) + if curr != "" { break initial } } @@ -427,14 +405,14 @@ func AutoSelectionsChangeListener(ctx context.Context) { case <-ctx.Done(): return case <-tick.C: - curr, err := AutoServerSelections() + curr, err := c.CurrentAutoSelectedServer() if err != nil { tick.Reset(steadyPollInterval) continue } if curr != prev { prev = curr - events.Emit(AutoSelectionsEvent{Selections: curr}) + events.Emit(AutoSelectedEvent{Selected: curr}) } tick.Reset(steadyPollInterval) } @@ -442,215 +420,145 @@ func AutoSelectionsChangeListener(ctx context.Context) { }() } -const urlTestHistoryFileName = "url_test_history.json" - -var urlTestMu sync.Mutex - -// RunURLTests performs URL tests for all outbounds defined in configs. It is intended to run in -// response to configuration updates to provide continuous bandit callback data even when the VPN -// tunnel is not active. When the tunnel IS active, its own CheckOutbounds handles URL testing, so -// this is skipped. -func RunURLTests(path string) { - // Skip if the tunnel is handling URL tests - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if isOpen(ctx) { - slog.Debug("Tunnel is active, skipping standalone URL tests") - return - } - - // Prevent overlapping runs - if !urlTestMu.TryLock() { - return - } - defer urlTestMu.Unlock() - - results, traceCtx, hasTrace, err := preTest(path) - if err != nil { - slog.Error("URL test failed", "error", err) - if len(results) == 0 { - return - } - // Tests ran but a non-critical step (e.g. saving history) failed. - // Continue to emit the span and log the results we do have. - } - - // Record URL test results in a span linked to the bandit's trace. - if hasTrace { - _, span := otel.Tracer(tracerName).Start(traceCtx, "radiance.url_tests_complete", - trace.WithAttributes( - attribute.Int("bandit.test_count", len(results)), - ), - ) - for tag, delay := range results { - span.AddEvent("url_test_result", trace.WithAttributes( - attribute.String("outbound", tag), - attribute.Int("latency_ms", int(delay)), - )) - } - span.End() - } - - var formattedResults []string - for tag, delay := range results { - formattedResults = append(formattedResults, fmt.Sprintf("%s: [%dms]", tag, delay)) - } - slog.Log(nil, internal.LevelTrace, "URL test complete", "results", strings.Join(formattedResults, "; ")) -} - -func preTest(path string) (map[string]uint16, context.Context, bool, error) { - slog.Info("Performing pre-start URL tests") - - confPath := filepath.Join(path, common.ConfigFileName) - slog.Debug("Loading config file", "confPath", confPath) - cfg, err := loadConfig(confPath) - if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to load config: %w", err) +// RunOfflineURLTests will run URL tests for all outbounds if the tunnel is not currently connected. +// This can improve initial connection times by pre-determining reachability and latency to servers. +// +// If [VPNClient.Connect] is called while RunOfflineURLTests is running, the tests will be cancelled and +// any results will be discarded. +func (c *VPNClient) RunOfflineURLTests(basePath string, outbounds []option.Outbound, banditURLs map[string]string) error { + c.mu.Lock() + if c.tunnel != nil { + c.mu.Unlock() + return ErrTunnelAlreadyConnected + } + select { + case <-c.preTestDone: + // no tests currently running, safe to start new tests + default: + c.mu.Unlock() + return errors.New("pre-start tests already running") } + ctx, cancel := context.WithCancel(box.BaseContext()) + c.preTestCancel = cancel + done := make(chan struct{}) + c.preTestDone = done + c.mu.Unlock() + defer close(done) // Extract bandit trace context for distributed tracing - traceCtx, hasTrace := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides) - - cfgOpts := cfg.Options - - slog.Debug("Loading user servers") - userOpts, err := loadUserOptions(path) - if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to load user options: %w", err) - } + traceCtx, hasTrace := traces.ExtractBanditTraceContext(banditURLs) - // since we are only doing URL tests, we only need the outbounds from both configs; we skip - // endpoints as most/all require elevated privileges to use. just using outbounds is sufficient - // to improve initial connect times. - outbounds := append(cfgOpts.Outbounds, userOpts.Outbounds...) + c.logger.Info("Performing pre-start URL tests") tags := make([]string, 0, len(outbounds)) for _, ob := range outbounds { tags = append(tags, ob.Tag) } - outbounds = append(outbounds, urlTestOutbound("preTest", tags, cfg.BanditURLOverrides)) + outbounds = append(outbounds, urlTestOutbound("preTest", tags, banditURLs)) options := option.Options{ Log: &option.LogOptions{Disabled: true}, Outbounds: outbounds, + Experimental: &option.ExperimentalOptions{ + CacheFile: &option.CacheFileOptions{ + Enabled: true, + Path: filepath.Join(basePath, cacheFileName), + CacheID: cacheID, + }, + }, } // create pre-started box instance. we just use the standard box since we don't need a // platform interface for testing. - ctx := box.BaseContext() ctx = service.ContextWith[filemanager.Manager](ctx, nil) urlTestHistoryStorage := urltest.NewHistoryStorage() ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage) service.MustRegister[adapter.URLTestHistoryStorage](ctx, urlTestHistoryStorage) // for good measure - ctx, cancel := context.WithTimeout(ctx, 15*time.Second) // enough time for bandit callback tests through proxies + ctx, cancel = context.WithTimeout(ctx, 5*time.Second) // enough time for tests to complete or fail defer cancel() instance, err := sbox.New(sbox.Options{ Context: ctx, Options: options, }) if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to create sing-box instance: %w", err) + return fmt.Errorf("failed to create sing-box instance: %w", err) } defer instance.Close() - if err := instance.PreStart(); err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to start sing-box instance: %w", err) - } - outbound, ok := instance.Outbound().Outbound("preTest") - if !ok { - return nil, context.Background(), false, errors.New("preTest outbound not found") + // connect may have been called while we were setting up, so check if we should abort before + // starting the instance. + select { + case <-ctx.Done(): + return fmt.Errorf("pre-start tests cancelled: %w", ctx.Err()) + default: } - tester, ok := outbound.(adapter.URLTestGroup) - if !ok { - return nil, context.Background(), false, errors.New("preTest outbound is not a URLTestGroup") + if err := instance.PreStart(); err != nil { + return fmt.Errorf("failed to start sing-box instance: %w", err) } + outbound, _ := instance.Outbound().Outbound("preTest") + tester, _ := outbound.(adapter.URLTestGroup) // run URL tests results, err := tester.URLTest(ctx) if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to perform URL tests: %w", err) - } - - historyPath := filepath.Join(path, urlTestHistoryFileName) - if err := saveURLTestResults(urlTestHistoryStorage, historyPath, results); err != nil { - return results, traceCtx, hasTrace, fmt.Errorf("failed to save URL test results: %w", err) - } - return results, traceCtx, hasTrace, nil -} - - -func saveURLTestResults(storage *urltest.HistoryStorage, path string, results map[string]uint16) error { - slog.Debug("Saving URL test history", "path", path) - history := make(map[string]*adapter.URLTestHistory, len(results)) - for tag := range results { - history[tag] = storage.LoadURLTestHistory(tag) - } - buf, err := json.Marshal(history) - if err != nil { - return fmt.Errorf("failed to marshal URL test history: %w", err) + c.logger.Error("Pre-start URL test failed", "error", err) + return fmt.Errorf("pre-start URL test failed: %w", err) } - return atomicfile.WriteFile(path, buf, 0o644) -} -func loadURLTestHistory(storage *urltest.HistoryStorage, path string) error { - slog.Debug("Loading URL test history", "path", path) - buf, err := atomicfile.ReadFile(path) - if errors.Is(err, os.ErrNotExist) { - return nil - } - if err != nil { - return fmt.Errorf("failed to read URL test history file: %w", err) + // Record URL test results in a span linked to the bandit's trace. + if hasTrace { + _, span := otel.Tracer(tracerName).Start(traceCtx, "radiance.url_tests_complete", + trace.WithAttributes( + attribute.Int("bandit.test_count", len(results)), + ), + ) + for tag, delay := range results { + span.AddEvent("url_test_result", trace.WithAttributes( + attribute.String("outbound", tag), + attribute.Int("latency_ms", int(delay)), + )) + } + span.End() } - history := make(map[string]*adapter.URLTestHistory) - if err := json.Unmarshal(buf, &history); err != nil { - return fmt.Errorf("failed to unmarshal URL test history: %w", err) - } - for tag, result := range history { - storage.StoreURLTestHistory(tag, result) + var fmttedResults []string + for tag, delay := range results { + fmttedResults = append(fmttedResults, fmt.Sprintf("%s: [%dms]", tag, delay)) } + c.logger.Log(nil, log.LevelTrace, "Pre-start URL test complete", "results", strings.Join(fmttedResults, "; ")) return nil } -func SmartRoutingEnabled() bool { - return settings.GetBool(settings.SmartRoutingKey) -} - -func SetSmartRouting(enable bool) error { - if SmartRoutingEnabled() == enable { - return nil - } - if err := settings.Set(settings.SmartRoutingKey, enable); err != nil { - return err - } - slog.Info("Updated Smart-Routing", "enabled", enable) - return restartTunnel() -} - -func AdBlockEnabled() bool { - return settings.GetBool(settings.AdBlockKey) -} - -func SetAdBlock(enable bool) error { - if AdBlockEnabled() == enable { - return nil - } - if err := settings.Set(settings.AdBlockKey, enable); err != nil { - return err - } - slog.Info("Updated Ad-Block", "enabled", enable) - return restartTunnel() -} - -func restartTunnel() error { - ctx := context.Background() - if !isOpen(ctx) { - return nil - } - slog.Info("Restarting tunnel") - options, err := getOptions() +// ClearNetErrorState attempts to clear any error state left by a previous unclean shutdown, such +// as from a crash. No errors are returned and this fails silently. +func ClearNetErrorState() { + options := baseOpts("") + options = option.Options{ + DNS: options.DNS, + Inbounds: options.Inbounds, + Route: &option.RouteOptions{ + AutoDetectInterface: true, + Rules: []option.Rule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultRule{ + RawDefaultRule: option.RawDefaultRule{ + Protocol: []string{"dns"}, + }, + RuleAction: option.RuleAction{ + Action: C.RuleActionTypeHijackDNS, + }, + }, + }, + }, + }, + } + ctx, cancel := context.WithCancel(box.BaseContext()) + defer cancel() + b, err := sbox.New(sbox.Options{ + Context: ctx, + Options: options, + }) if err != nil { - return err - } - if err := ipc.RestartService(ctx, options); err != nil { - return fmt.Errorf("failed to restart tunnel: %w", err) + return } - return nil + defer b.Close() + b.Start() } diff --git a/vpn/vpn_test.go b/vpn/vpn_test.go index a3b2c8fc..8a4caf32 100644 --- a/vpn/vpn_test.go +++ b/vpn/vpn_test.go @@ -1,235 +1,249 @@ package vpn import ( - "context" - "slices" + "errors" + "log/slog" + "sync" "testing" - box "github.com/getlantern/lantern-box" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/experimental/cachefile" - "github.com/sagernet/sing-box/experimental/clashapi" "github.com/sagernet/sing-box/experimental/libbox" - "github.com/sagernet/sing/service" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + rlog "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/servers" ) -func TestSelectServer(t *testing.T) { - var tests = []struct { - name string - initialGroup string - wantGroup string - wantTag string - }{ - { - name: "select in same group", - initialGroup: "socks", - wantGroup: "socks", - wantTag: "socks2-out", - }, - { - name: "select in different group", - initialGroup: "socks", - wantGroup: "http", - wantTag: "http2-out", - }, - } +// stubPlatform implements PlatformInterface for testing without real VPN operations. +type stubPlatform struct { + libbox.PlatformInterface - testutil.SetPathsForTesting(t) - mservice := setupVpnTest(t) + restartErr error + restartCalled bool + postCloseCalled bool + mu sync.Mutex +} - ctx := mservice.Ctx() - clashServer := service.FromContext[adapter.ClashServer](ctx).(*clashapi.Server) - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) +func (s *stubPlatform) RestartService() error { + s.mu.Lock() + defer s.mu.Unlock() + s.restartCalled = true + return s.restartErr +} - type _selector interface { - adapter.OutboundGroup - Start() error - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // set initial group - clashServer.SetMode(tt.initialGroup) - - // start the selector - outbound, ok := outboundMgr.Outbound(tt.wantGroup) - require.True(t, ok, tt.wantGroup+" selector should exist") - selector := outbound.(_selector) - require.NoError(t, selector.Start(), "failed to start selector") - - mservice.status = ipc.Connected - require.NoError(t, SelectServer(context.Background(), tt.wantGroup, tt.wantTag)) - assert.Equal(t, tt.wantTag, selector.Now(), tt.wantTag+" should be selected") - assert.Equal(t, tt.wantGroup, clashServer.Mode(), "clash mode should be "+tt.wantGroup) - }) - } +func (s *stubPlatform) PostServiceClose() { + s.mu.Lock() + defer s.mu.Unlock() + s.postCloseCalled = true } -func TestSelectedServer(t *testing.T) { - wantGroup := "socks" - wantTag := "socks2-out" +func TestNewVPNClient(t *testing.T) { + t.Run("with nil logger uses default", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), nil, nil) + require.NotNil(t, c) + assert.Equal(t, slog.Default(), c.logger) + assert.Equal(t, Disconnected, c.Status()) + }) - testutil.SetPathsForTesting(t) - opts, _, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to load test box options") - cacheFile := cachefile.New(context.Background(), *opts.Experimental.CacheFile) - require.NoError(t, cacheFile.Start(adapter.StartStateInitialize)) + t.Run("with custom logger", func(t *testing.T) { + logger := rlog.NoOpLogger() + c := NewVPNClient(t.TempDir(), logger, nil) + require.NotNil(t, c) + assert.Equal(t, logger, c.logger) + }) +} - require.NoError(t, cacheFile.StoreMode(wantGroup)) - require.NoError(t, cacheFile.StoreSelected(wantGroup, wantTag)) - _ = cacheFile.Close() +func TestStatus_DisconnectedWhenNoTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + assert.Equal(t, Disconnected, c.Status()) + assert.False(t, c.isOpen()) +} - t.Run("with tunnel open", func(t *testing.T) { - mservice := setupVpnTest(t) - outboundMgr := service.FromContext[adapter.OutboundManager](mservice.Ctx()) - require.NoError(t, outboundMgr.Start(adapter.StartStateStart), "failed to start outbound manager") +func TestClose_NilTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + // Closing when no tunnel is open should succeed without error. + assert.NoError(t, c.Close()) +} - group, tag, err := ipc.GetSelected(context.Background()) - require.NoError(t, err, "should not error when getting selected server") - assert.Equal(t, wantGroup, group, "group should match") - assert.Equal(t, wantTag, tag, "tag should match") - }) +func TestClose_CallsPostServiceClose(t *testing.T) { + p := &stubPlatform{} + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), p) + + // Set up a minimal tunnel that can be closed. + tun := &tunnel{} + tun.status.Store(Connected) + c.tunnel = tun + + err := c.Close() + assert.NoError(t, err) + assert.Nil(t, c.tunnel) + + p.mu.Lock() + assert.True(t, p.postCloseCalled, "PostServiceClose should be called after closing") + p.mu.Unlock() } -func TestAutoServerSelections(t *testing.T) { - testutil.SetPathsForTesting(t) - mgr := &mockOutMgr{ - outbounds: []adapter.Outbound{ - &mockOutbound{tag: "socks1-out"}, - &mockOutbound{tag: "socks2-out"}, - &mockOutbound{tag: "http1-out"}, - &mockOutbound{tag: "http2-out"}, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoLanternTag}, - now: "socks1-out", - all: []string{"socks1-out", "socks2-out"}, - }, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoUserTag}, - now: "http2-out", - all: []string{"http1-out", "http2-out"}, - }, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoAllTag}, - now: autoLanternTag, - all: []string{autoLanternTag, autoUserTag}, - }, - }, - } - want := AutoSelections{ - Lantern: "socks1-out", - User: "http2-out", - AutoAll: "socks1-out", +func TestDisconnect_NoTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + assert.NoError(t, c.Disconnect()) +} + +func TestConnect_AlreadyConnected(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(Connected) + c.tunnel = tun + + err := c.Connect(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelAlreadyConnected) +} + +func TestConnect_TransientStates(t *testing.T) { + for _, status := range []VPNStatus{Restarting, Connecting, Disconnecting} { + t.Run(string(status), func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(status) + c.tunnel = tun + + err := c.Connect(BoxOptions{}) + require.Error(t, err) + assert.Contains(t, err.Error(), string(status)) + }) } - ctx := box.BaseContext() - service.MustRegister[adapter.OutboundManager](ctx, mgr) - m := &mockService{ - ctx: ctx, - status: ipc.Connected, +} + +func TestConnect_CleansUpStaleTunnel(t *testing.T) { + for _, status := range []VPNStatus{Disconnected, ErrorStatus} { + t.Run(string(status), func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(status) + c.tunnel = tun + + // Connect will fail because BoxOptions has no outbounds, but the stale + // tunnel should be cleared first (the error comes from buildOptions). + err := c.Connect(BoxOptions{BasePath: t.TempDir()}) + require.Error(t, err) + // The tunnel should have been nilled out before buildOptions was called + assert.Contains(t, err.Error(), "no outbounds") + }) } - ipcServer := ipc.NewServer(m) - require.NoError(t, ipcServer.Start()) +} - got, err := AutoServerSelections() - require.NoError(t, err, "should not error when getting auto server selections") - require.Equal(t, want, got, "selections should match") +func TestRestart_NotConnected(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.Restart(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) } -type mockOutMgr struct { - adapter.OutboundManager - outbounds []adapter.Outbound +func TestRestart_NotConnectedStatus(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(Disconnected) + c.tunnel = tun + + err := c.Restart(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) } -func (o *mockOutMgr) Outbounds() []adapter.Outbound { - return o.outbounds +func TestRestart_WithPlatformInterface(t *testing.T) { + p := &stubPlatform{} + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), p) + tun := &tunnel{} + tun.status.Store(Connected) + c.tunnel = tun + + err := c.Restart(BoxOptions{}) + assert.NoError(t, err) + + p.mu.Lock() + assert.True(t, p.restartCalled) + p.mu.Unlock() + assert.Equal(t, Restarting, tun.Status()) } -func (o *mockOutMgr) Outbound(tag string) (adapter.Outbound, bool) { - idx := slices.IndexFunc(o.outbounds, func(ob adapter.Outbound) bool { - return ob.Tag() == tag - }) - if idx == -1 { - return nil, false - } - return o.outbounds[idx], true +func TestRestart_PlatformInterfaceError(t *testing.T) { + restartErr := errors.New("restart failed") + p := &stubPlatform{restartErr: restartErr} + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), p) + tun := &tunnel{} + tun.status.Store(Connected) + c.tunnel = tun + + err := c.Restart(BoxOptions{}) + require.Error(t, err) + assert.ErrorIs(t, err, restartErr) + assert.Equal(t, ErrorStatus, tun.Status()) } -type mockOutbound struct { - adapter.Outbound - tag string +func TestSelectServer_NotConnected(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.SelectServer("some-tag") + assert.ErrorIs(t, err, ErrTunnelNotConnected) } -func (o *mockOutbound) Tag() string { return o.tag } -func (o *mockOutbound) Type() string { return "mock" } +func TestSelectServer_DisconnectedTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(Disconnected) + c.tunnel = tun -type mockOutboundGroup struct { - mockOutbound - now string - all []string + err := c.SelectServer("some-tag") + assert.ErrorIs(t, err, ErrTunnelNotConnected) } -func (o *mockOutboundGroup) Now() string { return o.now } -func (o *mockOutboundGroup) All() []string { return o.all } - -var _ ipc.Service = (*mockService)(nil) +func TestUpdateOutbounds_NilTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.UpdateOutbounds("lantern", servers.Options{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) +} -type mockService struct { - ctx context.Context - status ipc.VPNStatus - clash *clashapi.Server +func TestAddOutbounds_NilTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.AddOutbounds("lantern", servers.Options{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) } -func (m *mockService) Ctx() context.Context { return m.ctx } -func (m *mockService) Status() ipc.VPNStatus { return m.status } -func (m *mockService) ClashServer() *clashapi.Server { return m.clash } -func (m *mockService) Close() error { return nil } -func (m *mockService) Start(context.Context, string) error { return nil } -func (m *mockService) Restart(context.Context, string) error { return nil } -func (m *mockService) UpdateOutbounds(options servers.Servers) error { return nil } -func (m *mockService) AddOutbounds(group string, options servers.Options) error { return nil } -func (m *mockService) RemoveOutbounds(group string, tags []string) error { return nil } +func TestRemoveOutbounds_NilTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.RemoveOutbounds("lantern", []string{"tag1"}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) +} -func setupVpnTest(t *testing.T) *mockService { - path := settings.GetString(settings.DataPathKey) - setupOpts := libbox.SetupOptions{ - BasePath: path, - WorkingPath: path, - TempPath: path, - } - require.NoError(t, libbox.Setup(&setupOpts)) +func TestConnections_NilTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + conns, err := c.Connections() + assert.Nil(t, conns) + assert.ErrorIs(t, err, ErrTunnelNotConnected) +} - _, boxOpts, err := testBoxOptions(path) - require.NoError(t, err, "failed to load test box options") +func TestCurrentAutoSelectedServer_NotOpen(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + selected, err := c.CurrentAutoSelectedServer() + assert.NoError(t, err) + assert.Empty(t, selected) +} - ctx := box.BaseContext() +func TestRunOfflineURLTests_AlreadyConnected(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + tun := &tunnel{} + tun.status.Store(Connected) + c.tunnel = tun - lb, err := libbox.NewServiceWithContext(ctx, boxOpts, nil) - require.NoError(t, err) - clashServer := service.FromContext[adapter.ClashServer](ctx) - cacheFile := service.FromContext[adapter.CacheFile](ctx) + err := c.RunOfflineURLTests("", nil, nil) + assert.ErrorIs(t, err, ErrTunnelAlreadyConnected) +} - m := &mockService{ - ctx: ctx, - status: ipc.Connected, - clash: clashServer.(*clashapi.Server), +func TestConcurrentStatusAccess(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = c.Status() + }() } - ipcServer := ipc.NewServer(m) - require.NoError(t, ipcServer.Start()) - - t.Cleanup(func() { - lb.Close() - ipcServer.Close() - cacheFile.Close() - clashServer.Close() - }) - require.NoError(t, cacheFile.Start(adapter.StartStateInitialize)) - require.NoError(t, clashServer.Start(adapter.StartStateStart)) - return m + wg.Wait() }