diff --git a/internal/account/account_test.go b/internal/account/account_test.go new file mode 100644 index 0000000..9c66019 --- /dev/null +++ b/internal/account/account_test.go @@ -0,0 +1,178 @@ +package account_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/krakenkey/cli/internal/account" + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/output" +) + +func newTestClient(baseURL string) *api.Client { + return api.NewClient(baseURL, "kk_test", "v0.0.0", "linux", "amd64") +} + +func newPrinter() (*output.Printer, *bytes.Buffer, *bytes.Buffer) { + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + return output.NewWithWriters("text", true, out, errOut), out, errOut +} + +func TestRunShow_Success(t *testing.T) { + now := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC) + profile := api.UserProfile{ + ID: "u1", + Username: "alice", + Email: "alice@example.com", + DisplayName: "Alice", + Plan: "pro", + CreatedAt: now, + ResourceCounts: api.ResourceCounts{ + Domains: 3, + Certificates: 7, + APIKeys: 2, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(profile) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := account.RunShow(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + + got := out.String() + for _, want := range []string{"u1", "alice", "alice@example.com", "Alice", "pro", "3", "7", "2", "2026"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunShow_AuthError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(api.APIError{StatusCode: 401, Message: "Unauthorized"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := account.RunShow(context.Background(), client, printer) + if _, ok := err.(*api.ErrAuth); !ok { + t.Errorf("err type = %T, want *api.ErrAuth", err) + } +} + +func TestRunPlan_ProPlan(t *testing.T) { + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + sub := api.Subscription{ + Plan: "pro", + Status: "active", + CurrentPeriodEnd: &end, + CancelAtPeriodEnd: false, + CreatedAt: now, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(sub) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := account.RunPlan(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunPlan: %v", err) + } + + got := out.String() + for _, want := range []string{"pro", "active", "2026-02-01", "Subscribed"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunPlan_FreeTier_ZeroCreatedAt(t *testing.T) { + sub := api.Subscription{ + Plan: "free", + Status: "active", + // No CurrentPeriodEnd, no CreatedAt (zero value) for free tier. + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(sub) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := account.RunPlan(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunPlan: %v", err) + } + + got := out.String() + if !strings.Contains(got, "free") { + t.Errorf("output missing 'free':\n%s", got) + } + // Should NOT print "Subscribed:" line when createdAt is zero. + if strings.Contains(got, "Subscribed") { + t.Errorf("output should not contain 'Subscribed' for free tier:\n%s", got) + } + // Should NOT print "Current period ends:" when nil. + if strings.Contains(got, "period ends") { + t.Errorf("output should not contain period end for free tier:\n%s", got) + } +} + +func TestRunPlan_CancelAtPeriodEnd(t *testing.T) { + end := time.Date(2026, 4, 1, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + sub := api.Subscription{ + Plan: "pro", + Status: "active", + CurrentPeriodEnd: &end, + CancelAtPeriodEnd: true, + CreatedAt: now, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(sub) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := account.RunPlan(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunPlan: %v", err) + } + + if !strings.Contains(out.String(), "cancel") { + t.Errorf("output should mention cancellation:\n%s", out.String()) + } +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..851bfa9 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,256 @@ +package auth_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/auth" + "github.com/krakenkey/cli/internal/config" + "github.com/krakenkey/cli/internal/output" +) + +func newTestClient(baseURL string) *api.Client { + return api.NewClient(baseURL, "kk_test", "v0.0.0", "linux", "amd64") +} + +func newPrinter() (*output.Printer, *bytes.Buffer, *bytes.Buffer) { + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + return output.NewWithWriters("text", true, out, errOut), out, errOut +} + +func TestRunLogin_Success(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/auth/profile" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.UserProfile{ + ID: "u1", + DisplayName: "Alice", + Email: "alice@example.com", + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunLogin(context.Background(), client, printer, "kk_test_key") + if err != nil { + t.Fatalf("RunLogin: %v", err) + } + if got := out.String(); !contains(got, "Alice") || !contains(got, "alice@example.com") { + t.Errorf("output = %q, want to contain user info", got) + } + + // Verify key was saved to config. + t.Setenv("KK_API_KEY", "") + t.Setenv("KK_API_URL", "") + t.Setenv("KK_OUTPUT", "") + cfg, err := config.Load(config.Flags{}) + if err != nil { + t.Fatalf("config.Load: %v", err) + } + if cfg.APIKey != "kk_test_key" { + t.Errorf("saved API key = %q, want kk_test_key", cfg.APIKey) + } +} + +func TestRunLogin_AuthError(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(api.APIError{StatusCode: 401, Message: "Unauthorized"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := auth.RunLogin(context.Background(), client, printer, "kk_bad_key") + if _, ok := err.(*api.ErrAuth); !ok { + t.Errorf("err type = %T, want *api.ErrAuth", err) + } +} + +func TestRunLogout_Success(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + + // Save a key first so there's something to remove. + if err := config.Save("", "kk_to_remove", ""); err != nil { + t.Fatalf("config.Save: %v", err) + } + + printer, out, _ := newPrinter() + err := auth.RunLogout(printer) + if err != nil { + t.Fatalf("RunLogout: %v", err) + } + if !contains(out.String(), "Logged out") { + t.Errorf("output = %q, want to contain 'Logged out'", out.String()) + } + + // Verify key was removed. + t.Setenv("KK_API_KEY", "") + t.Setenv("KK_API_URL", "") + t.Setenv("KK_OUTPUT", "") + cfg, _ := config.Load(config.Flags{}) + if cfg.APIKey != "" { + t.Errorf("API key still present after logout: %q", cfg.APIKey) + } +} + +func TestRunStatus_Success(t *testing.T) { + profile := api.UserProfile{ + ID: "u1", + DisplayName: "Bob", + Email: "bob@example.com", + Plan: "pro", + ResourceCounts: api.ResourceCounts{ + Domains: 3, + Certificates: 5, + APIKeys: 2, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(profile) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunStatus(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunStatus: %v", err) + } + + got := out.String() + for _, want := range []string{"Bob", "bob@example.com", "pro", "3", "5", "2"} { + if !contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunKeysList_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]api.APIKey{}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunKeysList(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunKeysList: %v", err) + } + if !contains(out.String(), "No API keys") { + t.Errorf("output = %q, want 'No API keys' message", out.String()) + } +} + +func TestRunKeysList_WithKeys(t *testing.T) { + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + exp := now.Add(30 * 24 * time.Hour) + keys := []api.APIKey{ + {ID: "k1", Name: "dev-key", CreatedAt: now, ExpiresAt: &exp}, + {ID: "k2", Name: "ci-key", CreatedAt: now, ExpiresAt: nil}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(keys) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunKeysList(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunKeysList: %v", err) + } + + got := out.String() + for _, want := range []string{"dev-key", "ci-key", "k1", "k2", "never"} { + if !contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunKeysCreate_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "my-key" { + t.Errorf("request body name = %v, want my-key", body["name"]) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CreateAPIKeyResponse{ + ID: "k1", + Name: "my-key", + APIKey: "kk_secret123", + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunKeysCreate(context.Background(), client, printer, "my-key", nil) + if err != nil { + t.Fatalf("RunKeysCreate: %v", err) + } + + got := out.String() + if !contains(got, "kk_secret123") { + t.Errorf("output missing API key secret:\n%s", got) + } + if !contains(got, "k1") { + t.Errorf("output missing key ID:\n%s", got) + } +} + +func TestRunKeysDelete_Success(t *testing.T) { + var gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := auth.RunKeysDelete(context.Background(), client, printer, "k1") + if err != nil { + t.Fatalf("RunKeysDelete: %v", err) + } + if gotPath != "/auth/api-keys/k1" { + t.Errorf("path = %q, want /auth/api-keys/k1", gotPath) + } + if !contains(out.String(), "deleted") { + t.Errorf("output = %q, want 'deleted'", out.String()) + } +} + +func contains(s, substr string) bool { + return len(s) > 0 && len(substr) > 0 && bytes.Contains([]byte(s), []byte(substr)) +} diff --git a/internal/cert/cert_test.go b/internal/cert/cert_test.go new file mode 100644 index 0000000..16bbf01 --- /dev/null +++ b/internal/cert/cert_test.go @@ -0,0 +1,538 @@ +package cert_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/cert" + "github.com/krakenkey/cli/internal/output" +) + +func newTestClient(baseURL string) *api.Client { + return api.NewClient(baseURL, "kk_test", "v0.0.0", "linux", "amd64") +} + +func newPrinter() (*output.Printer, *bytes.Buffer, *bytes.Buffer) { + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + return output.NewWithWriters("text", true, out, errOut), out, errOut +} + +func TestRunList_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]api.TlsCert{}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunList(context.Background(), client, printer, "") + if err != nil { + t.Fatalf("RunList: %v", err) + } + if !strings.Contains(out.String(), "No certificates") { + t.Errorf("output = %q, want 'No certificates' message", out.String()) + } +} + +func TestRunList_WithCerts(t *testing.T) { + now := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + exp := now.Add(90 * 24 * time.Hour) + certs := []api.TlsCert{ + { + ID: 1, + Status: "issued", + ExpiresAt: &exp, + AutoRenew: true, + CreatedAt: now, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "example.com"}}, + }, + }, + { + ID: 2, + Status: "pending", + AutoRenew: false, + CreatedAt: now, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "test.io"}}, + }, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(certs) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunList(context.Background(), client, printer, "") + if err != nil { + t.Fatalf("RunList: %v", err) + } + + got := out.String() + for _, want := range []string{"example.com", "test.io", "issued", "pending", "yes", "no"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunList_StatusFilter(t *testing.T) { + var gotQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotQuery = r.URL.RawQuery + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]api.TlsCert{}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + _ = cert.RunList(context.Background(), client, printer, "issued") + if gotQuery != "status=issued" { + t.Errorf("query = %q, want status=issued", gotQuery) + } +} + +func TestRunShow_Success(t *testing.T) { + now := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + exp := now.Add(90 * 24 * time.Hour) + tlsCert := api.TlsCert{ + ID: 42, + Status: "issued", + ExpiresAt: &exp, + AutoRenew: true, + CreatedAt: now, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "example.com"}}, + PublicKey: &api.CsrPublicKey{KeyType: "ECDSA", BitLength: 256}, + }, + } + details := api.TlsCertDetails{ + SerialNumber: "ABCDEF", + Issuer: "Let's Encrypt", + ValidFrom: now, + ValidTo: exp, + Fingerprint: "AA:BB:CC", + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.URL.Path == "/certs/tls/42/details" { + json.NewEncoder(w).Encode(details) + return + } + json.NewEncoder(w).Encode(tlsCert) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunShow(context.Background(), client, printer, 42) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + + got := out.String() + for _, want := range []string{"42", "issued", "example.com", "ECDSA 256", "ABCDEF", "Let's Encrypt", "AA:BB:CC"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunDownload_Success(t *testing.T) { + dir := t.TempDir() + outPath := filepath.Join(dir, "test.crt") + + certPem := "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----\n" + tlsCert := api.TlsCert{ + ID: 1, + Status: "issued", + CrtPem: certPem, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "example.com"}}, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tlsCert) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunDownload(context.Background(), client, printer, 1, outPath) + if err != nil { + t.Fatalf("RunDownload: %v", err) + } + + data, err := os.ReadFile(outPath) + if err != nil { + t.Fatalf("read output file: %v", err) + } + if string(data) != certPem { + t.Errorf("file content = %q, want PEM", string(data)) + } + if !strings.Contains(out.String(), "saved") { + t.Errorf("output = %q, want 'saved'", out.String()) + } +} + +func TestRunDownload_NotIssued(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: "pending"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunDownload(context.Background(), client, printer, 1, "") + if err == nil { + t.Fatal("expected error for non-issued cert") + } + if !strings.Contains(err.Error(), "not issued") { + t.Errorf("error = %q, want to contain 'not issued'", err.Error()) + } +} + +func TestRunDownload_DefaultFilename(t *testing.T) { + dir := t.TempDir() + // Change to temp dir so default filename writes there. + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir: %v", err) + } + defer func() { _ = os.Chdir(origDir) }() + + certPem := "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n" + tlsCert := api.TlsCert{ + ID: 1, + Status: "issued", + CrtPem: certPem, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "mysite.com"}}, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tlsCert) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err = cert.RunDownload(context.Background(), client, printer, 1, "") + if err != nil { + t.Fatalf("RunDownload: %v", err) + } + + if _, err := os.Stat(filepath.Join(dir, "mysite.com.crt")); err != nil { + t.Errorf("expected default output file mysite.com.crt: %v", err) + } +} + +func TestRunRenew_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method = %q, want POST", r.Method) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CertResponse{ID: 5, Status: "renewing"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunRenew(context.Background(), client, printer, 5, false, 0, 0) + if err != nil { + t.Fatalf("RunRenew: %v", err) + } + if !strings.Contains(out.String(), "Renewal triggered") { + t.Errorf("output = %q, want 'Renewal triggered'", out.String()) + } +} + +func TestRunRevoke_Success(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&gotBody) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CertResponse{ID: 3, Status: "revoking"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + reason := 1 + err := cert.RunRevoke(context.Background(), client, printer, 3, &reason) + if err != nil { + t.Fatalf("RunRevoke: %v", err) + } + if !strings.Contains(out.String(), "revocation") { + t.Errorf("output = %q, want 'revocation'", out.String()) + } + if gotBody["reason"] != float64(1) { + t.Errorf("request body reason = %v, want 1", gotBody["reason"]) + } +} + +func TestRunRevoke_NoReason(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CertResponse{ID: 3, Status: "revoking"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunRevoke(context.Background(), client, printer, 3, nil) + if err != nil { + t.Fatalf("RunRevoke with nil reason: %v", err) + } +} + +func TestRunRetry_Success(t *testing.T) { + var gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CertResponse{ID: 7, Status: "pending"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunRetry(context.Background(), client, printer, 7, false, 0, 0) + if err != nil { + t.Fatalf("RunRetry: %v", err) + } + if gotPath != "/certs/tls/7/retry" { + t.Errorf("path = %q, want /certs/tls/7/retry", gotPath) + } + if !strings.Contains(out.String(), "Retry triggered") { + t.Errorf("output = %q, want 'Retry triggered'", out.String()) + } +} + +func TestRunDelete_Success(t *testing.T) { + var gotMethod, gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunDelete(context.Background(), client, printer, 10) + if err != nil { + t.Fatalf("RunDelete: %v", err) + } + if gotMethod != http.MethodDelete { + t.Errorf("method = %q, want DELETE", gotMethod) + } + if gotPath != "/certs/tls/10" { + t.Errorf("path = %q, want /certs/tls/10", gotPath) + } + if !strings.Contains(out.String(), "deleted") { + t.Errorf("output = %q, want 'deleted'", out.String()) + } +} + +func TestRunUpdate_AutoRenew(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ + ID: 4, + AutoRenew: body["autoRenew"].(bool), + Status: "issued", + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + enable := true + err := cert.RunUpdate(context.Background(), client, printer, 4, &enable) + if err != nil { + t.Fatalf("RunUpdate: %v", err) + } + + got := out.String() + if !strings.Contains(got, "updated") { + t.Errorf("output missing 'updated':\n%s", got) + } + if !strings.Contains(got, "true") { + t.Errorf("output missing auto-renew value:\n%s", got) + } +} + +func TestPollUntilDone_ImmediateSuccess(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: "issued"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + c, err := cert.PollUntilDone(context.Background(), client, printer, 1, 50*time.Millisecond, 2*time.Second) + if err != nil { + t.Fatalf("PollUntilDone: %v", err) + } + if c.Status != "issued" { + t.Errorf("status = %q, want issued", c.Status) + } +} + +func TestPollUntilDone_TransitionsToIssued(t *testing.T) { + var callCount atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := callCount.Add(1) + w.Header().Set("Content-Type", "application/json") + status := "pending" + if n >= 3 { + status = "issued" + } + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: status}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + c, err := cert.PollUntilDone(context.Background(), client, printer, 1, 50*time.Millisecond, 5*time.Second) + if err != nil { + t.Fatalf("PollUntilDone: %v", err) + } + if c.Status != "issued" { + t.Errorf("status = %q, want issued", c.Status) + } + if callCount.Load() < 3 { + t.Errorf("expected at least 3 poll calls, got %d", callCount.Load()) + } +} + +func TestPollUntilDone_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: "pending"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + _, err := cert.PollUntilDone(context.Background(), client, printer, 1, 50*time.Millisecond, 200*time.Millisecond) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "timed out") { + t.Errorf("error = %q, want to contain 'timed out'", err.Error()) + } +} + +func TestPollUntilDone_ContextCancelled(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: "pending"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + _, err := cert.PollUntilDone(ctx, client, printer, 1, 50*time.Millisecond, 5*time.Second) + if err == nil { + t.Fatal("expected context cancelled error, got nil") + } +} + +func TestPollUntilDone_Failed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ID: 1, Status: "failed"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + c, err := cert.PollUntilDone(context.Background(), client, printer, 1, 50*time.Millisecond, 2*time.Second) + if err != nil { + t.Fatalf("PollUntilDone: %v", err) + } + // Failed is a terminal status — PollUntilDone returns it, caller decides the error. + if c.Status != "failed" { + t.Errorf("status = %q, want failed", c.Status) + } +} + +func TestCnFromCert_NoParsedCsr(t *testing.T) { + // Use RunShow to indirectly verify cnFromCert returns the ID as fallback. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.TlsCert{ + ID: 99, + Status: "pending", + CreatedAt: time.Now(), + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunShow(context.Background(), client, printer, 99) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + // When parsedCsr is nil, domain should show as the cert ID "99". + if !strings.Contains(out.String(), fmt.Sprintf("Domain: %d", 99)) { + t.Errorf("expected fallback domain '99' in output:\n%s", out.String()) + } +} diff --git a/internal/cert/issue_test.go b/internal/cert/issue_test.go new file mode 100644 index 0000000..fe8fd4f --- /dev/null +++ b/internal/cert/issue_test.go @@ -0,0 +1,224 @@ +package cert_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/cert" +) + +func TestRunIssue_NoWait(t *testing.T) { + dir := t.TempDir() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/certs/tls") { + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + if body["csrPem"] == "" { + t.Error("csrPem is empty in request body") + } + json.NewEncoder(w).Encode(api.CertResponse{ID: 10, Status: "pending"}) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunIssue(context.Background(), client, printer, cert.IssueOptions{ + Domain: "test.example.com", + KeyType: "ecdsa-p256", + KeyOut: filepath.Join(dir, "test.key"), + CSROut: filepath.Join(dir, "test.csr"), + Out: filepath.Join(dir, "test.crt"), + Wait: false, + }) + if err != nil { + t.Fatalf("RunIssue: %v", err) + } + + // Verify private key was written with restricted permissions. + keyInfo, err := os.Stat(filepath.Join(dir, "test.key")) + if err != nil { + t.Fatalf("stat key file: %v", err) + } + if perm := keyInfo.Mode().Perm(); perm != 0o600 { + t.Errorf("key file permissions = %o, want 0600", perm) + } + + // Verify CSR file exists. + if _, err := os.Stat(filepath.Join(dir, "test.csr")); err != nil { + t.Errorf("CSR file not found: %v", err) + } + + // Cert file should NOT exist (no --wait, cert not issued yet). + if _, err := os.Stat(filepath.Join(dir, "test.crt")); !os.IsNotExist(err) { + t.Errorf("cert file should not exist without --wait, err = %v", err) + } + + if !strings.Contains(out.String(), "submitted") { + t.Errorf("output missing 'submitted':\n%s", out.String()) + } +} + +func TestRunIssue_WithWait(t *testing.T) { + dir := t.TempDir() + certPem := "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/certs/tls"): + json.NewEncoder(w).Encode(api.CertResponse{ID: 10, Status: "pending"}) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/certs/tls/10"): + json.NewEncoder(w).Encode(api.TlsCert{ + ID: 10, + Status: "issued", + CrtPem: certPem, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "test.example.com"}}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunIssue(context.Background(), client, printer, cert.IssueOptions{ + Domain: "test.example.com", + KeyType: "ecdsa-p256", + KeyOut: filepath.Join(dir, "test.key"), + CSROut: filepath.Join(dir, "test.csr"), + Out: filepath.Join(dir, "test.crt"), + Wait: true, + PollInterval: 50 * time.Millisecond, + PollTimeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("RunIssue: %v", err) + } + + // Verify cert was written. + data, err := os.ReadFile(filepath.Join(dir, "test.crt")) + if err != nil { + t.Fatalf("read cert file: %v", err) + } + if string(data) != certPem { + t.Errorf("cert content = %q, want PEM", string(data)) + } + + if !strings.Contains(out.String(), "issued") { + t.Errorf("output missing 'issued':\n%s", out.String()) + } +} + +func TestRunIssue_WithAutoRenew(t *testing.T) { + dir := t.TempDir() + var autoRenewSet bool + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/certs/tls"): + json.NewEncoder(w).Encode(api.CertResponse{ID: 10, Status: "pending"}) + case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/certs/tls/10"): + autoRenewSet = true + json.NewEncoder(w).Encode(api.TlsCert{ID: 10, AutoRenew: true}) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunIssue(context.Background(), client, printer, cert.IssueOptions{ + Domain: "example.com", + KeyOut: filepath.Join(dir, "test.key"), + CSROut: filepath.Join(dir, "test.csr"), + Out: filepath.Join(dir, "test.crt"), + AutoRenew: true, + Wait: false, + }) + if err != nil { + t.Fatalf("RunIssue: %v", err) + } + if !autoRenewSet { + t.Error("auto-renew PATCH was not called") + } +} + +func TestRunIssue_DefaultFilenames(t *testing.T) { + dir := t.TempDir() + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir: %v", err) + } + defer func() { _ = os.Chdir(origDir) }() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.CertResponse{ID: 1, Status: "pending"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err = cert.RunIssue(context.Background(), client, printer, cert.IssueOptions{ + Domain: "mysite.io", + Wait: false, + }) + if err != nil { + t.Fatalf("RunIssue: %v", err) + } + + // Default filenames should be .key and .csr. + for _, name := range []string{"mysite.io.key", "mysite.io.csr"} { + if _, err := os.Stat(filepath.Join(dir, name)); err != nil { + t.Errorf("expected default file %s: %v", name, err) + } + } +} + +func TestRunIssue_InvalidKeyType(t *testing.T) { + dir := t.TempDir() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("API should not be called for invalid key type") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunIssue(context.Background(), client, printer, cert.IssueOptions{ + Domain: "example.com", + KeyType: "rsa-1024", + KeyOut: filepath.Join(dir, "test.key"), + CSROut: filepath.Join(dir, "test.csr"), + }) + if err == nil { + t.Fatal("expected error for invalid key type") + } +} diff --git a/internal/cert/submit_test.go b/internal/cert/submit_test.go new file mode 100644 index 0000000..c973aff --- /dev/null +++ b/internal/cert/submit_test.go @@ -0,0 +1,206 @@ +package cert_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/cert" +) + +func writeCSRFile(t *testing.T, dir, name string) string { + t.Helper() + p := filepath.Join(dir, name) + if err := os.WriteFile(p, []byte("-----BEGIN CERTIFICATE REQUEST-----\ntest\n-----END CERTIFICATE REQUEST-----\n"), 0o644); err != nil { + t.Fatalf("write CSR file: %v", err) + } + return p +} + +func TestRunSubmit_NoWait(t *testing.T) { + dir := t.TempDir() + csrPath := filepath.Join(dir, "test.csr") + csrPem := "-----BEGIN CERTIFICATE REQUEST-----\nMIIB...\n-----END CERTIFICATE REQUEST-----\n" + if err := os.WriteFile(csrPath, []byte(csrPem), 0o644); err != nil { + t.Fatalf("write CSR: %v", err) + } + + var gotCSR string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost { + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + gotCSR = body["csrPem"] + json.NewEncoder(w).Encode(api.CertResponse{ID: 20, Status: "pending"}) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunSubmit(context.Background(), client, printer, cert.SubmitOptions{ + CSRPath: csrPath, + Wait: false, + }) + if err != nil { + t.Fatalf("RunSubmit: %v", err) + } + + if gotCSR != csrPem { + t.Errorf("submitted CSR = %q, want original PEM", gotCSR) + } + if !strings.Contains(out.String(), "submitted") { + t.Errorf("output missing 'submitted':\n%s", out.String()) + } +} + +func TestRunSubmit_WithWait(t *testing.T) { + dir := t.TempDir() + csrPath := writeCSRFile(t, dir, "test.csr") + + certPem := "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + json.NewEncoder(w).Encode(api.CertResponse{ID: 20, Status: "pending"}) + case http.MethodGet: + json.NewEncoder(w).Encode(api.TlsCert{ + ID: 20, + Status: "issued", + CrtPem: certPem, + ParsedCsr: &api.ParsedCsr{ + Subject: []api.CsrSubjectField{{Name: "commonName", Value: "submit.example.com"}}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := cert.RunSubmit(context.Background(), client, printer, cert.SubmitOptions{ + CSRPath: csrPath, + Out: filepath.Join(dir, "result.crt"), + Wait: true, + PollInterval: 50 * time.Millisecond, + PollTimeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("RunSubmit: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "result.crt")) + if err != nil { + t.Fatalf("read cert: %v", err) + } + if string(data) != certPem { + t.Errorf("cert content = %q, want PEM", string(data)) + } + if !strings.Contains(out.String(), "issued") { + t.Errorf("output missing 'issued':\n%s", out.String()) + } +} + +func TestRunSubmit_FileNotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("API should not be called when CSR file is missing") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunSubmit(context.Background(), client, printer, cert.SubmitOptions{ + CSRPath: "/nonexistent/path/test.csr", + }) + if err == nil { + t.Fatal("expected error for missing CSR file") + } + if !strings.Contains(err.Error(), "read CSR file") { + t.Errorf("error = %q, want to contain 'read CSR file'", err.Error()) + } +} + +func TestRunSubmit_WithAutoRenew(t *testing.T) { + dir := t.TempDir() + csrPath := writeCSRFile(t, dir, "test.csr") + + var autoRenewSet bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + json.NewEncoder(w).Encode(api.CertResponse{ID: 20, Status: "pending"}) + case http.MethodPatch: + autoRenewSet = true + json.NewEncoder(w).Encode(api.TlsCert{ID: 20, AutoRenew: true}) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunSubmit(context.Background(), client, printer, cert.SubmitOptions{ + CSRPath: csrPath, + AutoRenew: true, + Wait: false, + }) + if err != nil { + t.Fatalf("RunSubmit: %v", err) + } + if !autoRenewSet { + t.Error("auto-renew PATCH was not called") + } +} + +func TestRunSubmit_IssuanceFailed(t *testing.T) { + dir := t.TempDir() + csrPath := writeCSRFile(t, dir, "test.csr") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + json.NewEncoder(w).Encode(api.CertResponse{ID: 20, Status: "pending"}) + case http.MethodGet: + json.NewEncoder(w).Encode(api.TlsCert{ID: 20, Status: "failed"}) + } + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := cert.RunSubmit(context.Background(), client, printer, cert.SubmitOptions{ + CSRPath: csrPath, + Wait: true, + PollInterval: 50 * time.Millisecond, + PollTimeout: 2 * time.Second, + }) + if err == nil { + t.Fatal("expected error for failed issuance") + } + if !strings.Contains(err.Error(), "failed") { + t.Errorf("error = %q, want to contain 'failed'", err.Error()) + } +} diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go new file mode 100644 index 0000000..fc4acad --- /dev/null +++ b/internal/domain/domain_test.go @@ -0,0 +1,254 @@ +package domain_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/domain" + "github.com/krakenkey/cli/internal/output" +) + +func newTestClient(baseURL string) *api.Client { + return api.NewClient(baseURL, "kk_test", "v0.0.0", "linux", "amd64") +} + +func newPrinter() (*output.Printer, *bytes.Buffer, *bytes.Buffer) { + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + return output.NewWithWriters("text", true, out, errOut), out, errOut +} + +func TestRunAdd_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method = %q, want POST", r.Method) + } + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + if body["hostname"] != "example.com" { + t.Errorf("hostname = %q, want example.com", body["hostname"]) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.Domain{ + ID: "d1", + Hostname: "example.com", + VerificationCode: "krakenkey-site-verification=abc123", + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunAdd(context.Background(), client, printer, "example.com") + if err != nil { + t.Fatalf("RunAdd: %v", err) + } + + got := out.String() + for _, want := range []string{"example.com", "krakenkey-site-verification=abc123", "TXT"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunAdd_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(api.APIError{StatusCode: 400, Message: "invalid hostname"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := domain.RunAdd(context.Background(), client, printer, "bad..host") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestRunList_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]api.Domain{}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunList(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunList: %v", err) + } + if !strings.Contains(out.String(), "No domains") { + t.Errorf("output = %q, want 'No domains' message", out.String()) + } +} + +func TestRunList_WithDomains(t *testing.T) { + now := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + domains := []api.Domain{ + {ID: "d1", Hostname: "example.com", IsVerified: true, CreatedAt: now}, + {ID: "d2", Hostname: "test.io", IsVerified: false, CreatedAt: now}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(domains) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunList(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunList: %v", err) + } + + got := out.String() + for _, want := range []string{"example.com", "test.io", "yes", "no"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunShow_Success(t *testing.T) { + now := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + d := api.Domain{ + ID: "d1", + Hostname: "example.com", + IsVerified: true, + VerificationCode: "krakenkey-site-verification=xyz789", + CreatedAt: now, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domains/d1" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(d) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunShow(context.Background(), client, printer, "d1") + if err != nil { + t.Fatalf("RunShow: %v", err) + } + + got := out.String() + for _, want := range []string{"d1", "example.com", "true", "krakenkey-site-verification=xyz789"} { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunVerify_Verified(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domains/d1/verify" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.Domain{ + ID: "d1", + Hostname: "example.com", + IsVerified: true, + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunVerify(context.Background(), client, printer, "d1") + if err != nil { + t.Fatalf("RunVerify: %v", err) + } + if !strings.Contains(out.String(), "verified") { + t.Errorf("output = %q, want 'verified'", out.String()) + } +} + +func TestRunVerify_NotVerified(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.Domain{ + ID: "d1", + Hostname: "example.com", + IsVerified: false, + }) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := domain.RunVerify(context.Background(), client, printer, "d1") + if err == nil { + t.Fatal("expected error for unverified domain, got nil") + } + if !strings.Contains(err.Error(), "verification failed") { + t.Errorf("error = %q, want to contain 'verification failed'", err.Error()) + } +} + +func TestRunDelete_Success(t *testing.T) { + var gotPath, gotMethod string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, out, _ := newPrinter() + + err := domain.RunDelete(context.Background(), client, printer, "d1") + if err != nil { + t.Fatalf("RunDelete: %v", err) + } + if gotMethod != http.MethodDelete { + t.Errorf("method = %q, want DELETE", gotMethod) + } + if gotPath != "/domains/d1" { + t.Errorf("path = %q, want /domains/d1", gotPath) + } + if !strings.Contains(out.String(), "deleted") { + t.Errorf("output = %q, want 'deleted'", out.String()) + } +} + +func TestRunDelete_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(api.APIError{StatusCode: 404, Message: "Not found"}) + })) + defer srv.Close() + + client := newTestClient(srv.URL) + printer, _, _ := newPrinter() + + err := domain.RunDelete(context.Background(), client, printer, "nonexistent") + if _, ok := err.(*api.ErrNotFound); !ok { + t.Errorf("err type = %T, want *api.ErrNotFound", err) + } +} diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go new file mode 100644 index 0000000..aa421b4 --- /dev/null +++ b/internal/integration/integration_test.go @@ -0,0 +1,397 @@ +//go:build integration + +// Package integration contains end-to-end tests that run against a live +// KrakenKey API. They are skipped in normal CI runs. +// +// Usage: +// +// KK_API_URL=https://api-dev.krakenkey.io KK_API_KEY=kk_... go test -tags integration ./internal/integration/ -v +package integration + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/krakenkey/cli/internal/api" + "github.com/krakenkey/cli/internal/auth" + "github.com/krakenkey/cli/internal/account" + "github.com/krakenkey/cli/internal/cert" + "github.com/krakenkey/cli/internal/csr" + "github.com/krakenkey/cli/internal/domain" + "github.com/krakenkey/cli/internal/output" +) + +func envOrSkip(t *testing.T, key string) string { + t.Helper() + v := os.Getenv(key) + if v == "" { + t.Skipf("skipping: %s not set", key) + } + return v +} + +func setup(t *testing.T) (*api.Client, *output.Printer, *bytes.Buffer) { + t.Helper() + apiURL := envOrSkip(t, "KK_API_URL") + apiKey := envOrSkip(t, "KK_API_KEY") + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + printer := output.NewWithWriters("text", true, out, errOut) + client := api.NewClient(apiURL, apiKey, "test", "linux", "amd64") + return client, printer, out +} + +// TestIntegration_AuthStatus verifies that the API key is valid and returns +// profile data. +func TestIntegration_AuthStatus(t *testing.T) { + client, printer, out := setup(t) + + err := auth.RunStatus(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunStatus: %v", err) + } + + got := out.String() + if !strings.Contains(got, "User:") { + t.Errorf("expected user info in output:\n%s", got) + } + if !strings.Contains(got, "Plan:") { + t.Errorf("expected plan info in output:\n%s", got) + } + t.Logf("Auth status output:\n%s", got) +} + +// TestIntegration_AccountShow verifies account profile retrieval. +func TestIntegration_AccountShow(t *testing.T) { + client, printer, out := setup(t) + + err := account.RunShow(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + + got := out.String() + for _, field := range []string{"ID:", "Username:", "Email:", "Plan:"} { + if !strings.Contains(got, field) { + t.Errorf("output missing %q:\n%s", field, got) + } + } + t.Logf("Account show output:\n%s", got) +} + +// TestIntegration_AccountPlan verifies billing subscription retrieval. +func TestIntegration_AccountPlan(t *testing.T) { + client, printer, out := setup(t) + + err := account.RunPlan(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunPlan: %v", err) + } + + got := out.String() + if !strings.Contains(got, "Plan:") || !strings.Contains(got, "Status:") { + t.Errorf("output missing plan/status:\n%s", got) + } + t.Logf("Account plan output:\n%s", got) +} + +// TestIntegration_DomainLifecycle tests add, list, show, and delete for a domain. +// It uses a throwaway subdomain to avoid conflicts. +func TestIntegration_DomainLifecycle(t *testing.T) { + client, printer, out := setup(t) + ctx := context.Background() + + hostname := "cli-test-" + time.Now().Format("20060102150405") + ".example.com" + + // Add domain. + out.Reset() + err := domain.RunAdd(ctx, client, printer, hostname) + if err != nil { + t.Fatalf("RunAdd: %v", err) + } + addOutput := out.String() + if !strings.Contains(addOutput, hostname) { + t.Errorf("RunAdd output missing hostname:\n%s", addOutput) + } + if !strings.Contains(addOutput, "krakenkey-site-verification=") { + t.Errorf("RunAdd output missing verification code:\n%s", addOutput) + } + t.Logf("Domain add output:\n%s", addOutput) + + // List domains — should contain the newly added one. + out.Reset() + err = domain.RunList(ctx, client, printer) + if err != nil { + t.Fatalf("RunList: %v", err) + } + if !strings.Contains(out.String(), hostname) { + t.Errorf("RunList output missing new domain:\n%s", out.String()) + } + + // Get the domain ID from the API for show/delete. + domains, err := client.ListDomains(ctx) + if err != nil { + t.Fatalf("client.ListDomains: %v", err) + } + var domainID string + for _, d := range domains { + if d.Hostname == hostname { + domainID = d.ID + break + } + } + if domainID == "" { + t.Fatalf("could not find domain %s in list", hostname) + } + + // Show domain. + out.Reset() + err = domain.RunShow(ctx, client, printer, domainID) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + if !strings.Contains(out.String(), hostname) { + t.Errorf("RunShow output missing hostname:\n%s", out.String()) + } + + // Delete domain (cleanup). + out.Reset() + err = domain.RunDelete(ctx, client, printer, domainID) + if err != nil { + t.Fatalf("RunDelete: %v", err) + } + if !strings.Contains(out.String(), "deleted") { + t.Errorf("RunDelete output missing confirmation:\n%s", out.String()) + } + + t.Log("Domain lifecycle test passed") +} + +// TestIntegration_CertList verifies certificate listing works. +func TestIntegration_CertList(t *testing.T) { + client, printer, out := setup(t) + + err := cert.RunList(context.Background(), client, printer, "") + if err != nil { + t.Fatalf("RunList: %v", err) + } + // Output should either show a table or "No certificates" — both are valid. + t.Logf("Cert list output:\n%s", out.String()) +} + +// TestIntegration_CertIssueAndCleanup generates a CSR locally, submits it, +// checks status, then deletes. It does NOT wait for issuance (that requires +// a verified domain with real DNS). +func TestIntegration_CertIssueAndCleanup(t *testing.T) { + client, printer, out := setup(t) + ctx := context.Background() + dir := t.TempDir() + + // Generate a CSR locally. + result, err := csr.Generate(csr.KeyTypeECDSAP256, csr.Subject{ + CommonName: "cli-integration-test.example.com", + }, nil) + if err != nil { + t.Fatalf("csr.Generate: %v", err) + } + + // Verify the generated CSR is valid. + block, _ := pem.Decode(result.CSRPem) + if block == nil { + t.Fatal("generated CSR is not valid PEM") + } + if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil { + t.Fatalf("invalid CSR: %v", err) + } + + // Write CSR to file for submit command. + csrPath := filepath.Join(dir, "test.csr") + if err := os.WriteFile(csrPath, result.CSRPem, 0o644); err != nil { + t.Fatalf("write CSR: %v", err) + } + + // Write key to file. + keyPath := filepath.Join(dir, "test.key") + if err := os.WriteFile(keyPath, result.PrivateKeyPem, 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + + // Verify key permissions. + keyInfo, err := os.Stat(keyPath) + if err != nil { + t.Fatalf("stat key: %v", err) + } + if perm := keyInfo.Mode().Perm(); perm != 0o600 { + t.Errorf("key permissions = %o, want 0600", perm) + } + + // Submit the CSR (without waiting). + out.Reset() + err = cert.RunSubmit(ctx, client, printer, cert.SubmitOptions{ + CSRPath: csrPath, + Wait: false, + }) + if err != nil { + t.Fatalf("RunSubmit: %v", err) + } + submitOutput := out.String() + if !strings.Contains(submitOutput, "submitted") { + t.Errorf("RunSubmit output missing 'submitted':\n%s", submitOutput) + } + t.Logf("Submit output:\n%s", submitOutput) + + // List certs to find the one we just submitted. + certs, err := client.ListCerts(ctx, "") + if err != nil { + t.Fatalf("client.ListCerts: %v", err) + } + + var certID int + for _, c := range certs { + if c.ParsedCsr != nil { + for _, f := range c.ParsedCsr.Subject { + if (f.Name == "commonName" || f.Name == "CN") && f.Value == "cli-integration-test.example.com" { + certID = c.ID + break + } + } + } + if certID > 0 { + break + } + } + if certID == 0 { + t.Fatal("could not find submitted cert in list") + } + + // Show the cert. + out.Reset() + err = cert.RunShow(ctx, client, printer, certID) + if err != nil { + t.Fatalf("RunShow: %v", err) + } + t.Logf("Cert show output:\n%s", out.String()) + + // Delete the cert (cleanup). The cert will likely be in "failed" state + // since the domain is unverified, which should be deletable. + // Wait briefly for the async processing to move past pending. + time.Sleep(2 * time.Second) + + out.Reset() + err = cert.RunDelete(ctx, client, printer, certID) + if err != nil { + // If delete fails (e.g. cert in non-deletable state), just log it. + t.Logf("RunDelete failed (may be in non-deletable state): %v", err) + } else { + t.Log("Cert deleted successfully") + } +} + +// TestIntegration_APIKeysLifecycle creates, lists, and deletes an API key. +func TestIntegration_APIKeysLifecycle(t *testing.T) { + client, printer, out := setup(t) + ctx := context.Background() + + keyName := "cli-test-" + time.Now().Format("150405") + + // Create a key. + out.Reset() + err := auth.RunKeysCreate(ctx, client, printer, keyName, nil) + if err != nil { + t.Fatalf("RunKeysCreate: %v", err) + } + createOutput := out.String() + if !strings.Contains(createOutput, "kk_") { + t.Errorf("RunKeysCreate output missing key secret:\n%s", createOutput) + } + t.Logf("Key create output:\n%s", createOutput) + + // List keys — should contain the new one. + out.Reset() + err = auth.RunKeysList(ctx, client, printer) + if err != nil { + t.Fatalf("RunKeysList: %v", err) + } + if !strings.Contains(out.String(), keyName) { + t.Errorf("RunKeysList output missing new key:\n%s", out.String()) + } + + // Find the key ID. + keys, err := client.ListAPIKeys(ctx) + if err != nil { + t.Fatalf("client.ListAPIKeys: %v", err) + } + var keyID string + for _, k := range keys { + if k.Name == keyName { + keyID = k.ID + break + } + } + if keyID == "" { + t.Fatalf("could not find key %s in list", keyName) + } + + // Delete the key. + out.Reset() + err = auth.RunKeysDelete(ctx, client, printer, keyID) + if err != nil { + t.Fatalf("RunKeysDelete: %v", err) + } + if !strings.Contains(out.String(), "deleted") { + t.Errorf("RunKeysDelete output missing confirmation:\n%s", out.String()) + } + + t.Log("API key lifecycle test passed") +} + +// TestIntegration_JSONOutput verifies JSON mode works end-to-end. +func TestIntegration_JSONOutput(t *testing.T) { + apiURL := envOrSkip(t, "KK_API_URL") + apiKey := envOrSkip(t, "KK_API_KEY") + + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + printer := output.NewWithWriters("json", true, out, errOut) + client := api.NewClient(apiURL, apiKey, "test", "linux", "amd64") + + err := auth.RunStatus(context.Background(), client, printer) + if err != nil { + t.Fatalf("RunStatus: %v", err) + } + + // In JSON mode, output should be valid JSON. + got := strings.TrimSpace(out.String()) + if !strings.HasPrefix(got, "{") { + t.Errorf("JSON output doesn't start with '{': %q", got) + } + if !strings.HasSuffix(got, "}") { + t.Errorf("JSON output doesn't end with '}': %q", got) + } + t.Logf("JSON output:\n%s", got) +} + +// TestIntegration_InvalidKey verifies that an invalid API key returns ErrAuth. +func TestIntegration_InvalidKey(t *testing.T) { + apiURL := envOrSkip(t, "KK_API_URL") + + out := &bytes.Buffer{} + errOut := &bytes.Buffer{} + printer := output.NewWithWriters("text", true, out, errOut) + client := api.NewClient(apiURL, "kk_invalid_key_12345", "test", "linux", "amd64") + + err := auth.RunStatus(context.Background(), client, printer) + if err == nil { + t.Fatal("expected auth error with invalid key") + } + if _, ok := err.(*api.ErrAuth); !ok { + t.Errorf("err type = %T, want *api.ErrAuth", err) + } +}