diff --git a/cmd/config/init.go b/cmd/config/init.go index b505ce8c7..f7f9a9648 100644 --- a/cmd/config/init.go +++ b/cmd/config/init.go @@ -317,6 +317,9 @@ func configInitRun(opts *ConfigInitOptions) error { output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) printLangPreferenceConfirmation(opts) output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": opts.AppID, "appSecret": "****", "brand": brand}) + if err := runProbe(opts.Ctx, f, opts.AppID, opts.appSecret, brand); err != nil { + return err + } return nil } @@ -356,6 +359,9 @@ func configInitRun(opts *ConfigInitOptions) error { } printLangPreferenceConfirmation(opts) output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "appSecret": "****", "brand": result.Brand}) + if err := runProbe(opts.Ctx, f, result.AppID, result.AppSecret, result.Brand); err != nil { + return err + } return nil } @@ -398,6 +404,11 @@ func configInitRun(opts *ConfigInitOptions) error { output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf(msg.ConfigSaved, result.AppID)) } printLangPreferenceConfirmation(opts) + if result.AppSecret != "" { + if err := runProbe(opts.Ctx, f, result.AppID, result.AppSecret, result.Brand); err != nil { + return err + } + } return nil } @@ -485,5 +496,10 @@ func configInitRun(opts *ConfigInitOptions) error { } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) printLangPreferenceConfirmation(opts) + if appSecretInput != "" { + if err := runProbe(opts.Ctx, f, resolvedAppId, appSecretInput, parseBrand(resolvedBrand)); err != nil { + return err + } + } return nil } diff --git a/cmd/config/init_probe.go b/cmd/config/init_probe.go new file mode 100644 index 000000000..96bafb426 --- /dev/null +++ b/cmd/config/init_probe.go @@ -0,0 +1,98 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" +) + +// probeTimeout is the total wall-clock budget for the credential probe step +// (covering both TAT acquisition and the subsequent probe request). +const probeTimeout = 3 * time.Second + +// runProbe runs a best-effort credential validation after config init has +// persisted the App ID and App Secret. It returns a non-nil error only for a +// deterministic credential-rejection signal; every other outcome returns nil +// so that valid configurations and transient/upstream noise never block the +// command. +// +// The function performs up to two HTTP calls in series, bounded by +// probeTimeout: +// +// 1. A TAT request using the just-saved credentials. FetchTAT surfaces a +// deterministic credential-rejection signal — a non-zero TAT body code or +// HTTP 401/403 — as *errs.AuthenticationError. That is the only outcome +// propagated to the caller, re-wrapped with an actionable, jargon-free +// message so the root dispatcher renders a typed error envelope and +// `config init` exits non-zero. Ambiguous failures (transport errors, +// 5xx, JSON parse errors, timeouts) surface as *errs.NetworkError / +// *errs.InternalError and are swallowed (return nil). +// +// 2. If TAT succeeded, a POST to the probe endpoint is fired. The outcome of +// that call (success, server error, timeout, parse failure) is always +// ignored — return nil regardless. +func runProbe(parent context.Context, factory *cmdutil.Factory, appID, appSecret string, brand core.LarkBrand) error { + if factory == nil { + return nil + } + httpClient, err := factory.HttpClient() + if err != nil { + return nil + } + + ctx, cancel := context.WithTimeout(parent, probeTimeout) + defer cancel() + + token, err := credential.FetchTAT(ctx, httpClient, brand, appID, appSecret) + if err != nil { + var authErr *errs.AuthenticationError + if errors.As(err, &authErr) { + // Deterministic credential rejection: re-surface with an + // actionable, jargon-free message while preserving the upstream + // code (10003 / 10014 / HTTP 401 / ...) and the error chain. + return &errs.AuthenticationError{ + Problem: errs.Problem{ + Category: errs.CategoryAuthentication, + Subtype: authErr.Subtype, + Code: authErr.Code, + Message: "configured credentials may be invalid: please verify the App ID and App Secret", + Hint: "re-run `lark-cli config init` with the correct App ID and App Secret", + }, + Cause: err, + } + } + // Ambiguous failure (transport / 5xx / parse / timeout) — stay silent. + return nil + } + + // TAT succeeded — fire the probe call. Any outcome is ignored. + url := core.ResolveEndpoints(brand).Open + "/open-apis/application/v6/larksuite_cli_app/probe" + body := []byte(fmt.Sprintf(`{"from":"lark-cli/%s"}`, build.Version)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + return nil +} diff --git a/cmd/config/init_probe_test.go b/cmd/config/init_probe_test.go new file mode 100644 index 000000000..ecf18bbab --- /dev/null +++ b/cmd/config/init_probe_test.go @@ -0,0 +1,326 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" +) + +// fakeRT routes requests to per-path handlers and records what it saw. +type fakeRT struct { + tatHandler func(req *http.Request) (*http.Response, error) + probeHandler func(req *http.Request) (*http.Response, error) + tatCalls int + probeCalls int + probeReq *http.Request + probeBody string +} + +func (f *fakeRT) RoundTrip(req *http.Request) (*http.Response, error) { + switch { + case strings.HasSuffix(req.URL.Path, "/auth/v3/tenant_access_token/internal"): + f.tatCalls++ + if f.tatHandler == nil { + return jsonResp(200, `{"code":0,"tenant_access_token":"t-ok"}`), nil + } + return f.tatHandler(req) + case strings.HasSuffix(req.URL.Path, "/application/v6/larksuite_cli_app/probe"): + f.probeCalls++ + f.probeReq = req + if req.Body != nil { + b, _ := io.ReadAll(req.Body) + f.probeBody = string(b) + } + if f.probeHandler == nil { + return jsonResp(200, `{"code":0,"data":{},"msg":"success"}`), nil + } + return f.probeHandler(req) + } + return nil, errors.New("unexpected URL: " + req.URL.String()) +} + +func jsonResp(code int, body string) *http.Response { + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} + +// fakeFactory builds a test Factory whose HttpClient is overridden to use +// the caller-supplied RoundTripper. +// +// Wired through cmdutil.TestFactory(t, nil) so the canonical IOStreams, +// Credential, Keychain and FileIO wiring is in place (per repo test-factory +// guidance). The HttpClient is then swapped to our stub so we can drive +// exact HTTP responses for the probe. Config-dir isolation is set up via +// t.Setenv(LARKSUITE_CLI_CONFIG_DIR, t.TempDir()) so any incidental config +// touch lands in a temp dir rather than the developer's real config. +// +// The returned buffer is the Factory's stderr. runProbe no longer writes any +// warning to stderr (it propagates a typed error instead), so every test +// asserts this buffer stays empty as an invariant. +func fakeFactory(t *testing.T, rt http.RoundTripper) (*cmdutil.Factory, *bytes.Buffer) { + t.Helper() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + f, _, errBuf, _ := cmdutil.TestFactory(t, nil) + f.HttpClient = func() (*http.Client, error) { + return &http.Client{Transport: rt}, nil + } + return f, errBuf +} + +// assertAuthError asserts err is a deterministic credential-rejection signal: +// a non-nil *errs.AuthenticationError carrying the expected upstream code and +// the actionable, jargon-free message runProbe re-wraps it with. +func assertAuthError(t *testing.T, err error, wantCode int) { + t.Helper() + if err == nil { + t.Fatalf("expected *errs.AuthenticationError (code %d), got nil", wantCode) + } + var authErr *errs.AuthenticationError + if !errors.As(err, &authErr) { + t.Fatalf("expected *errs.AuthenticationError, got %T: %v", err, err) + } + if authErr.Category != errs.CategoryAuthentication { + t.Errorf("Category = %q, want %q", authErr.Category, errs.CategoryAuthentication) + } + if authErr.Code != wantCode { + t.Errorf("Code = %d, want %d", authErr.Code, wantCode) + } + if !strings.Contains(authErr.Message, "configured credentials may be invalid") { + t.Errorf("Message = %q, want it to mention 'configured credentials may be invalid'", authErr.Message) + } + if authErr.Hint == "" { + t.Error("expected a non-empty Hint guiding the user to re-run config init") + } +} + +func TestRunProbe_TATCode10003_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(200, `{"code":10003,"msg":"invalid app_id"}`), nil + }, + } + f, errBuf := fakeFactory(t, rt) + + err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu) + + if rt.probeCalls != 0 { + t.Error("probe endpoint must not be called when TAT fails") + } + assertAuthError(t, err, 10003) + if errBuf.Len() != 0 { + t.Errorf("runProbe must not write to stderr, got: %q", errBuf.String()) + } +} + +func TestRunProbe_TATCode10005_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(200, `{"code":10005,"msg":"invalid app_secret"}`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 10005) +} + +func TestRunProbe_TATCode10009_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(200, `{"code":10009,"msg":"app_id and app_secret mismatch"}`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 10009) +} + +// 10014 ("app secret invalid") is the code feishu's TAT v3 endpoint actually +// returns when an existing appID is paired with the wrong secret — the +// single most common real-world failure. Locked in by this test because the +// original whitelist missed it. +func TestRunProbe_TATCode10014_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(200, `{"code":10014,"msg":"app secret invalid"}`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 10014) +} + +func TestRunProbe_TATHTTP401_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(401, `unauthorized`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 401) +} + +func TestRunProbe_TATHTTP403_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(403, `forbidden`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 403) +} + +// Policy: any non-zero TAT body code is treated as a deterministic +// credential-rejection signal (the server saw the payload and refused it), +// regardless of whether the specific code is one we recognize. This guards +// against upstream adding new credential-class codes without us noticing — +// previously, missing 10014 from a hand-rolled allowlist caused real-world +// "wrong app_secret" failures to go silent. The trade-off: ambiguous +// failures stay in NetworkError / InternalError lanes (transport / 5xx / +// parse / timeout) and remain silent, so valid configs are not disturbed. +func TestRunProbe_TATUnknownBodyCode_ReturnsAuthError(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(200, `{"code":99999,"msg":"future-unknown"}`), nil + }, + } + f, _ := fakeFactory(t, rt) + assertAuthError(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), 99999) +} + +// assertSilent asserts runProbe stayed quiet: no propagated error and nothing +// written to stderr. Used for every ambiguous (non-credential) outcome. +func assertSilent(t *testing.T, err error, errBuf *bytes.Buffer) { + t.Helper() + if err != nil { + t.Errorf("expected nil (silent), got error: %v", err) + } + if errBuf.Len() != 0 { + t.Errorf("expected no stderr output, got: %q", errBuf.String()) + } +} + +func TestRunProbe_TATHTTP500_Silent(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(500, `server error`), nil + }, + } + f, errBuf := fakeFactory(t, rt) + assertSilent(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), errBuf) +} + +func TestRunProbe_TATTransportError_Silent(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("network down") + }, + } + f, errBuf := fakeFactory(t, rt) + assertSilent(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), errBuf) +} + +func TestRunProbe_TATSuccess_ProbeFails_Silent(t *testing.T) { + rt := &fakeRT{ + probeHandler: func(req *http.Request) (*http.Response, error) { + return jsonResp(500, `server error`), nil + }, + } + f, errBuf := fakeFactory(t, rt) + err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu) + if rt.probeCalls != 1 { + t.Errorf("probe should be called once, got %d", rt.probeCalls) + } + assertSilent(t, err, errBuf) +} + +func TestRunProbe_TATSuccess_ProbeOK_Silent(t *testing.T) { + rt := &fakeRT{} + f, errBuf := fakeFactory(t, rt) + err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu) + if rt.tatCalls != 1 || rt.probeCalls != 1 { + t.Errorf("expected 1/1 calls, got tat=%d probe=%d", rt.tatCalls, rt.probeCalls) + } + assertSilent(t, err, errBuf) +} + +func TestRunProbe_ProbeRequestShape(t *testing.T) { + rt := &fakeRT{} + f, _ := fakeFactory(t, rt) + if err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if rt.probeReq == nil { + t.Fatal("probe request not captured") + } + if rt.probeReq.Method != http.MethodPost { + t.Errorf("probe method = %s, want POST", rt.probeReq.Method) + } + if got := rt.probeReq.URL.String(); got != "https://open.feishu.cn/open-apis/application/v6/larksuite_cli_app/probe" { + t.Errorf("probe URL = %s", got) + } + if got := rt.probeReq.Header.Get("Authorization"); got != "Bearer t-ok" { + t.Errorf("Authorization = %q, want Bearer t-ok", got) + } + if !strings.Contains(rt.probeBody, `"from":"lark-cli/`+build.Version+`"`) { + t.Errorf("probe body missing from field: %s", rt.probeBody) + } +} + +func TestRunProbe_LarkBrand_HostRoutedCorrectly(t *testing.T) { + rt := &fakeRT{} + f, _ := fakeFactory(t, rt) + if err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandLark); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rt.probeReq == nil { + t.Fatal("probe request not captured") + } + if !strings.Contains(rt.probeReq.URL.Host, "larksuite.com") { + t.Errorf("probe host = %s, want larksuite.com", rt.probeReq.URL.Host) + } +} + +func TestRunProbe_HTTPClientError_Silent(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + f, _, errBuf, _ := cmdutil.TestFactory(t, nil) + f.HttpClient = func() (*http.Client, error) { + return nil, errors.New("client init failed") + } + assertSilent(t, runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu), errBuf) +} + +func TestRunProbe_TimeoutHonored(t *testing.T) { + rt := &fakeRT{ + tatHandler: func(req *http.Request) (*http.Response, error) { + // Block until context is canceled. + <-req.Context().Done() + return nil, req.Context().Err() + }, + } + f, errBuf := fakeFactory(t, rt) + + start := time.Now() + err := runProbe(context.Background(), f, "cli_x", "secret_y", core.BrandFeishu) + elapsed := time.Since(start) + + if elapsed > 4*time.Second { + t.Errorf("runProbe took %v, expected <= ~3s", elapsed) + } + // A timeout is an ambiguous failure (context deadline → not an + // AuthenticationError), so it must stay silent and not block. + assertSilent(t, err, errBuf) +} diff --git a/internal/credential/default_provider.go b/internal/credential/default_provider.go index 45ec3ef00..4b4050b6d 100644 --- a/internal/credential/default_provider.go +++ b/internal/credential/default_provider.go @@ -4,9 +4,7 @@ package credential import ( - "bytes" "context" - "encoding/json" "fmt" "io" "net/http" @@ -135,42 +133,9 @@ func (p *DefaultTokenProvider) doResolveTAT(ctx context.Context) (*TokenResult, if err != nil { return nil, err } - ep := core.ResolveEndpoints(acct.Brand) - url := ep.Open + "/open-apis/auth/v3/tenant_access_token/internal" - - body, err := json.Marshal(map[string]string{ - "app_id": acct.AppID, - "app_secret": acct.AppSecret, - }) - if err != nil { - return nil, fmt.Errorf("failed to marshal TAT request: %w", err) - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - - resp, err := httpClient.Do(req) + token, err := FetchTAT(ctx, httpClient, acct.Brand, acct.AppID, acct.AppSecret) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("TAT API returned HTTP %d", resp.StatusCode) - } - - var result struct { - Code int `json:"code"` - Msg string `json:"msg"` - TenantAccessToken string `json:"tenant_access_token"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to parse TAT response: %w", err) - } - if result.Code != 0 { - return nil, fmt.Errorf("TAT API error: [%d] %s", result.Code, result.Msg) - } - return &TokenResult{Token: result.TenantAccessToken}, nil + return &TokenResult{Token: token}, nil } diff --git a/internal/credential/tat_fetch.go b/internal/credential/tat_fetch.go new file mode 100644 index 000000000..a3634b67a --- /dev/null +++ b/internal/credential/tat_fetch.go @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/core" +) + +// FetchTAT performs a single HTTP POST to obtain a tenant access token using +// the provided credentials. It does not read configuration or keychain. +// +// On failure it returns a typed error from the github.com/larksuite/cli/errs +// taxonomy. The caller is responsible for context timeouts. +// +// Error classification: +// +// - HTTP 401/403, or non-zero body code → *errs.AuthenticationError, with +// Problem.Code carrying the upstream signal (HTTP status when the request +// was rejected pre-body; body "code" when the response decoded successfully). +// - HTTP 5xx → *errs.NetworkError with CauseKind "5xx" and Retryable=true. +// - HTTP 4xx other than 401/403 (or any other non-2xx) → *errs.InternalError +// with subtype "sdk_error" — these indicate caller-side request shape +// problems, not transport failures. +// - Transport / dial / TLS / DNS failure → *errs.NetworkError with the +// underlying error wrapped via Cause. +// - JSON encode/decode failure, or 2xx with empty tenant_access_token → +// *errs.InternalError with subtype "sdk_error". +func FetchTAT( + ctx context.Context, + httpClient *http.Client, + brand core.LarkBrand, + appID, appSecret string, +) (string, error) { + ep := core.ResolveEndpoints(brand) + url := ep.Open + "/open-apis/auth/v3/tenant_access_token/internal" + + body, err := json.Marshal(map[string]string{ + "app_id": appID, + "app_secret": appSecret, + }) + if err != nil { + return "", &errs.InternalError{ + Problem: errs.Problem{ + Category: errs.CategoryInternal, + Subtype: errs.SubtypeSDKError, + Message: fmt.Sprintf("marshal TAT request: %v", err), + }, + Cause: err, + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", &errs.InternalError{ + Problem: errs.Problem{ + Category: errs.CategoryInternal, + Subtype: errs.SubtypeSDKError, + Message: fmt.Sprintf("build TAT request: %v", err), + }, + Cause: err, + } + } + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return "", &errs.NetworkError{ + Problem: errs.Problem{ + Category: errs.CategoryNetwork, + Subtype: errs.SubtypeNetworkTransport, + Message: fmt.Sprintf("TAT request transport error: %v", err), + Retryable: true, + }, + Cause: err, + } + } + defer resp.Body.Close() + + // HTTP 401/403 are unambiguous credential-rejection signals at the TAT + // endpoint — surface them as authentication errors so callers can + // distinguish them from generic network failures. + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return "", &errs.AuthenticationError{ + Problem: errs.Problem{ + Category: errs.CategoryAuthentication, + Subtype: errs.SubtypeTokenInvalid, + Code: resp.StatusCode, + Message: fmt.Sprintf("TAT endpoint rejected credentials with HTTP %d", resp.StatusCode), + }, + } + } + // 5xx is upstream/server side — transport-class with retry hint. + if resp.StatusCode >= 500 { + return "", &errs.NetworkError{ + Problem: errs.Problem{ + Category: errs.CategoryNetwork, + Subtype: errs.SubtypeNetworkTransport, + Message: fmt.Sprintf("TAT endpoint returned HTTP %d", resp.StatusCode), + Retryable: true, + }, + CauseKind: "5xx", + } + } + // Any other non-2xx (4xx other than 401/403, 1xx, 3xx without redirect + // handling) indicates a client-side request shape problem — treat as an + // internal error rather than overloading NetworkError, whose taxonomy is + // reserved for transport-layer failures (timeout / DNS / TLS / 5xx). + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", &errs.InternalError{ + Problem: errs.Problem{ + Category: errs.CategoryInternal, + Subtype: errs.SubtypeSDKError, + Message: fmt.Sprintf("TAT endpoint returned unexpected HTTP %d", resp.StatusCode), + }, + } + } + + var parsed struct { + Code int `json:"code"` + Msg string `json:"msg"` + TenantAccessToken string `json:"tenant_access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return "", &errs.InternalError{ + Problem: errs.Problem{ + Category: errs.CategoryInternal, + Subtype: errs.SubtypeSDKError, + Message: fmt.Sprintf("parse TAT response: %v", err), + }, + Cause: err, + } + } + if parsed.Code != 0 { + return "", &errs.AuthenticationError{ + Problem: errs.Problem{ + Category: errs.CategoryAuthentication, + Subtype: errs.SubtypeTokenInvalid, + Code: parsed.Code, + Message: fmt.Sprintf("TAT API rejected credentials: [%d] %s", parsed.Code, parsed.Msg), + }, + } + } + // Defensive: a 2xx response with code=0 but no token is a contract + // violation by upstream — surface it as an internal SDK error so callers + // fail fast rather than receiving a silent empty token. + if parsed.TenantAccessToken == "" { + return "", &errs.InternalError{ + Problem: errs.Problem{ + Category: errs.CategoryInternal, + Subtype: errs.SubtypeSDKError, + Message: "TAT response missing tenant_access_token despite code=0", + }, + } + } + return parsed.TenantAccessToken, nil +} diff --git a/internal/credential/tat_fetch_test.go b/internal/credential/tat_fetch_test.go new file mode 100644 index 000000000..3517620bc --- /dev/null +++ b/internal/credential/tat_fetch_test.go @@ -0,0 +1,303 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/core" +) + +// stubRoundTripper lets us assert request shape and return canned responses. +type stubRoundTripper struct { + gotReq *http.Request + gotBody string + respCode int + respBody string + err error +} + +func (s *stubRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + s.gotReq = req + if req.Body != nil { + b, _ := io.ReadAll(req.Body) + s.gotBody = string(b) + } + if s.err != nil { + return nil, s.err + } + return &http.Response{ + StatusCode: s.respCode, + Body: io.NopCloser(strings.NewReader(s.respBody)), + Header: make(http.Header), + }, nil +} + +func TestFetchTAT_Success(t *testing.T) { + rt := &stubRoundTripper{ + respCode: 200, + respBody: `{"code":0,"tenant_access_token":"t-abc","msg":"ok"}`, + } + hc := &http.Client{Transport: rt} + + token, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "t-abc" { + t.Errorf("token = %q, want t-abc", token) + } + if rt.gotReq.URL.String() != "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" { + t.Errorf("url = %s", rt.gotReq.URL.String()) + } + if !strings.Contains(rt.gotBody, `"app_id":"cli_app"`) || !strings.Contains(rt.gotBody, `"app_secret":"secret_x"`) { + t.Errorf("request body missing credentials: %s", rt.gotBody) + } +} + +func TestFetchTAT_BodyCodeNonZero_AuthenticationError(t *testing.T) { + rt := &stubRoundTripper{ + respCode: 200, + respBody: `{"code":10003,"msg":"invalid app_id"}`, + } + hc := &http.Client{Transport: rt} + + token, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatal("expected error for code 10003") + } + if token != "" { + t.Errorf("token = %q, want empty", token) + } + var authErr *errs.AuthenticationError + if !errors.As(err, &authErr) { + t.Fatalf("error not *errs.AuthenticationError: %T %v", err, err) + } + if authErr.Code != 10003 { + t.Errorf("Code = %d, want 10003", authErr.Code) + } + if authErr.Category != errs.CategoryAuthentication { + t.Errorf("Category = %q, want %q", authErr.Category, errs.CategoryAuthentication) + } + if authErr.Subtype != errs.SubtypeTokenInvalid { + t.Errorf("Subtype = %q, want %q", authErr.Subtype, errs.SubtypeTokenInvalid) + } +} + +func TestFetchTAT_HTTPUnauthorized_AuthenticationError(t *testing.T) { + rt := &stubRoundTripper{respCode: 401, respBody: `unauthorized`} + hc := &http.Client{Transport: rt} + + token, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatal("expected error for HTTP 401") + } + if token != "" { + t.Errorf("token = %q, want empty", token) + } + var authErr *errs.AuthenticationError + if !errors.As(err, &authErr) { + t.Fatalf("error not *errs.AuthenticationError: %T %v", err, err) + } + if authErr.Code != 401 { + t.Errorf("Code = %d, want 401 (HTTP status preserved)", authErr.Code) + } +} + +func TestFetchTAT_HTTPForbidden_AuthenticationError(t *testing.T) { + rt := &stubRoundTripper{respCode: 403, respBody: `forbidden`} + hc := &http.Client{Transport: rt} + + _, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + var authErr *errs.AuthenticationError + if !errors.As(err, &authErr) || authErr.Code != 403 { + t.Fatalf("want *errs.AuthenticationError Code=403, got %T %v", err, err) + } +} + +func TestFetchTAT_HTTP5xx_NetworkError(t *testing.T) { + rt := &stubRoundTripper{respCode: 503, respBody: `service unavailable`} + hc := &http.Client{Transport: rt} + + _, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + var netErr *errs.NetworkError + if !errors.As(err, &netErr) { + t.Fatalf("want *errs.NetworkError, got %T %v", err, err) + } + if !netErr.Retryable { + t.Errorf("5xx should be retryable") + } + if netErr.CauseKind != "5xx" { + t.Errorf("CauseKind = %q, want \"5xx\"", netErr.CauseKind) + } +} + +// Non-401/403 4xx must NOT be classified as NetworkError. NetworkError's +// taxonomy is reserved for transport-layer failures; a client-side HTTP +// shape problem (e.g. 400 / 404 / 429) is the caller's responsibility and +// belongs in InternalError. +func TestFetchTAT_HTTP4xxOther_InternalError(t *testing.T) { + tests := []int{ + http.StatusBadRequest, // 400 + http.StatusNotFound, // 404 + http.StatusTooManyRequests, // 429 + http.StatusUnsupportedMediaType, // 415 + } + for _, code := range tests { + t.Run(fmt.Sprintf("HTTP %d", code), func(t *testing.T) { + rt := &stubRoundTripper{respCode: code, respBody: `whatever`} + hc := &http.Client{Transport: rt} + + _, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatalf("expected error for HTTP %d", code) + } + var intErr *errs.InternalError + if !errors.As(err, &intErr) { + t.Fatalf("want *errs.InternalError, got %T %v", err, err) + } + var netErr *errs.NetworkError + if errors.As(err, &netErr) { + t.Errorf("HTTP %d must not be classified as NetworkError", code) + } + }) + } +} + +func TestFetchTAT_TransportError_NetworkError(t *testing.T) { + sentinel := errors.New("network down") + rt := &stubRoundTripper{err: sentinel} + hc := &http.Client{Transport: rt} + + _, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatal("expected error") + } + var netErr *errs.NetworkError + if !errors.As(err, &netErr) { + t.Fatalf("want *errs.NetworkError, got %T %v", err, err) + } + // Underlying transport error must still be reachable for diagnostics. + if !errors.Is(err, sentinel) { + t.Errorf("error chain does not include sentinel: %v", err) + } +} + +// 2xx with code=0 but empty tenant_access_token is treated as a contract +// violation by upstream; FetchTAT must fail fast rather than return a +// silent empty token that callers cannot debug. +func TestFetchTAT_EmptyTokenOnSuccess_InternalError(t *testing.T) { + rt := &stubRoundTripper{ + respCode: 200, + respBody: `{"code":0,"msg":"ok"}`, // tenant_access_token field absent + } + hc := &http.Client{Transport: rt} + + token, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatal("expected error when tenant_access_token is empty") + } + if token != "" { + t.Errorf("token = %q, want empty", token) + } + var intErr *errs.InternalError + if !errors.As(err, &intErr) { + t.Fatalf("want *errs.InternalError, got %T %v", err, err) + } + if intErr.Subtype != errs.SubtypeSDKError { + t.Errorf("Subtype = %q, want %q", intErr.Subtype, errs.SubtypeSDKError) + } +} + +func TestFetchTAT_BodyParseError_InternalError(t *testing.T) { + rt := &stubRoundTripper{respCode: 200, respBody: `not json`} + hc := &http.Client{Transport: rt} + + _, err := FetchTAT(context.Background(), hc, core.BrandFeishu, "cli_app", "secret_x") + if err == nil { + t.Fatal("expected parse error") + } + var intErr *errs.InternalError + if !errors.As(err, &intErr) { + t.Fatalf("want *errs.InternalError, got %T %v", err, err) + } + if intErr.Subtype != errs.SubtypeSDKError { + t.Errorf("Subtype = %q, want %q", intErr.Subtype, errs.SubtypeSDKError) + } +} + +func TestFetchTAT_BrandRouting(t *testing.T) { + tests := []struct { + brand core.LarkBrand + wantURL string + }{ + {core.BrandFeishu, "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"}, + {core.BrandLark, "https://open.larksuite.com/open-apis/auth/v3/tenant_access_token/internal"}, + } + for _, tc := range tests { + t.Run(string(tc.brand), func(t *testing.T) { + rt := &stubRoundTripper{ + respCode: 200, + respBody: `{"code":0,"tenant_access_token":"t"}`, + } + hc := &http.Client{Transport: rt} + if _, err := FetchTAT(context.Background(), hc, tc.brand, "a", "b"); err != nil { + t.Fatal(err) + } + if got := rt.gotReq.URL.String(); got != tc.wantURL { + t.Errorf("url = %s, want %s", got, tc.wantURL) + } + }) + } +} + +func TestFetchTAT_ContextCanceled(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // will not respond before cancel + <-r.Context().Done() + })) + defer srv.Close() + + // Point at the test server by using a stub that rewrites URL. + rt := &urlRewriteRT{base: srv.URL} + hc := &http.Client{Transport: rt} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // pre-canceled + + _, err := FetchTAT(ctx, hc, core.BrandFeishu, "a", "b") + if err == nil { + t.Fatal("expected error for canceled context") + } + // Canceled context surfaces as a transport-class NetworkError. + var netErr *errs.NetworkError + if !errors.As(err, &netErr) { + t.Fatalf("want *errs.NetworkError, got %T %v", err, err) + } + if !errors.Is(err, context.Canceled) { + t.Errorf("error chain missing context.Canceled: %v", err) + } +} + +// urlRewriteRT forwards requests to a fixed base URL (test server). +type urlRewriteRT struct{ base string } + +func (r *urlRewriteRT) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite to test server origin. + newURL := r.base + req.URL.Path + req2, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) + if err != nil { + return nil, err + } + req2.Header = req.Header + return http.DefaultTransport.RoundTrip(req2) +}