diff --git a/internal/adapters/ai/base_client.go b/internal/adapters/ai/base_client.go index a6a6fbf..e3be42d 100644 --- a/internal/adapters/ai/base_client.go +++ b/internal/adapters/ai/base_client.go @@ -12,6 +12,7 @@ import ( "time" "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" ) // maxErrorBodyBytes caps how much of an error response body is read into @@ -28,17 +29,18 @@ type BaseClient struct { // NewBaseClient creates a new base client with common configuration. func NewBaseClient(apiKey, model, baseURL string, timeout time.Duration) *BaseClient { - if timeout == 0 { - timeout = domain.TimeoutAI + // The default (timeout == 0) reuses the shared 120s client; a non-zero + // timeout gets its own client. + client := httputil.DefaultClient + if timeout != 0 { + client = httputil.NewClient(timeout) } return &BaseClient{ apiKey: apiKey, model: model, baseURL: baseURL, - client: &http.Client{ - Timeout: timeout, - }, + client: client, } } diff --git a/internal/adapters/ai/function_tools.go b/internal/adapters/ai/function_tools.go index a39dab3..bc662f3 100644 --- a/internal/adapters/ai/function_tools.go +++ b/internal/adapters/ai/function_tools.go @@ -176,59 +176,3 @@ func GetSchedulingTools() []domain.Tool { }, } } - -// GetSmartSchedulingTools returns additional tools for advanced AI scheduling features. -func GetSmartSchedulingTools() []domain.Tool { - return []domain.Tool{ - { - Name: "analyzeMeetingContext", - Description: "Analyze the context of a meeting request to determine priority, optimal duration, and required participants based on historical patterns.", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "meetingType": map[string]any{ - "type": "string", - "description": "Type of meeting (e.g., '1-on-1', 'team', 'planning', 'client')", - }, - "participants": map[string]any{ - "type": "array", - "description": "Participant email addresses", - "items": map[string]any{ - "type": "string", - }, - }, - "subject": map[string]any{ - "type": "string", - "description": "Meeting subject/topic", - }, - }, - "required": []string{"meetingType"}, - }, - }, - { - Name: "suggestRotatingSchedule", - Description: "Suggest a rotating meeting schedule for recurring meetings with participants across multiple timezones to ensure fairness.", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "participants": map[string]any{ - "type": "array", - "description": "Participant email addresses with their timezones", - "items": map[string]any{ - "type": "string", - }, - }, - "duration": map[string]any{ - "type": "integer", - "description": "Meeting duration in minutes", - }, - "frequency": map[string]any{ - "type": "string", - "description": "Meeting frequency (e.g., 'weekly', 'biweekly', 'monthly')", - }, - }, - "required": []string{"participants", "duration", "frequency"}, - }, - }, - } -} diff --git a/internal/adapters/ai/function_tools_test.go b/internal/adapters/ai/function_tools_test.go index 663b1ef..7e3801f 100644 --- a/internal/adapters/ai/function_tools_test.go +++ b/internal/adapters/ai/function_tools_test.go @@ -303,136 +303,6 @@ func TestGetSchedulingTools_GetTimezoneInfo(t *testing.T) { assert.Len(t, required, 1) } -func TestGetSmartSchedulingTools(t *testing.T) { - t.Parallel() - - tools := GetSmartSchedulingTools() - - require.NotEmpty(t, tools) - assert.Len(t, tools, 2, "Should have 2 smart scheduling tools") - - // Verify expected tool names - expectedNames := map[string]bool{ - "analyzeMeetingContext": false, - "suggestRotatingSchedule": false, - } - - for _, tool := range tools { - assert.NotEmpty(t, tool.Name) - assert.NotEmpty(t, tool.Description) - assert.NotNil(t, tool.Parameters) - - if _, exists := expectedNames[tool.Name]; exists { - expectedNames[tool.Name] = true - } - } - - // Verify all expected tools were found - for name, found := range expectedNames { - assert.True(t, found, "Expected tool %s not found", name) - } -} - -func TestGetSmartSchedulingTools_AnalyzeMeetingContext(t *testing.T) { - t.Parallel() - - tools := GetSmartSchedulingTools() - - var analyzeMeetingContext *struct { - Name string - Description string - Parameters map[string]any - } - - for i := range tools { - if tools[i].Name == "analyzeMeetingContext" { - analyzeMeetingContext = &struct { - Name string - Description string - Parameters map[string]any - }{ - Name: tools[i].Name, - Description: tools[i].Description, - Parameters: tools[i].Parameters, - } - break - } - } - - require.NotNil(t, analyzeMeetingContext) - - // Check description - assert.Contains(t, analyzeMeetingContext.Description, "context") - assert.Contains(t, analyzeMeetingContext.Description, "priority") - - // Check parameters - params := analyzeMeetingContext.Parameters - properties, ok := params["properties"].(map[string]any) - require.True(t, ok) - - // Check meetingType parameter - meetingType, ok := properties["meetingType"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "string", meetingType["type"]) - assert.Contains(t, meetingType["description"].(string), "1-on-1") - - // Check required only includes meetingType - required, ok := params["required"].([]string) - require.True(t, ok) - assert.Contains(t, required, "meetingType") -} - -func TestGetSmartSchedulingTools_SuggestRotatingSchedule(t *testing.T) { - t.Parallel() - - tools := GetSmartSchedulingTools() - - var suggestRotatingSchedule *struct { - Name string - Description string - Parameters map[string]any - } - - for i := range tools { - if tools[i].Name == "suggestRotatingSchedule" { - suggestRotatingSchedule = &struct { - Name string - Description string - Parameters map[string]any - }{ - Name: tools[i].Name, - Description: tools[i].Description, - Parameters: tools[i].Parameters, - } - break - } - } - - require.NotNil(t, suggestRotatingSchedule) - - // Check description - assert.Contains(t, suggestRotatingSchedule.Description, "rotating") - assert.Contains(t, suggestRotatingSchedule.Description, "fairness") - - // Check parameters - params := suggestRotatingSchedule.Parameters - properties, ok := params["properties"].(map[string]any) - require.True(t, ok) - - // Check frequency parameter - frequency, ok := properties["frequency"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "string", frequency["type"]) - assert.Contains(t, frequency["description"].(string), "weekly") - - // Check required fields - required, ok := params["required"].([]string) - require.True(t, ok) - assert.Contains(t, required, "participants") - assert.Contains(t, required, "duration") - assert.Contains(t, required, "frequency") -} - func TestToolParameterTypes(t *testing.T) { t.Parallel() diff --git a/internal/adapters/ai/pattern_learner.go b/internal/adapters/ai/pattern_learner.go deleted file mode 100644 index d6e355c..0000000 --- a/internal/adapters/ai/pattern_learner.go +++ /dev/null @@ -1,353 +0,0 @@ -package ai - -import ( - "context" - "encoding/json" - "fmt" - "os" - "strings" - "time" - - "github.com/nylas/cli/internal/domain" - "github.com/nylas/cli/internal/ports" -) - -// PatternLearner learns from calendar history to predict scheduling patterns. -type PatternLearner struct { - nylasClient ports.NylasClient - llmRouter ports.LLMRouter -} - -// NewPatternLearner creates a new pattern learner. -func NewPatternLearner(nylasClient ports.NylasClient, llmRouter ports.LLMRouter) *PatternLearner { - return &PatternLearner{ - nylasClient: nylasClient, - llmRouter: llmRouter, - } -} - -// SchedulingPatterns represents discovered patterns from calendar history. -type SchedulingPatterns struct { - UserID string `json:"user_id"` - AnalysisPeriod AnalysisPeriod `json:"analysis_period"` - AcceptancePatterns []AcceptancePattern `json:"acceptance_patterns"` - DurationPatterns []DurationPattern `json:"duration_patterns"` - TimezonePatterns []TimezonePattern `json:"timezone_patterns"` - ProductivityInsights []ProductivityInsight `json:"productivity_insights"` - Recommendations []string `json:"recommendations"` - TotalEventsAnalyzed int `json:"total_events_analyzed"` - GeneratedAt time.Time `json:"generated_at"` -} - -// AnalysisPeriod defines the time period analyzed. -type AnalysisPeriod struct { - StartDate time.Time `json:"start_date"` - EndDate time.Time `json:"end_date"` - Days int `json:"days"` -} - -// AcceptancePattern represents meeting acceptance rates by time/day. -type AcceptancePattern struct { - TimeSlot string `json:"time_slot"` // e.g., "Monday 9-11 AM" - AcceptRate float64 `json:"accept_rate"` // 0-1 - EventCount int `json:"event_count"` // Number of events in this slot - Description string `json:"description"` // Human-readable explanation - Confidence float64 `json:"confidence"` // 0-1, based on sample size -} - -// DurationPattern represents typical meeting duration patterns. -type DurationPattern struct { - MeetingType string `json:"meeting_type"` // e.g., "1-on-1", "Team standup" - ScheduledDuration int `json:"scheduled_duration"` // In minutes - ActualDuration int `json:"actual_duration"` // In minutes - Variance int `json:"variance"` // Difference - EventCount int `json:"event_count"` // Sample size - Description string `json:"description"` // Pattern description -} - -// TimezonePattern represents timezone preferences. -type TimezonePattern struct { - Timezone string `json:"timezone"` // e.g., "America/New_York" - EventCount int `json:"event_count"` // Number of events - Percentage float64 `json:"percentage"` // % of total events - PreferredTime string `json:"preferred_time"` // e.g., "2-4 PM PST" - Description string `json:"description"` // Pattern description -} - -// ProductivityInsight represents productivity patterns. -type ProductivityInsight struct { - InsightType string `json:"insight_type"` // e.g., "peak_focus", "low_energy" - TimeSlot string `json:"time_slot"` // e.g., "Tuesday 10 AM - 12 PM" - Score int `json:"score"` // 0-100 - Description string `json:"description"` // Explanation - BasedOn []string `json:"based_on"` // What data this is based on -} - -// LearnPatternsRequest represents a request to learn patterns. -type LearnPatternsRequest struct { - GrantID string `json:"grant_id"` - LookbackDays int `json:"lookback_days"` // How far back to analyze - MinConfidence float64 `json:"min_confidence"` // Minimum confidence threshold - IncludeRecurring bool `json:"include_recurring"` // Include recurring events -} - -// LearnPatterns analyzes calendar history and learns scheduling patterns. -func (p *PatternLearner) LearnPatterns(ctx context.Context, req *LearnPatternsRequest) (*SchedulingPatterns, error) { - // 1. Fetch historical events - events, err := p.fetchHistoricalEvents(ctx, req) - if err != nil { - return nil, fmt.Errorf("fetch historical events: %w", err) - } - - if len(events) == 0 { - return nil, fmt.Errorf("no events found in the specified period") - } - - // 2. Calculate analysis period - analysisPeriod := p.calculateAnalysisPeriod(events, req.LookbackDays) - - // 3. Analyze acceptance patterns - acceptancePatterns := p.analyzeAcceptancePatterns(events) - - // 4. Analyze duration patterns - durationPatterns := p.analyzeDurationPatterns(events) - - // 5. Analyze timezone patterns - timezonePatterns := p.analyzeTimezonePatterns(events) - - // 6. Analyze productivity patterns - productivityInsights := p.analyzeProductivityPatterns(events) - - // 7. Use LLM to generate recommendations - recommendations, err := p.generateRecommendations(ctx, events, acceptancePatterns, durationPatterns, timezonePatterns, productivityInsights) - if err != nil { - // Non-fatal: continue without LLM recommendations - recommendations = []string{"Unable to generate AI recommendations"} - } - - patterns := &SchedulingPatterns{ - UserID: req.GrantID, - AnalysisPeriod: analysisPeriod, - AcceptancePatterns: acceptancePatterns, - DurationPatterns: durationPatterns, - TimezonePatterns: timezonePatterns, - ProductivityInsights: productivityInsights, - Recommendations: recommendations, - TotalEventsAnalyzed: len(events), - GeneratedAt: time.Now(), - } - - return patterns, nil -} - -// fetchHistoricalEvents fetches calendar events for pattern analysis. -func (p *PatternLearner) fetchHistoricalEvents(ctx context.Context, req *LearnPatternsRequest) ([]domain.Event, error) { - now := time.Now() - startDate := now.AddDate(0, 0, -req.LookbackDays) - - // First get list of calendars to fetch events from all - calendars, err := p.nylasClient.GetCalendars(ctx, req.GrantID) - if err != nil { - return nil, fmt.Errorf("fetch calendars: %w", err) - } - - allEvents := []domain.Event{} - skipped := []string{} - - // Fetch events from each calendar - for _, calendar := range calendars { - events, err := p.nylasClient.GetEvents(ctx, req.GrantID, calendar.ID, &domain.EventQueryParams{ - Start: startDate.Unix(), - End: now.Unix(), - Limit: 200, // Maximum allowed by Nylas API v3 - }) - - if err != nil { - // Some calendars are read-only or temporarily unavailable. Record - // the skip with the underlying error so the caller (and test - // harness) can see analysis was partial — silently dropping the - // calendar gives the user "patterns" computed from incomplete - // data and no way to know. - skipped = append(skipped, fmt.Sprintf("%s: %v", calendar.ID, err)) - continue - } - - allEvents = append(allEvents, events...) - } - - if len(skipped) > 0 { - // Log to stderr; downstream callers that already check for - // PartialAnalysis on the returned struct (set below) get the same - // signal without depending on a logger interface. - fmt.Fprintf(os.Stderr, "warn: pattern analysis skipped %d calendar(s): %s\n", - len(skipped), strings.Join(skipped, "; ")) - } - - // Filter out recurring events if not requested - if !req.IncludeRecurring { - filtered := []domain.Event{} - for _, event := range allEvents { - // Check if event is recurring (has recurrence or is part of series) - if len(event.Recurrence) == 0 && event.MasterEventID == "" { - filtered = append(filtered, event) - } - } - return filtered, nil - } - - return allEvents, nil -} - -// calculateAnalysisPeriod calculates the actual period analyzed. -func (p *PatternLearner) calculateAnalysisPeriod(events []domain.Event, _ int) AnalysisPeriod { - if len(events) == 0 { - return AnalysisPeriod{} - } - - earliest := events[0].When.StartTime - latest := events[0].When.EndTime - - for _, event := range events { - if event.When.StartTime < earliest { - earliest = event.When.StartTime - } - if event.When.EndTime > latest { - latest = event.When.EndTime - } - } - - // Convert Unix timestamps to time.Time - earliestTime := time.Unix(earliest, 0) - latestTime := time.Unix(latest, 0) - - days := int(latestTime.Sub(earliestTime).Hours() / 24) - - return AnalysisPeriod{ - StartDate: earliestTime, - EndDate: latestTime, - Days: days, - } -} - -// generateRecommendations uses LLM to generate actionable recommendations. -func (p *PatternLearner) generateRecommendations(ctx context.Context, events []domain.Event, acceptance []AcceptancePattern, duration []DurationPattern, timezone []TimezonePattern, productivity []ProductivityInsight) ([]string, error) { - // Build context for LLM - patternContext := p.buildPatternContext(events, acceptance, duration, timezone, productivity) - - // Create chat request - chatReq := &domain.ChatRequest{ - Messages: []domain.ChatMessage{ - { - Role: "system", - Content: "You are an expert productivity coach analyzing calendar patterns. Provide 3-5 actionable recommendations to improve scheduling and productivity.", - }, - { - Role: "user", - Content: fmt.Sprintf("Based on the following calendar analysis, provide specific recommendations:\n\n%s", patternContext), - }, - }, - Temperature: 0.7, - MaxTokens: 500, - } - - // Call LLM - response, err := p.llmRouter.Chat(ctx, chatReq) - if err != nil { - return nil, err - } - - // Parse recommendations (simple line-based parsing) - recommendations := []string{} - lines := strings.Split(response.Content, "\n") - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if trimmed != "" && len(trimmed) > 10 { - // Remove numbering if present - if len(trimmed) > 3 && trimmed[0] >= '1' && trimmed[0] <= '9' && trimmed[1] == '.' { - trimmed = strings.TrimSpace(trimmed[3:]) - } - recommendations = append(recommendations, trimmed) - } - } - - if len(recommendations) == 0 { - recommendations = []string{"No specific recommendations available"} - } - - return recommendations, nil -} - -// buildPatternContext builds context string for LLM. -// Uses strings.Builder for efficient string concatenation. -func (p *PatternLearner) buildPatternContext(events []domain.Event, acceptance []AcceptancePattern, duration []DurationPattern, timezone []TimezonePattern, productivity []ProductivityInsight) string { - var sb strings.Builder - // Pre-allocate estimated capacity (header + patterns) - sb.Grow(512) - - fmt.Fprintf(&sb, "Calendar Analysis (%d events analyzed):\n\n", len(events)) - - // Acceptance patterns - if len(acceptance) > 0 { - sb.WriteString("Meeting Acceptance Patterns:\n") - for i, pattern := range acceptance { - if i >= 5 { - break // Top 5 - } - fmt.Fprintf(&sb, "- %s: %.0f%% acceptance (%d events) - %s\n", - pattern.TimeSlot, pattern.AcceptRate*100, pattern.EventCount, pattern.Description) - } - sb.WriteByte('\n') - } - - // Duration patterns - if len(duration) > 0 { - sb.WriteString("Meeting Duration Patterns:\n") - for _, pattern := range duration { - fmt.Fprintf(&sb, "- %s: avg %d minutes (%d events)\n", - pattern.MeetingType, pattern.ScheduledDuration, pattern.EventCount) - } - sb.WriteByte('\n') - } - - // Timezone patterns - if len(timezone) > 0 { - sb.WriteString("Timezone Patterns:\n") - for i, pattern := range timezone { - if i >= 3 { - break // Top 3 - } - fmt.Fprintf(&sb, "- %s: %.0f%% of meetings (%d events)\n", - pattern.Timezone, pattern.Percentage*100, pattern.EventCount) - } - sb.WriteByte('\n') - } - - // Productivity insights - if len(productivity) > 0 { - sb.WriteString("Productivity Insights:\n") - for _, insight := range productivity { - fmt.Fprintf(&sb, "- %s\n", insight.Description) - } - sb.WriteByte('\n') - } - - return sb.String() -} - -// SavePatterns saves learned patterns (stub for future storage implementation). -// Returns an error rather than nil so callers can't mistake the no-op for a -// successful persist — pairs with LoadPatterns which already errors. -func (p *PatternLearner) SavePatterns(ctx context.Context, patterns *SchedulingPatterns) error { - return fmt.Errorf("pattern storage not yet implemented") -} - -// LoadPatterns loads previously learned patterns (stub for future storage implementation). -func (p *PatternLearner) LoadPatterns(ctx context.Context, userID string) (*SchedulingPatterns, error) { - // Future: Load from local storage/database - return nil, fmt.Errorf("pattern storage not yet implemented") -} - -// ExportPatterns exports patterns to JSON. -func (p *PatternLearner) ExportPatterns(patterns *SchedulingPatterns) ([]byte, error) { - return json.MarshalIndent(patterns, "", " ") -} diff --git a/internal/adapters/ai/pattern_learner_analysis.go b/internal/adapters/ai/pattern_learner_analysis.go deleted file mode 100644 index 1d1113e..0000000 --- a/internal/adapters/ai/pattern_learner_analysis.go +++ /dev/null @@ -1,264 +0,0 @@ -package ai - -import ( - "cmp" - "fmt" - "slices" - "strings" - "time" - - "github.com/nylas/cli/internal/domain" -) - -// analyzeAcceptancePatterns analyzes meeting acceptance rates by time slots. -func (p *PatternLearner) analyzeAcceptancePatterns(events []domain.Event) []AcceptancePattern { - // Group events by day and time slot - slotCounts := make(map[string]int) - slotTotal := make(map[string]int) - - for _, event := range events { - // Convert Unix timestamp to time.Time - startTime := time.Unix(event.When.StartTime, 0) - day := startTime.Weekday().String() - hour := startTime.Hour() - - // Categorize into time blocks - var timeBlock string - if hour >= 9 && hour < 11 { - timeBlock = "9-11 AM" - } else if hour >= 11 && hour < 13 { - timeBlock = "11 AM-1 PM" - } else if hour >= 13 && hour < 15 { - timeBlock = "1-3 PM" - } else if hour >= 15 && hour < 17 { - timeBlock = "3-5 PM" - } else { - timeBlock = "Outside hours" - } - - slot := fmt.Sprintf("%s %s", day, timeBlock) - slotTotal[slot]++ - - // Consider event "accepted" if status is confirmed or busy is true - if event.Status == "confirmed" || event.Busy { - slotCounts[slot]++ - } - } - - // Calculate acceptance rates - patterns := []AcceptancePattern{} - for slot, total := range slotTotal { - if total < 3 { - // Skip slots with too few samples - continue - } - - accepted := slotCounts[slot] - acceptRate := float64(accepted) / float64(total) - - // Confidence based on sample size (higher samples = higher confidence) - confidence := float64(total) / 20.0 - if confidence > 1.0 { - confidence = 1.0 - } - - description := "" - if acceptRate > 0.8 { - description = "You prefer meetings during this time" - } else if acceptRate < 0.4 { - description = "You tend to avoid meetings during this time" - } else { - description = "Moderate acceptance rate" - } - - patterns = append(patterns, AcceptancePattern{ - TimeSlot: slot, - AcceptRate: acceptRate, - EventCount: total, - Description: description, - Confidence: confidence, - }) - } - - // Sort by accept rate (highest first) - slices.SortFunc(patterns, func(a, b AcceptancePattern) int { - return cmp.Compare(b.AcceptRate, a.AcceptRate) // Descending order - }) - - return patterns -} - -// analyzeDurationPatterns analyzes meeting duration patterns. -func (p *PatternLearner) analyzeDurationPatterns(events []domain.Event) []DurationPattern { - // Group events by type (inferred from title patterns) - typeMap := make(map[string][]domain.Event) - - for _, event := range events { - meetingType := p.inferMeetingType(event.Title) - typeMap[meetingType] = append(typeMap[meetingType], event) - } - - patterns := []DurationPattern{} - - for meetingType, typeEvents := range typeMap { - if len(typeEvents) < 3 { - // Skip types with too few samples - continue - } - - // Calculate average scheduled duration - var totalScheduled, totalActual int - for _, event := range typeEvents { - // Calculate duration from Unix timestamps (in seconds) - durationSec := event.When.EndTime - event.When.StartTime - scheduledDuration := int(durationSec / 60) // Convert to minutes - totalScheduled += scheduledDuration - - // Actual duration is same as scheduled (we don't have end-time tracking) - totalActual += scheduledDuration - } - - avgScheduled := totalScheduled / len(typeEvents) - avgActual := totalActual / len(typeEvents) - - patterns = append(patterns, DurationPattern{ - MeetingType: meetingType, - ScheduledDuration: avgScheduled, - ActualDuration: avgActual, - Variance: avgActual - avgScheduled, - EventCount: len(typeEvents), - Description: fmt.Sprintf("Average %d-minute %s meetings", avgScheduled, meetingType), - }) - } - - return patterns -} - -// inferMeetingType infers meeting type from title. -func (p *PatternLearner) inferMeetingType(title string) string { - titleLower := strings.ToLower(title) - - if containsAny(titleLower, []string{"1:1", "1-on-1", "one-on-one"}) { - return "1-on-1" - } - if containsAny(titleLower, []string{"standup", "daily", "scrum"}) { - return "Standup" - } - if containsAny(titleLower, []string{"review", "retrospective", "retro"}) { - return "Review" - } - if containsAny(titleLower, []string{"planning", "plan"}) { - return "Planning" - } - if containsAny(titleLower, []string{"interview", "candidate"}) { - return "Interview" - } - if containsAny(titleLower, []string{"client", "customer"}) { - return "Client call" - } - - return "General meeting" -} - -// containsAny checks if string contains any of the substrings. -func containsAny(s string, substrs []string) bool { - for _, substr := range substrs { - if len(s) >= len(substr) { - // Simple substring check - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - } - } - return false -} - -// analyzeTimezonePatterns analyzes timezone preferences. -func (p *PatternLearner) analyzeTimezonePatterns(events []domain.Event) []TimezonePattern { - tzCounts := make(map[string]int) - totalEvents := len(events) - - for _, event := range events { - tz := event.When.StartTimezone - if tz == "" { - tz = "UTC" - } - tzCounts[tz]++ - } - - patterns := []TimezonePattern{} - for tz, count := range tzCounts { - percentage := float64(count) / float64(totalEvents) - - description := fmt.Sprintf("%d%% of meetings in this timezone", int(percentage*100)) - - patterns = append(patterns, TimezonePattern{ - Timezone: tz, - EventCount: count, - Percentage: percentage, - PreferredTime: "Varies", // Would need more analysis - Description: description, - }) - } - - // Sort by event count (most common first) - slices.SortFunc(patterns, func(a, b TimezonePattern) int { - return cmp.Compare(b.EventCount, a.EventCount) // Descending order - }) - - return patterns -} - -// analyzeProductivityPatterns analyzes productivity patterns. -func (p *PatternLearner) analyzeProductivityPatterns(events []domain.Event) []ProductivityInsight { - // Analyze meeting density by day - dayDensity := make(map[string]int) - for _, event := range events { - startTime := time.Unix(event.When.StartTime, 0) - day := startTime.Weekday().String() - dayDensity[day]++ - } - - insights := []ProductivityInsight{} - - // Find peak and low days - maxDay := "" - maxCount := 0 - minDay := "" - minCount := len(events) + 1 - - for day, count := range dayDensity { - if count > maxCount { - maxCount = count - maxDay = day - } - if count < minCount { - minCount = count - minDay = day - } - } - - if maxDay != "" { - insights = append(insights, ProductivityInsight{ - InsightType: "high_meeting_density", - TimeSlot: maxDay, - Score: 30, // Lower score for high meeting days - Description: fmt.Sprintf("%s has the most meetings (%d) - may impact focus time", maxDay, maxCount), - BasedOn: []string{"Meeting count by day"}, - }) - } - - if minDay != "" { - insights = append(insights, ProductivityInsight{ - InsightType: "low_meeting_density", - TimeSlot: minDay, - Score: 90, // Higher score for low meeting days - Description: fmt.Sprintf("%s has the fewest meetings (%d) - good for deep work", minDay, minCount), - BasedOn: []string{"Meeting count by day"}, - }) - } - - return insights -} diff --git a/internal/adapters/ai/pattern_learner_analysis_test.go b/internal/adapters/ai/pattern_learner_analysis_test.go deleted file mode 100644 index 43dd012..0000000 --- a/internal/adapters/ai/pattern_learner_analysis_test.go +++ /dev/null @@ -1,460 +0,0 @@ -//go:build !integration - -package ai - -import ( - "testing" - "time" - - "github.com/nylas/cli/internal/domain" - "github.com/stretchr/testify/assert" -) - -func TestPatternLearner_AnalyzeAcceptancePatterns(t *testing.T) { - learner := &PatternLearner{} - - // Create events at specific local times for testing - // Use Local timezone to match what analyzeAcceptancePatterns expects - makeEventLocal := func(year, month, day, hour int, status string, busy bool) domain.Event { - eventTime := time.Date(year, time.Month(month), day, hour, 0, 0, 0, time.Local) - - return domain.Event{ - When: domain.EventWhen{ - StartTime: eventTime.Unix(), - EndTime: eventTime.Add(time.Hour).Unix(), - }, - Status: status, - Busy: busy, - } - } - - t.Run("returns empty for insufficient samples", func(t *testing.T) { - events := []domain.Event{ - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - // Only 2 events - should be skipped - } - result := learner.analyzeAcceptancePatterns(events) - assert.Empty(t, result) - }) - - t.Run("calculates acceptance rate correctly", func(t *testing.T) { - // 4 events same slot: 3 confirmed/busy, 1 tentative/not busy - events := []domain.Event{ - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "tentative", false), - } - result := learner.analyzeAcceptancePatterns(events) - - assert.Len(t, result, 1) - assert.InDelta(t, 0.75, result[0].AcceptRate, 0.01) // 3/4 = 0.75 - }) - - t.Run("sorts by acceptance rate descending", func(t *testing.T) { - // Create events in two different slots with different acceptance rates - // Note: We use different hours that fall into distinct time blocks - events := []domain.Event{ - // First slot: low acceptance (25%) - makeEventLocal(2024, 1, 1, 9, "confirmed", true), - makeEventLocal(2024, 1, 1, 9, "tentative", false), - makeEventLocal(2024, 1, 1, 9, "tentative", false), - makeEventLocal(2024, 1, 1, 9, "tentative", false), - // Second slot: high acceptance (100%) - makeEventLocal(2024, 1, 1, 15, "confirmed", true), - makeEventLocal(2024, 1, 1, 15, "confirmed", true), - makeEventLocal(2024, 1, 1, 15, "confirmed", true), - } - result := learner.analyzeAcceptancePatterns(events) - - // If we have multiple slots, highest rate should be first - if len(result) > 1 { - assert.Greater(t, result[0].AcceptRate, result[1].AcceptRate) - } - }) - - t.Run("generates descriptions based on acceptance rate", func(t *testing.T) { - // High acceptance (>80%) should say "prefer meetings" - events := []domain.Event{ - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - makeEventLocal(2024, 1, 1, 10, "confirmed", true), - } - result := learner.analyzeAcceptancePatterns(events) - - assert.Len(t, result, 1) - assert.Contains(t, result[0].Description, "prefer meetings") - }) -} - -func TestPatternLearner_AnalyzeAcceptancePatterns_Descriptions(t *testing.T) { - learner := &PatternLearner{} - - makeEvents := func(count int, accepted int) []domain.Event { - events := make([]domain.Event, count) - baseTime := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) // Monday 10 AM - - for i := 0; i < count; i++ { - status := "tentative" - busy := false - if i < accepted { - status = "confirmed" - busy = true - } - events[i] = domain.Event{ - When: domain.EventWhen{ - StartTime: baseTime.Unix(), - EndTime: baseTime.Add(time.Hour).Unix(), - }, - Status: status, - Busy: busy, - } - } - return events - } - - tests := []struct { - name string - events []domain.Event - wantDesc string - }{ - { - name: "high acceptance description", - events: makeEvents(5, 5), // 100% acceptance - wantDesc: "prefer meetings", - }, - { - name: "low acceptance description", - events: makeEvents(5, 1), // 20% acceptance - wantDesc: "avoid meetings", - }, - { - name: "moderate acceptance description", - events: makeEvents(4, 2), // 50% acceptance - wantDesc: "Moderate acceptance", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := learner.analyzeAcceptancePatterns(tt.events) - - if len(result) > 0 { - assert.Contains(t, result[0].Description, tt.wantDesc) - } - }) - } -} - -func TestPatternLearner_AnalyzeDurationPatterns(t *testing.T) { - learner := &PatternLearner{} - - makeEvent := func(title string, durationMinutes int) domain.Event { - start := time.Now() - return domain.Event{ - Title: title, - When: domain.EventWhen{ - StartTime: start.Unix(), - EndTime: start.Add(time.Duration(durationMinutes) * time.Minute).Unix(), - }, - } - } - - tests := []struct { - name string - events []domain.Event - wantTypes []string - wantDurations map[string]int // type -> expected avg duration - }{ - { - name: "groups by meeting type", - events: []domain.Event{ - makeEvent("Team Standup", 15), - makeEvent("Daily Standup", 15), - makeEvent("Standup meeting", 20), - makeEvent("1:1 with Alice", 30), - makeEvent("1-on-1 with Bob", 30), - makeEvent("One-on-one review", 45), - }, - wantTypes: []string{"Standup", "1-on-1"}, - wantDurations: map[string]int{ - "Standup": 16, // (15+15+20)/3 = 16.67 - "1-on-1": 35, // (30+30+45)/3 = 35 - }, - }, - { - name: "skips types with fewer than 3 samples", - events: []domain.Event{ - makeEvent("Interview candidate", 60), - makeEvent("Interview with John", 60), - // Only 2 interviews - should be skipped - }, - wantTypes: []string{}, // No patterns should be returned - }, - { - name: "categorizes general meetings", - events: []domain.Event{ - makeEvent("Random meeting", 30), - makeEvent("Discussion", 45), - makeEvent("Sync up", 30), - }, - wantTypes: []string{"General meeting"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := learner.analyzeDurationPatterns(tt.events) - - typeNames := make([]string, len(result)) - for i, p := range result { - typeNames[i] = p.MeetingType - } - - for _, wantType := range tt.wantTypes { - assert.Contains(t, typeNames, wantType, "expected type %s to be present", wantType) - } - - for meetingType, wantDuration := range tt.wantDurations { - for _, p := range result { - if p.MeetingType == meetingType { - assert.InDelta(t, wantDuration, p.ScheduledDuration, 2, "duration for %s", meetingType) - } - } - } - }) - } -} - -func TestPatternLearner_InferMeetingType(t *testing.T) { - learner := &PatternLearner{} - - tests := []struct { - title string - want string - }{ - {"1:1 with Alice", "1-on-1"}, - {"1-on-1 review", "1-on-1"}, - {"One-on-one meeting", "1-on-1"}, - {"Daily Standup", "Standup"}, - {"Team Scrum", "Standup"}, - {"Daily sync", "Standup"}, - {"Sprint Review", "Review"}, - {"Retrospective", "Review"}, - {"Retro meeting", "Review"}, - {"Sprint Planning", "Planning"}, - {"Q4 Plan session", "Planning"}, - {"Interview with candidate", "Interview"}, - {"Candidate screening", "Interview"}, - {"Client call", "Client call"}, - {"Customer meeting", "Client call"}, - {"Random meeting", "General meeting"}, - {"", "General meeting"}, - } - - for _, tt := range tests { - t.Run(tt.title, func(t *testing.T) { - result := learner.inferMeetingType(tt.title) - assert.Equal(t, tt.want, result) - }) - } -} - -func TestContainsAny(t *testing.T) { - tests := []struct { - s string - substrs []string - want bool - }{ - {"hello world", []string{"hello"}, true}, - {"hello world", []string{"world"}, true}, - {"hello world", []string{"foo", "bar"}, false}, - {"standup meeting", []string{"standup", "daily"}, true}, - {"daily scrum", []string{"standup", "scrum"}, true}, - {"", []string{"test"}, false}, - {"test", []string{}, false}, - {"short", []string{"longer"}, false}, - } - - for _, tt := range tests { - t.Run(tt.s, func(t *testing.T) { - result := containsAny(tt.s, tt.substrs) - assert.Equal(t, tt.want, result) - }) - } -} - -func TestPatternLearner_AnalyzeTimezonePatterns(t *testing.T) { - learner := &PatternLearner{} - - tests := []struct { - name string - events []domain.Event - wantTop string // Expected top timezone - wantPct float64 // Expected percentage for top - }{ - { - name: "calculates timezone distribution", - events: []domain.Event{ - {When: domain.EventWhen{StartTimezone: "America/New_York"}}, - {When: domain.EventWhen{StartTimezone: "America/New_York"}}, - {When: domain.EventWhen{StartTimezone: "America/New_York"}}, - {When: domain.EventWhen{StartTimezone: "America/Los_Angeles"}}, - {When: domain.EventWhen{StartTimezone: "America/Los_Angeles"}}, - }, - wantTop: "America/New_York", - wantPct: 0.6, // 3/5 - }, - { - name: "defaults empty timezone to UTC", - events: []domain.Event{ - {When: domain.EventWhen{StartTimezone: ""}}, - {When: domain.EventWhen{StartTimezone: ""}}, - {When: domain.EventWhen{StartTimezone: "Europe/London"}}, - }, - wantTop: "UTC", - wantPct: 0.67, // 2/3 - }, - { - name: "sorts by event count", - events: []domain.Event{ - {When: domain.EventWhen{StartTimezone: "Asia/Tokyo"}}, - {When: domain.EventWhen{StartTimezone: "Europe/Paris"}}, - {When: domain.EventWhen{StartTimezone: "Europe/Paris"}}, - {When: domain.EventWhen{StartTimezone: "Europe/Paris"}}, - }, - wantTop: "Europe/Paris", - wantPct: 0.75, // 3/4 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := learner.analyzeTimezonePatterns(tt.events) - - assert.NotEmpty(t, result) - assert.Equal(t, tt.wantTop, result[0].Timezone) - assert.InDelta(t, tt.wantPct, result[0].Percentage, 0.01) - }) - } -} - -func TestPatternLearner_AnalyzeProductivityPatterns(t *testing.T) { - learner := &PatternLearner{} - - makeEvent := func(weekday time.Weekday) domain.Event { - baseTime := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) - daysToAdd := int(weekday) - int(baseTime.Weekday()) - if daysToAdd < 0 { - daysToAdd += 7 - } - eventTime := baseTime.AddDate(0, 0, daysToAdd) - - return domain.Event{ - When: domain.EventWhen{ - StartTime: eventTime.Unix(), - EndTime: eventTime.Add(time.Hour).Unix(), - }, - } - } - - tests := []struct { - name string - events []domain.Event - wantHighDay string - wantLowDay string - }{ - { - name: "identifies high and low meeting days", - events: []domain.Event{ - makeEvent(time.Monday), - makeEvent(time.Monday), - makeEvent(time.Monday), - makeEvent(time.Monday), - makeEvent(time.Monday), // 5 on Monday - makeEvent(time.Tuesday), - makeEvent(time.Tuesday), // 2 on Tuesday - makeEvent(time.Wednesday), // 1 on Wednesday - }, - wantHighDay: "Monday", - wantLowDay: "Wednesday", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := learner.analyzeProductivityPatterns(tt.events) - - assert.NotEmpty(t, result) - - // Find high and low density insights - var highDensity, lowDensity *ProductivityInsight - for i := range result { - if result[i].InsightType == "high_meeting_density" { - highDensity = &result[i] - } - if result[i].InsightType == "low_meeting_density" { - lowDensity = &result[i] - } - } - - if tt.wantHighDay != "" { - assert.NotNil(t, highDensity) - assert.Equal(t, tt.wantHighDay, highDensity.TimeSlot) - assert.Equal(t, 30, highDensity.Score) // Low score for busy day - } - - if tt.wantLowDay != "" { - assert.NotNil(t, lowDensity) - assert.Equal(t, tt.wantLowDay, lowDensity.TimeSlot) - assert.Equal(t, 90, lowDensity.Score) // High score for quiet day - } - }) - } -} - -// Note: splitLines/trimSpace helpers were removed in favor of stdlib -// strings.Split/strings.TrimSpace (used directly in pattern_learner.go). - -func TestPatternLearner_ConfidenceCalculation(t *testing.T) { - learner := &PatternLearner{} - - // Test that confidence is calculated based on sample size - tests := []struct { - name string - sampleSize int - wantConfidence float64 - }{ - {"low sample", 5, 0.25}, // 5/20 = 0.25 - {"medium sample", 10, 0.5}, // 10/20 = 0.5 - {"high sample", 20, 1.0}, // 20/20 = 1.0, capped at 1.0 - {"very high sample", 40, 1.0}, // Capped at 1.0 - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create events with the specified sample size - events := make([]domain.Event, tt.sampleSize) - for i := 0; i < tt.sampleSize; i++ { - baseTime := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) - events[i] = domain.Event{ - When: domain.EventWhen{ - StartTime: baseTime.Unix(), - EndTime: baseTime.Add(time.Hour).Unix(), - }, - Status: "confirmed", - Busy: true, - } - } - - result := learner.analyzeAcceptancePatterns(events) - - if len(result) > 0 { - assert.InDelta(t, tt.wantConfidence, result[0].Confidence, 0.01) - } - }) - } -} diff --git a/internal/adapters/ai/pattern_learner_test.go b/internal/adapters/ai/pattern_learner_test.go deleted file mode 100644 index 06e564c..0000000 --- a/internal/adapters/ai/pattern_learner_test.go +++ /dev/null @@ -1,304 +0,0 @@ -//go:build !integration - -package ai - -import ( - "context" - "encoding/json" - "testing" - "time" - - "github.com/nylas/cli/internal/domain" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewPatternLearner(t *testing.T) { - t.Run("creates learner with nil dependencies", func(t *testing.T) { - learner := NewPatternLearner(nil, nil) - - assert.NotNil(t, learner) - assert.Nil(t, learner.nylasClient) - assert.Nil(t, learner.llmRouter) - }) -} - -func TestPatternLearner_CalculateAnalysisPeriod(t *testing.T) { - learner := &PatternLearner{} - - tests := []struct { - name string - events []domain.Event - requestedDays int - wantDays int - wantEmpty bool - }{ - { - name: "returns empty for no events", - events: []domain.Event{}, - requestedDays: 30, - wantEmpty: true, - }, - { - name: "calculates correct period", - events: []domain.Event{ - {When: domain.EventWhen{StartTime: time.Now().Add(-7 * 24 * time.Hour).Unix(), EndTime: time.Now().Add(-7*24*time.Hour + time.Hour).Unix()}}, - {When: domain.EventWhen{StartTime: time.Now().Unix(), EndTime: time.Now().Add(time.Hour).Unix()}}, - }, - requestedDays: 30, - wantDays: 7, // Approximately 7 days span - }, - { - name: "handles single event", - events: []domain.Event{ - {When: domain.EventWhen{StartTime: time.Now().Unix(), EndTime: time.Now().Add(time.Hour).Unix()}}, - }, - requestedDays: 30, - wantDays: 0, // Same day - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := learner.calculateAnalysisPeriod(tt.events, tt.requestedDays) - - if tt.wantEmpty { - assert.True(t, result.StartDate.IsZero()) - return - } - - assert.False(t, result.StartDate.IsZero()) - assert.False(t, result.EndDate.IsZero()) - // Allow some tolerance in days calculation - assert.InDelta(t, tt.wantDays, result.Days, 1) - }) - } -} - -// Note: Tests requiring NylasClient/LLMRouter mocks are in integration tests - -func TestPatternLearner_BuildPatternContext(t *testing.T) { - learner := &PatternLearner{} - - events := []domain.Event{{ID: "test-event"}} - acceptance := []AcceptancePattern{ - {TimeSlot: "Monday 9-11 AM", AcceptRate: 0.85, EventCount: 10, Description: "High acceptance"}, - } - duration := []DurationPattern{ - {MeetingType: "1-on-1", ScheduledDuration: 30, EventCount: 5}, - } - timezone := []TimezonePattern{ - {Timezone: "America/New_York", EventCount: 20, Percentage: 0.8}, - } - productivity := []ProductivityInsight{ - {Description: "Peak productivity in morning"}, - } - - result := learner.buildPatternContext(events, acceptance, duration, timezone, productivity) - - assert.Contains(t, result, "Calendar Analysis (1 events analyzed)") - assert.Contains(t, result, "Meeting Acceptance Patterns") - assert.Contains(t, result, "Monday 9-11 AM") - assert.Contains(t, result, "85%") - assert.Contains(t, result, "Meeting Duration Patterns") - assert.Contains(t, result, "1-on-1") - assert.Contains(t, result, "Timezone Patterns") - assert.Contains(t, result, "America/New_York") - assert.Contains(t, result, "Productivity Insights") - assert.Contains(t, result, "Peak productivity") -} - -func TestPatternLearner_BuildPatternContext_LimitsOutput(t *testing.T) { - learner := &PatternLearner{} - - // Create more than 5 acceptance patterns - acceptance := make([]AcceptancePattern, 10) - for i := 0; i < 10; i++ { - acceptance[i] = AcceptancePattern{ - TimeSlot: "Slot " + string(rune('A'+i)), - AcceptRate: 0.5, - EventCount: 5, - } - } - - // Create more than 3 timezone patterns - timezone := make([]TimezonePattern, 5) - for i := 0; i < 5; i++ { - timezone[i] = TimezonePattern{ - Timezone: "TZ" + string(rune('0'+i)), - EventCount: 10, - Percentage: 0.2, - } - } - - result := learner.buildPatternContext( - []domain.Event{{ID: "test"}}, - acceptance, - []DurationPattern{}, - timezone, - []ProductivityInsight{}, - ) - - // Should only include top 5 acceptance patterns - assert.Contains(t, result, "Slot A") - assert.Contains(t, result, "Slot E") - assert.NotContains(t, result, "Slot F") // 6th pattern should be excluded - - // Should only include top 3 timezone patterns - assert.Contains(t, result, "TZ0") - assert.Contains(t, result, "TZ2") - assert.NotContains(t, result, "TZ3") // 4th pattern should be excluded -} - -func TestPatternLearner_SaveLoadPatterns(t *testing.T) { - ctx := context.Background() - learner := &PatternLearner{} - - t.Run("SavePatterns returns not implemented error", func(t *testing.T) { - // SavePatterns is a stub; returning a real error keeps a caller - // from mistaking the no-op for a successful persist. - err := learner.SavePatterns(ctx, &SchedulingPatterns{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not yet implemented") - }) - - t.Run("LoadPatterns returns not implemented error", func(t *testing.T) { - _, err := learner.LoadPatterns(ctx, "user-123") - assert.Error(t, err) - assert.Contains(t, err.Error(), "not yet implemented") - }) -} - -func TestPatternLearner_ExportPatterns(t *testing.T) { - learner := &PatternLearner{} - - patterns := &SchedulingPatterns{ - UserID: "user-123", - AnalysisPeriod: AnalysisPeriod{ - StartDate: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), - EndDate: time.Date(2024, 1, 31, 0, 0, 0, 0, time.UTC), - Days: 30, - }, - AcceptancePatterns: []AcceptancePattern{ - {TimeSlot: "Monday 9-11 AM", AcceptRate: 0.9}, - }, - DurationPatterns: []DurationPattern{ - {MeetingType: "Standup", ScheduledDuration: 15}, - }, - TimezonePatterns: []TimezonePattern{ - {Timezone: "UTC", Percentage: 1.0}, - }, - ProductivityInsights: []ProductivityInsight{ - {InsightType: "peak_focus", TimeSlot: "Morning"}, - }, - Recommendations: []string{"Schedule focus blocks"}, - TotalEventsAnalyzed: 100, - GeneratedAt: time.Now(), - } - - data, err := learner.ExportPatterns(patterns) - - require.NoError(t, err) - assert.NotEmpty(t, data) - - // Verify it's valid JSON - var parsed SchedulingPatterns - err = json.Unmarshal(data, &parsed) - require.NoError(t, err) - assert.Equal(t, "user-123", parsed.UserID) - assert.Equal(t, 100, parsed.TotalEventsAnalyzed) - assert.Len(t, parsed.AcceptancePatterns, 1) -} - -func TestSchedulingPatternsTypes(t *testing.T) { - t.Run("AnalysisPeriod serializes correctly", func(t *testing.T) { - period := AnalysisPeriod{ - StartDate: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), - EndDate: time.Date(2024, 1, 31, 0, 0, 0, 0, time.UTC), - Days: 30, - } - - data, err := json.Marshal(period) - require.NoError(t, err) - assert.Contains(t, string(data), `"days":30`) - }) - - t.Run("AcceptancePattern serializes correctly", func(t *testing.T) { - pattern := AcceptancePattern{ - TimeSlot: "Monday 9-11 AM", - AcceptRate: 0.85, - EventCount: 10, - Description: "High acceptance", - Confidence: 0.9, - } - - data, err := json.Marshal(pattern) - require.NoError(t, err) - assert.Contains(t, string(data), `"time_slot":"Monday 9-11 AM"`) - assert.Contains(t, string(data), `"accept_rate":0.85`) - assert.Contains(t, string(data), `"confidence":0.9`) - }) - - t.Run("DurationPattern serializes correctly", func(t *testing.T) { - pattern := DurationPattern{ - MeetingType: "1-on-1", - ScheduledDuration: 30, - ActualDuration: 35, - Variance: 5, - EventCount: 20, - Description: "Typically runs over", - } - - data, err := json.Marshal(pattern) - require.NoError(t, err) - assert.Contains(t, string(data), `"meeting_type":"1-on-1"`) - assert.Contains(t, string(data), `"variance":5`) - }) - - t.Run("TimezonePattern serializes correctly", func(t *testing.T) { - pattern := TimezonePattern{ - Timezone: "America/Los_Angeles", - EventCount: 50, - Percentage: 0.6, - PreferredTime: "2-4 PM PST", - Description: "Most meetings in PST", - } - - data, err := json.Marshal(pattern) - require.NoError(t, err) - assert.Contains(t, string(data), `"timezone":"America/Los_Angeles"`) - assert.Contains(t, string(data), `"preferred_time":"2-4 PM PST"`) - }) - - t.Run("ProductivityInsight serializes correctly", func(t *testing.T) { - insight := ProductivityInsight{ - InsightType: "peak_focus", - TimeSlot: "Tuesday 10 AM - 12 PM", - Score: 90, - Description: "Best time for deep work", - BasedOn: []string{"Meeting density", "Focus blocks"}, - } - - data, err := json.Marshal(insight) - require.NoError(t, err) - assert.Contains(t, string(data), `"insight_type":"peak_focus"`) - assert.Contains(t, string(data), `"score":90`) - assert.Contains(t, string(data), `"based_on"`) - }) - - t.Run("LearnPatternsRequest contains all fields", func(t *testing.T) { - req := LearnPatternsRequest{ - GrantID: "grant-123", - LookbackDays: 30, - MinConfidence: 0.8, - IncludeRecurring: true, - } - - data, err := json.Marshal(req) - require.NoError(t, err) - assert.Contains(t, string(data), `"grant_id":"grant-123"`) - assert.Contains(t, string(data), `"lookback_days":30`) - assert.Contains(t, string(data), `"min_confidence":0.8`) - assert.Contains(t, string(data), `"include_recurring":true`) - }) -} diff --git a/internal/adapters/browser/browser.go b/internal/adapters/browser/browser.go index 69d7887..7a70264 100644 --- a/internal/adapters/browser/browser.go +++ b/internal/adapters/browser/browser.go @@ -38,25 +38,3 @@ func createCommand(url string) *exec.Cmd { return exec.Command("xdg-open", url) } } - -// MockBrowser is a mock implementation for testing. -type MockBrowser struct { - OpenCalled bool - LastURL string - OpenFunc func(url string) error -} - -// NewMockBrowser creates a new MockBrowser. -func NewMockBrowser() *MockBrowser { - return &MockBrowser{} -} - -// Open records the call and optionally calls the custom function. -func (m *MockBrowser) Open(url string) error { - m.OpenCalled = true - m.LastURL = url - if m.OpenFunc != nil { - return m.OpenFunc(url) - } - return nil -} diff --git a/internal/adapters/browser/browser_test.go b/internal/adapters/browser/browser_test.go index 7496d1c..64f296f 100644 --- a/internal/adapters/browser/browser_test.go +++ b/internal/adapters/browser/browser_test.go @@ -1,7 +1,6 @@ package browser import ( - "errors" "slices" "testing" ) @@ -27,88 +26,6 @@ func TestDefaultBrowser_Open(t *testing.T) { _ = err } -func TestNewMockBrowser(t *testing.T) { - mock := NewMockBrowser() - if mock == nil { - t.Error("NewMockBrowser() returned nil") - return - } - if mock.OpenCalled { - t.Error("OpenCalled should be false initially") - } - if mock.LastURL != "" { - t.Error("LastURL should be empty initially") - } -} - -func TestMockBrowser_Open(t *testing.T) { - tests := []struct { - name string - url string - openFunc func(url string) error - wantErr bool - }{ - { - name: "successful open", - url: "https://example.com", - openFunc: nil, - wantErr: false, - }, - { - name: "open with error", - url: "https://example.com", - openFunc: func(url string) error { - return errors.New("failed to open") - }, - wantErr: true, - }, - { - name: "open with custom func", - url: "https://custom.com", - openFunc: func(url string) error { - if url != "https://custom.com" { - return errors.New("wrong URL") - } - return nil - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := NewMockBrowser() - mock.OpenFunc = tt.openFunc - - err := mock.Open(tt.url) - - if !mock.OpenCalled { - t.Error("OpenCalled should be true") - } - if mock.LastURL != tt.url { - t.Errorf("LastURL = %q, want %q", mock.LastURL, tt.url) - } - if (err != nil) != tt.wantErr { - t.Errorf("Open() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestMockBrowser_RecordsMultipleCalls(t *testing.T) { - mock := NewMockBrowser() - - _ = mock.Open("https://first.com") - if mock.LastURL != "https://first.com" { - t.Errorf("First URL = %q, want %q", mock.LastURL, "https://first.com") - } - - _ = mock.Open("https://second.com") - if mock.LastURL != "https://second.com" { - t.Errorf("Second URL = %q, want %q", mock.LastURL, "https://second.com") - } -} - func TestCreateCommand(t *testing.T) { tests := []struct { name string @@ -149,36 +66,3 @@ func TestCreateCommand_ReturnsNonNil(t *testing.T) { t.Fatal("createCommand() should never return nil") } } - -func TestMockBrowser_DefaultBehavior(t *testing.T) { - mock := NewMockBrowser() - - // Should not panic when OpenFunc is nil - err := mock.Open("https://example.com") - if err != nil { - t.Errorf("Open() with nil OpenFunc should return nil, got %v", err) - } -} - -func TestMockBrowser_MultipleDifferentURLs(t *testing.T) { - mock := NewMockBrowser() - - urls := []string{ - "https://nylas.com", - "http://localhost:8080", - "https://example.com/auth/callback", - } - - for _, url := range urls { - err := mock.Open(url) - if err != nil { - t.Errorf("Open(%q) returned unexpected error: %v", url, err) - } - if mock.LastURL != url { - t.Errorf("After Open(%q), LastURL = %q", url, mock.LastURL) - } - if !mock.OpenCalled { - t.Error("OpenCalled should remain true") - } - } -} diff --git a/internal/adapters/config/validation.go b/internal/adapters/config/validation.go deleted file mode 100644 index 4df6f01..0000000 --- a/internal/adapters/config/validation.go +++ /dev/null @@ -1,78 +0,0 @@ -package config - -import ( - "fmt" - "os" - "strings" -) - -// RequiredEnvVar represents a required environment variable. -type RequiredEnvVar struct { - Name string - Description string - Optional bool -} - -// ValidateRequiredEnvVars validates that all required environment variables are set. -// Returns a slice of missing variables. -func ValidateRequiredEnvVars(vars []RequiredEnvVar) []string { - var missing []string - - for _, v := range vars { - if v.Optional { - continue - } - - value := os.Getenv(v.Name) - if value == "" { - missing = append(missing, v.Name) - } - } - - return missing -} - -// FormatMissingEnvVars formats missing environment variables into a user-friendly error message. -func FormatMissingEnvVars(missing []string, vars []RequiredEnvVar) string { - if len(missing) == 0 { - return "" - } - - var sb strings.Builder - sb.WriteString("Missing required environment variables:\n\n") - - // Create a map of descriptions - descriptions := make(map[string]string) - for _, v := range vars { - descriptions[v.Name] = v.Description - } - - for _, name := range missing { - desc := descriptions[name] - if desc == "" { - desc = "No description available" - } - _, _ = fmt.Fprintf(&sb, " %s - %s\n", name, desc) - } - - return sb.String() -} - -// ValidateAPICredentials validates Nylas API credentials from environment. -func ValidateAPICredentials() error { - requiredVars := []RequiredEnvVar{ - { - Name: "NYLAS_API_KEY", - Description: "Your Nylas API key", - Optional: false, - }, - } - - missing := ValidateRequiredEnvVars(requiredVars) - if len(missing) > 0 { - return fmt.Errorf("%s\nSet via: export NYLAS_API_KEY=your_key_here\nOr run: nylas auth config", - FormatMissingEnvVars(missing, requiredVars)) - } - - return nil -} diff --git a/internal/adapters/config/validation_test.go b/internal/adapters/config/validation_test.go deleted file mode 100644 index 424c3a6..0000000 --- a/internal/adapters/config/validation_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package config - -import ( - "os" - "strings" - "testing" -) - -func TestValidateRequiredEnvVars(t *testing.T) { - tests := []struct { - name string - vars []RequiredEnvVar - envVars map[string]string - expected []string - }{ - { - name: "all required vars present", - vars: []RequiredEnvVar{ - {Name: "TEST_VAR1", Description: "Test var 1", Optional: false}, - {Name: "TEST_VAR2", Description: "Test var 2", Optional: false}, - }, - envVars: map[string]string{ - "TEST_VAR1": "value1", - "TEST_VAR2": "value2", - }, - expected: []string{}, - }, - { - name: "one required var missing", - vars: []RequiredEnvVar{ - {Name: "TEST_VAR1", Description: "Test var 1", Optional: false}, - {Name: "TEST_VAR2", Description: "Test var 2", Optional: false}, - }, - envVars: map[string]string{ - "TEST_VAR1": "value1", - }, - expected: []string{"TEST_VAR2"}, - }, - { - name: "all required vars missing", - vars: []RequiredEnvVar{ - {Name: "TEST_VAR1", Description: "Test var 1", Optional: false}, - {Name: "TEST_VAR2", Description: "Test var 2", Optional: false}, - }, - envVars: map[string]string{}, - expected: []string{"TEST_VAR1", "TEST_VAR2"}, - }, - { - name: "optional vars ignored", - vars: []RequiredEnvVar{ - {Name: "TEST_VAR1", Description: "Test var 1", Optional: false}, - {Name: "TEST_VAR2", Description: "Test var 2", Optional: true}, - }, - envVars: map[string]string{ - "TEST_VAR1": "value1", - }, - expected: []string{}, - }, - { - name: "no vars to validate", - vars: []RequiredEnvVar{}, - envVars: map[string]string{}, - expected: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Set up environment - for k, v := range tt.envVars { - _ = os.Setenv(k, v) - } - defer func() { - // Clean up - for k := range tt.envVars { - _ = os.Unsetenv(k) - } - for _, v := range tt.vars { - _ = os.Unsetenv(v.Name) - } - }() - - missing := ValidateRequiredEnvVars(tt.vars) - - if len(missing) != len(tt.expected) { - t.Errorf("ValidateRequiredEnvVars() returned %d missing vars, want %d", len(missing), len(tt.expected)) - } - - for i, m := range missing { - if i >= len(tt.expected) { - break - } - if m != tt.expected[i] { - t.Errorf("ValidateRequiredEnvVars() missing[%d] = %q, want %q", i, m, tt.expected[i]) - } - } - }) - } -} - -func TestFormatMissingEnvVars(t *testing.T) { - tests := []struct { - name string - missing []string - vars []RequiredEnvVar - contains []string - }{ - { - name: "no missing vars", - missing: []string{}, - vars: []RequiredEnvVar{}, - contains: []string{}, - }, - { - name: "one missing var", - missing: []string{"TEST_VAR"}, - vars: []RequiredEnvVar{ - {Name: "TEST_VAR", Description: "A test variable"}, - }, - contains: []string{"Missing required environment variables", "TEST_VAR", "A test variable"}, - }, - { - name: "multiple missing vars", - missing: []string{"VAR1", "VAR2"}, - vars: []RequiredEnvVar{ - {Name: "VAR1", Description: "First var"}, - {Name: "VAR2", Description: "Second var"}, - }, - contains: []string{"VAR1", "VAR2", "First var", "Second var"}, - }, - { - name: "missing var with no description", - missing: []string{"NO_DESC"}, - vars: []RequiredEnvVar{ - {Name: "NO_DESC", Description: ""}, - }, - contains: []string{"NO_DESC", "No description available"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := FormatMissingEnvVars(tt.missing, tt.vars) - - if len(tt.missing) == 0 && result != "" { - t.Errorf("FormatMissingEnvVars() with no missing vars should return empty string, got %q", result) - } - - for _, expected := range tt.contains { - if !strings.Contains(result, expected) { - t.Errorf("FormatMissingEnvVars() result should contain %q\nGot: %s", expected, result) - } - } - }) - } -} - -func TestValidateAPICredentials(t *testing.T) { - tests := []struct { - name string - apiKey string - wantError bool - }{ - { - name: "API key present", - apiKey: "test-api-key", - wantError: false, - }, - { - name: "API key missing", - apiKey: "", - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Set up environment - if tt.apiKey != "" { - _ = os.Setenv("NYLAS_API_KEY", tt.apiKey) - } else { - _ = os.Unsetenv("NYLAS_API_KEY") - } - defer func() { - _ = os.Unsetenv("NYLAS_API_KEY") - }() - - err := ValidateAPICredentials() - - if tt.wantError && err == nil { - t.Error("ValidateAPICredentials() expected error, got nil") - } - - if !tt.wantError && err != nil { - t.Errorf("ValidateAPICredentials() unexpected error: %v", err) - } - - if tt.wantError && err != nil { - // Check that error message contains helpful info - errMsg := err.Error() - if !strings.Contains(errMsg, "NYLAS_API_KEY") { - t.Errorf("Error message should mention NYLAS_API_KEY, got: %s", errMsg) - } - } - }) - } -} diff --git a/internal/adapters/dashboard/http.go b/internal/adapters/dashboard/http.go index 9638e9a..b115e01 100644 --- a/internal/adapters/dashboard/http.go +++ b/internal/adapters/dashboard/http.go @@ -144,18 +144,6 @@ func (c *AccountClient) doPostRaw(ctx context.Context, path string, body any, ex return resp.Data, nil } -// unwrapEnvelope extracts the "data" field from the API response envelope. -// The dashboard-account API wraps all successful responses in: -// -// {"request_id": "...", "success": true, "data": {...}} -func unwrapEnvelope(body []byte) ([]byte, error) { - resp, err := unwrapRawResponse(body) - if err != nil { - return nil, err - } - return resp.Data, nil -} - func unwrapRawResponse(body []byte) (rawResponse, error) { var envelope struct { Data json.RawMessage `json:"data"` diff --git a/internal/adapters/dashboard/http_test.go b/internal/adapters/dashboard/http_test.go index 596fcf0..a392784 100644 --- a/internal/adapters/dashboard/http_test.go +++ b/internal/adapters/dashboard/http_test.go @@ -149,54 +149,6 @@ func TestParseErrorResponse(t *testing.T) { } } -func TestUnwrapEnvelope(t *testing.T) { - tests := []struct { - name string - body string - wantKey string - wantErr bool - }{ - { - name: "unwraps data field", - body: `{"request_id":"abc","success":true,"data":{"name":"test"}}`, - wantKey: "name", - }, - { - name: "returns body as-is when no data field", - body: `{"name":"test"}`, - wantKey: "name", - }, - { - name: "returns error on invalid JSON", - body: "not json", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := unwrapEnvelope([]byte(tt.body)) - if tt.wantErr { - if err == nil { - t.Fatal("expected error, got nil") - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - var parsed map[string]any - if jsonErr := json.Unmarshal(result, &parsed); jsonErr != nil { - t.Fatalf("result is not valid JSON: %v", jsonErr) - } - if _, ok := parsed[tt.wantKey]; !ok { - t.Errorf("result missing key %q: %s", tt.wantKey, string(result)) - } - }) - } -} - func TestDashboardAPIError_Error(t *testing.T) { tests := []struct { name string diff --git a/internal/adapters/gpg/encrypt.go b/internal/adapters/gpg/encrypt.go index 2a4186c..b22520c 100644 --- a/internal/adapters/gpg/encrypt.go +++ b/internal/adapters/gpg/encrypt.go @@ -14,7 +14,7 @@ import ( const keyserverFetchTimeout = 5 * time.Second // ListPublicKeys lists all public keys in the keyring. -func (s *service) ListPublicKeys(ctx context.Context) ([]KeyInfo, error) { +func (s *Service) ListPublicKeys(ctx context.Context) ([]KeyInfo, error) { cmd := exec.CommandContext(ctx, "gpg", "--list-keys", "--with-colons", "--with-fingerprint") output, err := cmd.Output() if err != nil { @@ -29,7 +29,7 @@ func (s *service) ListPublicKeys(ctx context.Context) ([]KeyInfo, error) { } // FindPublicKeyByEmail finds a public key by email, auto-fetching from key servers if not found locally. -func (s *service) FindPublicKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) { +func (s *Service) FindPublicKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) { // Normalize email for comparison email = strings.ToLower(strings.TrimSpace(email)) @@ -77,7 +77,7 @@ func (s *service) FindPublicKeyByEmail(ctx context.Context, email string) (*KeyI } // fetchKeyByEmail tries to fetch a public key by email from key servers. -func (s *service) fetchKeyByEmail(ctx context.Context, email string) error { +func (s *Service) fetchKeyByEmail(ctx context.Context, email string) error { // Validate email format parsed, err := mail.ParseAddress(email) if err != nil { @@ -150,7 +150,7 @@ func keyMatchesEmail(key *KeyInfo, email string) bool { } // EncryptData encrypts data for one or more recipients using their public keys. -func (s *service) EncryptData(ctx context.Context, recipientKeyIDs []string, data []byte) (*EncryptResult, error) { +func (s *Service) EncryptData(ctx context.Context, recipientKeyIDs []string, data []byte) (*EncryptResult, error) { if len(recipientKeyIDs) == 0 { return nil, fmt.Errorf("at least one recipient key ID is required") } @@ -208,7 +208,7 @@ func (s *service) EncryptData(ctx context.Context, recipientKeyIDs []string, dat // SignAndEncryptData signs data with the sender's private key and encrypts for recipients. // This provides maximum security: only recipients can decrypt, and they can verify the sender. -func (s *service) SignAndEncryptData(ctx context.Context, signerKeyID string, recipientKeyIDs []string, data []byte, senderEmail string) (*EncryptResult, error) { +func (s *Service) SignAndEncryptData(ctx context.Context, signerKeyID string, recipientKeyIDs []string, data []byte, senderEmail string) (*EncryptResult, error) { if signerKeyID == "" { return nil, fmt.Errorf("signer key ID is required for sign+encrypt") } @@ -294,7 +294,7 @@ func (s *service) SignAndEncryptData(ctx context.Context, signerKeyID string, re // DecryptData decrypts PGP encrypted data using the user's private key. // It also handles signed+encrypted messages, returning signature verification info. -func (s *service) DecryptData(ctx context.Context, ciphertext []byte) (*DecryptResult, error) { +func (s *Service) DecryptData(ctx context.Context, ciphertext []byte) (*DecryptResult, error) { if len(ciphertext) == 0 { return nil, fmt.Errorf("ciphertext is empty") } diff --git a/internal/adapters/gpg/service.go b/internal/adapters/gpg/service.go index 039e8ec..89d60d2 100644 --- a/internal/adapters/gpg/service.go +++ b/internal/adapters/gpg/service.go @@ -46,57 +46,17 @@ func isValidGPGKeyID(keyID string) bool { return false } -// Service provides GPG signing, verification, and encryption operations. -type Service interface { - // CheckGPGAvailable verifies GPG is installed and accessible. - CheckGPGAvailable(ctx context.Context) error - - // ListSigningKeys lists all available secret keys for signing. - ListSigningKeys(ctx context.Context) ([]KeyInfo, error) - - // GetDefaultSigningKey gets the default signing key from git config. - GetDefaultSigningKey(ctx context.Context) (*KeyInfo, error) - - // FindKeyByEmail finds a signing key that contains the given email in its UIDs. - // Returns the key ID (not the email) for use with --local-user. - FindKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) - - // SignData signs data with the specified key and returns a detached signature. - // senderEmail is optional - when provided, it embeds that email in the Signer's User ID subpacket. - SignData(ctx context.Context, keyID string, data []byte, senderEmail string) (*SignResult, error) - - // VerifyDetachedSignature verifies a detached signature against data. - // Returns verification result including signer info and trust level. - VerifyDetachedSignature(ctx context.Context, data []byte, signature []byte) (*VerifyResult, error) - - // ListPublicKeys lists all public keys in the keyring. - ListPublicKeys(ctx context.Context) ([]KeyInfo, error) - - // FindPublicKeyByEmail finds a public key by email, auto-fetching from key servers if not found locally. - FindPublicKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) - - // EncryptData encrypts data for one or more recipients using their public keys. - EncryptData(ctx context.Context, recipientKeyIDs []string, data []byte) (*EncryptResult, error) - - // SignAndEncryptData signs data with the sender's private key and encrypts for recipients. - // This provides maximum security: only recipients can decrypt, and they can verify the sender. - SignAndEncryptData(ctx context.Context, signerKeyID string, recipientKeyIDs []string, data []byte, senderEmail string) (*EncryptResult, error) - - // DecryptData decrypts PGP encrypted data using the user's private key. - // Returns the decrypted plaintext along with optional signature verification info. - DecryptData(ctx context.Context, ciphertext []byte) (*DecryptResult, error) -} - -// service implements Service using the system GPG command. -type service struct{} +// Service provides GPG signing, verification, and encryption operations +// using the system GPG command. +type Service struct{} // NewService creates a new GPG service. -func NewService() Service { - return &service{} +func NewService() *Service { + return &Service{} } // CheckGPGAvailable verifies GPG is installed. -func (s *service) CheckGPGAvailable(ctx context.Context) error { +func (s *Service) CheckGPGAvailable(ctx context.Context) error { cmd := exec.CommandContext(ctx, "gpg", "--version") if err := cmd.Run(); err != nil { return fmt.Errorf("GPG not found. Install with: sudo apt install gnupg (Linux) or brew install gnupg (macOS)") @@ -105,7 +65,7 @@ func (s *service) CheckGPGAvailable(ctx context.Context) error { } // ListSigningKeys lists all secret keys available for signing. -func (s *service) ListSigningKeys(ctx context.Context) ([]KeyInfo, error) { +func (s *Service) ListSigningKeys(ctx context.Context) ([]KeyInfo, error) { // Use --with-colons format for reliable parsing cmd := exec.CommandContext(ctx, "gpg", "--list-secret-keys", "--with-colons", "--with-fingerprint") output, err := cmd.Output() @@ -120,7 +80,7 @@ func (s *service) ListSigningKeys(ctx context.Context) ([]KeyInfo, error) { } // GetDefaultSigningKey retrieves the default signing key from git config. -func (s *service) GetDefaultSigningKey(ctx context.Context) (*KeyInfo, error) { +func (s *Service) GetDefaultSigningKey(ctx context.Context) (*KeyInfo, error) { // Try to get key from git config cmd := exec.CommandContext(ctx, "git", "config", "--get", "user.signingkey") output, err := cmd.Output() @@ -153,7 +113,7 @@ func (s *service) GetDefaultSigningKey(ctx context.Context) (*KeyInfo, error) { // Returns the KeyInfo with the actual key ID for use with --local-user. // This is important because GPG's --sender option only works correctly when // --local-user is a key ID, not an email address. -func (s *service) FindKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) { +func (s *Service) FindKeyByEmail(ctx context.Context, email string) (*KeyInfo, error) { keys, err := s.ListSigningKeys(ctx) if err != nil { return nil, err @@ -178,7 +138,7 @@ func (s *service) FindKeyByEmail(ctx context.Context, email string) (*KeyInfo, e // SignData creates a detached signature for the given data. // senderEmail is optional - when provided, it embeds that email in the Signer's User ID subpacket. -func (s *service) SignData(ctx context.Context, keyID string, data []byte, senderEmail string) (*SignResult, error) { +func (s *Service) SignData(ctx context.Context, keyID string, data []byte, senderEmail string) (*SignResult, error) { // Validate keyID to prevent command injection (SEC-001) if !isValidGPGKeyID(keyID) { return nil, fmt.Errorf("invalid GPG key ID format: %q", keyID) @@ -253,7 +213,7 @@ var KeyServers = []string{ } // VerifyDetachedSignature verifies a detached signature against data. -func (s *service) VerifyDetachedSignature(ctx context.Context, data []byte, signature []byte) (*VerifyResult, error) { +func (s *Service) VerifyDetachedSignature(ctx context.Context, data []byte, signature []byte) (*VerifyResult, error) { // Create temporary files for data and signature dataFile, err := createTempFile("gpg-verify-data-", data) if err != nil { @@ -305,7 +265,7 @@ func (s *service) VerifyDetachedSignature(ctx context.Context, data []byte, sign } // runVerify executes gpg --verify and returns the parsed result. -func (s *service) runVerify(ctx context.Context, sigFile, dataFile string) (*VerifyResult, string, error) { +func (s *Service) runVerify(ctx context.Context, sigFile, dataFile string) (*VerifyResult, string, error) { cmd := exec.CommandContext(ctx, "gpg", "--verify", "--status-fd", "1", sigFile, dataFile) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout @@ -321,7 +281,7 @@ func (s *service) runVerify(ctx context.Context, sigFile, dataFile string) (*Ver // fetchKeyFromServer attempts to fetch a public key from multiple key servers. // It tries each server in order until one succeeds. -func (s *service) fetchKeyFromServer(ctx context.Context, keyID string) error { +func (s *Service) fetchKeyFromServer(ctx context.Context, keyID string) error { var lastErr error for _, server := range KeyServers { diff --git a/internal/adapters/keyring/crossplatform_test.go b/internal/adapters/keyring/crossplatform_test.go index b2b9980..6456865 100644 --- a/internal/adapters/keyring/crossplatform_test.go +++ b/internal/adapters/keyring/crossplatform_test.go @@ -656,74 +656,3 @@ func TestEncryptedFileStore_MigratesOnFirstGet(t *testing.T) { t.Fatalf("failed to decrypt migrated file with passphrase-derived key: %v", err) } } - -// TestDetectKeyType verifies the detectKeyType helper across the expected states. -func TestDetectKeyType(t *testing.T) { - t.Run("none_when_no_file", func(t *testing.T) { - tmpDir := t.TempDir() - setFileStorePassphrase(t) - - store, err := NewEncryptedFileStore(tmpDir) - // Fresh install with passphrase set — construction should succeed. - if err != nil { - t.Fatalf("NewEncryptedFileStore failed: %v", err) - } - // No file written yet. - kt, err := store.detectKeyType() - if err != nil { - t.Fatalf("detectKeyType failed: %v", err) - } - if kt != fileStoreKeyNone { - t.Fatalf("detectKeyType = %d, want fileStoreKeyNone (%d)", kt, fileStoreKeyNone) - } - }) - - t.Run("passphrase_only_after_write", func(t *testing.T) { - tmpDir := t.TempDir() - setFileStorePassphrase(t) - - store, err := NewEncryptedFileStore(tmpDir) - if err != nil { - t.Fatalf("NewEncryptedFileStore failed: %v", err) - } - if err := store.Set("k", "v"); err != nil { - t.Fatalf("Set failed: %v", err) - } - kt, err := store.detectKeyType() - if err != nil { - t.Fatalf("detectKeyType failed: %v", err) - } - if kt != fileStoreKeyPassphraseOnly { - t.Fatalf("detectKeyType = %d, want fileStoreKeyPassphraseOnly (%d)", kt, fileStoreKeyPassphraseOnly) - } - }) - - t.Run("legacy_only_before_migration", func(t *testing.T) { - tmpDir := t.TempDir() - setFileStorePassphrase(t) - - legacyKey, err := deriveLegacyKey() - if err != nil { - t.Fatalf("deriveLegacyKey failed: %v", err) - } - ct, err := encryptWithKey(legacyKey, []byte(`{"k":"v"}`)) - if err != nil { - t.Fatalf("encryptWithKey failed: %v", err) - } - if err := os.WriteFile(filepath.Join(tmpDir, ".secrets.enc"), ct, 0600); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - store, err := NewEncryptedFileStore(tmpDir) - if err != nil { - t.Fatalf("NewEncryptedFileStore failed: %v", err) - } - kt, err := store.detectKeyType() - if err != nil { - t.Fatalf("detectKeyType failed: %v", err) - } - if kt != fileStoreKeyLegacyOnly { - t.Fatalf("detectKeyType = %d, want fileStoreKeyLegacyOnly (%d)", kt, fileStoreKeyLegacyOnly) - } - }) -} diff --git a/internal/adapters/keyring/file.go b/internal/adapters/keyring/file.go index a9fef6d..8943ac8 100644 --- a/internal/adapters/keyring/file.go +++ b/internal/adapters/keyring/file.go @@ -20,16 +20,6 @@ const ( fileStoreSaltSize = 16 ) -// fileStoreKeyType describes which key(s) can decrypt the on-disk .secrets.enc file. -type fileStoreKeyType int - -const ( - fileStoreKeyNone fileStoreKeyType = iota // file does not exist or neither key decrypts it - fileStoreKeyLegacyOnly // decryptable only with the legacy machine-derived key - fileStoreKeyPassphraseOnly // decryptable only with the passphrase-derived key - fileStoreKeyBoth // decryptable with either key -) - // EncryptedFileStore implements SecretStore using an encrypted file. // This is a fallback for environments where the system keyring is unavailable. // Uses AES-256-GCM encryption with an Argon2id key derived from a user-supplied @@ -179,56 +169,6 @@ func (f *EncryptedFileStore) Name() string { return "encrypted file" } -// detectKeyType returns which key(s) can currently decrypt the on-disk file. -// It reads the file once and probes each key in order. If the file does not -// exist, fileStoreKeyNone is returned with no error. -func (f *EncryptedFileStore) detectKeyType() (fileStoreKeyType, error) { - data, err := os.ReadFile(f.path) - if err != nil { - if os.IsNotExist(err) { - return fileStoreKeyNone, nil - } - return fileStoreKeyNone, err - } - - hasPassphrase := false - if key, err := f.passphraseKey(false); err == nil { - if _, err := decryptWithKey(key, data); err == nil { - hasPassphrase = true - } - zeroBytes(key) - } - - hasLegacy := f.canDecryptWithLegacyKeys(data) - - switch { - case hasPassphrase && hasLegacy: - return fileStoreKeyBoth, nil - case hasPassphrase: - return fileStoreKeyPassphraseOnly, nil - case hasLegacy: - return fileStoreKeyLegacyOnly, nil - default: - return fileStoreKeyNone, nil - } -} - -// canDecryptWithLegacyKeys returns true when the ciphertext can be opened by -// either the migration master key or the legacy machine-derived key. -func (f *EncryptedFileStore) canDecryptWithLegacyKeys(data []byte) bool { - if len(f.migrationKey) > 0 { - if _, err := decryptWithKey(f.migrationKey, data); err == nil { - return true - } - } - if len(f.legacyKey) > 0 { - if _, err := decryptWithKey(f.legacyKey, data); err == nil { - return true - } - } - return false -} - // loadSecrets loads and decrypts the secrets file. func (f *EncryptedFileStore) loadSecrets() (map[string]string, error) { data, err := os.ReadFile(f.path) diff --git a/internal/adapters/mcp/proxy.go b/internal/adapters/mcp/proxy.go index cde8ccd..16a5ecf 100644 --- a/internal/adapters/mcp/proxy.go +++ b/internal/adapters/mcp/proxy.go @@ -14,7 +14,7 @@ import ( "strings" "sync" - "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" ) @@ -25,9 +25,6 @@ const ( NylasMCPEndpointEU = "https://mcp.eu.nylas.com" ) -// DefaultTimeout for HTTP requests - uses centralized domain constant. -var DefaultTimeout = domain.TimeoutMCP - // GetMCPEndpoint returns the appropriate MCP endpoint for the given region. func GetMCPEndpoint(region string) string { switch strings.ToLower(region) { @@ -69,9 +66,7 @@ func NewProxy(apiKey, region string) *Proxy { endpoint: GetMCPEndpoint(region), apiKey: apiKey, authHeader: "Bearer " + apiKey, // Cache auth header - httpClient: &http.Client{ - Timeout: DefaultTimeout, - }, + httpClient: httputil.DefaultClient, } } diff --git a/internal/adapters/mime/builder.go b/internal/adapters/mime/builder.go index 286bcd9..66fe5a6 100644 --- a/internal/adapters/mime/builder.go +++ b/internal/adapters/mime/builder.go @@ -14,23 +14,6 @@ import ( "github.com/nylas/cli/internal/domain" ) -// Builder constructs MIME messages. -type Builder interface { - // BuildSignedMessage builds a PGP/MIME signed message (RFC 3156). - BuildSignedMessage(req *SignedMessageRequest) ([]byte, error) - - // PrepareContentToSign prepares the MIME content part that will be signed. - // Returns the exact bytes that should be signed with GPG. - // This includes the part headers and encoded body with CRLF line endings. - PrepareContentToSign(body, contentType string, attachments []domain.Attachment) ([]byte, error) - - // BuildEncryptedMessage builds a PGP/MIME encrypted message (RFC 3156 Section 4). - BuildEncryptedMessage(req *EncryptedMessageRequest) ([]byte, error) - - // PrepareContentToEncrypt prepares the MIME content that will be encrypted. - PrepareContentToEncrypt(body, contentType string, attachments []domain.Attachment) ([]byte, error) -} - // messageRequest is an interface for shared email header fields. // Both SignedMessageRequest and EncryptedMessageRequest implement this. type messageRequest interface { @@ -82,16 +65,16 @@ func (r *SignedMessageRequest) getHeaders() map[string]string { return r func (r *SignedMessageRequest) getMessageID() string { return r.MessageID } func (r *SignedMessageRequest) getDate() time.Time { return r.Date } -// builder implements Builder. -type builder struct{} +// Builder constructs MIME messages. +type Builder struct{} // NewBuilder creates a new MIME builder. -func NewBuilder() Builder { - return &builder{} +func NewBuilder() *Builder { + return &Builder{} } // BuildSignedMessage constructs a PGP/MIME signed message per RFC 3156. -func (b *builder) BuildSignedMessage(req *SignedMessageRequest) ([]byte, error) { +func (b *Builder) BuildSignedMessage(req *SignedMessageRequest) ([]byte, error) { if err := validateSignedRequest(req); err != nil { return nil, err } @@ -187,13 +170,13 @@ func writeCommonHeaders(buf *bytes.Buffer, req messageRequest) { } // writeHeaders writes RFC 822 headers for signed messages. -func (b *builder) writeHeaders(buf *bytes.Buffer, req *SignedMessageRequest) error { +func (b *Builder) writeHeaders(buf *bytes.Buffer, req *SignedMessageRequest) error { writeCommonHeaders(buf, req) return nil } // writeContentPart writes the signed content part (body + attachments if any). -func (b *builder) writeContentPart(buf *bytes.Buffer, req *SignedMessageRequest) error { +func (b *Builder) writeContentPart(buf *bytes.Buffer, req *SignedMessageRequest) error { if len(req.Attachments) == 0 { // Simple case: just body return b.writeBodyPart(buf, req) @@ -223,7 +206,7 @@ func (b *builder) writeContentPart(buf *bytes.Buffer, req *SignedMessageRequest) } // writeBodyPart writes the email body. -func (b *builder) writeBodyPart(buf *bytes.Buffer, req *SignedMessageRequest) error { +func (b *Builder) writeBodyPart(buf *bytes.Buffer, req *SignedMessageRequest) error { contentType := req.ContentType if contentType == "" { contentType = "text/plain" @@ -251,7 +234,7 @@ func (b *builder) writeBodyPart(buf *bytes.Buffer, req *SignedMessageRequest) er } // writeAttachmentPart writes an attachment part. -func (b *builder) writeAttachmentPart(buf *bytes.Buffer, att *domain.Attachment) error { +func (b *Builder) writeAttachmentPart(buf *bytes.Buffer, att *domain.Attachment) error { contentType := att.ContentType if contentType == "" { contentType = "application/octet-stream" @@ -391,7 +374,7 @@ func getMicAlg(hashAlgo string) string { // PrepareContentToSign prepares the MIME content part for signing. // This returns the exact bytes that should be signed with GPG. -func (b *builder) PrepareContentToSign(body, contentType string, attachments []domain.Attachment) ([]byte, error) { +func (b *Builder) PrepareContentToSign(body, contentType string, attachments []domain.Attachment) ([]byte, error) { if contentType == "" { contentType = "text/plain" } diff --git a/internal/adapters/mime/encrypted.go b/internal/adapters/mime/encrypted.go index cc1316b..cfad030 100644 --- a/internal/adapters/mime/encrypted.go +++ b/internal/adapters/mime/encrypted.go @@ -57,7 +57,7 @@ func (r *EncryptedMessageRequest) getDate() time.Time { retur // [Encrypted content] // -----END PGP MESSAGE----- // --boundary-- -func (b *builder) BuildEncryptedMessage(req *EncryptedMessageRequest) ([]byte, error) { +func (b *Builder) BuildEncryptedMessage(req *EncryptedMessageRequest) ([]byte, error) { if err := validateEncryptedRequest(req); err != nil { return nil, err } @@ -98,7 +98,7 @@ func (b *builder) BuildEncryptedMessage(req *EncryptedMessageRequest) ([]byte, e } // writeEncryptedHeaders writes RFC 822 headers for encrypted messages. -func (b *builder) writeEncryptedHeaders(buf *bytes.Buffer, req *EncryptedMessageRequest) error { +func (b *Builder) writeEncryptedHeaders(buf *bytes.Buffer, req *EncryptedMessageRequest) error { writeCommonHeaders(buf, req) return nil } @@ -116,7 +116,7 @@ func validateEncryptedRequest(req *EncryptedMessageRequest) error { // PrepareContentToEncrypt prepares the MIME content that will be encrypted. // This builds a complete MIME body (with attachments if any) that gets encrypted as a whole. -func (b *builder) PrepareContentToEncrypt(body, contentType string, attachments []domain.Attachment) ([]byte, error) { +func (b *Builder) PrepareContentToEncrypt(body, contentType string, attachments []domain.Attachment) ([]byte, error) { // Reuse the PrepareContentToSign logic since the content structure is the same // The only difference is what we do with the result (encrypt vs sign) return b.PrepareContentToSign(body, contentType, attachments) diff --git a/internal/adapters/nylas/attachments.go b/internal/adapters/nylas/attachments.go index ab788ba..9496c3c 100644 --- a/internal/adapters/nylas/attachments.go +++ b/internal/adapters/nylas/attachments.go @@ -47,30 +47,20 @@ func (c *HTTPClient) GetAttachment(ctx context.Context, grantID, messageID, atta func (c *HTTPClient) DownloadAttachment(ctx context.Context, grantID, messageID, attachmentID string) (io.ReadCloser, error) { queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/attachments/%s/download", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID), url.PathEscape(attachmentID)) - // The response body streams under the request context, so the default - // API timeout would cut off large/slow downloads mid-stream. Apply the - // dedicated download timeout when the caller hasn't set a deadline. - cancel := context.CancelFunc(func() {}) - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - ctx, cancel = context.WithTimeout(ctx, domain.TimeoutDownload) - } - req, err := http.NewRequestWithContext(ctx, "GET", queryURL, nil) if err != nil { - cancel() return nil, err } c.setAuthHeader(req) + // doRequest applies the per-request timeout and wraps the streaming body + // to release its context on close. The shared client also caps the whole + // transfer at the server-side 120s ceiling. resp, err := c.doRequest(ctx, req) if err != nil { - cancel() return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - // Release the download context when the body is closed or fully read. - resp.Body = &cancelOnCloseBody{ReadCloser: resp.Body, cancel: cancel} - if resp.StatusCode == http.StatusNotFound { _ = resp.Body.Close() return nil, domain.ErrAttachmentNotFound diff --git a/internal/adapters/nylas/attachments_test.go b/internal/adapters/nylas/attachments_test.go index e724858..a1588ea 100644 --- a/internal/adapters/nylas/attachments_test.go +++ b/internal/adapters/nylas/attachments_test.go @@ -207,14 +207,14 @@ func TestHTTPClient_DownloadAttachment(t *testing.T) { assert.Nil(t, reader) }) - t.Run("streams past the default request timeout", func(t *testing.T) { + t.Run("download is bound by the request timeout", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("first-chunk-")) if f, ok := w.(http.Flusher); ok { f.Flush() } - // Stall mid-stream for longer than the default request timeout. + // Stall mid-stream for longer than the request timeout. time.Sleep(300 * time.Millisecond) _, _ = w.Write([]byte("second-chunk")) })) @@ -223,17 +223,16 @@ func TestHTTPClient_DownloadAttachment(t *testing.T) { client := nylas.NewHTTPClient() client.SetCredentials("client-id", "secret", "api-key") client.SetBaseURL(server.URL) - // Shrink the default API timeout below the server's mid-stream stall to - // prove downloads use the dedicated (longer) download timeout instead. + // Downloads are not exempt from the timeout: a stream that stalls past + // the request timeout is cut off (the shared client also caps at 120s). client.SetRequestTimeout(50 * time.Millisecond) reader, err := client.DownloadAttachment(context.Background(), "grant-123", "msg-456", "attach-789") require.NoError(t, err) defer func() { _ = reader.Close() }() - content, err := io.ReadAll(reader) - require.NoError(t, err) - assert.Equal(t, "first-chunk-second-chunk", string(content)) + _, err = io.ReadAll(reader) + require.Error(t, err) }) } diff --git a/internal/adapters/nylas/client.go b/internal/adapters/nylas/client.go index aa7c61b..2e97f5e 100644 --- a/internal/adapters/nylas/client.go +++ b/internal/adapters/nylas/client.go @@ -16,6 +16,7 @@ import ( "github.com/nylas/cli/internal/adapters/providers" "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" "github.com/nylas/cli/internal/version" "golang.org/x/time/rate" @@ -74,11 +75,11 @@ type HTTPClient struct { // Retry logic handles transient errors with exponential backoff and Retry-After header support. func NewHTTPClient() *HTTPClient { return &HTTPClient{ - httpClient: &http.Client{ - // Remove global timeout since we use per-request context timeouts - Timeout: 0, - }, - baseURL: baseURLUS, + // Nylas enforces a server-side 120s request ceiling, which matches the + // shared client's default timeout. Per-request context deadlines + // (requestTimeout) still apply as the tighter bound for normal calls. + httpClient: httputil.DefaultClient, + baseURL: baseURLUS, // Create token bucket rate limiter: 10 requests/second, burst of 20 rateLimiter: rate.NewLimiter(rate.Limit(defaultRateLimit), defaultRateLimit*2), requestTimeout: domain.TimeoutAPI, diff --git a/internal/adapters/nylas/client_helpers.go b/internal/adapters/nylas/client_helpers.go index 5095e08..c762add 100644 --- a/internal/adapters/nylas/client_helpers.go +++ b/internal/adapters/nylas/client_helpers.go @@ -230,16 +230,6 @@ func (qb *QueryBuilder) AddSlice(key string, values []string) *QueryBuilder { return qb } -// Encode returns the encoded query string. -func (qb *QueryBuilder) Encode() string { - return qb.values.Encode() -} - -// Values returns the underlying url.Values. -func (qb *QueryBuilder) Values() url.Values { - return qb.values -} - // BuildURL appends the query string to the base URL if there are parameters. func (qb *QueryBuilder) BuildURL(baseURL string) string { if len(qb.values) == 0 { diff --git a/internal/adapters/nylas/security_test.go b/internal/adapters/nylas/security_test.go index 15ae620..6276ffd 100644 --- a/internal/adapters/nylas/security_test.go +++ b/internal/adapters/nylas/security_test.go @@ -5,6 +5,8 @@ package nylas import ( "testing" + + "github.com/nylas/cli/internal/httputil" ) // TestOTPExtractionSecurity tests that OTP extraction is secure. @@ -116,12 +118,13 @@ func TestOTPExtractionSecurity(t *testing.T) { // TestHTTPClientSecurity tests HTTP client security. func TestHTTPClientSecurity(t *testing.T) { - t.Run("client_uses_per_request_timeouts", func(t *testing.T) { + t.Run("client_timeout_matches_server_ceiling", func(t *testing.T) { client := NewHTTPClient() - // HTTP client should NOT have a global timeout (we use per-request context timeouts) - // This allows better control and prevents blocking other requests - if client.httpClient.Timeout != 0 { - t.Error("HTTP client should not have global timeout (uses per-request context timeouts)") + // The HTTP client carries the shared 120s timeout, matching the Nylas + // server-side request ceiling. Per-request context deadlines + // (requestTimeout) still apply as the tighter bound for normal calls. + if client.httpClient.Timeout != httputil.DefaultClientTimeout { + t.Errorf("HTTP client timeout = %v, want %v (server ceiling)", client.httpClient.Timeout, httputil.DefaultClientTimeout) } // Verify rate limiter is configured if client.rateLimiter == nil { diff --git a/internal/adapters/providers/registry.go b/internal/adapters/providers/registry.go index 00983cb..0c0e63a 100644 --- a/internal/adapters/providers/registry.go +++ b/internal/adapters/providers/registry.go @@ -2,7 +2,6 @@ package providers import ( - "fmt" "sync" "github.com/nylas/cli/internal/ports" @@ -37,36 +36,3 @@ func Register(name string, factory ProviderFactory) { defer defaultRegistry.mu.Unlock() defaultRegistry.factories[name] = factory } - -// Get returns a provider factory by name -func Get(name string) (ProviderFactory, error) { - defaultRegistry.mu.RLock() - defer defaultRegistry.mu.RUnlock() - - factory, ok := defaultRegistry.factories[name] - if !ok { - return nil, fmt.Errorf("provider %q not registered", name) - } - return factory, nil -} - -// List returns all registered provider names -func List() []string { - defaultRegistry.mu.RLock() - defer defaultRegistry.mu.RUnlock() - - names := make([]string, 0, len(defaultRegistry.factories)) - for name := range defaultRegistry.factories { - names = append(names, name) - } - return names -} - -// NewClient creates a new provider client by name -func NewClient(provider string, config ProviderConfig) (ports.NylasClient, error) { - factory, err := Get(provider) - if err != nil { - return nil, err - } - return factory(config) -} diff --git a/internal/adapters/templates/store.go b/internal/adapters/templates/store.go index 86238c3..9a1392d 100644 --- a/internal/adapters/templates/store.go +++ b/internal/adapters/templates/store.go @@ -19,7 +19,7 @@ import ( // ErrTemplateNotFound is returned when a template is not found. var ErrTemplateNotFound = errors.New("template not found") -// FileStore implements TemplateStore using a JSON file. +// FileStore stores email templates in a JSON file. type FileStore struct { path string mu sync.RWMutex diff --git a/internal/adapters/utilities/mock.go b/internal/adapters/utilities/mock.go deleted file mode 100644 index b2a80f7..0000000 --- a/internal/adapters/utilities/mock.go +++ /dev/null @@ -1,79 +0,0 @@ -package utilities - -import ( - "context" - "time" - - "github.com/nylas/cli/internal/domain" -) - -// MockUtilityServices implements ports.UtilityServices for testing. -type MockUtilityServices struct { - // TimeZoneService - ConvertTimeFunc func(ctx context.Context, fromZone, toZone string, t time.Time) (time.Time, error) - FindMeetingTimeFunc func(ctx context.Context, req *domain.MeetingFinderRequest) (*domain.MeetingTimeSlots, error) - GetDSTTransitionsFunc func(ctx context.Context, zone string, year int) ([]domain.DSTTransition, error) - ListTimeZonesFunc func(ctx context.Context) ([]string, error) - GetTimeZoneInfoFunc func(ctx context.Context, zone string, at time.Time) (*domain.TimeZoneInfo, error) -} - -// NewMockUtilityServices creates a new mock utility services with sensible defaults. -func NewMockUtilityServices() *MockUtilityServices { - return &MockUtilityServices{ - // TimeZoneService defaults - ConvertTimeFunc: func(ctx context.Context, fromZone, toZone string, t time.Time) (time.Time, error) { - return t, nil - }, - FindMeetingTimeFunc: func(ctx context.Context, req *domain.MeetingFinderRequest) (*domain.MeetingTimeSlots, error) { - return &domain.MeetingTimeSlots{Slots: []domain.MeetingSlot{}, TimeZones: req.TimeZones}, nil - }, - GetDSTTransitionsFunc: func(ctx context.Context, zone string, year int) ([]domain.DSTTransition, error) { - return []domain.DSTTransition{}, nil - }, - ListTimeZonesFunc: func(ctx context.Context) ([]string, error) { - return []string{"UTC", "America/New_York", "Europe/London"}, nil - }, - GetTimeZoneInfoFunc: func(ctx context.Context, zone string, at time.Time) (*domain.TimeZoneInfo, error) { - return &domain.TimeZoneInfo{Name: zone, Abbreviation: "UTC", Offset: 0}, nil - }, - } -} - -// ============================================================================ -// TimeZoneService implementation -// ============================================================================ - -func (m *MockUtilityServices) ConvertTime(ctx context.Context, fromZone, toZone string, t time.Time) (time.Time, error) { - if m.ConvertTimeFunc != nil { - return m.ConvertTimeFunc(ctx, fromZone, toZone, t) - } - return t, nil -} - -func (m *MockUtilityServices) FindMeetingTime(ctx context.Context, req *domain.MeetingFinderRequest) (*domain.MeetingTimeSlots, error) { - if m.FindMeetingTimeFunc != nil { - return m.FindMeetingTimeFunc(ctx, req) - } - return &domain.MeetingTimeSlots{}, nil -} - -func (m *MockUtilityServices) GetDSTTransitions(ctx context.Context, zone string, year int) ([]domain.DSTTransition, error) { - if m.GetDSTTransitionsFunc != nil { - return m.GetDSTTransitionsFunc(ctx, zone, year) - } - return []domain.DSTTransition{}, nil -} - -func (m *MockUtilityServices) ListTimeZones(ctx context.Context) ([]string, error) { - if m.ListTimeZonesFunc != nil { - return m.ListTimeZonesFunc(ctx) - } - return []string{}, nil -} - -func (m *MockUtilityServices) GetTimeZoneInfo(ctx context.Context, zone string, at time.Time) (*domain.TimeZoneInfo, error) { - if m.GetTimeZoneInfoFunc != nil { - return m.GetTimeZoneInfoFunc(ctx, zone, at) - } - return &domain.TimeZoneInfo{}, nil -} diff --git a/internal/adapters/utilities/timezone/service.go b/internal/adapters/utilities/timezone/service.go index 903a5a5..6b33f24 100644 --- a/internal/adapters/utilities/timezone/service.go +++ b/internal/adapters/utilities/timezone/service.go @@ -10,8 +10,7 @@ import ( "github.com/nylas/cli/internal/domain" ) -// Service implements ports.TimeZoneService. -// Provides time zone conversion, meeting finder, and DST transition utilities. +// Service provides time zone conversion, meeting finder, and DST transition utilities. type Service struct{} // NewService creates a new time zone service. diff --git a/internal/adapters/webhookserver/server.go b/internal/adapters/webhookserver/server.go index 8ee72bb..08544f2 100644 --- a/internal/adapters/webhookserver/server.go +++ b/internal/adapters/webhookserver/server.go @@ -71,7 +71,7 @@ var rootTemplate = template.Must(template.New("root").Parse(` `)) -// Server implements the WebhookServer interface. +// Server is a local webhook receiver server. type Server struct { config ports.WebhookServerConfig server *http.Server diff --git a/internal/air/handlers_reply_later.go b/internal/air/handlers_reply_later.go index cb5d18e..33bf681 100644 --- a/internal/air/handlers_reply_later.go +++ b/internal/air/handlers_reply_later.go @@ -49,7 +49,7 @@ func (s *Server) handleGetReplyLaterItems(w http.ResponseWriter, r *http.Request rlStore.mu.RLock() defer rlStore.mu.RUnlock() - showCompleted := ParseBool(r.URL.Query(), "completed") + showCompleted := NewQueryParams(r.URL.Query()).GetBool("completed") items := make([]*ReplyLaterItem, 0) for _, item := range rlStore.items { diff --git a/internal/air/query_helpers.go b/internal/air/query_helpers.go index aca8e09..79adcb6 100644 --- a/internal/air/query_helpers.go +++ b/internal/air/query_helpers.go @@ -59,17 +59,6 @@ func (q *QueryParams) GetBool(key string) bool { return q.values.Get(key) == "true" } -// GetBoolPtr parses a boolean query parameter and returns a pointer. -// Returns nil if the parameter is not present, otherwise returns pointer to the bool value. -func (q *QueryParams) GetBoolPtr(key string) *bool { - s := q.values.Get(key) - if s == "" { - return nil - } - val := s == "true" - return &val -} - // GetString returns the parameter value, or defaultVal if empty. func (q *QueryParams) GetString(key, defaultVal string) string { s := q.values.Get(key) @@ -78,33 +67,3 @@ func (q *QueryParams) GetString(key, defaultVal string) string { } return s } - -// Has returns true if the parameter is present (even if empty). -func (q *QueryParams) Has(key string) bool { - _, ok := q.values[key] - return ok -} - -// ParseLimit is a standalone helper for parsing limit with standard bounds. -// Deprecated: Use QueryParams.GetLimit() instead for new code. -func ParseLimit(query url.Values, defaultVal int) int { - return NewQueryParams(query).GetLimit(defaultVal) -} - -// ParseInt is a standalone helper for parsing an integer with bounds. -// Deprecated: Use QueryParams.GetInt() instead for new code. -func ParseInt(query url.Values, key string, defaultVal, minVal, maxVal int) int { - return NewQueryParams(query).GetInt(key, defaultVal, minVal, maxVal) -} - -// ParseInt64 is a standalone helper for parsing int64 values. -// Deprecated: Use QueryParams.GetInt64() instead for new code. -func ParseInt64(query url.Values, key string, defaultVal int64) int64 { - return NewQueryParams(query).GetInt64(key, defaultVal) -} - -// ParseBool is a standalone helper for parsing boolean values. -// Deprecated: Use QueryParams.GetBool() instead for new code. -func ParseBool(query url.Values, key string) bool { - return NewQueryParams(query).GetBool(key) -} diff --git a/internal/air/query_helpers_extended_test.go b/internal/air/query_helpers_extended_test.go index e6160c4..f014fa1 100644 --- a/internal/air/query_helpers_extended_test.go +++ b/internal/air/query_helpers_extended_test.go @@ -260,49 +260,6 @@ func TestQueryParams_GetBool_Comprehensive(t *testing.T) { } } -// TestQueryParams_GetBoolPtr_Comprehensive tests boolean pointer parsing. -func TestQueryParams_GetBoolPtr_Comprehensive(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - wantNil bool - wantVal bool - }{ - // Present values - {"true", "flag=true", "flag", false, true}, - {"false", "flag=false", "flag", false, false}, - {"other", "flag=other", "flag", false, false}, - - // Nil cases - {"empty value", "flag=", "flag", true, false}, - {"missing key", "", "flag", true, false}, - {"different key", "other=true", "flag", true, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetBoolPtr(tt.key) - - if tt.wantNil { - if got != nil { - t.Errorf("GetBoolPtr() = %v, want nil", *got) - } - } else { - if got == nil { - t.Errorf("GetBoolPtr() = nil, want %v", tt.wantVal) - } else if *got != tt.wantVal { - t.Errorf("GetBoolPtr() = %v, want %v", *got, tt.wantVal) - } - } - }) - } -} - // TestQueryParams_GetString_Comprehensive tests string parsing. func TestQueryParams_GetString_Comprehensive(t *testing.T) { t.Parallel() @@ -342,35 +299,22 @@ func TestQueryParams_GetString_Comprehensive(t *testing.T) { } } -// TestQueryParams_Has_Comprehensive tests key presence detection. -func TestQueryParams_Has_Comprehensive(t *testing.T) { +// TestQueryParams_Get covers the raw Get accessor, including the missing-key +// case where it must return an empty string. +func TestQueryParams_Get(t *testing.T) { t.Parallel() - tests := []struct { - name string - query string - key string - want bool - }{ - {"key with value", "key=value", "key", true}, - {"key without value", "key=", "key", true}, - {"key only", "key", "key", true}, - {"missing key", "other=value", "key", false}, - {"empty query", "", "key", false}, - {"similar key prefix", "keyboard=value", "key", false}, - {"similar key suffix", "mykey=value", "key", false}, - {"case sensitive", "Key=value", "key", false}, - } + values, _ := url.ParseQuery("foo=bar&empty=") + q := NewQueryParams(values) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.Has(tt.key) - if got != tt.want { - t.Errorf("Has() = %v, want %v", got, tt.want) - } - }) + if got := q.Get("foo"); got != "bar" { + t.Errorf("Get(foo) = %q, want %q", got, "bar") + } + if got := q.Get("empty"); got != "" { + t.Errorf("Get(empty) = %q, want empty", got) + } + if got := q.Get("missing"); got != "" { + t.Errorf("Get(missing) = %q, want empty", got) } } diff --git a/internal/air/query_helpers_test.go b/internal/air/query_helpers_test.go deleted file mode 100644 index f085cea..0000000 --- a/internal/air/query_helpers_test.go +++ /dev/null @@ -1,279 +0,0 @@ -package air - -import ( - "net/url" - "testing" -) - -func TestQueryParams_GetInt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - defaultVal int - minVal int - maxVal int - want int - }{ - {"valid value", "limit=25", "limit", 50, 1, 100, 25}, - {"missing key", "", "limit", 50, 1, 100, 50}, - {"empty value", "limit=", "limit", 50, 1, 100, 50}, - {"below min", "limit=0", "limit", 50, 1, 100, 50}, - {"above max", "limit=150", "limit", 50, 1, 100, 50}, - {"invalid string", "limit=abc", "limit", 50, 1, 100, 50}, - {"negative value", "limit=-5", "limit", 50, 1, 100, 50}, - {"at min boundary", "limit=1", "limit", 50, 1, 100, 1}, - {"at max boundary", "limit=100", "limit", 50, 1, 100, 100}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetInt(tt.key, tt.defaultVal, tt.minVal, tt.maxVal) - if got != tt.want { - t.Errorf("GetInt() = %d, want %d", got, tt.want) - } - }) - } -} - -func TestQueryParams_GetLimit(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - defaultVal int - want int - }{ - {"valid limit", "limit=25", 50, 25}, - {"missing limit", "", 50, 50}, - {"max limit", "limit=200", 50, 200}, - {"over max limit", "limit=500", 50, 50}, - {"zero limit", "limit=0", 50, 50}, - {"different default", "limit=abc", 100, 100}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetLimit(tt.defaultVal) - if got != tt.want { - t.Errorf("GetLimit() = %d, want %d", got, tt.want) - } - }) - } -} - -func TestQueryParams_GetInt64(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - defaultVal int64 - want int64 - }{ - {"valid timestamp", "start=1704067200", "start", 0, 1704067200}, - {"missing key", "", "start", 0, 0}, - {"empty value", "start=", "start", 100, 100}, - {"invalid string", "start=abc", "start", 0, 0}, - {"negative value", "start=-1000", "start", 0, -1000}, - {"large value", "start=9999999999999", "start", 0, 9999999999999}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetInt64(tt.key, tt.defaultVal) - if got != tt.want { - t.Errorf("GetInt64() = %d, want %d", got, tt.want) - } - }) - } -} - -func TestQueryParams_GetBool(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - want bool - }{ - {"true value", "unread=true", "unread", true}, - {"false value", "unread=false", "unread", false}, - {"missing key", "", "unread", false}, - {"empty value", "unread=", "unread", false}, - {"uppercase TRUE", "unread=TRUE", "unread", false}, - {"numeric 1", "unread=1", "unread", false}, - {"yes value", "unread=yes", "unread", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetBool(tt.key) - if got != tt.want { - t.Errorf("GetBool() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestQueryParams_GetBoolPtr(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - wantNil bool - wantVal bool - }{ - {"true value", "unread=true", "unread", false, true}, - {"false value", "unread=false", "unread", false, false}, - {"missing key", "", "unread", true, false}, - {"empty value", "unread=", "unread", true, false}, // Empty = not meaningfully set - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetBoolPtr(tt.key) - if tt.wantNil { - if got != nil { - t.Errorf("GetBoolPtr() = %v, want nil", *got) - } - } else { - if got == nil { - t.Errorf("GetBoolPtr() = nil, want %v", tt.wantVal) - } else if *got != tt.wantVal { - t.Errorf("GetBoolPtr() = %v, want %v", *got, tt.wantVal) - } - } - }) - } -} - -func TestQueryParams_GetString(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - defaultVal string - want string - }{ - {"present value", "folder=inbox", "folder", "all", "inbox"}, - {"missing key", "", "folder", "all", "all"}, - {"empty value", "folder=", "folder", "all", "all"}, - {"special chars", "q=hello+world", "q", "", "hello world"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.GetString(tt.key, tt.defaultVal) - if got != tt.want { - t.Errorf("GetString() = %q, want %q", got, tt.want) - } - }) - } -} - -func TestQueryParams_Has(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - query string - key string - want bool - }{ - {"present with value", "limit=50", "limit", true}, - {"present empty", "limit=", "limit", true}, - {"missing", "other=value", "limit", false}, - {"empty query", "", "limit", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - values, _ := url.ParseQuery(tt.query) - q := NewQueryParams(values) - got := q.Has(tt.key) - if got != tt.want { - t.Errorf("Has() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestQueryParams_Get(t *testing.T) { - t.Parallel() - - values, _ := url.ParseQuery("foo=bar&empty=") - q := NewQueryParams(values) - - if got := q.Get("foo"); got != "bar" { - t.Errorf("Get(foo) = %q, want %q", got, "bar") - } - if got := q.Get("empty"); got != "" { - t.Errorf("Get(empty) = %q, want empty", got) - } - if got := q.Get("missing"); got != "" { - t.Errorf("Get(missing) = %q, want empty", got) - } -} - -// Test standalone helper functions (deprecated but still need coverage) -func TestParseLimit(t *testing.T) { - t.Parallel() - - values, _ := url.ParseQuery("limit=75") - got := ParseLimit(values, 50) - if got != 75 { - t.Errorf("ParseLimit() = %d, want 75", got) - } -} - -func TestParseInt(t *testing.T) { - t.Parallel() - - values, _ := url.ParseQuery("page=5") - got := ParseInt(values, "page", 1, 1, 100) - if got != 5 { - t.Errorf("ParseInt() = %d, want 5", got) - } -} - -func TestParseInt64(t *testing.T) { - t.Parallel() - - values, _ := url.ParseQuery("timestamp=1704067200") - got := ParseInt64(values, "timestamp", 0) - if got != 1704067200 { - t.Errorf("ParseInt64() = %d, want 1704067200", got) - } -} - -func TestParseBool(t *testing.T) { - t.Parallel() - - values, _ := url.ParseQuery("active=true") - got := ParseBool(values, "active") - if !got { - t.Errorf("ParseBool() = %v, want true", got) - } -} diff --git a/internal/air/server_lifecycle.go b/internal/air/server_lifecycle.go index adc501f..c841aaf 100644 --- a/internal/air/server_lifecycle.go +++ b/internal/air/server_lifecycle.go @@ -14,6 +14,7 @@ import ( "github.com/nylas/cli/internal/air/cache" authapp "github.com/nylas/cli/internal/app/auth" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" ) @@ -313,14 +314,7 @@ func (s *Server) Start() error { PerformanceMonitoringMiddleware( MethodOverrideMiddleware(mux)))))))) - server := &http.Server{ - Addr: s.addr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 120 * time.Second, - MaxHeaderBytes: 1 << 20, // 1 MB - } + server := httputil.NewServer(s.addr, handler, 30*time.Second) return server.ListenAndServe() } diff --git a/internal/app/auth/grants.go b/internal/app/auth/grants.go index 35ad777..c5a18bb 100644 --- a/internal/app/auth/grants.go +++ b/internal/app/auth/grants.go @@ -68,7 +68,18 @@ func (s *GrantService) ListGrants(ctx context.Context) ([]domain.GrantStatus, er }) } _ = s.grantStore.ReplaceGrants(cacheGrants) - if !defaultStillExists { + switch { + case defaultGrant == "": + // No default resolved from either store; nothing to reconcile. + case defaultStillExists: + // The resolved default is live. Persist it to BOTH stores so the grant + // cache (authoritative for every other command, e.g. GetGrantID) agrees + // with what auth list shows. Without this, a default that survives only + // in config surfaces here while commands like `email list` report "no + // grant" — the confusing split the user hit. + _ = PersistDefaultGrant(s.config, s.grantStore, defaultGrant) + default: + // Resolved default no longer exists on Nylas; clear it everywhere. _ = PersistDefaultGrant(s.config, s.grantStore, "") } diff --git a/internal/app/auth/grants_test.go b/internal/app/auth/grants_test.go index a2811b1..24a84ff 100644 --- a/internal/app/auth/grants_test.go +++ b/internal/app/auth/grants_test.go @@ -100,6 +100,37 @@ func TestGrantService_ListGrantsClearsStaleConfigDefault(t *testing.T) { assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) } +// TestGrantService_ListGrantsRestoresLiveConfigDefaultToCache covers the split +// the user hit: the grant cache lost its default (e.g. a self-heal cleared it) +// but config still points at a valid, live grant. auth list must write that +// default back into the cache so other commands (GetGrantID/email list) agree, +// instead of leaving auth list showing a default no one else can see. +func TestGrantService_ListGrantsRestoresLiveConfigDefaultToCache(t *testing.T) { + grantStore := newMockGrantStore() + configStore := newMockConfigStore() + configStore.config.DefaultGrant = "grant-2" // config-only default; cache has none + client := nylas.NewMockClient() + client.ListGrantsFunc = func(ctx context.Context) ([]domain.Grant, error) { + return []domain.Grant{ + {ID: "grant-1", Email: "one@example.com", Provider: domain.ProviderGoogle, GrantStatus: "valid"}, + {ID: "grant-2", Email: "two@example.com", Provider: domain.ProviderGoogle, GrantStatus: "valid"}, + }, nil + } + + svc := NewGrantService(client, grantStore, configStore) + + got, err := svc.ListGrants(context.Background()) + require.NoError(t, err) + require.Len(t, got, 2) + assert.True(t, got[1].IsDefault) + + // The live default is now in the cache, so GetGrantID-style lookups resolve it. + defaultGrant, err := grantStore.GetDefaultGrant() + require.NoError(t, err) + assert.Equal(t, "grant-2", defaultGrant) + assert.Equal(t, "grant-2", configStore.config.DefaultGrant) +} + func TestGrantService_CachedGrantCountUsesGrantStore(t *testing.T) { grantStore := newMockGrantStore() grantStore.grants["grant-1"] = domain.GrantInfo{ID: "grant-1", Email: "one@example.com"} diff --git a/internal/chat/server.go b/internal/chat/server.go index e41aed4..df31ae6 100644 --- a/internal/chat/server.go +++ b/internal/chat/server.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" "github.com/nylas/cli/internal/webguard" ) @@ -130,14 +131,9 @@ func (s *Server) Start() error { webguard.OriginProtectionMiddleware( webguard.SecurityHeadersMiddleware(mux))) - server := &http.Server{ - Addr: s.addr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - WriteTimeout: 360 * time.Second, // baseline for SSE streaming; handleChat extends the per-connection write deadline during approval waits - IdleTimeout: 120 * time.Second, - MaxHeaderBytes: 1 << 20, - } + // 360s WriteTimeout is the baseline for SSE streaming; handleChat extends + // the per-connection write deadline during approval waits. + server := httputil.NewServer(s.addr, handler, 360*time.Second) return server.ListenAndServe() } diff --git a/internal/cli/audit_hooks.go b/internal/cli/audit_hooks.go index fd966c7..1fc7c8e 100644 --- a/internal/cli/audit_hooks.go +++ b/internal/cli/audit_hooks.go @@ -72,6 +72,12 @@ func initAuditHooks(rootCmd *cobra.Command) { // auditPreRun is called before every command execution. func auditPreRun(cmd *cobra.Command, args []string) error { + // Apply the global --quiet flag process-wide so decorative output + // (success messages, tables, spinners, progress) is suppressed. Structured + // data is handled separately by the OutputWriter. + quiet, _ := cmd.Flags().GetBool("quiet") + common.SetQuiet(quiet) + // Don't audit help, version, or completion commands if isExcludedCommand(cmd) { return nil diff --git a/internal/cli/audit_hooks_lifecycle_test.go b/internal/cli/audit_hooks_lifecycle_test.go index 911ab35..461157b 100644 --- a/internal/cli/audit_hooks_lifecycle_test.go +++ b/internal/cli/audit_hooks_lifecycle_test.go @@ -216,3 +216,31 @@ func TestInitAuditHooks(t *testing.T) { t.Error("initAuditHooks() did not set AuditRequestHook") } } + +// TestAuditPreRun_AppliesQuietFlag verifies the root PersistentPreRunE wires the +// --quiet flag into the process-wide quiet mode, so decorative output is +// actually suppressed (the guards in format/progress read common.IsQuiet()). +func TestAuditPreRun_AppliesQuietFlag(t *testing.T) { + common.SetQuiet(false) + defer common.SetQuiet(false) + + cmd := &cobra.Command{Use: "demo"} + cmd.Flags().BoolP("quiet", "q", false, "") + + if err := auditPreRun(cmd, nil); err != nil { + t.Fatalf("auditPreRun without --quiet: %v", err) + } + if common.IsQuiet() { + t.Error("expected quiet mode off when --quiet is not set") + } + + if err := cmd.Flags().Set("quiet", "true"); err != nil { + t.Fatalf("set --quiet: %v", err) + } + if err := auditPreRun(cmd, nil); err != nil { + t.Fatalf("auditPreRun with --quiet: %v", err) + } + if !common.IsQuiet() { + t.Error("expected quiet mode on when --quiet is set") + } +} diff --git a/internal/cli/common/client.go b/internal/cli/common/client.go index 2b7aaa6..3cac302 100644 --- a/internal/cli/common/client.go +++ b/internal/cli/common/client.go @@ -277,10 +277,51 @@ func WithClient[T any](args []string, fn func(ctx context.Context, client ports. ctx, cancel := CreateContext() defer cancel() + // When the grant comes from the stored default (no explicit grant arg or + // NYLAS_GRANT_ID), verify it still exists. A removed or re-authenticated + // default otherwise surfaces as a confusing downstream 404 instead of a + // clear "select a grant" message. + if grantFromDefault(args) { + if verr := validateDefaultGrant(ctx, client, grantID); verr != nil { + return zero, verr + } + } + // Execute function return fn(ctx, client, grantID) } +// grantFromDefault reports whether GetGrantID would resolve the grant from the +// stored default rather than an explicit argument or NYLAS_GRANT_ID. +func grantFromDefault(args []string) bool { + if len(args) > 0 && args[0] != "" { + return false + } + return os.Getenv("NYLAS_GRANT_ID") == "" +} + +// validateDefaultGrant confirms the resolved default grant still exists. If it +// was removed, the stale default is cleared and a clear, actionable error is +// returned. Non-"not found" errors (e.g. transient network failures) are +// ignored so the real operation can surface them. +func validateDefaultGrant(ctx context.Context, client ports.NylasClient, grantID string) error { + if _, err := client.GetGrant(ctx, grantID); err == nil || !errors.Is(err, domain.ErrGrantNotFound) { + return nil + } + + // Self-heal: drop the stale pointer so the next run reports "no default". + if store, serr := NewDefaultGrantStore(); serr == nil { + _ = store.SetDefaultGrant("") + } + + return NewUserErrorWithSuggestions( + fmt.Sprintf("The default grant (%s) is no longer available — it may have been removed or re-authenticated.", grantID), + "List current accounts with: nylas auth list", + "Select an active account with: nylas auth switch ", + "Or specify one directly: nylas [command] ", + ) +} + // WithClientNoGrant is a generic helper for commands that don't need a grant ID. // This is useful for admin commands or commands that operate without a specific account. // diff --git a/internal/cli/common/client_test.go b/internal/cli/common/client_test.go index 145dc58..c1cef07 100644 --- a/internal/cli/common/client_test.go +++ b/internal/cli/common/client_test.go @@ -3,12 +3,15 @@ package common import ( + "context" + "errors" "os" "path/filepath" "testing" "github.com/nylas/cli/internal/adapters/config" "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/adapters/nylas" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" "github.com/stretchr/testify/assert" @@ -438,3 +441,54 @@ func TestContainsAt_UnicodeSupport(t *testing.T) { }) } } + +func TestGrantFromDefault(t *testing.T) { + t.Run("explicit grant arg is not from default", func(t *testing.T) { + assert.False(t, grantFromDefault([]string{"grant-123"})) + }) + t.Run("NYLAS_GRANT_ID is not from default", func(t *testing.T) { + t.Setenv("NYLAS_GRANT_ID", "grant-env") + assert.False(t, grantFromDefault(nil)) + }) + t.Run("no arg and no env resolves from default", func(t *testing.T) { + t.Setenv("NYLAS_GRANT_ID", "") + assert.True(t, grantFromDefault(nil)) + assert.True(t, grantFromDefault([]string{""})) + }) +} + +func TestValidateDefaultGrant(t *testing.T) { + // Isolate any self-heal write to a throwaway cache location. + t.Setenv("HOME", t.TempDir()) + ctx := context.Background() + + t.Run("removed default returns a clear, actionable error", func(t *testing.T) { + client := nylas.NewMockClient() + client.GetGrantFunc = func(context.Context, string) (*domain.Grant, error) { + return nil, domain.ErrGrantNotFound + } + + err := validateDefaultGrant(ctx, client, "stale-grant-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "no longer available") + assert.Contains(t, err.Error(), "stale-grant-id") + }) + + t.Run("valid default passes", func(t *testing.T) { + client := nylas.NewMockClient() + client.GetGrantFunc = func(_ context.Context, id string) (*domain.Grant, error) { + return &domain.Grant{ID: id}, nil + } + + assert.NoError(t, validateDefaultGrant(ctx, client, "good-grant-id")) + }) + + t.Run("transient error does not block the operation", func(t *testing.T) { + client := nylas.NewMockClient() + client.GetGrantFunc = func(context.Context, string) (*domain.Grant, error) { + return nil, errors.New("network unreachable") + } + + assert.NoError(t, validateDefaultGrant(ctx, client, "some-grant-id")) + }) +} diff --git a/internal/cli/common/common_test.go b/internal/cli/common/common_test.go index 0f18d91..4933122 100644 --- a/internal/cli/common/common_test.go +++ b/internal/cli/common/common_test.go @@ -18,44 +18,15 @@ import ( // Logger Tests // ============================================================================= -func TestLogger_Init(t *testing.T) { - ResetLogger() +func TestQuietMode(t *testing.T) { + SetQuiet(false) + defer SetQuiet(false) - InitLogger(false, false) - assert.NotNil(t, GetLogger()) - assert.False(t, IsDebug()) assert.False(t, IsQuiet()) -} - -func TestLogger_DebugMode(t *testing.T) { - ResetLogger() - - InitLogger(true, false) - assert.True(t, IsDebug()) - assert.False(t, IsQuiet()) -} - -func TestLogger_QuietMode(t *testing.T) { - ResetLogger() - - InitLogger(false, true) - assert.False(t, IsDebug()) + SetQuiet(true) assert.True(t, IsQuiet()) } -func TestLogger_Functions(t *testing.T) { - ResetLogger() - InitLogger(true, false) - - // These should not panic - Debug("debug message", "key", "value") - Info("info message") - Warn("warning message") - Error("error message") - DebugHTTP("GET", "https://api.nylas.com", 200, "100ms") - DebugAPI("GetMessages", "grant_id", "test") -} - // ============================================================================= // Retry Tests // ============================================================================= @@ -115,8 +86,7 @@ func TestRetry_MaxRetriesExceeded(t *testing.T) { } func TestRetry_NonRetryableError(t *testing.T) { - ResetLogger() - InitLogger(false, true) // quiet mode + SetQuiet(true) // quiet mode config := DefaultRetryConfig() attempts := 0 @@ -188,8 +158,7 @@ func TestRetry_NoRetryConfig(t *testing.T) { // ============================================================================= func TestSpinner_StartStop(t *testing.T) { - ResetLogger() - InitLogger(false, true) // quiet mode to avoid output + SetQuiet(true) // quiet mode to avoid output var buf bytes.Buffer spinner := NewSpinner("Loading...").SetWriter(&buf) @@ -203,8 +172,7 @@ func TestSpinner_StartStop(t *testing.T) { } func TestSpinner_StopWithMessage(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) var buf bytes.Buffer spinner := NewSpinner("Loading...").SetWriter(&buf) @@ -216,126 +184,12 @@ func TestSpinner_StopWithMessage(t *testing.T) { assert.Contains(t, buf.String(), "Done!") } -func TestProgressBar_Increment(t *testing.T) { - ResetLogger() - InitLogger(false, true) // quiet mode - - var buf bytes.Buffer - bar := NewProgressBar(10, "Processing").SetWriter(&buf) - - for i := 0; i < 10; i++ { - bar.Increment() - } - - // In quiet mode, should not produce output - assert.Empty(t, buf.String()) -} - -func TestProgressBar_Set(t *testing.T) { - ResetLogger() - InitLogger(false, true) - - bar := NewProgressBar(100, "Processing") - bar.Set(50) - bar.Finish() -} - -func TestCounter(t *testing.T) { - ResetLogger() - InitLogger(false, true) - - counter := NewCounter("Items") - counter.Increment() - counter.Increment() - counter.Increment() - - assert.Equal(t, 3, counter.Count()) - counter.Finish() -} - // ============================================================================= // Format Tests // ============================================================================= -func TestParseFormat(t *testing.T) { - tests := []struct { - input string - expected OutputFormat - hasError bool - }{ - {"table", FormatTable, false}, - {"TABLE", FormatTable, false}, - {"", FormatTable, false}, - {"json", FormatJSON, false}, - {"JSON", FormatJSON, false}, - {"csv", FormatCSV, false}, - {"yaml", FormatYAML, false}, - {"yml", FormatYAML, false}, - {"invalid", "", true}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - format, err := ParseFormat(tt.input) - if tt.hasError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, format) - } - }) - } -} - -func TestFormatter_JSON(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatJSON).SetWriter(&buf) - - data := map[string]string{"key": "value"} - err := formatter.Format(data) - - require.NoError(t, err) - assert.Contains(t, buf.String(), `"key"`) - assert.Contains(t, buf.String(), `"value"`) -} - -func TestFormatter_YAML(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatYAML).SetWriter(&buf) - - data := map[string]string{"key": "value"} - err := formatter.Format(data) - - require.NoError(t, err) - assert.Contains(t, buf.String(), "key:") - assert.Contains(t, buf.String(), "value") -} - -func TestFormatter_CSV(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - type Item struct { - Name string `json:"name"` - Value int `json:"value"` - } - - data := []Item{ - {Name: "item1", Value: 1}, - {Name: "item2", Value: 2}, - } - - err := formatter.Format(data) - - require.NoError(t, err) - assert.Contains(t, buf.String(), "name") - assert.Contains(t, buf.String(), "item1") - assert.Contains(t, buf.String(), "item2") -} - func TestTable(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) var buf bytes.Buffer table := NewTable("ID", "NAME", "STATUS").SetWriter(&buf) @@ -354,8 +208,7 @@ func TestTable(t *testing.T) { } func TestTable_AlignRight(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) var buf bytes.Buffer table := NewTable("NAME", "COUNT").SetWriter(&buf) @@ -367,8 +220,7 @@ func TestTable_AlignRight(t *testing.T) { } func TestConfirm(t *testing.T) { - ResetLogger() - InitLogger(false, true) // quiet mode + SetQuiet(true) // quiet mode // In quiet mode, should return default WITHOUT prompting. // Destructive commands rely on this: default-no confirms cancel in quiet @@ -399,8 +251,7 @@ func TestConfirm_Interactive(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResetLogger() - InitLogger(false, false) // NOT quiet: exercise the stdin-reading path + SetQuiet(false) // NOT quiet: exercise the stdin-reading path r, w, err := os.Pipe() require.NoError(t, err) @@ -489,8 +340,7 @@ func TestWrapError_AlreadyCLIError(t *testing.T) { } func TestFormatError(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) err := domain.ErrNotConfigured formatted := FormatError(err) @@ -500,8 +350,7 @@ func TestFormatError(t *testing.T) { } func TestFormatError_DebugMode(t *testing.T) { - ResetLogger() - InitLogger(true, false) + SetQuiet(false) err := errors.New("detailed error message") formatted := FormatError(err) diff --git a/internal/cli/common/crud_test.go b/internal/cli/common/crud_test.go index 75b731a..0f1497d 100644 --- a/internal/cli/common/crud_test.go +++ b/internal/cli/common/crud_test.go @@ -53,8 +53,7 @@ func TestRunDelete_Confirmation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResetLogger() - InitLogger(false, tt.quiet) + SetQuiet(tt.quiet) if !tt.quiet { withStdin(t, tt.stdin) } @@ -81,8 +80,7 @@ func TestRunDelete_Confirmation(t *testing.T) { } func TestRunDelete_WrapsDeleteError(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) deleteErr := errors.New("backend exploded") err := RunDelete(DeleteConfig{ @@ -117,8 +115,7 @@ func TestNewDeleteCommand_Confirmation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResetLogger() - InitLogger(false, tt.quiet) + SetQuiet(tt.quiet) if !tt.quiet { withStdin(t, tt.stdin) } @@ -148,8 +145,7 @@ func TestNewDeleteCommand_Confirmation(t *testing.T) { // resolves the grant ID from the second positional argument and honors // --force. func TestNewDeleteCommand_GrantPathForce(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) deleted := false cmd := NewDeleteCommand(DeleteCommandConfig{ diff --git a/internal/cli/common/errors.go b/internal/cli/common/errors.go index 084f2b5..72e232f 100644 --- a/internal/cli/common/errors.go +++ b/internal/cli/common/errors.go @@ -323,12 +323,6 @@ func FormatError(err error) string { _, _ = Yellow.Fprintf(&sb, " • %s\n", cliErr.Suggestion) } - // Original error in debug mode - if IsDebug() && cliErr.Err != nil && cliErr.Err.Error() != cliErr.Message { - _, _ = sb.WriteString("\n") - _, _ = Dim.Fprintf(&sb, "Debug details: %s\n", cliErr.Err.Error()) - } - return sb.String() } diff --git a/internal/cli/common/flags.go b/internal/cli/common/flags.go index db130a8..f7db37e 100644 --- a/internal/cli/common/flags.go +++ b/internal/cli/common/flags.go @@ -7,32 +7,12 @@ func AddLimitFlag(cmd *cobra.Command, target *int, defaultValue int) { cmd.Flags().IntVarP(target, "limit", "n", defaultValue, "Maximum number of items to show") } -// AddFormatFlag adds a --format/-f flag for output format. -func AddFormatFlag(cmd *cobra.Command, target *string) { - cmd.Flags().StringVarP(target, "format", "f", "table", "Output format (table, json, yaml, csv)") -} - -// AddIDFlag adds a --id flag to show resource IDs. -func AddIDFlag(cmd *cobra.Command, target *bool) { - cmd.Flags().BoolVar(target, "id", false, "Show IDs in output") -} - // AddPageTokenFlag adds a --page-token flag for pagination. func AddPageTokenFlag(cmd *cobra.Command, target *string) { cmd.Flags().StringVar(target, "page-token", "", "Page token for pagination") } -// AddForceFlag adds a --force/-f flag for skipping confirmations. -func AddForceFlag(cmd *cobra.Command, target *bool) { - cmd.Flags().BoolVarP(target, "force", "f", false, "Skip confirmation prompts") -} - // AddYesFlag adds a --yes/-y flag for skipping confirmations. func AddYesFlag(cmd *cobra.Command, target *bool) { cmd.Flags().BoolVarP(target, "yes", "y", false, "Skip confirmation prompts") } - -// AddVerboseFlag adds a --verbose/-v flag for verbose output. -func AddVerboseFlag(cmd *cobra.Command, target *bool) { - cmd.Flags().BoolVarP(target, "verbose", "v", false, "Show verbose output") -} diff --git a/internal/cli/common/format.go b/internal/cli/common/format.go index 53344d6..0f9ba17 100644 --- a/internal/cli/common/format.go +++ b/internal/cli/common/format.go @@ -1,270 +1,15 @@ package common import ( - "encoding/csv" "encoding/json" "fmt" "io" "os" - "reflect" "strings" "github.com/nylas/cli/internal/domain" - "gopkg.in/yaml.v3" ) -// OutputFormat represents the output format type. -type OutputFormat string - -const ( - FormatTable OutputFormat = "table" - FormatJSON OutputFormat = "json" - FormatCSV OutputFormat = "csv" - FormatYAML OutputFormat = "yaml" -) - -// ParseFormat parses a format string into OutputFormat. -func ParseFormat(s string) (OutputFormat, error) { - switch strings.ToLower(s) { - case "table", "": - return FormatTable, nil - case "json": - return FormatJSON, nil - case "csv": - return FormatCSV, nil - case "yaml", "yml": - return FormatYAML, nil - default: - return "", NewInputError(fmt.Sprintf("invalid format: %s (valid: table, json, csv, yaml)", s)) - } -} - -// Formatter handles output formatting. -type Formatter struct { - format OutputFormat - writer io.Writer -} - -// NewFormatter creates a new formatter. -func NewFormatter(format OutputFormat) *Formatter { - return &Formatter{ - format: format, - writer: os.Stdout, - } -} - -// SetWriter sets the output writer. -func (f *Formatter) SetWriter(w io.Writer) *Formatter { - f.writer = w - return f -} - -// Format formats and outputs data based on the configured format. -func (f *Formatter) Format(data any) error { - switch f.format { - case FormatJSON: - return f.formatJSON(data) - case FormatCSV: - return f.formatCSV(data) - case FormatYAML: - return f.formatYAML(data) - default: - return f.formatTable(data) - } -} - -// formatJSON outputs data as JSON. -func (f *Formatter) formatJSON(data any) error { - encoder := json.NewEncoder(f.writer) - encoder.SetIndent("", " ") - return encoder.Encode(data) -} - -// formatYAML outputs data as YAML. -func (f *Formatter) formatYAML(data any) error { - encoder := yaml.NewEncoder(f.writer) - encoder.SetIndent(2) - defer func() { _ = encoder.Close() }() - return encoder.Encode(data) -} - -// formatCSV outputs data as CSV. -func (f *Formatter) formatCSV(data any) error { - writer := csv.NewWriter(f.writer) - defer writer.Flush() - - // Handle slice of structs - v := reflect.ValueOf(data) - if v.Kind() == reflect.Pointer { - v = v.Elem() - } - - if v.Kind() != reflect.Slice { - // Single item - wrap in slice - return f.formatCSVSingle(writer, data) - } - - if v.Len() == 0 { - return nil - } - - // Get headers from first element - first := v.Index(0) - if first.Kind() == reflect.Pointer { - first = first.Elem() - } - - // Pre-compute field info once for all rows (avoids repeated reflection) - fields := getFieldInfo(first.Type()) - headers := make([]string, len(fields)) - for i, field := range fields { - headers[i] = field.name - } - if err := writer.Write(headers); err != nil { - return err - } - - // Pre-allocate row slice, reuse for all rows - row := make([]string, len(fields)) - - // Write rows using cached field info - for i := 0; i < v.Len(); i++ { - elem := v.Index(i) - if elem.Kind() == reflect.Pointer { - elem = elem.Elem() - } - getCSVRowInto(elem, fields, row) - if err := writer.Write(row); err != nil { - return err - } - } - - return nil -} - -// formatCSVSingle formats a single item as CSV. -func (f *Formatter) formatCSVSingle(writer *csv.Writer, data any) error { - v := reflect.ValueOf(data) - if v.Kind() == reflect.Pointer { - v = v.Elem() - } - - headers := getCSVHeaders(v) - if err := writer.Write(headers); err != nil { - return err - } - - row := getCSVRow(v) - return writer.Write(row) -} - -// fieldInfo caches field metadata to avoid repeated reflection. -type fieldInfo struct { - index int - name string -} - -// getFieldInfo extracts field metadata from a struct type once. -// This avoids repeated reflection on every row. -func getFieldInfo(t reflect.Type) []fieldInfo { - if t.Kind() != reflect.Struct { - return nil - } - - fields := make([]fieldInfo, 0, t.NumField()) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if !field.IsExported() { - continue - } - - // Skip fields with json:"-" - tag := field.Tag.Get("json") - if tag == "-" { - continue - } - - // Use json tag name if available - name := field.Name - if tag != "" { - parts := strings.Split(tag, ",") - if parts[0] != "" { - name = parts[0] - } - } - fields = append(fields, fieldInfo{index: i, name: name}) - } - return fields -} - -// getCSVHeaders extracts CSV headers from a struct. -func getCSVHeaders(v reflect.Value) []string { - if v.Kind() != reflect.Struct { - return []string{"value"} - } - - fields := getFieldInfo(v.Type()) - headers := make([]string, len(fields)) - for i, f := range fields { - headers[i] = f.name - } - return headers -} - -// getCSVRow extracts CSV row values from a struct. -func getCSVRow(v reflect.Value) []string { - if v.Kind() != reflect.Struct { - return []string{fmt.Sprintf("%v", v.Interface())} - } - - fields := getFieldInfo(v.Type()) - row := make([]string, len(fields)) - for i, f := range fields { - row[i] = formatValue(v.Field(f.index)) - } - return row -} - -// getCSVRowInto fills an existing slice with row values, avoiding allocation. -func getCSVRowInto(v reflect.Value, fields []fieldInfo, row []string) { - for i, f := range fields { - row[i] = formatValue(v.Field(f.index)) - } -} - -// formatValue formats a reflect.Value as a string. -func formatValue(v reflect.Value) string { - if !v.IsValid() { - return "" - } - - switch v.Kind() { - case reflect.Pointer, reflect.Interface: - if v.IsNil() { - return "" - } - return formatValue(v.Elem()) - case reflect.Slice, reflect.Array: - if v.Len() == 0 { - return "" - } - parts := make([]string, v.Len()) - for i := 0; i < v.Len(); i++ { - parts[i] = formatValue(v.Index(i)) - } - return strings.Join(parts, "; ") - default: - return fmt.Sprintf("%v", v.Interface()) - } -} - -// formatTable outputs data as a formatted table. -func (f *Formatter) formatTable(data any) error { - // This is a placeholder - actual table formatting is done by specific commands - // for more control over display - return f.formatJSON(data) -} - // Table provides a simple table builder. type Table struct { headers []string @@ -490,18 +235,6 @@ func PrintEmptyStateWithHint(resourceName, hint string) { } } -// PrintListHeader prints a consistent "found N items" header. -func PrintListHeader(count int, resourceName string) { - if IsQuiet() { - return - } - if count == 1 { - fmt.Printf("Found 1 %s:\n\n", resourceName) - } else { - fmt.Printf("Found %d %ss:\n\n", count, resourceName) - } -} - // PrintJSON writes data to stdout as pretty-printed JSON. // This is a convenience function for commands that need simple JSON output. func PrintJSON(data any) error { @@ -510,23 +243,6 @@ func PrintJSON(data any) error { return enc.Encode(data) } -// PrintSeparator prints a horizontal line separator of specified width. -// Common widths: 40 (narrow), 50 (medium), 60 (wide), 70 (extra wide). -func PrintSeparator(width int) { - if IsQuiet() { - return - } - fmt.Println(strings.Repeat("─", width)) -} - -// PrintDoubleSeparator prints a double-line separator for section headers. -func PrintDoubleSeparator(width int) { - if IsQuiet() { - return - } - fmt.Println(strings.Repeat("═", width)) -} - // truncateCell truncates a table cell to maxLen characters with ellipsis using proper UTF-8 rune counting func truncateCell(s string, maxLen int) string { runes := []rune(s) diff --git a/internal/cli/common/format_output_test.go b/internal/cli/common/format_output_test.go deleted file mode 100644 index 8213f71..0000000 --- a/internal/cli/common/format_output_test.go +++ /dev/null @@ -1,271 +0,0 @@ -//go:build !integration - -package common - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestFormatter_JSON_Output(t *testing.T) { - tests := []struct { - name string - data any - contains []string - }{ - { - name: "simple map", - data: map[string]string{"key": "value"}, - contains: []string{`"key"`, `"value"`}, - }, - { - name: "slice of maps", - data: []map[string]int{{"a": 1}, {"b": 2}}, - contains: []string{`"a"`, `"b"`, "1", "2"}, - }, - { - name: "struct", - data: struct { - Name string `json:"name"` - Count int `json:"count"` - }{Name: "test", Count: 42}, - contains: []string{`"name"`, `"test"`, `"count"`, "42"}, - }, - { - name: "nested struct", - data: map[string]any{"outer": map[string]string{"inner": "value"}}, - contains: []string{`"outer"`, `"inner"`, `"value"`}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatJSON).SetWriter(&buf) - - err := formatter.Format(tt.data) - require.NoError(t, err) - - output := buf.String() - for _, s := range tt.contains { - assert.Contains(t, output, s) - } - }) - } -} - -func TestFormatter_YAML_Output(t *testing.T) { - tests := []struct { - name string - data any - contains []string - }{ - { - name: "simple map", - data: map[string]string{"key": "value"}, - contains: []string{"key:", "value"}, - }, - { - name: "multiple fields", - data: map[string]int{"count": 10, "total": 100}, - contains: []string{"count:", "10", "total:", "100"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatYAML).SetWriter(&buf) - - err := formatter.Format(tt.data) - require.NoError(t, err) - - output := buf.String() - for _, s := range tt.contains { - assert.Contains(t, output, s) - } - }) - } -} - -func TestFormatter_CSV_Slice(t *testing.T) { - type Item struct { - Name string `json:"name"` - Value int `json:"value"` - Tag string `json:"tag"` - } - - tests := []struct { - name string - data []Item - contains []string - }{ - { - name: "multiple items", - data: []Item{ - {Name: "item1", Value: 1, Tag: "a"}, - {Name: "item2", Value: 2, Tag: "b"}, - }, - contains: []string{"name", "value", "tag", "item1", "item2", "1", "2", "a", "b"}, - }, - { - name: "single item", - data: []Item{{Name: "only", Value: 99, Tag: "x"}}, - contains: []string{"name", "value", "only", "99", "x"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - err := formatter.Format(tt.data) - require.NoError(t, err) - - output := buf.String() - for _, s := range tt.contains { - assert.Contains(t, output, s) - } - }) - } -} - -func TestFormatter_CSV_EmptySlice(t *testing.T) { - type Item struct { - Name string `json:"name"` - } - - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - err := formatter.Format([]Item{}) - require.NoError(t, err) - - // Empty slice should produce no output - assert.Empty(t, buf.String()) -} - -func TestFormatter_CSV_SingleItem(t *testing.T) { - type Item struct { - ID string `json:"id"` - Name string `json:"name"` - } - - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - // Test single item (not in slice) - err := formatter.Format(Item{ID: "123", Name: "test"}) - require.NoError(t, err) - - output := buf.String() - assert.Contains(t, output, "id") - assert.Contains(t, output, "name") - assert.Contains(t, output, "123") - assert.Contains(t, output, "test") -} - -func TestFormatter_CSV_NonStructTypes(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - // Non-struct types should fall back to "value" header - err := formatter.Format("simple string") - require.NoError(t, err) - - output := buf.String() - assert.Contains(t, output, "value") - assert.Contains(t, output, "simple string") -} - -func TestGetCSVHeaders(t *testing.T) { - type TestStruct struct { - Public string `json:"public_field"` - NoTag string - SkipField string `json:"-"` - unexported string //nolint:unused - } - - tests := []struct { - name string - data any - expected []string - }{ - { - name: "struct with json tags", - data: TestStruct{Public: "val", NoTag: "val2"}, - expected: []string{"public_field", "NoTag"}, - }, - { - name: "non-struct returns value", - data: "string", - expected: []string{"value"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // We need to use reflection to test this internal function - // Test through Format instead - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - switch v := tt.data.(type) { - case TestStruct: - err := formatter.Format(v) - require.NoError(t, err) - output := buf.String() - for _, exp := range tt.expected { - assert.Contains(t, output, exp) - } - case string: - err := formatter.Format(v) - require.NoError(t, err) - output := buf.String() - assert.Contains(t, output, "value") - } - }) - } -} - -func TestFormatValue_SpecialTypes(t *testing.T) { - type ItemWithSlice struct { - Tags []string `json:"tags"` - } - - tests := []struct { - name string - data any - contains string - }{ - { - name: "slice field", - data: []ItemWithSlice{{Tags: []string{"a", "b", "c"}}}, - contains: "a; b; c", - }, - { - name: "empty slice field", - data: []ItemWithSlice{{Tags: []string{}}}, - contains: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - formatter := NewFormatter(FormatCSV).SetWriter(&buf) - - err := formatter.Format(tt.data) - require.NoError(t, err) - - output := buf.String() - if tt.contains != "" { - assert.Contains(t, output, tt.contains) - } - }) - } -} diff --git a/internal/cli/common/format_parse_test.go b/internal/cli/common/format_parse_test.go deleted file mode 100644 index 9e7b0ec..0000000 --- a/internal/cli/common/format_parse_test.go +++ /dev/null @@ -1,60 +0,0 @@ -//go:build !integration - -package common - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestParseFormat_AllFormats(t *testing.T) { - tests := []struct { - name string - input string - expected OutputFormat - hasError bool - }{ - // Table format variants - {"table lowercase", "table", FormatTable, false}, - {"table uppercase", "TABLE", FormatTable, false}, - {"table mixed case", "Table", FormatTable, false}, - {"empty defaults to table", "", FormatTable, false}, - - // JSON format variants - {"json lowercase", "json", FormatJSON, false}, - {"json uppercase", "JSON", FormatJSON, false}, - {"json mixed case", "Json", FormatJSON, false}, - - // CSV format variants - {"csv lowercase", "csv", FormatCSV, false}, - {"csv uppercase", "CSV", FormatCSV, false}, - {"csv mixed case", "Csv", FormatCSV, false}, - - // YAML format variants - {"yaml lowercase", "yaml", FormatYAML, false}, - {"yaml uppercase", "YAML", FormatYAML, false}, - {"yml shorthand", "yml", FormatYAML, false}, - {"YML uppercase", "YML", FormatYAML, false}, - - // Invalid formats - {"invalid format", "invalid", "", true}, - {"xml not supported", "xml", "", true}, - {"html not supported", "html", "", true}, - {"spaces not trimmed", " json ", "", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - format, err := ParseFormat(tt.input) - - if tt.hasError { - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid format") - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, format) - } - }) - } -} diff --git a/internal/cli/common/format_table_test.go b/internal/cli/common/format_table_test.go index 7583b65..fd074b0 100644 --- a/internal/cli/common/format_table_test.go +++ b/internal/cli/common/format_table_test.go @@ -11,8 +11,7 @@ import ( ) func TestTable_BasicOperations(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) t.Run("create and render table", func(t *testing.T) { var buf bytes.Buffer @@ -70,8 +69,7 @@ func TestTable_BasicOperations(t *testing.T) { } func TestTable_QuietMode(t *testing.T) { - ResetLogger() - InitLogger(false, true) // Enable quiet mode + SetQuiet(true) // Enable quiet mode var buf bytes.Buffer table := NewTable("HEADER").SetWriter(&buf) @@ -83,8 +81,7 @@ func TestTable_QuietMode(t *testing.T) { } func TestPrintFunctions_QuietMode(t *testing.T) { - ResetLogger() - InitLogger(false, true) // Enable quiet mode + SetQuiet(true) // Enable quiet mode // These should not panic in quiet mode PrintSuccess("success: %s", "test") @@ -93,8 +90,7 @@ func TestPrintFunctions_QuietMode(t *testing.T) { } func TestPrintError_AlwaysPrints(t *testing.T) { - ResetLogger() - InitLogger(false, true) // Enable quiet mode + SetQuiet(true) // Enable quiet mode // PrintError should print even in quiet mode (to stderr) // We can't easily capture stderr, so just verify no panic @@ -102,8 +98,7 @@ func TestPrintError_AlwaysPrints(t *testing.T) { } func TestTable_UTF8Support(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) t.Run("handles UTF-8 characters correctly", func(t *testing.T) { var buf bytes.Buffer @@ -136,8 +131,7 @@ func TestTable_UTF8Support(t *testing.T) { } func TestTable_MaxWidth(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) t.Run("truncates long text with max width", func(t *testing.T) { var buf bytes.Buffer @@ -284,8 +278,7 @@ func TestVisualWidth(t *testing.T) { } func TestTable_Alignment(t *testing.T) { - ResetLogger() - InitLogger(false, false) + SetQuiet(false) t.Run("columns align properly with varying widths", func(t *testing.T) { var buf bytes.Buffer diff --git a/internal/cli/common/logger.go b/internal/cli/common/logger.go index bc58bc6..3c846ab 100644 --- a/internal/cli/common/logger.go +++ b/internal/cli/common/logger.go @@ -1,126 +1,21 @@ // Package common provides shared utilities for CLI commands. package common -import ( - "io" - "log/slog" - "os" - "sync" -) +import "sync/atomic" -var ( - logger *slog.Logger - loggerOnce sync.Once - debugMode bool - quietMode bool -) +// quietMode suppresses decorative output (success messages, tables, spinners, +// progress) process-wide. It is set from the --quiet flag at command startup +// (see the root PersistentPreRunE); the OutputWriter handles structured data +// separately. atomic.Bool because it is written at startup while spinner/ +// progress goroutines read it concurrently. +var quietMode atomic.Bool -// LogLevel represents logging levels. -type LogLevel int - -const ( - LogLevelError LogLevel = iota - LogLevelWarn - LogLevelInfo - LogLevelDebug -) - -// InitLogger initializes the global logger with the specified options. -func InitLogger(debug, quiet bool) { - loggerOnce.Do(func() { - debugMode = debug - quietMode = quiet - - var level slog.Level - if debug { - level = slog.LevelDebug - } else { - level = slog.LevelInfo - } - - var output io.Writer = os.Stderr - if quiet { - output = io.Discard - } - - opts := &slog.HandlerOptions{ - Level: level, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - // Remove time from output for cleaner CLI display - if a.Key == slog.TimeKey { - return slog.Attr{} - } - return a - }, - } - - handler := slog.NewTextHandler(output, opts) - logger = slog.New(handler) - }) -} - -// ResetLogger resets the logger (for testing). -func ResetLogger() { - loggerOnce = sync.Once{} - logger = nil - debugMode = false - quietMode = false -} - -// GetLogger returns the global logger, initializing with defaults if needed. -func GetLogger() *slog.Logger { - if logger == nil { - InitLogger(false, false) - } - return logger -} - -// IsDebug returns true if debug mode is enabled. -func IsDebug() bool { - return debugMode +// SetQuiet enables or disables quiet mode. +func SetQuiet(quiet bool) { + quietMode.Store(quiet) } // IsQuiet returns true if quiet mode is enabled. func IsQuiet() bool { - return quietMode -} - -// Debug logs a debug message. -func Debug(msg string, args ...any) { - GetLogger().Debug(msg, args...) -} - -// Info logs an info message. -func Info(msg string, args ...any) { - GetLogger().Info(msg, args...) -} - -// Warn logs a warning message. -func Warn(msg string, args ...any) { - GetLogger().Warn(msg, args...) -} - -// Error logs an error message. -func Error(msg string, args ...any) { - GetLogger().Error(msg, args...) -} - -// DebugHTTP logs HTTP request/response details in debug mode. -func DebugHTTP(method, url string, statusCode int, duration string) { - if debugMode { - Debug("HTTP request", - "method", method, - "url", url, - "status", statusCode, - "duration", duration, - ) - } -} - -// DebugAPI logs API operation details in debug mode. -func DebugAPI(operation string, args ...any) { - if debugMode { - allArgs := append([]any{"operation", operation}, args...) - Debug("API call", allArgs...) - } + return quietMode.Load() } diff --git a/internal/cli/common/pagination.go b/internal/cli/common/pagination.go index 738062a..13be030 100644 --- a/internal/cli/common/pagination.go +++ b/internal/cli/common/pagination.go @@ -186,52 +186,3 @@ func SetupPagination(limit int, fetchAll bool, maxItems int) PaginationLimits { Mode: PaginateSinglePage, } } - -// FetchAllWithProgress fetches all pages and shows a progress indicator. -func FetchAllWithProgress[T any](ctx context.Context, fetcher PageFetcher[T], maxItems int) ([]T, error) { - config := DefaultPaginationConfig() - config.MaxItems = maxItems - return FetchAllPages(ctx, config, fetcher) -} - -// PaginatedDisplay handles displaying paginated results with optional streaming. -type PaginatedDisplay struct { - PageSize int - CurrentPage int - TotalFetched int - Writer io.Writer -} - -// NewPaginatedDisplay creates a new paginated display helper. -func NewPaginatedDisplay(pageSize int) *PaginatedDisplay { - return &PaginatedDisplay{ - PageSize: pageSize, - CurrentPage: 0, - Writer: os.Stdout, - } -} - -// SetWriter sets the output writer. -func (p *PaginatedDisplay) SetWriter(w io.Writer) *PaginatedDisplay { - p.Writer = w - return p -} - -// DisplayPage shows a summary after displaying items. -func (p *PaginatedDisplay) DisplayPage(itemsDisplayed int, hasMore bool) { - p.CurrentPage++ - p.TotalFetched += itemsDisplayed - - if !IsQuiet() && hasMore { - _, _ = fmt.Fprintf(p.Writer, "\n--- Page %d (%d items, %d total) ---\n", - p.CurrentPage, itemsDisplayed, p.TotalFetched) - } -} - -// DisplaySummary shows a final summary. -func (p *PaginatedDisplay) DisplaySummary() { - if !IsQuiet() && p.CurrentPage > 1 { - _, _ = fmt.Fprintf(p.Writer, "\nFetched %d items across %d pages\n", - p.TotalFetched, p.CurrentPage) - } -} diff --git a/internal/cli/common/pagination_test.go b/internal/cli/common/pagination_test.go index e13a75d..0069af1 100644 --- a/internal/cli/common/pagination_test.go +++ b/internal/cli/common/pagination_test.go @@ -3,7 +3,6 @@ package common import ( - "bytes" "context" "errors" "strconv" @@ -66,8 +65,7 @@ func TestNormalizePageSize(t *testing.T) { } func TestFetchAllPages_SinglePage(t *testing.T) { - ResetLogger() - InitLogger(false, true) // Quiet mode to avoid progress output + SetQuiet(true) // Quiet mode to avoid progress output config := DefaultPaginationConfig() config.ShowProgress = false @@ -92,8 +90,7 @@ func TestFetchAllPages_SinglePage(t *testing.T) { } func TestFetchAllPages_MultiplePages(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -131,8 +128,7 @@ func TestFetchAllPages_MultiplePages(t *testing.T) { } func TestFetchAllPages_MaxItems(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -152,8 +148,7 @@ func TestFetchAllPages_MaxItems(t *testing.T) { } func TestFetchAllPages_MaxPages(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -176,8 +171,7 @@ func TestFetchAllPages_MaxPages(t *testing.T) { } func TestFetchAllPages_FetcherError(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -202,8 +196,7 @@ func TestFetchAllPages_FetcherError(t *testing.T) { } func TestFetchAllPages_ContextCancellation(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -230,8 +223,7 @@ func TestFetchAllPages_ContextCancellation(t *testing.T) { } func TestFetchAllPages_EmptyFirstPage(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -249,43 +241,8 @@ func TestFetchAllPages_EmptyFirstPage(t *testing.T) { assert.Empty(t, results) } -func TestFetchAllWithProgress(t *testing.T) { - ResetLogger() - InitLogger(false, true) - - fetcher := func(ctx context.Context, cursor string) (PageResult[string], error) { - return PageResult[string]{ - Data: []string{"a", "b", "c"}, - NextCursor: "", - }, nil - } - - results, err := FetchAllWithProgress(context.Background(), fetcher, 0) - - require.NoError(t, err) - assert.Equal(t, 3, len(results)) -} - -func TestFetchAllWithProgress_WithMaxItems(t *testing.T) { - ResetLogger() - InitLogger(false, true) - - fetcher := func(ctx context.Context, cursor string) (PageResult[string], error) { - return PageResult[string]{ - Data: []string{"a", "b", "c", "d", "e"}, - NextCursor: "more", - }, nil - } - - results, err := FetchAllWithProgress(context.Background(), fetcher, 2) - - require.NoError(t, err) - assert.Equal(t, 2, len(results)) -} - func TestFetchCursorPages(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) t.Run("caps results across multiple pages", func(t *testing.T) { fetcher := func(ctx context.Context, cursor string) (PageResult[string], error) { @@ -311,77 +268,8 @@ func TestFetchCursorPages(t *testing.T) { }) } -func TestPaginatedDisplay_Operations(t *testing.T) { - ResetLogger() - InitLogger(false, false) // Not quiet for display output - - t.Run("display page", func(t *testing.T) { - var buf bytes.Buffer - display := NewPaginatedDisplay(10).SetWriter(&buf) - - display.DisplayPage(5, true) - - assert.Equal(t, 1, display.CurrentPage) - assert.Equal(t, 5, display.TotalFetched) - assert.Contains(t, buf.String(), "Page 1") - }) - - t.Run("display multiple pages", func(t *testing.T) { - var buf bytes.Buffer - display := NewPaginatedDisplay(10).SetWriter(&buf) - - display.DisplayPage(10, true) - display.DisplayPage(10, true) - display.DisplayPage(5, false) // Last page - - assert.Equal(t, 3, display.CurrentPage) - assert.Equal(t, 25, display.TotalFetched) - }) - - t.Run("display summary", func(t *testing.T) { - var buf bytes.Buffer - display := NewPaginatedDisplay(10).SetWriter(&buf) - - display.DisplayPage(10, true) - display.DisplayPage(5, false) - display.DisplaySummary() - - output := buf.String() - assert.Contains(t, output, "15 items") - assert.Contains(t, output, "2 pages") - }) - - t.Run("no summary for single page", func(t *testing.T) { - var buf bytes.Buffer - display := NewPaginatedDisplay(10).SetWriter(&buf) - - display.DisplayPage(5, false) // Single page, no more - buf.Reset() - display.DisplaySummary() - - // Summary should not be shown for single page - assert.Empty(t, buf.String()) - }) -} - -func TestPaginatedDisplay_QuietMode(t *testing.T) { - ResetLogger() - InitLogger(false, true) // Quiet mode - - var buf bytes.Buffer - display := NewPaginatedDisplay(10).SetWriter(&buf) - - display.DisplayPage(10, true) - display.DisplayPage(5, false) - display.DisplaySummary() - - // In quiet mode, should not produce output - assert.Empty(t, buf.String()) -} - func TestFetchAllPages_WithProgress(t *testing.T) { - ResetLogger() - InitLogger(false, false) // Not quiet + SetQuiet(false) // Not quiet config := DefaultPaginationConfig() config.ShowProgress = true @@ -409,8 +297,7 @@ func TestFetchAllPages_WithProgress(t *testing.T) { } func TestFetchAllPages_ContextDeadline(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -522,8 +409,7 @@ func TestPaginationMode(t *testing.T) { } func TestFetchAllPages_StuckCursor(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false @@ -548,8 +434,7 @@ func TestFetchAllPages_StuckCursor(t *testing.T) { } func TestFetchAllPages_EmptyPageClaimingMore(t *testing.T) { - ResetLogger() - InitLogger(false, true) + SetQuiet(true) config := DefaultPaginationConfig() config.ShowProgress = false diff --git a/internal/cli/common/progress.go b/internal/cli/common/progress.go index c88a1b8..3c5a671 100644 --- a/internal/cli/common/progress.go +++ b/internal/cli/common/progress.go @@ -48,12 +48,6 @@ func NewSpinner(message string) *Spinner { } } -// SetFrames sets the spinner animation frames. -func (s *Spinner) SetFrames(frames []string) *Spinner { - s.frames = frames - return s -} - // SetWriter sets the output writer. func (s *Spinner) SetWriter(w io.Writer) *Spinner { s.writer = w @@ -110,14 +104,6 @@ func (s *Spinner) Stop() { <-s.done } -// StopWithMessage stops the spinner and prints a final message. -func (s *Spinner) StopWithMessage(message string) { - s.Stop() - if !IsQuiet() { - _, _ = fmt.Fprintln(s.writer, message) - } -} - // StopWithSuccess stops the spinner with a success message. func (s *Spinner) StopWithSuccess(message string) { s.Stop() @@ -134,129 +120,6 @@ func (s *Spinner) StopWithError(message string) { } } -// ProgressBar provides a progress bar for determinate operations. -type ProgressBar struct { - total int - current int - width int - message string - writer io.Writer - startTime time.Time - mu sync.Mutex -} - -// NewProgressBar creates a new progress bar. -func NewProgressBar(total int, message string) *ProgressBar { - return &ProgressBar{ - total: total, - current: 0, - width: 40, - message: message, - writer: os.Stderr, - startTime: time.Now(), - } -} - -// SetWidth sets the progress bar width. -func (p *ProgressBar) SetWidth(width int) *ProgressBar { - p.width = width - return p -} - -// SetWriter sets the output writer. -func (p *ProgressBar) SetWriter(w io.Writer) *ProgressBar { - p.writer = w - return p -} - -// Increment increments the progress by 1. -func (p *ProgressBar) Increment() { - p.Add(1) -} - -// Add adds n to the current progress. -func (p *ProgressBar) Add(n int) { - p.mu.Lock() - defer p.mu.Unlock() - - p.current += n - if p.current > p.total { - p.current = p.total - } - - p.render() -} - -// Set sets the current progress. -func (p *ProgressBar) Set(n int) { - p.mu.Lock() - defer p.mu.Unlock() - - p.current = n - if p.current > p.total { - p.current = p.total - } - - p.render() -} - -// render draws the progress bar. -func (p *ProgressBar) render() { - if IsQuiet() { - return - } - - percent := float64(p.current) / float64(p.total) - filled := int(percent * float64(p.width)) - empty := p.width - filled - - // Calculate ETA - elapsed := time.Since(p.startTime) - var eta string - if p.current > 0 { - remaining := time.Duration(float64(elapsed) / percent * (1 - percent)) - if remaining > time.Second { - eta = fmt.Sprintf(" ETA: %s", formatDuration(remaining)) - } - } - - bar := strings.Repeat("█", filled) + strings.Repeat("░", empty) - - _, _ = fmt.Fprintf(p.writer, "\r%s %s %s %d/%d (%.0f%%)%s", - p.message, - Cyan.Sprint("["), - bar, - p.current, - p.total, - percent*100, - eta, - ) - - if p.current >= p.total { - _, _ = fmt.Fprintln(p.writer) - } -} - -// Finish completes the progress bar. -func (p *ProgressBar) Finish() { - p.mu.Lock() - defer p.mu.Unlock() - - p.current = p.total - p.render() -} - -// formatDuration formats a duration for display. -func formatDuration(d time.Duration) string { - if d < time.Minute { - return fmt.Sprintf("%ds", int(d.Seconds())) - } - if d < time.Hour { - return fmt.Sprintf("%dm%ds", int(d.Minutes()), int(d.Seconds())%60) - } - return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60) -} - // Counter provides a simple counter display. type Counter struct { message string @@ -292,13 +155,6 @@ func (c *Counter) Finish() { } } -// Count returns the current count. -func (c *Counter) Count() int { - c.mu.Lock() - defer c.mu.Unlock() - return c.count -} - // RunWithSpinner executes a function while displaying a spinner. // It handles spinner start/stop and error propagation. func RunWithSpinner(message string, fn func() error) error { diff --git a/internal/cli/common/retry.go b/internal/cli/common/retry.go index 3c510ed..9cd556a 100644 --- a/internal/cli/common/retry.go +++ b/internal/cli/common/retry.go @@ -101,7 +101,6 @@ func WithRetry(ctx context.Context, config RetryConfig, fn RetryFunc) error { // Check if we should retry if !IsRetryable(err) { - Debug("not retrying non-retryable error", "error", err) return err } @@ -113,13 +112,6 @@ func WithRetry(ctx context.Context, config RetryConfig, fn RetryFunc) error { // Calculate delay with exponential backoff and jitter delay := calculateDelay(config, attempt) - Debug("retrying after error", - "attempt", attempt+1, - "max_retries", config.MaxRetries, - "delay", delay, - "error", err, - ) - // Wait before retrying select { case <-ctx.Done(): diff --git a/internal/cli/common/string.go b/internal/cli/common/string.go index 8b9d2a6..1f2b59c 100644 --- a/internal/cli/common/string.go +++ b/internal/cli/common/string.go @@ -1,7 +1,5 @@ package common -import "strings" - // Truncate shortens a string to maxLen characters, adding "..." if truncated. func Truncate(s string, maxLen int) string { if len(s) <= maxLen { @@ -12,14 +10,3 @@ func Truncate(s string, maxLen int) string { } return s[:maxLen-3] + "..." } - -// ExtractDomain extracts the domain portion from an email address. -// For example, "info@qasim.nylas.email" returns "qasim.nylas.email". -// Returns empty string if the email format is invalid. -func ExtractDomain(email string) string { - parts := strings.Split(email, "@") - if len(parts) == 2 { - return parts[1] - } - return "" -} diff --git a/internal/cli/common/string_test.go b/internal/cli/common/string_test.go index 412e0b5..c347e01 100644 --- a/internal/cli/common/string_test.go +++ b/internal/cli/common/string_test.go @@ -74,66 +74,3 @@ func TestTruncate(t *testing.T) { }) } } - -func TestExtractDomain(t *testing.T) { - tests := []struct { - name string - email string - want string - }{ - { - name: "standard email", - email: "user@example.com", - want: "example.com", - }, - { - name: "nylas inbox email", - email: "info@qasim.nylas.email", - want: "qasim.nylas.email", - }, - { - name: "subdomain email", - email: "test@subdomain.domain.com", - want: "subdomain.domain.com", - }, - { - name: "no @ symbol", - email: "invalidemail", - want: "", - }, - { - name: "multiple @ symbols", - email: "user@domain@extra.com", - want: "", - }, - { - name: "empty string", - email: "", - want: "", - }, - { - name: "only @ symbol", - email: "@", - want: "", - }, - { - name: "@ at start", - email: "@domain.com", - want: "domain.com", - }, - { - name: "@ at end", - email: "user@", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ExtractDomain(tt.email) - if got != tt.want { - t.Errorf("ExtractDomain(%q) = %q, want %q", tt.email, got, tt.want) - } - }) - } -} diff --git a/internal/cli/common/timeutil.go b/internal/cli/common/timeutil.go index 75b1671..b89b8a5 100644 --- a/internal/cli/common/timeutil.go +++ b/internal/cli/common/timeutil.go @@ -1,7 +1,5 @@ package common -import "time" - // Standard date/time format constants. const ( // Machine-readable formats @@ -37,23 +35,3 @@ const ( ShortDateTime = "Jan 2 15:04" ShortDate = "Jan 2" ) - -// ParseDate parses a date string in YYYY-MM-DD format. -func ParseDate(s string) (time.Time, error) { - return time.Parse(DateFormat, s) -} - -// ParseTime parses a time string in HH:MM format. -func ParseTime(s string) (time.Time, error) { - return time.Parse(TimeFormat, s) -} - -// FormatDate formats a time as a date string (YYYY-MM-DD). -func FormatDate(t time.Time) string { - return t.Format(DateFormat) -} - -// FormatDisplayDate formats a time for user display. -func FormatDisplayDate(t time.Time) string { - return t.Format(DisplayDateTime) -} diff --git a/internal/cli/common/validation.go b/internal/cli/common/validation.go index f40f957..e55176c 100644 --- a/internal/cli/common/validation.go +++ b/internal/cli/common/validation.go @@ -2,7 +2,6 @@ package common import ( "fmt" - "net/url" "strings" ) @@ -42,45 +41,6 @@ func ValidateRequiredFlag(flagName, value string) error { return nil } -// ValidateRequiredArg returns an error if args is empty or first arg is empty. -// Use for commands that require at least one argument. -// -// Example: -// -// if err := common.ValidateRequiredArg(args, "message ID"); err != nil { -// return err -// } -func ValidateRequiredArg(args []string, name string) error { - if len(args) == 0 || args[0] == "" { - return NewUserError( - fmt.Sprintf("%s is required", name), - fmt.Sprintf("Provide %s as an argument", name), - ) - } - return nil -} - -// ValidateURL returns an error if value is not a valid URL. -// -// Example: -// -// if err := common.ValidateURL("webhook URL", webhookURL); err != nil { -// return err -// } -func ValidateURL(name, value string) error { - if value == "" { - return ValidateRequired(name, value) - } - u, err := url.Parse(value) - if err != nil || (u.Scheme != "http" && u.Scheme != "https") || u.Host == "" { - return NewUserError( - fmt.Sprintf("invalid %s: %s", name, value), - "URL must be a valid HTTP or HTTPS URL", - ) - } - return nil -} - // ValidateEmail returns an error if value doesn't look like an email address. // This is a basic check for @ symbol, not RFC 5322 compliant. // diff --git a/internal/cli/dashboard/exports.go b/internal/cli/dashboard/exports.go index 884664d..c7c99f4 100644 --- a/internal/cli/dashboard/exports.go +++ b/internal/cli/dashboard/exports.go @@ -2,7 +2,6 @@ package dashboard import ( dashboardapp "github.com/nylas/cli/internal/app/dashboard" - "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/ports" ) @@ -36,20 +35,3 @@ func ActivateAPIKey(apiKey, clientID, region, orgID string) error { func GetActiveOrgID() (string, error) { return getActiveOrgID() } - -// SyncSessionOrg syncs the active org from the server session (exported for setup wizard). -// Failures are logged as warnings rather than returned, since this is a -// best-effort step that should not block an otherwise successful login. -func SyncSessionOrg() { - authSvc, _, err := createAuthService() - if err != nil { - common.PrintWarning("failed to create auth service for org sync: %v", err) - return - } - syncSessionOrgWithWarning(authSvc) -} - -// ReadLine prompts for a line of text input (exported for setup wizard). -func ReadLine(prompt string) (string, error) { - return common.InputPrompt(prompt, "") -} diff --git a/internal/cli/doctor_checks.go b/internal/cli/doctor_checks.go index 61a6bae..c825143 100644 --- a/internal/cli/doctor_checks.go +++ b/internal/cli/doctor_checks.go @@ -12,6 +12,7 @@ import ( "github.com/nylas/cli/internal/adapters/nylas" "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" ) @@ -204,7 +205,7 @@ func checkNetworkConnectivity() CheckResult { } } - client := &http.Client{Timeout: domain.TimeoutHealthCheck} + client := httputil.DefaultClient start := time.Now() resp, err := client.Do(req) latency := time.Since(start) diff --git a/internal/cli/email/attachments.go b/internal/cli/email/attachments.go index 040e037..2ac744e 100644 --- a/internal/cli/email/attachments.go +++ b/internal/cli/email/attachments.go @@ -8,7 +8,7 @@ import ( "strings" "github.com/nylas/cli/internal/cli/common" - "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" "github.com/spf13/cobra" ) @@ -171,10 +171,9 @@ func newAttachmentsDownloadCmd() *cobra.Command { return struct{}{}, common.NewInputError(fmt.Sprintf("output path is a directory: %s", finalOutputPath)) } - // Download the attachment with a dedicated long timeout: - // the WithClient context carries the default API timeout, - // which would cut off large downloads mid-stream. - dlCtx, dlCancel := common.CreateContextWithTimeout(domain.TimeoutDownload) + // Download the attachment with the standard 120s timeout, + // matching the Nylas server-side request ceiling. + dlCtx, dlCancel := common.CreateContextWithTimeout(httputil.DefaultClientTimeout) defer dlCancel() reader, err := client.DownloadAttachment(dlCtx, grantID, messageID, attachmentID) if err != nil { diff --git a/internal/cli/email/search_test.go b/internal/cli/email/search_test.go index b013698..bcaf6e9 100644 --- a/internal/cli/email/search_test.go +++ b/internal/cli/email/search_test.go @@ -50,9 +50,8 @@ func (s *stubMessagesClient) GetMessagesWithCursor(ctx context.Context, grantID } func TestFetchMessages(t *testing.T) { - common.ResetLogger() - common.InitLogger(false, true) - defer common.ResetLogger() + common.SetQuiet(true) + defer common.SetQuiet(false) t.Run("uses direct fetch when maxItems is negative", func(t *testing.T) { expected := []domain.Message{{ID: "msg-1"}, {ID: "msg-2"}} diff --git a/internal/cli/email/send_gpg.go b/internal/cli/email/send_gpg.go index 7648eab..fb848fd 100644 --- a/internal/cli/email/send_gpg.go +++ b/internal/cli/email/send_gpg.go @@ -162,7 +162,7 @@ func sendSecureEmail(ctx context.Context, client ports.NylasClient, grantID stri } // resolveSigningKey determines the signing key to use. -func resolveSigningKey(ctx context.Context, gpgSvc gpg.Service, explicitKeyID string, req *domain.SendMessageRequest) (keyID, identity string) { +func resolveSigningKey(ctx context.Context, gpgSvc *gpg.Service, explicitKeyID string, req *domain.SendMessageRequest) (keyID, identity string) { if explicitKeyID != "" { return explicitKeyID, explicitKeyID } @@ -193,7 +193,7 @@ func resolveSigningKey(ctx context.Context, gpgSvc gpg.Service, explicitKeyID st } // resolveRecipientKeys determines the encryption keys for all recipients. -func resolveRecipientKeys(ctx context.Context, gpgSvc gpg.Service, explicitKeyID string, to, cc, bcc []domain.EmailParticipant) ([]string, error) { +func resolveRecipientKeys(ctx context.Context, gpgSvc *gpg.Service, explicitKeyID string, to, cc, bcc []domain.EmailParticipant) ([]string, error) { // If explicit key provided, use it if explicitKeyID != "" { return []string{explicitKeyID}, nil @@ -237,7 +237,7 @@ func resolveRecipientKeys(ctx context.Context, gpgSvc gpg.Service, explicitKeyID } // buildSignedMessage builds a signed-only PGP/MIME message. -func buildSignedMessage(ctx context.Context, gpgSvc gpg.Service, mimeBuilder mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType, signerKeyID, signingIdentity string) ([]byte, error) { +func buildSignedMessage(ctx context.Context, gpgSvc *gpg.Service, mimeBuilder *mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType, signerKeyID, signingIdentity string) ([]byte, error) { spinner := common.NewSpinner(fmt.Sprintf("Signing email with GPG identity: %s...", signingIdentity)) spinner.Start() defer spinner.Stop() @@ -285,7 +285,7 @@ func buildSignedMessage(ctx context.Context, gpgSvc gpg.Service, mimeBuilder mim } // buildEncryptedMessage builds an encrypted-only PGP/MIME message. -func buildEncryptedMessage(ctx context.Context, gpgSvc gpg.Service, mimeBuilder mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType string, recipientKeyIDs []string) ([]byte, error) { +func buildEncryptedMessage(ctx context.Context, gpgSvc *gpg.Service, mimeBuilder *mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType string, recipientKeyIDs []string) ([]byte, error) { spinner := common.NewSpinner("Encrypting email...") spinner.Start() defer spinner.Stop() @@ -323,7 +323,7 @@ func buildEncryptedMessage(ctx context.Context, gpgSvc gpg.Service, mimeBuilder // buildSignedEncryptedMessage builds a signed AND encrypted PGP/MIME message. // Order: Sign first, then encrypt (per OpenPGP best practice). -func buildSignedEncryptedMessage(ctx context.Context, gpgSvc gpg.Service, mimeBuilder mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType, signerKeyID string, recipientKeyIDs []string) ([]byte, error) { +func buildSignedEncryptedMessage(ctx context.Context, gpgSvc *gpg.Service, mimeBuilder *mime.Builder, req *domain.SendMessageRequest, toContacts []domain.EmailParticipant, subject, body, contentType, signerKeyID string, recipientKeyIDs []string) ([]byte, error) { spinner := common.NewSpinner("Signing and encrypting email...") spinner.Start() defer spinner.Stop() diff --git a/internal/cli/email/templates.go b/internal/cli/email/templates.go index ac63777..6ed07c5 100644 --- a/internal/cli/email/templates.go +++ b/internal/cli/email/templates.go @@ -3,7 +3,6 @@ package email import ( "github.com/nylas/cli/internal/adapters/templates" "github.com/nylas/cli/internal/cli/common" - "github.com/nylas/cli/internal/ports" "github.com/spf13/cobra" ) @@ -47,6 +46,6 @@ Templates are local to your machine and are not synced with Nylas.`, } // getTemplateStore creates a template store instance. -func getTemplateStore() ports.TemplateStore { +func getTemplateStore() *templates.FileStore { return templates.NewDefaultFileStore() } diff --git a/internal/cli/integration/ai_pattern_learning_test.go b/internal/cli/integration/ai_pattern_learning_test.go deleted file mode 100644 index 24cb79b..0000000 --- a/internal/cli/integration/ai_pattern_learning_test.go +++ /dev/null @@ -1,343 +0,0 @@ -//go:build integration - -package integration - -import ( - "context" - "testing" - "time" - - "github.com/nylas/cli/internal/adapters/ai" - "github.com/nylas/cli/internal/domain" -) - -// TestAI_PatternLearning tests the pattern learning functionality. -func TestAI_PatternLearning(t *testing.T) { - skipIfMissingCreds(t) - - t.Run("learn_patterns_from_calendar_history", func(t *testing.T) { - // Create Nylas client - client := getTestClient() - - // Create LLM router with Ollama default - cfg := &domain.AIConfig{ - DefaultProvider: "ollama", - Ollama: &domain.OllamaConfig{ - Host: "http://localhost:11434", - Model: "llama3.1:8b", - }, - } - llmRouter := ai.NewRouter(cfg) - - // Create pattern learner - learner := ai.NewPatternLearner(client, llmRouter) - - // Create learning request (analyze last 30 days) - req := &ai.LearnPatternsRequest{ - GrantID: testGrantID, - LookbackDays: 30, - MinConfidence: 0.5, - IncludeRecurring: false, - } - - // Learn patterns - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - t.Log("Analyzing calendar history for patterns...") - patterns, err := learner.LearnPatterns(ctx, req) - if err != nil { - // No events is expected if the test account doesn't have recent calendar data - if err.Error() == "no events found in the specified period" { - t.Logf("No events found (expected for empty test account): %v", err) - return - } - t.Fatalf("Failed to learn patterns: %v", err) - } - - // Verify patterns structure - if patterns == nil { - t.Fatal("Expected patterns, got nil") - } - - t.Logf("Analysis Period: %d days (%s to %s)", - patterns.AnalysisPeriod.Days, - patterns.AnalysisPeriod.StartDate.Format("Jan 2"), - patterns.AnalysisPeriod.EndDate.Format("Jan 2")) - t.Logf("Total Events Analyzed: %d", patterns.TotalEventsAnalyzed) - - // Test acceptance patterns - if len(patterns.AcceptancePatterns) > 0 { - t.Logf("Found %d acceptance patterns:", len(patterns.AcceptancePatterns)) - for i, pattern := range patterns.AcceptancePatterns { - if i >= 3 { - break // Show top 3 - } - t.Logf(" - %s: %.0f%% acceptance rate (%d events, %.0f%% confidence)", - pattern.TimeSlot, - pattern.AcceptRate*100, - pattern.EventCount, - pattern.Confidence*100) - } - } else { - t.Log("No acceptance patterns found (might need more calendar history)") - } - - // Test duration patterns - if len(patterns.DurationPatterns) > 0 { - t.Logf("Found %d duration patterns:", len(patterns.DurationPatterns)) - for i, pattern := range patterns.DurationPatterns { - if i >= 3 { - break // Show top 3 - } - t.Logf(" - %s: avg %d min (%d events)", - pattern.MeetingType, - pattern.ScheduledDuration, - pattern.EventCount) - } - } - - // Test timezone patterns - if len(patterns.TimezonePatterns) > 0 { - t.Logf("Found %d timezone patterns:", len(patterns.TimezonePatterns)) - for i, pattern := range patterns.TimezonePatterns { - if i >= 3 { - break // Show top 3 - } - t.Logf(" - %s: %.0f%% of meetings (%d events)", - pattern.Timezone, - pattern.Percentage*100, - pattern.EventCount) - } - } - - // Test productivity insights - if len(patterns.ProductivityInsights) > 0 { - t.Logf("Found %d productivity insights:", len(patterns.ProductivityInsights)) - for _, insight := range patterns.ProductivityInsights { - t.Logf(" - %s: %s (score: %d/100)", - insight.InsightType, - insight.Description, - insight.Score) - } - } - - // Test recommendations - if len(patterns.Recommendations) > 0 { - t.Logf("AI Recommendations (%d):", len(patterns.Recommendations)) - for i, rec := range patterns.Recommendations { - if i >= 5 { - break // Show top 5 - } - t.Logf(" %d. %s", i+1, rec) - } - } - - // Verify basic structure - if patterns.UserID == "" { - t.Error("Expected UserID to be set") - } - if patterns.GeneratedAt.IsZero() { - t.Error("Expected GeneratedAt to be set") - } - }) - - t.Run("pattern_learning_with_no_events", func(t *testing.T) { - // Create Nylas client - client := getTestClient() - - // Create LLM router - cfg := &domain.AIConfig{ - DefaultProvider: "ollama", - Ollama: &domain.OllamaConfig{ - Host: "http://localhost:11434", - Model: "llama3.1:8b", - }, - } - llmRouter := ai.NewRouter(cfg) - - // Create pattern learner - learner := ai.NewPatternLearner(client, llmRouter) - - // Create learning request with very short lookback (likely no events) - req := &ai.LearnPatternsRequest{ - GrantID: testGrantID, - LookbackDays: 1, // Only 1 day - MinConfidence: 0.9, // High confidence threshold - IncludeRecurring: false, - } - - // Learn patterns - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - patterns, err := learner.LearnPatterns(ctx, req) - - // Should handle empty/small dataset gracefully - if err != nil && patterns == nil { - t.Logf("No patterns found with minimal data (expected): %v", err) - return - } - - if patterns != nil { - t.Logf("Found patterns even with minimal data: %d events analyzed", - patterns.TotalEventsAnalyzed) - } - }) - - t.Run("pattern_export_to_json", func(t *testing.T) { - // Create Nylas client - client := getTestClient() - - // Create LLM router - cfg := &domain.AIConfig{ - DefaultProvider: "ollama", - Ollama: &domain.OllamaConfig{ - Host: "http://localhost:11434", - Model: "llama3.1:8b", - }, - } - llmRouter := ai.NewRouter(cfg) - - // Create pattern learner - learner := ai.NewPatternLearner(client, llmRouter) - - // Create learning request - req := &ai.LearnPatternsRequest{ - GrantID: testGrantID, - LookbackDays: 30, - MinConfidence: 0.5, - IncludeRecurring: false, - } - - // Learn patterns - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - patterns, err := learner.LearnPatterns(ctx, req) - if err != nil { - // No events is expected if the test account doesn't have recent calendar data - if err.Error() == "no events found in the specified period" { - t.Logf("No events found (expected for empty test account): %v", err) - return - } - t.Fatalf("Failed to learn patterns: %v", err) - } - - // Export to JSON - jsonData, err := learner.ExportPatterns(patterns) - if err != nil { - t.Fatalf("Failed to export patterns: %v", err) - } - - if len(jsonData) == 0 { - t.Error("Expected JSON data, got empty") - } - - t.Logf("Exported %d bytes of JSON pattern data", len(jsonData)) - - // Verify JSON structure (basic check) - jsonStr := string(jsonData) - if !containsStr(jsonStr, "user_id") { - t.Error("JSON should contain 'user_id' field") - } - if !containsStr(jsonStr, "analysis_period") { - t.Error("JSON should contain 'analysis_period' field") - } - }) -} - -// TestCLI_AI_PatternAnalyze tests the CLI analyze command. -func TestCLI_AI_PatternAnalyze(t *testing.T) { - skipIfMissingCreds(t) - - // Check if Ollama is available - if !checkOllamaAvailable() { - t.Skip("Ollama not available - skipping pattern analysis test") - } - - t.Run("cli_analyze_patterns", func(t *testing.T) { - t.Log("Running: nylas calendar ai analyze --days 30") - - stdout, stderr, err := runCLI("calendar", "ai", "analyze", - "--days", "30") - - output := stdout + stderr - - // Command may timeout with context canceled - that's okay if we got partial output - if err != nil && !containsStr(output, "Analysis Period") { - // Only fail if we got no output at all - if len(output) < 100 { - t.Fatalf("Command failed with no output: %v\nOutput: %s", err, output) - } - t.Logf("Command completed with error but produced output: %v", err) - } - - // Verify output contains expected sections - if !containsStr(output, "Analysis Period") && !containsStr(output, "Analyzing") { - t.Error("Output should contain 'Analysis Period' or 'Analyzing' text") - } - - t.Logf("Pattern Analysis Output:\n%s", truncateOutput(output, 500)) - }) - - t.Run("cli_analyze_patterns_json", func(t *testing.T) { - t.Log("Running: nylas calendar ai analyze --json") - - stdout, stderr, err := runCLI("calendar", "ai", "analyze", - "--days", "30", - "--json") - - output := stdout + stderr - - // Command may show insufficient data warning - that's expected - if containsStr(output, "Insufficient data") || containsStr(output, "no events found") { - t.Logf("No events found (expected): %s", output) - return - } - - // Command may timeout but still produce output - that's okay - if err != nil && !containsStr(output, "Analysis Period") { - // Only fail if we got no output at all - if len(output) < 100 { - t.Fatalf("Command failed with no output: %v\nOutput: %s", err, output) - } - t.Logf("Command completed with error but produced output: %v", err) - } - - // Note: --json flag may not be fully implemented yet for analyze command - // Just verify we got some output with analysis data - if !containsStr(output, "Analysis Period") && !containsStr(output, "Total Meetings") { - t.Error("Output should contain analysis data") - } - - t.Logf("Output (first 500 chars):\n%s", truncateOutput(output, 500)) - }) -} - -// Helper functions - -func containsStr(s, substr string) bool { - // Simple substring check using strings package would be better, - // but keeping it simple for now - for i := 0; i <= len(s)-len(substr); i++ { - match := true - for j := 0; j < len(substr); j++ { - if s[i+j] != substr[j] { - match = false - break - } - } - if match { - return true - } - } - return false -} - -func truncateOutput(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] + "..." -} diff --git a/internal/cli/integration/slack_operations_test.go b/internal/cli/integration/slack_operations_test.go index 6193603..16f8fd6 100644 --- a/internal/cli/integration/slack_operations_test.go +++ b/internal/cli/integration/slack_operations_test.go @@ -9,57 +9,6 @@ import ( "time" ) -// ============================================================================= -// SLACK USERS TESTS -// ============================================================================= - -func TestSlack_UsersList(t *testing.T) { - skipIfMissingSlackCreds(t) - - tests := []struct { - name string - args []string - contains []string - }{ - { - name: "list users (subcommand)", - args: []string{"slack", "users", "list", "--limit", "5"}, - contains: []string{}, // Just verify it runs - }, - { - name: "list users with limit", - args: []string{"slack", "users", "list", "--limit", "5"}, - contains: []string{}, - }, - { - name: "list users with IDs", - args: []string{"slack", "users", "list", "--id", "--limit", "5"}, - contains: []string{"[U"}, // User IDs start with U - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - stdout, stderr, err := runSlackCLI(t, tt.args...) - - if err != nil { - if strings.Contains(stderr, "not authenticated") { - t.Skip("Not authenticated with Slack") - } - t.Fatalf("Command failed: %v\nstderr: %s", err, stderr) - } - - for _, expected := range tt.contains { - if !strings.Contains(stdout, expected) { - t.Errorf("Expected output to contain %q\nGot: %s", expected, stdout) - } - } - - t.Logf("Output:\n%s", stdout) - }) - } -} - // ============================================================================= // SLACK SEARCH TESTS // ============================================================================= @@ -203,58 +152,3 @@ func TestSlack_SendMessage(t *testing.T) { t.Logf("Send output:\n%s", stdout) } - -// ============================================================================= -// SLACK WORKFLOW TEST -// ============================================================================= - -func TestSlack_Workflow(t *testing.T) { - skipIfMissingSlackCreds(t) - - // Test a typical workflow: auth status -> list channels -> read messages - - t.Run("auth_status", func(t *testing.T) { - stdout, stderr, err := runSlackCLI(t, "slack", "auth", "status") - if err != nil { - if strings.Contains(stderr, "not authenticated") { - t.Skip("Not authenticated with Slack") - } - t.Fatalf("Auth status failed: %v", err) - } - t.Logf("Auth: %s", strings.TrimSpace(stdout)) - }) - - t.Run("list_channels", func(t *testing.T) { - stdout, stderr, err := runSlackCLI(t, "slack", "channels", "list", "--limit", "5") - if err != nil { - t.Fatalf("List channels failed: %v\nstderr: %s", err, stderr) - } - - // Verify test channel exists - if !strings.Contains(stdout, slackUserChannel) { - t.Logf("Warning: Test channel %s not found in first 5 channels", slackUserChannel) - } - t.Logf("Channels: %d lines", len(strings.Split(stdout, "\n"))) - }) - - t.Run("read_messages", func(t *testing.T) { - stdout, stderr, err := runSlackCLI(t, "slack", "messages", "list", "--channel-id", slackUserChannel, "--limit", "3") - if err != nil { - if strings.Contains(stderr, "channel not found") { - t.Skipf("Channel %s not found", slackUserChannel) - } - t.Fatalf("Read messages failed: %v\nstderr: %s", err, stderr) - } - - lines := strings.Split(strings.TrimSpace(stdout), "\n") - t.Logf("Messages: %d lines of output", len(lines)) - }) - - t.Run("list_users", func(t *testing.T) { - stdout, stderr, err := runSlackCLI(t, "slack", "users", "list", "--limit", "5") - if err != nil { - t.Fatalf("List users failed: %v\nstderr: %s", err, stderr) - } - t.Logf("Users: %d lines", len(strings.Split(stdout, "\n"))) - }) -} diff --git a/internal/cli/root.go b/internal/cli/root.go index b8093c7..d671668 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -8,12 +8,13 @@ import ( "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/cli/setup" + "github.com/nylas/cli/internal/version" ) var rootCmd = &cobra.Command{ Use: "nylas", Short: "Nylas CLI - Email, calendar, and contacts from your terminal", - Version: Version, + Version: version.Version, Long: `Quick start: nylas init Guided setup (first time) nylas email list List recent emails diff --git a/internal/cli/update/github.go b/internal/cli/update/github.go index 2afce82..58ab3d7 100644 --- a/internal/cli/update/github.go +++ b/internal/cli/update/github.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/httputil" ) const ( @@ -17,9 +17,6 @@ const ( // GitHub API endpoints releasesAPIURL = "https://api.github.com/repos/%s/%s/releases/latest" - - // HTTP client timeout - httpTimeout = 30 * time.Second ) // Release represents a GitHub release. @@ -51,9 +48,8 @@ func getLatestRelease(ctx context.Context) (*Release, error) { } req.Header.Set("Accept", "application/vnd.github.v3+json") - req.Header.Set("User-Agent", "nylas-cli") - client := &http.Client{Timeout: httpTimeout} + client := httputil.DefaultClient resp, err := client.Do(req) if err != nil { return nil, common.WrapFetchError("release", err) diff --git a/internal/cli/update/installer.go b/internal/cli/update/installer.go index f067103..0e29c62 100644 --- a/internal/cli/update/installer.go +++ b/internal/cli/update/installer.go @@ -18,6 +18,7 @@ import ( "strings" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/httputil" ) const binaryName = "nylas" @@ -53,9 +54,7 @@ func downloadFile(ctx context.Context, url string) (string, error) { return "", common.WrapCreateError("request", err) } - req.Header.Set("User-Agent", "nylas-cli") - - client := &http.Client{Timeout: httpTimeout} + client := httputil.DefaultClient resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("download: %w", err) @@ -92,9 +91,7 @@ func downloadChecksums(ctx context.Context, url string) (map[string]string, erro return nil, common.WrapCreateError("request", err) } - req.Header.Set("User-Agent", "nylas-cli") - - client := &http.Client{Timeout: httpTimeout} + client := httputil.DefaultClient resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("download checksums: %w", err) diff --git a/internal/cli/update/update.go b/internal/cli/update/update.go index 44909be..22484c0 100644 --- a/internal/cli/update/update.go +++ b/internal/cli/update/update.go @@ -8,9 +8,9 @@ import ( "github.com/spf13/cobra" - "github.com/nylas/cli/internal/cli" "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/version" ) // NewUpdateCmd creates the update command. @@ -53,7 +53,7 @@ This command will: } func runUpdate(ctx context.Context, checkOnly, force, yes bool) error { - currentVersion := cli.Version + currentVersion := version.Version fmt.Printf("Current version: %s\n", currentVersion) diff --git a/internal/cli/version.go b/internal/cli/version.go index fb43ac2..aa0f898 100644 --- a/internal/cli/version.go +++ b/internal/cli/version.go @@ -8,13 +8,6 @@ import ( "github.com/spf13/cobra" ) -// Version aliases for backwards compatibility -var ( - Version = version.Version - Commit = version.Commit - BuildDate = version.BuildDate -) - func newVersionCmd() *cobra.Command { return &cobra.Command{ Use: "version", diff --git a/internal/cli/webhook/list.go b/internal/cli/webhook/list.go index 4616454..e1dd65d 100644 --- a/internal/cli/webhook/list.go +++ b/internal/cli/webhook/list.go @@ -11,7 +11,6 @@ import ( "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" "github.com/spf13/cobra" - "gopkg.in/yaml.v3" ) func newListCmd() *cobra.Command { @@ -71,14 +70,6 @@ Shows webhook ID, description, URL, status, and trigger types.`, return cmd } -func outputJSON(webhooks []domain.Webhook) error { - return common.PrintJSON(webhooks) -} - -func outputYAML(webhooks []domain.Webhook) error { - return yaml.NewEncoder(os.Stdout).Encode(webhooks) -} - func outputCSV(webhooks []domain.Webhook) error { w := csv.NewWriter(os.Stdout) defer w.Flush() diff --git a/internal/cli/webhook/render_test.go b/internal/cli/webhook/render_test.go index e8fd676..fdfc12b 100644 --- a/internal/cli/webhook/render_test.go +++ b/internal/cli/webhook/render_test.go @@ -1,7 +1,6 @@ package webhook import ( - "encoding/json" "io" "os" "testing" @@ -99,29 +98,6 @@ func TestOutputHelpers(t *testing.T) { }, } - t.Run("outputJSON", func(t *testing.T) { - output := captureStdout(t, func() { - err := outputJSON(webhooks) - require.NoError(t, err) - }) - - var decoded []domain.Webhook - err := json.Unmarshal([]byte(output), &decoded) - require.NoError(t, err) - require.Len(t, decoded, 1) - assert.Equal(t, "webhook-123", decoded[0].ID) - }) - - t.Run("outputYAML", func(t *testing.T) { - output := captureStdout(t, func() { - err := outputYAML(webhooks) - require.NoError(t, err) - }) - - assert.Contains(t, output, "id: webhook-123") - assert.Contains(t, output, "status: active") - }) - t.Run("outputCSV", func(t *testing.T) { output := captureStdout(t, func() { err := outputCSV(webhooks) diff --git a/internal/domain/analytics.go b/internal/domain/analytics.go index e40fa8f..0f2ce7b 100644 --- a/internal/domain/analytics.go +++ b/internal/domain/analytics.go @@ -119,13 +119,6 @@ type FocusTimeBlock struct { Conflicts int `json:"conflicts"` // Number of meetings that would conflict } -// PatternStore defines the interface for storing learned patterns. -type PatternStore interface { - SavePattern(pattern *MeetingPattern) error - LoadPattern(userEmail string) (*MeetingPattern, error) - DeletePattern(userEmail string) error -} - // ============================================================================ // Conflict Detection & Resolution // ============================================================================ diff --git a/internal/domain/config.go b/internal/domain/config.go index 3949712..be47b0a 100644 --- a/internal/domain/config.go +++ b/internal/domain/config.go @@ -11,14 +11,6 @@ const ( // TimeoutAPI is the default timeout for Nylas API calls (90s). TimeoutAPI = 90 * time.Second - // TimeoutAI is the timeout for AI/LLM operations (120s). - // AI providers may take longer due to model inference time. - TimeoutAI = 120 * time.Second - - // TimeoutMCP is the timeout for MCP proxy operations (90s). - // Allows time for tool execution and response processing. - TimeoutMCP = 90 * time.Second - // TimeoutHealthCheck is the timeout for health/connectivity checks (10s). TimeoutHealthCheck = 10 * time.Second @@ -37,10 +29,8 @@ const ( // TimeoutQuickCheck is the timeout for quick checks like version checking (5s). TimeoutQuickCheck = 5 * time.Second - // HTTP Server timeouts + // HTTP server timeouts (used by httputil.NewServer) HTTPReadHeaderTimeout = 10 * time.Second // Time to read request headers - HTTPReadTimeout = 30 * time.Second // Time to read entire request - HTTPWriteTimeout = 30 * time.Second // Time to write response HTTPIdleTimeout = 120 * time.Second // Keep-alive connection idle timeout ) diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 97eeefe..9a53659 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -2,35 +2,8 @@ package domain import ( "fmt" - "time" ) -// ============================================================================= -// Shared Interfaces -// ============================================================================= - -// Validator is implemented by types that can validate themselves. -type Validator interface { - Validate() error -} - -// Paginated is implemented by all paginated response types. -type Paginated interface { - GetPagination() Pagination - HasMore() bool -} - -// Resource is implemented by all domain resources with ID. -type Resource interface { - GetID() string -} - -// Timestamped is implemented by resources with creation/update timestamps. -type Timestamped interface { - GetCreatedAt() time.Time - GetUpdatedAt() time.Time -} - // ============================================================================= // Person Type (Base for EmailParticipant and Participant) // ============================================================================= @@ -57,40 +30,3 @@ func (p Person) DisplayName() string { } return p.Email } - -// ============================================================================= -// List Response Helpers -// ============================================================================= - -// FilterFunc is a predicate function for filtering list items. -type FilterFunc[T any] func(T) bool - -// MapFunc transforms an item of type T to type R. -type MapFunc[T, R any] func(T) R - -// Filter returns items matching the predicate. -func Filter[T any](items []T, predicate FilterFunc[T]) []T { - var result []T - for _, item := range items { - if predicate(item) { - result = append(result, item) - } - } - return result -} - -// Map transforms a slice using the provided function. -func Map[T, R any](items []T, fn MapFunc[T, R]) []R { - result := make([]R, len(items)) - for i, item := range items { - result[i] = fn(item) - } - return result -} - -// Each iterates over items and calls the function for each. -func Each[T any](items []T, fn func(T)) { - for _, item := range items { - fn(item) - } -} diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go index 5e19c3b..1686f07 100644 --- a/internal/httputil/httputil.go +++ b/internal/httputil/httputil.go @@ -6,12 +6,75 @@ import ( "io" "log/slog" "net/http" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/version" ) // MaxRequestBodySize is the maximum allowed request body size (1MB). // This prevents memory exhaustion attacks via large payloads. const MaxRequestBodySize = 1 << 20 +// DefaultClientTimeout is the standard timeout for outbound HTTP clients. +const DefaultClientTimeout = 120 * time.Second + +// userAgentTransport sets the standard CLI User-Agent on every outbound +// request that doesn't already specify one, so all clients built via +// NewClient identify consistently. +type userAgentTransport struct { + base http.RoundTripper +} + +func (t userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Header.Get("User-Agent") != "" { + return t.base.RoundTrip(req) + } + // RoundTrippers must not mutate the caller's request (net/http contract) — + // a retried or shared request would race. Clone before setting the header. + clone := req.Clone(req.Context()) + clone.Header.Set("User-Agent", version.UserAgent()) + return t.base.RoundTrip(clone) +} + +// NewClient returns an outbound HTTP client with the given timeout and the +// standard CLI User-Agent. A zero (or negative) timeout uses DefaultClientTimeout. +// +// Most callers should use the shared DefaultClient instead of minting their +// own; NewClient exists for the rare case that needs a non-default timeout. +func NewClient(timeout time.Duration) *http.Client { + if timeout <= 0 { + timeout = DefaultClientTimeout + } + return &http.Client{ + Timeout: timeout, + Transport: userAgentTransport{base: http.DefaultTransport}, + } +} + +// DefaultClient is the shared outbound HTTP client (120s timeout + standard +// User-Agent). http.Client is safe for concurrent use, so every caller that +// wants the default policy should reuse this single instance rather than +// constructing its own. Do not mutate it; build a dedicated client for +// special behavior (e.g. disabled redirects). +var DefaultClient = NewClient(DefaultClientTimeout) + +// NewServer returns an *http.Server hardened with the standard CLI defaults: +// a 10s header-read timeout, 120s idle timeout, and a 1MB max header size. +// +// writeTimeout is per-caller because streaming endpoints (e.g. SSE) need a +// long or disabled (0) write deadline; pass 0 for no write timeout. +func NewServer(addr string, handler http.Handler, writeTimeout time.Duration) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: domain.HTTPReadHeaderTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: domain.HTTPIdleTimeout, + MaxHeaderBytes: 1 << 20, // 1 MB + } +} + // LimitedBody wraps a request body with a size limit. // Returns a ReadCloser that will return an error if the body exceeds maxBytes. func LimitedBody(w http.ResponseWriter, r *http.Request, maxBytes int64) io.ReadCloser { diff --git a/internal/ports/templates.go b/internal/ports/templates.go deleted file mode 100644 index 9ea7e79..0000000 --- a/internal/ports/templates.go +++ /dev/null @@ -1,34 +0,0 @@ -// Package ports defines the interfaces for external dependencies. -package ports - -import ( - "context" - - "github.com/nylas/cli/internal/domain" -) - -// TemplateStore defines the interface for email template storage. -// Templates are stored locally (not in Nylas API) for use with email sending. -type TemplateStore interface { - // List returns all templates, optionally filtered by category. - // If category is empty, all templates are returned. - List(ctx context.Context, category string) ([]domain.EmailTemplate, error) - - // Get retrieves a template by its ID. - Get(ctx context.Context, id string) (*domain.EmailTemplate, error) - - // Create creates a new template and returns it with generated ID. - Create(ctx context.Context, t *domain.EmailTemplate) (*domain.EmailTemplate, error) - - // Update updates an existing template. - Update(ctx context.Context, t *domain.EmailTemplate) (*domain.EmailTemplate, error) - - // Delete removes a template by its ID. - Delete(ctx context.Context, id string) error - - // IncrementUsage increments the usage count for a template. - IncrementUsage(ctx context.Context, id string) error - - // Path returns the path to the templates file. - Path() string -} diff --git a/internal/ports/utilities.go b/internal/ports/utilities.go deleted file mode 100644 index 30f3ec1..0000000 --- a/internal/ports/utilities.go +++ /dev/null @@ -1,33 +0,0 @@ -package ports - -import ( - "context" - "time" - - "github.com/nylas/cli/internal/domain" -) - -// UtilityServices defines interfaces for non-Nylas utility features. -// These services provide offline-capable tools that don't require Nylas API access. -type UtilityServices interface { - TimeZoneService -} - -// TimeZoneService provides time zone conversion and meeting finder utilities. -// Addresses the pain point where 83% of professionals struggle with time zone scheduling. -type TimeZoneService interface { - // ConvertTime converts a time from one zone to another - ConvertTime(ctx context.Context, fromZone, toZone string, t time.Time) (time.Time, error) - - // FindMeetingTime finds overlapping working hours across multiple time zones - FindMeetingTime(ctx context.Context, req *domain.MeetingFinderRequest) (*domain.MeetingTimeSlots, error) - - // GetDSTTransitions returns DST transition dates for a zone in a given year - GetDSTTransitions(ctx context.Context, zone string, year int) ([]domain.DSTTransition, error) - - // ListTimeZones returns all available IANA time zones - ListTimeZones(ctx context.Context) ([]string, error) - - // GetTimeZoneInfo returns detailed information about a time zone - GetTimeZoneInfo(ctx context.Context, zone string, at time.Time) (*domain.TimeZoneInfo, error) -} diff --git a/internal/ports/webhook_server.go b/internal/ports/webhook_server.go index 3895e89..18f28a3 100644 --- a/internal/ports/webhook_server.go +++ b/internal/ports/webhook_server.go @@ -53,30 +53,6 @@ type WebhookServerStats struct { // WebhookEventHandler is called when a webhook event is received. type WebhookEventHandler func(event *WebhookEvent) -// WebhookServer defines the interface for a local webhook receiver server. -type WebhookServer interface { - // Start starts the webhook server. - Start(ctx context.Context) error - - // Stop stops the webhook server. - Stop() error - - // GetLocalURL returns the local server URL. - GetLocalURL() string - - // GetPublicURL returns the public URL (from tunnel, if any). - GetPublicURL() string - - // GetStats returns server statistics. - GetStats() WebhookServerStats - - // OnEvent registers a handler for webhook events. - OnEvent(handler WebhookEventHandler) - - // Events returns a channel for receiving webhook events. - Events() <-chan *WebhookEvent -} - // TunnelConfig holds configuration for a tunnel. type TunnelConfig struct { Provider string // cloudflared or ngrok diff --git a/internal/studio/server.go b/internal/studio/server.go index e5db324..52467de 100644 --- a/internal/studio/server.go +++ b/internal/studio/server.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" "github.com/nylas/cli/internal/webguard" ) @@ -46,11 +47,7 @@ func (s *Server) Start(ctx context.Context) error { webguard.OriginProtectionMiddleware( webguard.SecurityHeadersMiddleware(mux))) - s.httpServer = &http.Server{ - Addr: s.addr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - } + s.httpServer = httputil.NewServer(s.addr, handler, 0) errCh := make(chan error, 1) go func() { diff --git a/internal/testutil/context.go b/internal/testutil/context.go deleted file mode 100644 index da5e459..0000000 --- a/internal/testutil/context.go +++ /dev/null @@ -1,35 +0,0 @@ -// Package testutil provides common test utilities and helpers for the Nylas CLI. -package testutil - -import ( - "context" - "testing" - "time" -) - -// TestContext creates a 30-second timeout context with automatic cleanup. -// This is the standard context for most tests that interact with the API. -func TestContext(t *testing.T) context.Context { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - t.Cleanup(cancel) - return ctx -} - -// LongTestContext creates a 120-second timeout context with automatic cleanup. -// Use this for integration tests that may take longer (e.g., multiple API calls). -func LongTestContext(t *testing.T) context.Context { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - t.Cleanup(cancel) - return ctx -} - -// QuickTestContext creates a 5-second timeout context with automatic cleanup. -// Use this for unit tests that should complete quickly (e.g., no network calls). -func QuickTestContext(t *testing.T) context.Context { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(cancel) - return ctx -} diff --git a/internal/testutil/context_test.go b/internal/testutil/context_test.go deleted file mode 100644 index c2b1943..0000000 --- a/internal/testutil/context_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package testutil_test - -import ( - "testing" - "time" - - "github.com/nylas/cli/internal/testutil" -) - -func TestTestContext(t *testing.T) { - ctx := testutil.TestContext(t) - if ctx == nil { - t.Fatal("Expected non-nil context") - } - - deadline, ok := ctx.Deadline() - if !ok { - t.Fatal("Expected context to have deadline") - } - - // Should be approximately 30 seconds from now - actualDuration := time.Until(deadline) - if actualDuration < 29*time.Second || actualDuration > 31*time.Second { - t.Errorf("Expected deadline ~30s from now, got %v", actualDuration) - } -} - -func TestLongTestContext(t *testing.T) { - ctx := testutil.LongTestContext(t) - if ctx == nil { - t.Fatal("Expected non-nil context") - } - - deadline, ok := ctx.Deadline() - if !ok { - t.Fatal("Expected context to have deadline") - } - - // Should be approximately 120 seconds from now - actualDuration := time.Until(deadline) - if actualDuration < 119*time.Second || actualDuration > 121*time.Second { - t.Errorf("Expected deadline ~120s from now, got %v", actualDuration) - } -} - -func TestQuickTestContext(t *testing.T) { - ctx := testutil.QuickTestContext(t) - if ctx == nil { - t.Fatal("Expected non-nil context") - } - - deadline, ok := ctx.Deadline() - if !ok { - t.Fatal("Expected context to have deadline") - } - - // Should be approximately 5 seconds from now - actualDuration := time.Until(deadline) - if actualDuration < 4*time.Second || actualDuration > 6*time.Second { - t.Errorf("Expected deadline ~5s from now, got %v", actualDuration) - } -} - -func TestContextCancellation(t *testing.T) { - ctx := testutil.TestContext(t) - - // Context should not be cancelled initially - select { - case <-ctx.Done(): - t.Fatal("Context should not be cancelled initially") - default: - // Expected - } - - // Context should be cancelled after test cleanup - // (This is implicitly tested by t.Cleanup when test finishes) -} - -func TestPointerHelpers(t *testing.T) { - // Test StringPtr - str := "test" - strPtr := testutil.StringPtr(str) - if strPtr == nil { - t.Fatal("StringPtr returned nil") - return - } - if *strPtr != str { - t.Errorf("Expected %q, got %q", str, *strPtr) - } - - // Test BoolPtr - boolVal := true - boolPtr := testutil.BoolPtr(boolVal) - if boolPtr == nil { - t.Fatal("BoolPtr returned nil") - return - } - if *boolPtr != boolVal { - t.Errorf("Expected %v, got %v", boolVal, *boolPtr) - } - - // Test IntPtr - intVal := 42 - intPtr := testutil.IntPtr(intVal) - if intPtr == nil { - t.Fatal("IntPtr returned nil") - return - } - if *intPtr != intVal { - t.Errorf("Expected %d, got %d", intVal, *intPtr) - } -} - -func TestRequireEnv(t *testing.T) { - // Test with existing env var - t.Setenv("TEST_VAR", "test_value") - value := testutil.RequireEnv(t, "TEST_VAR") - if value != "test_value" { - t.Errorf("Expected 'test_value', got %q", value) - } - - // Note: Testing the skip behavior is complex with the standard testing package - // The skip functionality is tested implicitly when RequireEnv is used in real tests -} diff --git a/internal/testutil/helpers.go b/internal/testutil/helpers.go deleted file mode 100644 index 63cf1e5..0000000 --- a/internal/testutil/helpers.go +++ /dev/null @@ -1,305 +0,0 @@ -// Package testutil provides common test utilities and helpers to reduce duplication -// across test files. -package testutil - -import ( - "bytes" - "encoding/json" - "net/http" - "os" - "path/filepath" - "testing" -) - -// StringPtr returns a pointer to the given string value. -// This is useful for test data that requires *string fields. -func StringPtr(s string) *string { - return &s -} - -// BoolPtr returns a pointer to the given boolean value. -// This is useful for test data that requires *bool fields. -func BoolPtr(b bool) *bool { - return &b -} - -// IntPtr returns a pointer to the given integer value. -// This is useful for test data that requires *int fields. -func IntPtr(i int) *int { - return &i -} - -// RequireEnv retrieves an environment variable and skips the test if not set. -// This is useful for integration tests that require specific configuration. -func RequireEnv(t *testing.T, key string) string { - t.Helper() - value := os.Getenv(key) - if value == "" { - t.Skipf("Environment variable %s not set, skipping test", key) - } - return value -} - -// TempConfig creates a temporary config file with the given content for testing. -// The file is automatically cleaned up when the test completes. -// -// Example: -// -// configPath := testutil.TempConfig(t, `region: "us"\ncallback_port: 8080`) -// store := config.NewFileStore(configPath) -func TempConfig(t *testing.T, content string) string { - t.Helper() - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - if err := os.WriteFile(configPath, []byte(content), 0600); err != nil { - t.Fatalf("Failed to create temp config: %v", err) - } - - return configPath -} - -// TempDir creates a temporary directory for testing. -// The directory is automatically cleaned up when the test completes. -// -// Note: This is just a wrapper around t.TempDir() for consistency. -func TempDir(t *testing.T) string { - t.Helper() - return t.TempDir() -} - -// TempFile creates a temporary file with the given content for testing. -// The file is automatically cleaned up when the test completes. -// -// Example: -// -// filePath := testutil.TempFile(t, "test.txt", "file contents") -func TempFile(t *testing.T, name, content string) string { - t.Helper() - - tmpDir := t.TempDir() - filePath := filepath.Join(tmpDir, name) - - if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - - return filePath -} - -// AssertNoError fails the test if err is not nil. -// -// Example: -// -// testutil.AssertNoError(t, err, "failed to create user") -func AssertNoError(t *testing.T, err error, msg string) { - t.Helper() - - if err != nil { - t.Fatalf("%s: %v", msg, err) - } -} - -// AssertError fails the test if err is nil. -// -// Example: -// -// testutil.AssertError(t, err, "expected error for invalid input") -func AssertError(t *testing.T, err error, msg string) { - t.Helper() - - if err == nil { - t.Fatalf("%s: expected error, got nil", msg) - } -} - -// AssertEqual fails the test if got != want. -// -// Example: -// -// testutil.AssertEqual(t, user.Name, "John", "user name") -func AssertEqual[T comparable](t *testing.T, got, want T, msg string) { - t.Helper() - - if got != want { - t.Errorf("%s: got %v, want %v", msg, got, want) - } -} - -// AssertNotEqual fails the test if got == want. -// -// Example: -// -// testutil.AssertNotEqual(t, user.ID, "", "user ID should not be empty") -func AssertNotEqual[T comparable](t *testing.T, got, unwanted T, msg string) { - t.Helper() - - if got == unwanted { - t.Errorf("%s: got %v, should not equal %v", msg, got, unwanted) - } -} - -// AssertContains fails the test if haystack does not contain needle. -// -// Example: -// -// testutil.AssertContains(t, output, "Success", "output should contain success message") -func AssertContains(t *testing.T, haystack, needle, msg string) { - t.Helper() - - if !contains(haystack, needle) { - t.Errorf("%s: %q does not contain %q", msg, haystack, needle) - } -} - -// AssertNotContains fails the test if haystack contains needle. -// -// Example: -// -// testutil.AssertNotContains(t, output, "Error", "output should not contain error") -func AssertNotContains(t *testing.T, haystack, needle, msg string) { - t.Helper() - - if contains(haystack, needle) { - t.Errorf("%s: %q should not contain %q", msg, haystack, needle) - } -} - -// AssertNil fails the test if value is not nil. -// -// Example: -// -// testutil.AssertNil(t, err, "error should be nil") -func AssertNil(t *testing.T, value any, msg string) { - t.Helper() - - if value != nil { - t.Errorf("%s: expected nil, got %v", msg, value) - } -} - -// AssertNotNil fails the test if value is nil. -// -// Example: -// -// testutil.AssertNotNil(t, user, "user should not be nil") -func AssertNotNil(t *testing.T, value any, msg string) { - t.Helper() - - if value == nil { - t.Errorf("%s: expected non-nil value, got nil", msg) - } -} - -// AssertTrue fails the test if condition is false. -// -// Example: -// -// testutil.AssertTrue(t, user.IsActive, "user should be active") -func AssertTrue(t *testing.T, condition bool, msg string) { - t.Helper() - - if !condition { - t.Errorf("%s: expected true, got false", msg) - } -} - -// AssertFalse fails the test if condition is true. -// -// Example: -// -// testutil.AssertFalse(t, user.IsDeleted, "user should not be deleted") -func AssertFalse(t *testing.T, condition bool, msg string) { - t.Helper() - - if condition { - t.Errorf("%s: expected false, got true", msg) - } -} - -// AssertLen fails the test if the length of slice is not equal to expected. -// -// Example: -// -// testutil.AssertLen(t, users, 5, "should have 5 users") -func AssertLen[T any](t *testing.T, slice []T, expected int, msg string) { - t.Helper() - - if len(slice) != expected { - t.Errorf("%s: got length %d, want %d", msg, len(slice), expected) - } -} - -// AssertEmpty fails the test if slice is not empty. -// -// Example: -// -// testutil.AssertEmpty(t, errors, "should have no errors") -func AssertEmpty[T any](t *testing.T, slice []T, msg string) { - t.Helper() - - if len(slice) != 0 { - t.Errorf("%s: expected empty slice, got length %d", msg, len(slice)) - } -} - -// AssertNotEmpty fails the test if slice is empty. -// -// Example: -// -// testutil.AssertNotEmpty(t, results, "should have results") -func AssertNotEmpty[T any](t *testing.T, slice []T, msg string) { - t.Helper() - - if len(slice) == 0 { - t.Errorf("%s: expected non-empty slice, got empty slice", msg) - } -} - -// WriteJSONResponse writes a JSON response to an http.ResponseWriter in tests. -// This eliminates the duplicate pattern of setting Content-Type, writing status, -// and encoding JSON that appears 180+ times across test files. -// -// Safe to call from httptest.NewServer handler goroutines — on encode failure it -// writes a 500 error response instead of calling t.Fatalf (which is unsafe from -// non-test goroutines). -// -// Example: -// -// server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// testutil.WriteJSONResponse(t, w, http.StatusOK, map[string]string{"id": "123"}) -// })) -func WriteJSONResponse(t *testing.T, w http.ResponseWriter, statusCode int, data any) { - t.Helper() - - // Encode into a buffer first so that on failure we can return 500 - // instead of a success status with a broken body. Writing headers - // is deferred until we know encoding succeeded. - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(data); err != nil { - // t.Errorf is safe from any goroutine (unlike t.Fatalf). - t.Errorf("WriteJSONResponse: failed to encode JSON: %v", err) - http.Error(w, "WriteJSONResponse: encode failed", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - _, _ = w.Write(buf.Bytes()) -} - -// Helper function to check if string contains substring -func contains(s, substr string) bool { - if len(substr) == 0 { - return true - } - - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - - return false -} diff --git a/internal/testutil/helpers_test.go b/internal/testutil/helpers_test.go deleted file mode 100644 index d6a098c..0000000 --- a/internal/testutil/helpers_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package testutil - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestTempConfig(t *testing.T) { - content := "test: value" - path := TempConfig(t, content) - - // #nosec G304 -- reading test file created by test helper - data, err := os.ReadFile(path) - if err != nil { - t.Fatalf("Failed to read temp config: %v", err) - } - - if string(data) != content { - t.Errorf("Content = %q, want %q", string(data), content) - } -} - -func TestTempFile(t *testing.T) { - content := "test content" - path := TempFile(t, "test.txt", content) - - // #nosec G304 -- reading test file created by test helper - data, err := os.ReadFile(path) - if err != nil { - t.Fatalf("Failed to read temp file: %v", err) - } - - if string(data) != content { - t.Errorf("Content = %q, want %q", string(data), content) - } -} - -func TestAssertEqual(t *testing.T) { - // Test successful assertion (should not fail) - AssertEqual(t, "hello", "hello", "strings should be equal") - AssertEqual(t, 42, 42, "numbers should be equal") - AssertEqual(t, true, true, "booleans should be equal") -} - -func TestAssertContains(t *testing.T) { - AssertContains(t, "hello world", "world", "should contain substring") -} - -func TestContainsHelper(t *testing.T) { - tests := []struct { - name string - s string - substr string - want bool - }{ - {"empty substring", "hello", "", true}, - {"contains", "hello world", "world", true}, - {"does not contain", "hello", "xyz", false}, - {"exact match", "test", "test", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := contains(tt.s, tt.substr) - if got != tt.want { - t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want) - } - }) - } -} - -func TestWriteJSONResponse(t *testing.T) { - t.Run("writes status and JSON body", func(t *testing.T) { - data := map[string]string{"id": "123", "name": "test"} - - rec := httptest.NewRecorder() - WriteJSONResponse(t, rec, http.StatusOK, data) - - resp := rec.Result() - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) - } - - if ct := resp.Header.Get("Content-Type"); ct != "application/json" { - t.Errorf("Content-Type = %q, want %q", ct, "application/json") - } - - var got map[string]string - if err := json.NewDecoder(resp.Body).Decode(&got); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if got["id"] != "123" { - t.Errorf("id = %q, want %q", got["id"], "123") - } - if got["name"] != "test" { - t.Errorf("name = %q, want %q", got["name"], "test") - } - }) - - t.Run("writes custom status codes", func(t *testing.T) { - rec := httptest.NewRecorder() - WriteJSONResponse(t, rec, http.StatusCreated, map[string]bool{"ok": true}) - - if rec.Code != http.StatusCreated { - t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) - } - }) - - t.Run("writes array data", func(t *testing.T) { - items := []string{"a", "b", "c"} - rec := httptest.NewRecorder() - WriteJSONResponse(t, rec, http.StatusOK, items) - - var got []string - if err := json.NewDecoder(rec.Body).Decode(&got); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if len(got) != 3 { - t.Errorf("len = %d, want 3", len(got)) - } - }) - - t.Run("integrates with httptest server", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - WriteJSONResponse(t, w, http.StatusOK, map[string]string{"status": "ok"}) - })) - defer server.Close() - - resp, err := http.Get(server.URL) //nolint:gosec // test URL - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer func() { _ = resp.Body.Close() }() - - var result map[string]string - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Fatalf("decode failed: %v", err) - } - if result["status"] != "ok" { - t.Errorf("status = %q, want %q", result["status"], "ok") - } - }) - - t.Run("returns 500 on encode failure", func(t *testing.T) { - // json.Marshal cannot encode channels — use this to trigger failure. - unencodable := make(chan int) - - // Use a fake *testing.T so we can observe the Errorf call without - // failing the real test. - fakeT := &testing.T{} - - rec := httptest.NewRecorder() - WriteJSONResponse(fakeT, rec, http.StatusOK, unencodable) - - if rec.Code != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) - } - if !fakeT.Failed() { - t.Error("expected fakeT to be marked as failed") - } - }) -} diff --git a/internal/tui/formatting_helpers.go b/internal/tui/formatting_helpers.go index 093d3d5..db2da83 100644 --- a/internal/tui/formatting_helpers.go +++ b/internal/tui/formatting_helpers.go @@ -1,8 +1,6 @@ package tui import ( - "html" - "strings" "time" "github.com/nylas/cli/internal/cli/common" @@ -27,80 +25,5 @@ func formatFileSize(size int64) string { // stripHTMLForTUI removes HTML tags from a string for terminal display. func stripHTMLForTUI(s string) string { - // Remove style and script tags and their contents. - s = removeTagWithContent(s, "style") - s = removeTagWithContent(s, "script") - s = removeTagWithContent(s, "head") - - // Replace block-level elements with newlines before stripping tags. - blockTags := []string{"br", "p", "div", "tr", "li", "h1", "h2", "h3", "h4", "h5", "h6"} - for _, tag := range blockTags { - s = strings.ReplaceAll(s, "<"+tag+">", "\n") - s = strings.ReplaceAll(s, "<"+tag+"/>", "\n") - s = strings.ReplaceAll(s, "<"+tag+" />", "\n") - s = strings.ReplaceAll(s, "", "\n") - s = strings.ReplaceAll(s, "<"+strings.ToUpper(tag)+">", "\n") - s = strings.ReplaceAll(s, "", "\n") - } - - // Strip remaining HTML tags. - var result strings.Builder - inTag := false - for _, r := range s { - switch { - case r == '<': - inTag = true - case r == '>': - inTag = false - case !inTag: - result.WriteRune(r) - } - } - - // Decode HTML entities. - text := html.UnescapeString(result.String()) - - // Clean up whitespace. - text = strings.ReplaceAll(text, "\r\n", "\n") - text = strings.ReplaceAll(text, "\r", "\n") - - for strings.Contains(text, " ") { - text = strings.ReplaceAll(text, " ", " ") - } - - for strings.Contains(text, "\n\n\n") { - text = strings.ReplaceAll(text, "\n\n\n", "\n\n") - } - - lines := strings.Split(text, "\n") - for i, line := range lines { - lines[i] = strings.TrimSpace(line) - } - text = strings.Join(lines, "\n") - - return strings.TrimSpace(text) -} - -// removeTagWithContent removes an HTML tag and all its content. -func removeTagWithContent(s, tag string) string { - result := s - for { - lower := strings.ToLower(result) - startIdx := strings.Index(lower, "<"+tag) - if startIdx == -1 { - break - } - endTag := "" - endIdx := strings.Index(lower[startIdx:], endTag) - if endIdx == -1 { - closeIdx := strings.Index(result[startIdx:], ">") - if closeIdx == -1 { - break - } - result = result[:startIdx] + result[startIdx+closeIdx+1:] - } else { - result = result[:startIdx] + result[startIdx+endIdx+len(endTag):] - } - } - return result + return common.StripHTML(s) } diff --git a/internal/ui/server.go b/internal/ui/server.go index 1bbbda9..6e9f1aa 100644 --- a/internal/ui/server.go +++ b/internal/ui/server.go @@ -5,12 +5,12 @@ import ( "html/template" "io/fs" "net/http" - "time" "github.com/nylas/cli/internal/adapters/config" "github.com/nylas/cli/internal/adapters/keyring" authapp "github.com/nylas/cli/internal/app/auth" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/httputil" "github.com/nylas/cli/internal/ports" "github.com/nylas/cli/internal/webguard" ) @@ -101,11 +101,7 @@ func (s *Server) Start() error { webguard.OriginProtectionMiddleware( webguard.SecurityHeadersMiddleware(mux))) - server := &http.Server{ - Addr: s.addr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - } + server := httputil.NewServer(s.addr, handler, 0) return server.ListenAndServe() } diff --git a/internal/util/slices.go b/internal/util/slices.go index 42f3429..4776774 100644 --- a/internal/util/slices.go +++ b/internal/util/slices.go @@ -12,93 +12,3 @@ func Map[T, U any](s []T, f func(T) U) []U { } return result } - -// Filter returns a new slice containing only elements that satisfy the predicate. -// Returns nil if the input slice is nil. -func Filter[T any](s []T, keep func(T) bool) []T { - if s == nil { - return nil - } - result := make([]T, 0, len(s)) - for _, v := range s { - if keep(v) { - result = append(result, v) - } - } - return result -} - -// Reduce accumulates values using a reducer function. -// The reducer function takes an accumulator and the current element, -// and returns the new accumulator value. -func Reduce[T, U any](s []T, initial U, reduce func(U, T) U) U { - acc := initial - for _, v := range s { - acc = reduce(acc, v) - } - return acc -} - -// Contains checks if a slice contains a value. -// Uses == for comparison, so T must be comparable. -func Contains[T comparable](s []T, v T) bool { - for _, item := range s { - if item == v { - return true - } - } - return false -} - -// Partition splits a slice into two based on a predicate. -// The first slice contains elements where predicate returns true, -// the second contains elements where it returns false. -// Returns (nil, nil) if the input slice is nil. -func Partition[T any](s []T, predicate func(T) bool) ([]T, []T) { - if s == nil { - return nil, nil - } - trueSlice := make([]T, 0, len(s)/2) - falseSlice := make([]T, 0, len(s)/2) - for _, v := range s { - if predicate(v) { - trueSlice = append(trueSlice, v) - } else { - falseSlice = append(falseSlice, v) - } - } - return trueSlice, falseSlice -} - -// Find returns the first element that satisfies the predicate. -// Returns the zero value of T and false if no element is found. -func Find[T any](s []T, predicate func(T) bool) (T, bool) { - for _, v := range s { - if predicate(v) { - return v, true - } - } - var zero T - return zero, false -} - -// Any returns true if any element satisfies the predicate. -func Any[T any](s []T, predicate func(T) bool) bool { - for _, v := range s { - if predicate(v) { - return true - } - } - return false -} - -// All returns true if all elements satisfy the predicate. -// Returns true for empty slices. -func All[T any](s []T, predicate func(T) bool) bool { - for _, v := range s { - if !predicate(v) { - return false - } - } - return true -} diff --git a/internal/util/slices_test.go b/internal/util/slices_test.go index dbe964a..333a027 100644 --- a/internal/util/slices_test.go +++ b/internal/util/slices_test.go @@ -55,310 +55,6 @@ func TestMap_DifferentTypes(t *testing.T) { assert.Equal(t, []int{1, 2, 3}, result) } -func TestFilter(t *testing.T) { - tests := []struct { - name string - input []int - predicate func(int) bool - expected []int - }{ - { - name: "filter even numbers", - input: []int{1, 2, 3, 4, 5, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: []int{2, 4, 6}, - }, - { - name: "filter greater than 3", - input: []int{1, 2, 3, 4, 5}, - predicate: func(x int) bool { return x > 3 }, - expected: []int{4, 5}, - }, - { - name: "nil slice", - input: nil, - predicate: func(x int) bool { return x%2 == 0 }, - expected: nil, - }, - { - name: "empty slice", - input: []int{}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: []int{}, - }, - { - name: "no matches", - input: []int{1, 3, 5}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: []int{}, - }, - { - name: "all match", - input: []int{2, 4, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: []int{2, 4, 6}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := util.Filter(tt.input, tt.predicate) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestReduce(t *testing.T) { - tests := []struct { - name string - input []int - initial int - reduce func(int, int) int - expected int - }{ - { - name: "sum", - input: []int{1, 2, 3, 4, 5}, - initial: 0, - reduce: func(acc, x int) int { return acc + x }, - expected: 15, - }, - { - name: "product", - input: []int{1, 2, 3, 4}, - initial: 1, - reduce: func(acc, x int) int { return acc * x }, - expected: 24, - }, - { - name: "empty slice", - input: []int{}, - initial: 10, - reduce: func(acc, x int) int { return acc + x }, - expected: 10, - }, - { - name: "nil slice", - input: nil, - initial: 5, - reduce: func(acc, x int) int { return acc + x }, - expected: 5, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := util.Reduce(tt.input, tt.initial, tt.reduce) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestContains(t *testing.T) { - tests := []struct { - name string - slice []string - value string - expected bool - }{ - { - name: "contains value", - slice: []string{"foo", "bar", "baz"}, - value: "bar", - expected: true, - }, - { - name: "does not contain value", - slice: []string{"foo", "bar", "baz"}, - value: "qux", - expected: false, - }, - { - name: "empty slice", - slice: []string{}, - value: "foo", - expected: false, - }, - { - name: "nil slice", - slice: nil, - value: "foo", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := util.Contains(tt.slice, tt.value) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestPartition(t *testing.T) { - tests := []struct { - name string - input []int - predicate func(int) bool - expectedTrue []int - expectedFalse []int - }{ - { - name: "partition even and odd", - input: []int{1, 2, 3, 4, 5, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedTrue: []int{2, 4, 6}, - expectedFalse: []int{1, 3, 5}, - }, - { - name: "all true", - input: []int{2, 4, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedTrue: []int{2, 4, 6}, - expectedFalse: []int{}, - }, - { - name: "all false", - input: []int{1, 3, 5}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedTrue: []int{}, - expectedFalse: []int{1, 3, 5}, - }, - { - name: "empty slice", - input: []int{}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedTrue: []int{}, - expectedFalse: []int{}, - }, - { - name: "nil slice", - input: nil, - predicate: func(x int) bool { return x%2 == 0 }, - expectedTrue: nil, - expectedFalse: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - trueSlice, falseSlice := util.Partition(tt.input, tt.predicate) - assert.Equal(t, tt.expectedTrue, trueSlice) - assert.Equal(t, tt.expectedFalse, falseSlice) - }) - } -} - -func TestFind(t *testing.T) { - tests := []struct { - name string - input []int - predicate func(int) bool - expectedValue int - expectedFound bool - }{ - { - name: "find first even", - input: []int{1, 3, 4, 5, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedValue: 4, - expectedFound: true, - }, - { - name: "not found", - input: []int{1, 3, 5}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedValue: 0, - expectedFound: false, - }, - { - name: "empty slice", - input: []int{}, - predicate: func(x int) bool { return x%2 == 0 }, - expectedValue: 0, - expectedFound: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - value, found := util.Find(tt.input, tt.predicate) - assert.Equal(t, tt.expectedValue, value) - assert.Equal(t, tt.expectedFound, found) - }) - } -} - -func TestAny(t *testing.T) { - tests := []struct { - name string - input []int - predicate func(int) bool - expected bool - }{ - { - name: "has even number", - input: []int{1, 3, 4, 5}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: true, - }, - { - name: "no even numbers", - input: []int{1, 3, 5}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: false, - }, - { - name: "empty slice", - input: []int{}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := util.Any(tt.input, tt.predicate) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestAll(t *testing.T) { - tests := []struct { - name string - input []int - predicate func(int) bool - expected bool - }{ - { - name: "all even", - input: []int{2, 4, 6}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: true, - }, - { - name: "not all even", - input: []int{2, 3, 4}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: false, - }, - { - name: "empty slice", - input: []int{}, - predicate: func(x int) bool { return x%2 == 0 }, - expected: true, // vacuous truth - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := util.All(tt.input, tt.predicate) - assert.Equal(t, tt.expected, result) - }) - } -} - // Benchmark tests func BenchmarkMap(b *testing.B) { input := make([]int, 1000) @@ -371,15 +67,3 @@ func BenchmarkMap(b *testing.B) { _ = util.Map(input, func(x int) int { return x * 2 }) } } - -func BenchmarkFilter(b *testing.B) { - input := make([]int, 1000) - for i := range input { - input[i] = i - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = util.Filter(input, func(x int) bool { return x%2 == 0 }) - } -}