diff --git a/cmd/config.go b/cmd/config.go index e6bcd11..85be651 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -104,7 +104,10 @@ func detectProjectRoot() *project.Context { } // Store in config package for GetMemoryBasePath and other consumers - config.SetProjectContext(ctx) + if err := config.SetProjectContext(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to set project context: %v\n", err) + return nil + } // Log in verbose mode if viper.GetBool("verbose") && ctx.RootPath != cwd { diff --git a/cmd/hook.go b/cmd/hook.go index 480f9b3..f24f09e 100644 --- a/cmd/hook.go +++ b/cmd/hook.go @@ -436,22 +436,32 @@ Tasks Completed: %d `, session.SessionID, int(elapsed.Minutes()), session.TasksCompleted) // Remove session file - sessionPath := getHookSessionPath() - _ = os.Remove(sessionPath) + sessionPath, err := getHookSessionPath() + if err == nil { + _ = os.Remove(sessionPath) + } return nil } // Session persistence helpers -func getHookSessionPath() string { +func getHookSessionPath() (string, error) { // Hook commands use GetMemoryBasePathOrGlobal since they may run // before project context is fully established (e.g., SessionStart) - return filepath.Join(config.GetMemoryBasePathOrGlobal(), "hook_session.json") + memoryPath, err := config.GetMemoryBasePathOrGlobal() + if err != nil { + return "", fmt.Errorf("get memory path: %w", err) + } + return filepath.Join(memoryPath, "hook_session.json"), nil } func loadHookSession() (*HookSession, error) { - data, err := os.ReadFile(getHookSessionPath()) + sessionPath, err := getHookSessionPath() + if err != nil { + return nil, err + } + data, err := os.ReadFile(sessionPath) if err != nil { return nil, err } @@ -470,7 +480,10 @@ func saveHookSession(session *HookSession) error { return err } - sessionPath := getHookSessionPath() + sessionPath, err := getHookSessionPath() + if err != nil { + return err + } // Ensure directory exists if err := os.MkdirAll(filepath.Dir(sessionPath), 0755); err != nil { return fmt.Errorf("create session directory: %w", err) diff --git a/cmd/mcp_server.go b/cmd/mcp_server.go index af2c706..301ed61 100644 --- a/cmd/mcp_server.go +++ b/cmd/mcp_server.go @@ -97,7 +97,10 @@ func mcpFormattedErrorResponse(formattedError string) (*mcpsdk.CallToolResultFor func initMCPRepository() (*memory.Repository, error) { // MCP server is a special case - it may run in sandboxed environments // where project context isn't available. Use the fallback-enabled path. - memoryPath := config.GetMemoryBasePathOrGlobal() + memoryPath, err := config.GetMemoryBasePathOrGlobal() + if err != nil { + return nil, fmt.Errorf("determine memory path: %w", err) + } repo, err := memory.NewDefaultRepository(memoryPath) if err != nil { @@ -214,7 +217,13 @@ func runMCPServer(ctx context.Context) error { - next: Get next pending task from plan (use auto_start=true to claim immediately) - current: Get current in-progress task for session - start: Claim a specific task by ID -- complete: Mark task as completed with summary`, +- complete: Mark task as completed with summary + +REQUIRED FIELDS BY ACTION: +- next: session_id (required) +- current: session_id (required) +- start: task_id (required), session_id (required) +- complete: task_id (required)`, } mcpsdk.AddTool(server, taskTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.TaskToolParams]) (*mcpsdk.CallToolResultFor[any], error) { result, err := mcppresenter.HandleTaskTool(ctx, repo, params.Arguments) @@ -233,7 +242,12 @@ func runMCPServer(ctx context.Context) error { Description: `Unified plan creation tool. Use action parameter to select operation: - clarify: Refine goal with clarifying questions (loop until is_ready_to_plan=true) - generate: Create plan with tasks from enriched goal -- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures)`, +- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures) + +REQUIRED FIELDS BY ACTION: +- clarify: goal (required) +- generate: goal (required), enriched_goal (required) - call clarify first to get enriched_goal +- audit: none required (defaults to active plan)`, } mcpsdk.AddTool(server, planTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.PlanToolParams]) (*mcpsdk.CallToolResultFor[any], error) { result, err := mcppresenter.HandlePlanTool(ctx, repo, params.Arguments) diff --git a/cmd/plan.go b/cmd/plan.go index 4b104eb..58c4126 100644 --- a/cmd/plan.go +++ b/cmd/plan.go @@ -357,13 +357,12 @@ func printPlanTable(plans []task.Plan) { goal = goal[:57] + "..." } // Tasks count - service ListPlans probably returns plans without tasks or with? - // task.Repository interface implies ListPlans returns []Plan which contains Tasks? - // SQLite implementation usually does. If not, we might check tasks length. - // Assuming populated for now or length 0. + // ListPlans sets TaskCount but leaves Tasks nil for efficiency. + // Use GetTaskCount() to get the count regardless of how the plan was loaded. fmt.Printf("%-18s %-12s %-6d %s\n", idStyle.Render(p.ID), dateStyle.Render(p.CreatedAt.Format("2006-01-02")), - len(p.Tasks), + p.GetTaskCount(), goalStyle.Render(goal)) } fmt.Printf("\n%s\n", lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(fmt.Sprintf("Total: %d plan(s)", len(plans)))) diff --git a/internal/config/paths.go b/internal/config/paths.go index e87b924..e594723 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -38,13 +38,15 @@ var GetGlobalConfigDir = func() (string, error) { // SetProjectContext sets the detected project context for use by GetMemoryBasePath. // This MUST be called during CLI initialization before any command that needs project context. -func SetProjectContext(ctx *project.Context) { +// Returns error if ctx is nil. +func SetProjectContext(ctx *project.Context) error { if ctx == nil { - panic("SetProjectContext called with nil context") + return errors.New("SetProjectContext called with nil context") } projectContextMu.Lock() defer projectContextMu.Unlock() projectContext = ctx + return nil } // ClearProjectContext resets the project context. Only use in tests. @@ -62,14 +64,14 @@ func GetProjectContext() *project.Context { return projectContext } -// MustGetProjectContext returns the project context or panics if not set. -// Use this when project context is required and absence is a programming error. -func MustGetProjectContext() *project.Context { +// GetProjectContextOrError returns the project context or an error if not set. +// Use this when project context is required. +func GetProjectContextOrError() (*project.Context, error) { ctx := GetProjectContext() if ctx == nil { - panic(ErrProjectContextNotSet) + return nil, ErrProjectContextNotSet } - return ctx + return ctx, nil } // DetectAndSetProjectContext detects the project root and sets it. @@ -90,7 +92,9 @@ func DetectAndSetProjectContext() (*project.Context, error) { return nil, fmt.Errorf("%w: %v", ErrDetectionFailed, err) } - SetProjectContext(ctx) + if err := SetProjectContext(ctx); err != nil { + return nil, fmt.Errorf("set project context: %w", err) + } return ctx, nil } @@ -133,19 +137,18 @@ func GetMemoryBasePath() (string, error) { // // ALL OTHER COMMANDS should use GetMemoryBasePath() which enforces fail-fast behavior. // Using this function inappropriately masks project detection failures. -func GetMemoryBasePathOrGlobal() string { +func GetMemoryBasePathOrGlobal() (string, error) { path, err := GetMemoryBasePath() if err == nil { - return path + return path, nil } // Only fall back to global for non-project commands dir, err := GetGlobalConfigDir() if err != nil { - // This is a critical failure - can't determine any valid path - panic(fmt.Sprintf("cannot determine memory path: %v", err)) + return "", fmt.Errorf("cannot determine memory path: %w", err) } - return filepath.Join(dir, "memory") + return filepath.Join(dir, "memory"), nil } // GetProjectRoot returns the detected project root path. @@ -160,13 +163,3 @@ func GetProjectRoot() (string, error) { } return ctx.RootPath, nil } - -// MustGetProjectRoot returns the project root or panics. -// Use when project root is required and absence is a programming error. -func MustGetProjectRoot() string { - root, err := GetProjectRoot() - if err != nil { - panic(err) - } - return root -} diff --git a/internal/config/paths_test.go b/internal/config/paths_test.go new file mode 100644 index 0000000..c75dbbd --- /dev/null +++ b/internal/config/paths_test.go @@ -0,0 +1,180 @@ +package config + +import ( + "errors" + "testing" + + "github.com/josephgoksu/TaskWing/internal/project" +) + +func TestSetProjectContext_NilReturnsError(t *testing.T) { + // Clear any existing context + ClearProjectContext() + + err := SetProjectContext(nil) + if err == nil { + t.Fatal("expected error for nil context, got nil") + } + + // Verify error message is helpful + if err.Error() != "SetProjectContext called with nil context" { + t.Errorf("unexpected error message: %s", err.Error()) + } +} + +func TestSetProjectContext_ValidContext(t *testing.T) { + // Clear any existing context + ClearProjectContext() + defer ClearProjectContext() + + ctx := &project.Context{ + RootPath: "/test/path", + MarkerType: project.MarkerGit, + } + + err := SetProjectContext(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify context was set + got := GetProjectContext() + if got == nil { + t.Fatal("expected context to be set") + } + if got.RootPath != ctx.RootPath { + t.Errorf("expected RootPath %q, got %q", ctx.RootPath, got.RootPath) + } +} + +func TestGetProjectContextOrError_NotSet(t *testing.T) { + ClearProjectContext() + + ctx, err := GetProjectContextOrError() + if err == nil { + t.Fatal("expected error when context not set") + } + if !errors.Is(err, ErrProjectContextNotSet) { + t.Errorf("expected ErrProjectContextNotSet, got: %v", err) + } + if ctx != nil { + t.Error("expected nil context") + } +} + +func TestGetProjectContextOrError_Set(t *testing.T) { + ClearProjectContext() + defer ClearProjectContext() + + expected := &project.Context{RootPath: "/test"} + _ = SetProjectContext(expected) + + ctx, err := GetProjectContextOrError() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ctx != expected { + t.Error("context does not match expected") + } +} + +func TestGetProjectRoot_NotSet(t *testing.T) { + ClearProjectContext() + + root, err := GetProjectRoot() + if err == nil { + t.Fatal("expected error when context not set") + } + if !errors.Is(err, ErrProjectContextNotSet) { + t.Errorf("expected ErrProjectContextNotSet, got: %v", err) + } + if root != "" { + t.Errorf("expected empty root, got: %s", root) + } +} + +func TestGetProjectRoot_EmptyRootPath(t *testing.T) { + ClearProjectContext() + defer ClearProjectContext() + + ctx := &project.Context{RootPath: ""} + _ = SetProjectContext(ctx) + + root, err := GetProjectRoot() + if err == nil { + t.Fatal("expected error for empty RootPath") + } + if root != "" { + t.Errorf("expected empty root, got: %s", root) + } +} + +func TestGetProjectRoot_Valid(t *testing.T) { + ClearProjectContext() + defer ClearProjectContext() + + expected := "/my/project" + ctx := &project.Context{RootPath: expected} + _ = SetProjectContext(ctx) + + root, err := GetProjectRoot() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if root != expected { + t.Errorf("expected %q, got %q", expected, root) + } +} + +func TestGetMemoryBasePath_NotSet(t *testing.T) { + ClearProjectContext() + + path, err := GetMemoryBasePath() + if err == nil { + t.Fatal("expected error when context not set") + } + if !errors.Is(err, ErrProjectContextNotSet) { + t.Errorf("expected ErrProjectContextNotSet, got: %v", err) + } + if path != "" { + t.Errorf("expected empty path, got: %s", path) + } +} + +func TestGetMemoryBasePathOrGlobal_FallsBackToGlobal(t *testing.T) { + ClearProjectContext() + + // Should fall back to global without error + path, err := GetMemoryBasePathOrGlobal() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if path == "" { + t.Error("expected non-empty path") + } + // Should contain "memory" in the path + if len(path) < 6 || path[len(path)-6:] != "memory" { + t.Errorf("expected path to end with 'memory', got: %s", path) + } +} + +func TestGetMemoryBasePathOrGlobal_GlobalDirError(t *testing.T) { + ClearProjectContext() + + // Save original function + original := GetGlobalConfigDir + defer func() { GetGlobalConfigDir = original }() + + // Mock to return error + GetGlobalConfigDir = func() (string, error) { + return "", errors.New("test error: cannot get home dir") + } + + path, err := GetMemoryBasePathOrGlobal() + if err == nil { + t.Fatal("expected error when global config dir fails") + } + if path != "" { + t.Errorf("expected empty path on error, got: %s", path) + } +} diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index 841246f..d467df9 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -536,6 +536,20 @@ func HandleTaskTool(ctx context.Context, repo *memory.Repository, params TaskToo // handleTaskNext implements the 'next' action - get the next pending task. func handleTaskNext(ctx context.Context, repo *memory.Repository, params TaskToolParams) (*TaskToolResult, error) { + // Validate required fields + sessionID := strings.TrimSpace(params.SessionID) + if sessionID == "" { + return &TaskToolResult{ + Action: "next", + Error: "session_id is required for next action", + Content: FormatMultiValidationError( + "next", + []string{"session_id"}, + "Provide a unique session identifier (e.g., from hook session-init or a UUID).", + ), + }, nil + } + appCtx := app.NewContext(repo) taskApp := app.NewTaskApp(appCtx) @@ -547,7 +561,7 @@ func handleTaskNext(ctx context.Context, repo *memory.Repository, params TaskToo result, err := taskApp.Next(ctx, app.TaskNextOptions{ PlanID: params.PlanID, - SessionID: params.SessionID, + SessionID: sessionID, // Use validated/trimmed value AutoStart: params.AutoStart, CreateBranch: createBranch, SkipUnpushedCheck: params.SkipUnpushedCheck, @@ -567,10 +581,24 @@ func handleTaskNext(ctx context.Context, repo *memory.Repository, params TaskToo // handleTaskCurrent implements the 'current' action - get the current in-progress task. func handleTaskCurrent(ctx context.Context, repo *memory.Repository, params TaskToolParams) (*TaskToolResult, error) { + // Validate required fields + sessionID := strings.TrimSpace(params.SessionID) + if sessionID == "" { + return &TaskToolResult{ + Action: "current", + Error: "session_id is required for current action", + Content: FormatMultiValidationError( + "current", + []string{"session_id"}, + "Provide the session identifier used when starting the task.", + ), + }, nil + } + appCtx := app.NewContext(repo) taskApp := app.NewTaskApp(appCtx) - result, err := taskApp.Current(ctx, params.SessionID, params.PlanID) + result, err := taskApp.Current(ctx, sessionID, params.PlanID) if err != nil { return &TaskToolResult{ Action: "current", diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_test.go index e9091c9..fcce8b9 100644 --- a/internal/mcp/handlers_test.go +++ b/internal/mcp/handlers_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" ) @@ -210,6 +211,54 @@ func TestHandleTaskTool_CompleteMissingTaskID(t *testing.T) { } } +func TestHandleTaskTool_NextMissingSessionID(t *testing.T) { + params := TaskToolParams{ + Action: TaskActionNext, + SessionID: "", // missing + } + + result, err := HandleTaskTool(context.Background(), nil, params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Error == "" { + t.Error("expected error for missing session_id") + } + if !strings.Contains(result.Error, "session_id") { + t.Errorf("error should mention session_id: %s", result.Error) + } + if result.Action != "next" { + t.Errorf("expected action 'next', got %q", result.Action) + } + // Should have actionable guidance in content + if !strings.Contains(result.Content, "session") { + t.Error("content should mention session for guidance") + } +} + +func TestHandleTaskTool_CurrentMissingSessionID(t *testing.T) { + params := TaskToolParams{ + Action: TaskActionCurrent, + SessionID: "", // missing + } + + result, err := HandleTaskTool(context.Background(), nil, params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Error == "" { + t.Error("expected error for missing session_id") + } + if !strings.Contains(result.Error, "session_id") { + t.Errorf("error should mention session_id: %s", result.Error) + } + if result.Action != "current" { + t.Errorf("expected action 'current', got %q", result.Action) + } +} + func TestHandleTaskTool_ActionRouting(t *testing.T) { // Test actions that have validation before hitting the repo tests := []struct { @@ -217,6 +266,8 @@ func TestHandleTaskTool_ActionRouting(t *testing.T) { name string expectError bool }{ + {TaskActionNext, "next", true}, // missing session_id + {TaskActionCurrent, "current", true}, // missing session_id {TaskActionStart, "start", true}, // missing task_id {TaskActionComplete, "complete", true}, // missing task_id } @@ -324,6 +375,71 @@ func TestHandlePlanTool_GenerateMissingEnrichedGoal(t *testing.T) { } } +// TestHandlePlanTool_GenerateErrorContainsFieldDetails validates that validation errors +// contain actionable field-level details to help AI clients self-correct. +func TestHandlePlanTool_GenerateErrorContainsFieldDetails(t *testing.T) { + tests := []struct { + name string + params PlanToolParams + expectedFields []string + fieldCount int // expected number of missing fields + }{ + { + name: "missing_both_fields_lists_both", + params: PlanToolParams{ + Action: PlanActionGenerate, + // Both goal and enriched_goal missing + }, + expectedFields: []string{"goal", "enriched_goal"}, + fieldCount: 2, + }, + { + name: "missing_goal_only_lists_goal", + params: PlanToolParams{ + Action: PlanActionGenerate, + EnrichedGoal: "some enriched goal", + }, + expectedFields: []string{"goal"}, + fieldCount: 1, + }, + { + name: "missing_enriched_goal_only_lists_enriched_goal", + params: PlanToolParams{ + Action: PlanActionGenerate, + Goal: "some goal", + }, + expectedFields: []string{"enriched_goal"}, + fieldCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := HandlePlanTool(context.Background(), nil, tt.params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Error should list missing fields + for _, field := range tt.expectedFields { + if !strings.Contains(result.Error, field) { + t.Errorf("error should contain field %q: %s", field, result.Error) + } + } + + // Verify correct number of fields reported (check bracket contents) + if tt.fieldCount == 1 && strings.Contains(result.Error, ", ") { + t.Errorf("error lists more fields than expected for single missing field: %s", result.Error) + } + + // Content should have actionable guidance + if !strings.Contains(result.Content, "clarify") { + t.Error("content should mention 'clarify' action for guidance") + } + }) + } +} + func TestHandlePlanTool_ActionRouting(t *testing.T) { // Test actions that have validation before hitting the repo tests := []struct { diff --git a/internal/mcp/types.go b/internal/mcp/types.go index 1aaa05f..7133cfa 100644 --- a/internal/mcp/types.go +++ b/internal/mcp/types.go @@ -148,13 +148,19 @@ type CodeToolParams struct { // TaskToolParams defines the parameters for the unified task tool. // Consolidates: task_next, task_current, task_start, task_complete +// +// Required fields by action: +// - next: session_id +// - current: session_id +// - start: task_id, session_id +// - complete: task_id type TaskToolParams struct { // Action specifies which operation to perform. // Required. One of: next, current, start, complete Action TaskAction `json:"action"` // TaskID is the task identifier. - // Required for: start, complete + // REQUIRED for: start, complete (will error if empty for these actions) TaskID string `json:"task_id,omitempty"` // PlanID is the plan identifier. @@ -163,7 +169,7 @@ type TaskToolParams struct { PlanID string `json:"plan_id,omitempty"` // SessionID is the unique AI session identifier. - // Required for: next, current, start + // REQUIRED for: next, current, start (will error if empty for these actions) SessionID string `json:"session_id,omitempty"` // Summary describes what was accomplished. @@ -251,17 +257,22 @@ type DebugToolParams struct { // PlanToolParams defines the parameters for the unified plan tool. // Consolidates: plan_clarify, plan_generate, audit_plan +// +// Required fields by action: +// - clarify: goal +// - generate: goal, enriched_goal (call clarify first to get enriched_goal) +// - audit: none (defaults to active plan) type PlanToolParams struct { // Action specifies which operation to perform. // Required. One of: clarify, generate, audit Action PlanAction `json:"action"` // Goal is the user's development goal. - // Required for: clarify, generate + // REQUIRED for: clarify, generate (will error if empty for these actions) Goal string `json:"goal,omitempty"` // EnrichedGoal is the full technical specification from clarify. - // Required for: generate + // REQUIRED for: generate (will error if empty; call clarify first to get this) EnrichedGoal string `json:"enriched_goal,omitempty"` // History is a JSON array of previous Q&A from clarify loop. diff --git a/internal/memory/rows_err_test.go b/internal/memory/rows_err_test.go index 947f4b2..fd94cd5 100644 --- a/internal/memory/rows_err_test.go +++ b/internal/memory/rows_err_test.go @@ -578,20 +578,6 @@ func TestMockRowsErr(t *testing.T) { } } -// scannerMock implements sql.Rows-like interface for testing error scenarios -type scannerMock struct { - scanErr error - rowsErr error -} - -func (s *scannerMock) Scan(dest ...any) error { - return s.scanErr -} - -func (s *scannerMock) Err() error { - return s.rowsErr -} - // TestCheckRowsErrHelper_FunctionExists verifies the helper function works. func TestCheckRowsErrHelper_FunctionExists(t *testing.T) { tmpDir, err := os.MkdirTemp("", "taskwing-rows-err-test-*") diff --git a/internal/memory/sqlite.go b/internal/memory/sqlite.go index 031806d..7b909ac 100644 --- a/internal/memory/sqlite.go +++ b/internal/memory/sqlite.go @@ -640,7 +640,7 @@ func (s *SQLiteStore) CreateFeature(f Feature) error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "sqlite") }() // Insert into SQLite _, err = tx.Exec(` @@ -785,7 +785,7 @@ func (s *SQLiteStore) AddDecision(featureID string, d Decision) error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "sqlite") }() _, err = tx.Exec(` INSERT INTO decisions (id, feature_id, title, summary, reasoning, tradeoffs, created_at) @@ -847,7 +847,7 @@ func (s *SQLiteStore) DeleteDecision(id string) error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "sqlite") }() _, err = tx.Exec("DELETE FROM decisions WHERE id = ?", id) if err != nil { @@ -1643,7 +1643,7 @@ func (s *SQLiteStore) ClearAllKnowledge() error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "sqlite") }() // Clear in order respecting foreign key constraints tables := []string{"node_edges", "nodes", "decisions", "patterns", "features"} @@ -1685,7 +1685,7 @@ func (s *SQLiteStore) UpsertNodeBySummary(n Node) error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "sqlite") }() // First check if node with exact summary+agent exists var existingID string diff --git a/internal/memory/task_store.go b/internal/memory/task_store.go index 7b07a94..976fa68 100644 --- a/internal/memory/task_store.go +++ b/internal/memory/task_store.go @@ -3,7 +3,9 @@ package memory import ( "database/sql" "encoding/json" + "errors" "fmt" + "log" "strings" "time" @@ -11,6 +13,14 @@ import ( "github.com/josephgoksu/TaskWing/internal/task" ) +// rollbackWithLog attempts rollback and logs non-ErrTxDone errors at warn level. +// This ensures transaction cleanup failures are visible without masking the original error. +func rollbackWithLog(tx *sql.Tx, context string) { + if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { + log.Printf("[WARN] rollback failed (%s): %v", context, err) + } +} + // nullTimeString returns nil for zero time, RFC3339 string otherwise func nullTimeString(t time.Time) interface{} { if t.IsZero() { @@ -27,19 +37,37 @@ type txExecutor interface { // insertTaskTx inserts a task and its relations within a transaction. // This is the SINGLE source of truth for task insertion logic. func insertTaskTx(tx txExecutor, t *task.Task) error { - acJSON, _ := json.Marshal(t.AcceptanceCriteria) - vsJSON, _ := json.Marshal(t.ValidationSteps) - keywordsJSON, _ := json.Marshal(t.Keywords) - queriesJSON, _ := json.Marshal(t.SuggestedRecallQueries) - filesJSON, _ := json.Marshal(t.FilesModified) - expectedFilesJSON, _ := json.Marshal(t.ExpectedFiles) + acJSON, err := json.Marshal(t.AcceptanceCriteria) + if err != nil { + return fmt.Errorf("marshal acceptance_criteria for task %s: %w", t.ID, err) + } + vsJSON, err := json.Marshal(t.ValidationSteps) + if err != nil { + return fmt.Errorf("marshal validation_steps for task %s: %w", t.ID, err) + } + keywordsJSON, err := json.Marshal(t.Keywords) + if err != nil { + return fmt.Errorf("marshal keywords for task %s: %w", t.ID, err) + } + queriesJSON, err := json.Marshal(t.SuggestedRecallQueries) + if err != nil { + return fmt.Errorf("marshal suggested_recall_queries for task %s: %w", t.ID, err) + } + filesJSON, err := json.Marshal(t.FilesModified) + if err != nil { + return fmt.Errorf("marshal files_modified for task %s: %w", t.ID, err) + } + expectedFilesJSON, err := json.Marshal(t.ExpectedFiles) + if err != nil { + return fmt.Errorf("marshal expected_files for task %s: %w", t.ID, err) + } var parentID interface{} if t.ParentTaskID != "" { parentID = t.ParentTaskID } - _, err := tx.Exec(` + _, err = tx.Exec(` INSERT INTO tasks ( id, plan_id, title, description, acceptance_criteria, validation_steps, @@ -108,7 +136,7 @@ func (s *SQLiteStore) CreatePlan(p *task.Plan) error { if err != nil { return fmt.Errorf("begin tx: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "task_store") }() if _, err = tx.Exec(` INSERT INTO plans (id, goal, enriched_goal, status, created_at, updated_at) @@ -189,9 +217,11 @@ func (s *SQLiteStore) ListPlans() ([]task.Plan, error) { if lastAuditReport.Valid { p.LastAuditReport = lastAuditReport.String } - // Store task count in a placeholder slice (just for count display) - // This avoids loading all tasks but allows len(p.Tasks) to work - p.Tasks = make([]task.Task, taskCount) + // Store task count for efficient list views without loading all tasks. + // Use plan.GetTaskCount() to get the count regardless of how the plan was loaded. + p.TaskCount = taskCount + // Leave Tasks nil - callers should use GetTaskCount() for counts, + // or call GetPlanWithTasks() if they need actual task data. plans = append(plans, p) } if err := checkRowsErr(rows); err != nil { @@ -258,7 +288,7 @@ func (s *SQLiteStore) UpdatePlanAuditReport(id string, status task.PlanStatus, a if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "task_store") }() now := time.Now().UTC().Format(time.RFC3339) @@ -316,7 +346,7 @@ func (s *SQLiteStore) CreateTask(t *task.Task) error { if err != nil { return fmt.Errorf("begin tx: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "task_store") }() if err := insertTaskTx(tx, t); err != nil { return err @@ -912,7 +942,7 @@ func (s *SQLiteStore) SetActivePlan(id string) error { if err != nil { return fmt.Errorf("begin tx: %w", err) } - defer func() { _ = tx.Rollback() }() + defer func() { rollbackWithLog(tx, "task_store") }() now := time.Now().UTC().Format(time.RFC3339) diff --git a/internal/memory/task_store_test.go b/internal/memory/task_store_test.go new file mode 100644 index 0000000..b02b877 --- /dev/null +++ b/internal/memory/task_store_test.go @@ -0,0 +1,110 @@ +package memory + +import ( + "os" + "path/filepath" + "testing" + + "github.com/josephgoksu/TaskWing/internal/task" +) + +func TestListPlans_TaskCountNotPlaceholderSlice(t *testing.T) { + // Create a temporary database + tmpDir, err := os.MkdirTemp("", "taskwing-test-*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "memory.db") + store, err := NewSQLiteStore(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + defer store.Close() + + // Create a plan + plan := &task.Plan{ + ID: "plan-test-123", + Goal: "Test goal", + EnrichedGoal: "Enriched test goal", + Status: task.PlanStatusActive, + } + if err := store.CreatePlan(plan); err != nil { + t.Fatalf("create plan: %v", err) + } + + // Create some tasks for this plan + tasks := []task.Task{ + {ID: "task-1", PlanID: plan.ID, Title: "Task 1", Description: "Desc 1", Status: task.StatusPending, Priority: 80}, + {ID: "task-2", PlanID: plan.ID, Title: "Task 2", Description: "Desc 2", Status: task.StatusCompleted, Priority: 70}, + {ID: "task-3", PlanID: plan.ID, Title: "Task 3", Description: "Desc 3", Status: task.StatusInProgress, Priority: 60}, + } + for _, tsk := range tasks { + if err := store.CreateTask(&tsk); err != nil { + t.Fatalf("create task %s: %v", tsk.ID, err) + } + } + + // List plans + plans, err := store.ListPlans() + if err != nil { + t.Fatalf("list plans: %v", err) + } + + if len(plans) != 1 { + t.Fatalf("expected 1 plan, got %d", len(plans)) + } + + p := plans[0] + + // Verify TaskCount is set correctly + if p.TaskCount != 3 { + t.Errorf("expected TaskCount=3, got %d", p.TaskCount) + } + + // Verify Tasks slice is nil (not placeholder slice) + if p.Tasks != nil { + t.Errorf("expected Tasks to be nil, got slice of length %d", len(p.Tasks)) + } + + // Verify GetTaskCount() returns correct value + if p.GetTaskCount() != 3 { + t.Errorf("expected GetTaskCount()=3, got %d", p.GetTaskCount()) + } + + // Verify iterating over Tasks doesn't yield misleading zero-value structs + for i, tsk := range p.Tasks { + // This loop should not execute since Tasks is nil + t.Errorf("unexpected task at index %d: %+v", i, tsk) + } +} + +func TestGetTaskCount_FallsBackToTasksLength(t *testing.T) { + // When TaskCount is 0 but Tasks are populated, use len(Tasks) + plan := &task.Plan{ + ID: "test-plan", + TaskCount: 0, // Not set + Tasks: []task.Task{ + {ID: "t1", Title: "Task 1"}, + {ID: "t2", Title: "Task 2"}, + }, + } + + if plan.GetTaskCount() != 2 { + t.Errorf("expected GetTaskCount()=2 (from len(Tasks)), got %d", plan.GetTaskCount()) + } +} + +func TestGetTaskCount_UsesTaskCountIfSet(t *testing.T) { + // When TaskCount is set, use it regardless of Tasks + plan := &task.Plan{ + ID: "test-plan", + TaskCount: 5, + Tasks: nil, // Not loaded + } + + if plan.GetTaskCount() != 5 { + t.Errorf("expected GetTaskCount()=5 (from TaskCount), got %d", plan.GetTaskCount()) + } +} diff --git a/internal/task/models.go b/internal/task/models.go index f4c506d..16fb9af 100644 --- a/internal/task/models.go +++ b/internal/task/models.go @@ -151,6 +151,7 @@ type Plan struct { EnrichedGoal string `json:"enrichedGoal"` // Full technical specification refined by Clarifying Agent Status PlanStatus `json:"status"` // draft, active, completed, verified, needs_revision, archived Tasks []Task `json:"tasks"` + TaskCount int `json:"taskCount,omitempty"` // Precomputed count for list views (avoids loading all tasks) CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` @@ -158,6 +159,17 @@ type Plan struct { LastAuditReport string `json:"lastAuditReport,omitempty"` // JSON-serialized AuditReport } +// GetTaskCount returns the number of tasks in this plan. +// It uses TaskCount if set (from ListPlans), otherwise falls back to len(Tasks). +// This handles both cases: ListPlans (which sets TaskCount but not Tasks) +// and GetPlanWithTasks (which populates Tasks but not TaskCount). +func (p *Plan) GetTaskCount() int { + if p.TaskCount > 0 { + return p.TaskCount + } + return len(p.Tasks) +} + // Scope customization is available via .taskwing.yaml or ~/.taskwing/config.yaml: // // task: diff --git a/internal/util/id_test.go b/internal/util/id_test.go index 31a3f54..6316abb 100644 --- a/internal/util/id_test.go +++ b/internal/util/id_test.go @@ -285,7 +285,7 @@ func containsError(err, target error) bool { } return err.Error() == target.Error() || len(err.Error()) > len(target.Error()) && - err.Error()[len(err.Error())-len(target.Error()):] == target.Error() + err.Error()[len(err.Error())-len(target.Error()):] == target.Error() } func TestAmbiguousErrorMessage(t *testing.T) {