Skip to content
Merged
5 changes: 4 additions & 1 deletion cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ func detectProjectRoot() *project.Context {
}

// Store in config package for GetMemoryBasePath and other consumers
config.SetProjectContext(ctx)
if err := config.SetProjectContext(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to set project context: %v\n", err)
return nil
}

// Log in verbose mode
if viper.GetBool("verbose") && ctx.RootPath != cwd {
Expand Down
25 changes: 19 additions & 6 deletions cmd/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,32 @@ Tasks Completed: %d
`, session.SessionID, int(elapsed.Minutes()), session.TasksCompleted)

// Remove session file
sessionPath := getHookSessionPath()
_ = os.Remove(sessionPath)
sessionPath, err := getHookSessionPath()
if err == nil {
_ = os.Remove(sessionPath)
}

return nil
}

// Session persistence helpers

func getHookSessionPath() string {
func getHookSessionPath() (string, error) {
// Hook commands use GetMemoryBasePathOrGlobal since they may run
// before project context is fully established (e.g., SessionStart)
return filepath.Join(config.GetMemoryBasePathOrGlobal(), "hook_session.json")
memoryPath, err := config.GetMemoryBasePathOrGlobal()
if err != nil {
return "", fmt.Errorf("get memory path: %w", err)
}
return filepath.Join(memoryPath, "hook_session.json"), nil
}

func loadHookSession() (*HookSession, error) {
data, err := os.ReadFile(getHookSessionPath())
sessionPath, err := getHookSessionPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(sessionPath)
if err != nil {
return nil, err
}
Expand All @@ -470,7 +480,10 @@ func saveHookSession(session *HookSession) error {
return err
}

sessionPath := getHookSessionPath()
sessionPath, err := getHookSessionPath()
if err != nil {
return err
}
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(sessionPath), 0755); err != nil {
return fmt.Errorf("create session directory: %w", err)
Expand Down
20 changes: 17 additions & 3 deletions cmd/mcp_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ func mcpFormattedErrorResponse(formattedError string) (*mcpsdk.CallToolResultFor
func initMCPRepository() (*memory.Repository, error) {
// MCP server is a special case - it may run in sandboxed environments
// where project context isn't available. Use the fallback-enabled path.
memoryPath := config.GetMemoryBasePathOrGlobal()
memoryPath, err := config.GetMemoryBasePathOrGlobal()
if err != nil {
return nil, fmt.Errorf("determine memory path: %w", err)
}

repo, err := memory.NewDefaultRepository(memoryPath)
if err != nil {
Expand Down Expand Up @@ -214,7 +217,13 @@ func runMCPServer(ctx context.Context) error {
- next: Get next pending task from plan (use auto_start=true to claim immediately)
- current: Get current in-progress task for session
- start: Claim a specific task by ID
- complete: Mark task as completed with summary`,
- complete: Mark task as completed with summary

REQUIRED FIELDS BY ACTION:
- next: session_id (required)
- current: session_id (required)
- start: task_id (required), session_id (required)
- complete: task_id (required)`,
}
mcpsdk.AddTool(server, taskTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.TaskToolParams]) (*mcpsdk.CallToolResultFor[any], error) {
result, err := mcppresenter.HandleTaskTool(ctx, repo, params.Arguments)
Expand All @@ -233,7 +242,12 @@ func runMCPServer(ctx context.Context) error {
Description: `Unified plan creation tool. Use action parameter to select operation:
- clarify: Refine goal with clarifying questions (loop until is_ready_to_plan=true)
- generate: Create plan with tasks from enriched goal
- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures)`,
- audit: Verify completed plan with build/test/semantic checks (auto-fixes failures)

REQUIRED FIELDS BY ACTION:
- clarify: goal (required)
- generate: goal (required), enriched_goal (required) - call clarify first to get enriched_goal
- audit: none required (defaults to active plan)`,
}
mcpsdk.AddTool(server, planTool, func(ctx context.Context, session *mcpsdk.ServerSession, params *mcpsdk.CallToolParamsFor[mcppresenter.PlanToolParams]) (*mcpsdk.CallToolResultFor[any], error) {
result, err := mcppresenter.HandlePlanTool(ctx, repo, params.Arguments)
Expand Down
7 changes: 3 additions & 4 deletions cmd/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,12 @@ func printPlanTable(plans []task.Plan) {
goal = goal[:57] + "..."
}
// Tasks count - service ListPlans probably returns plans without tasks or with?
// task.Repository interface implies ListPlans returns []Plan which contains Tasks?
// SQLite implementation usually does. If not, we might check tasks length.
// Assuming populated for now or length 0.
// ListPlans sets TaskCount but leaves Tasks nil for efficiency.
// Use GetTaskCount() to get the count regardless of how the plan was loaded.
fmt.Printf("%-18s %-12s %-6d %s\n",
idStyle.Render(p.ID),
dateStyle.Render(p.CreatedAt.Format("2006-01-02")),
len(p.Tasks),
p.GetTaskCount(),
goalStyle.Render(goal))
}
fmt.Printf("\n%s\n", lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(fmt.Sprintf("Total: %d plan(s)", len(plans))))
Expand Down
39 changes: 16 additions & 23 deletions internal/config/paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ var GetGlobalConfigDir = func() (string, error) {

// SetProjectContext sets the detected project context for use by GetMemoryBasePath.
// This MUST be called during CLI initialization before any command that needs project context.
func SetProjectContext(ctx *project.Context) {
// Returns error if ctx is nil.
func SetProjectContext(ctx *project.Context) error {
if ctx == nil {
panic("SetProjectContext called with nil context")
return errors.New("SetProjectContext called with nil context")
}
projectContextMu.Lock()
defer projectContextMu.Unlock()
projectContext = ctx
return nil
}

// ClearProjectContext resets the project context. Only use in tests.
Expand All @@ -62,14 +64,14 @@ func GetProjectContext() *project.Context {
return projectContext
}

// MustGetProjectContext returns the project context or panics if not set.
// Use this when project context is required and absence is a programming error.
func MustGetProjectContext() *project.Context {
// GetProjectContextOrError returns the project context or an error if not set.
// Use this when project context is required.
func GetProjectContextOrError() (*project.Context, error) {
ctx := GetProjectContext()
if ctx == nil {
panic(ErrProjectContextNotSet)
return nil, ErrProjectContextNotSet
}
return ctx
return ctx, nil
}

// DetectAndSetProjectContext detects the project root and sets it.
Expand All @@ -90,7 +92,9 @@ func DetectAndSetProjectContext() (*project.Context, error) {
return nil, fmt.Errorf("%w: %v", ErrDetectionFailed, err)
}

SetProjectContext(ctx)
if err := SetProjectContext(ctx); err != nil {
return nil, fmt.Errorf("set project context: %w", err)
}
return ctx, nil
}

Expand Down Expand Up @@ -133,19 +137,18 @@ func GetMemoryBasePath() (string, error) {
//
// ALL OTHER COMMANDS should use GetMemoryBasePath() which enforces fail-fast behavior.
// Using this function inappropriately masks project detection failures.
func GetMemoryBasePathOrGlobal() string {
func GetMemoryBasePathOrGlobal() (string, error) {
path, err := GetMemoryBasePath()
if err == nil {
return path
return path, nil
}

// Only fall back to global for non-project commands
dir, err := GetGlobalConfigDir()
if err != nil {
// This is a critical failure - can't determine any valid path
panic(fmt.Sprintf("cannot determine memory path: %v", err))
return "", fmt.Errorf("cannot determine memory path: %w", err)
}
return filepath.Join(dir, "memory")
return filepath.Join(dir, "memory"), nil
}

// GetProjectRoot returns the detected project root path.
Expand All @@ -160,13 +163,3 @@ func GetProjectRoot() (string, error) {
}
return ctx.RootPath, nil
}

// MustGetProjectRoot returns the project root or panics.
// Use when project root is required and absence is a programming error.
func MustGetProjectRoot() string {
root, err := GetProjectRoot()
if err != nil {
panic(err)
}
return root
}
180 changes: 180 additions & 0 deletions internal/config/paths_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package config

import (
"errors"
"testing"

"github.com/josephgoksu/TaskWing/internal/project"
)

func TestSetProjectContext_NilReturnsError(t *testing.T) {
// Clear any existing context
ClearProjectContext()

err := SetProjectContext(nil)
if err == nil {
t.Fatal("expected error for nil context, got nil")
}

// Verify error message is helpful
if err.Error() != "SetProjectContext called with nil context" {
t.Errorf("unexpected error message: %s", err.Error())
}
}

func TestSetProjectContext_ValidContext(t *testing.T) {
// Clear any existing context
ClearProjectContext()
defer ClearProjectContext()

ctx := &project.Context{
RootPath: "/test/path",
MarkerType: project.MarkerGit,
}

err := SetProjectContext(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Verify context was set
got := GetProjectContext()
if got == nil {
t.Fatal("expected context to be set")
}
if got.RootPath != ctx.RootPath {
t.Errorf("expected RootPath %q, got %q", ctx.RootPath, got.RootPath)
}
}

func TestGetProjectContextOrError_NotSet(t *testing.T) {
ClearProjectContext()

ctx, err := GetProjectContextOrError()
if err == nil {
t.Fatal("expected error when context not set")
}
if !errors.Is(err, ErrProjectContextNotSet) {
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
}
if ctx != nil {
t.Error("expected nil context")
}
}

func TestGetProjectContextOrError_Set(t *testing.T) {
ClearProjectContext()
defer ClearProjectContext()

expected := &project.Context{RootPath: "/test"}
_ = SetProjectContext(expected)

ctx, err := GetProjectContextOrError()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ctx != expected {
t.Error("context does not match expected")
}
}

func TestGetProjectRoot_NotSet(t *testing.T) {
ClearProjectContext()

root, err := GetProjectRoot()
if err == nil {
t.Fatal("expected error when context not set")
}
if !errors.Is(err, ErrProjectContextNotSet) {
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
}
if root != "" {
t.Errorf("expected empty root, got: %s", root)
}
}

func TestGetProjectRoot_EmptyRootPath(t *testing.T) {
ClearProjectContext()
defer ClearProjectContext()

ctx := &project.Context{RootPath: ""}
_ = SetProjectContext(ctx)

root, err := GetProjectRoot()
if err == nil {
t.Fatal("expected error for empty RootPath")
}
if root != "" {
t.Errorf("expected empty root, got: %s", root)
}
}

func TestGetProjectRoot_Valid(t *testing.T) {
ClearProjectContext()
defer ClearProjectContext()

expected := "/my/project"
ctx := &project.Context{RootPath: expected}
_ = SetProjectContext(ctx)

root, err := GetProjectRoot()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if root != expected {
t.Errorf("expected %q, got %q", expected, root)
}
}

func TestGetMemoryBasePath_NotSet(t *testing.T) {
ClearProjectContext()

path, err := GetMemoryBasePath()
if err == nil {
t.Fatal("expected error when context not set")
}
if !errors.Is(err, ErrProjectContextNotSet) {
t.Errorf("expected ErrProjectContextNotSet, got: %v", err)
}
if path != "" {
t.Errorf("expected empty path, got: %s", path)
}
}

func TestGetMemoryBasePathOrGlobal_FallsBackToGlobal(t *testing.T) {
ClearProjectContext()

// Should fall back to global without error
path, err := GetMemoryBasePathOrGlobal()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if path == "" {
t.Error("expected non-empty path")
}
// Should contain "memory" in the path
if len(path) < 6 || path[len(path)-6:] != "memory" {
t.Errorf("expected path to end with 'memory', got: %s", path)
}
}

func TestGetMemoryBasePathOrGlobal_GlobalDirError(t *testing.T) {
ClearProjectContext()

// Save original function
original := GetGlobalConfigDir
defer func() { GetGlobalConfigDir = original }()

// Mock to return error
GetGlobalConfigDir = func() (string, error) {
return "", errors.New("test error: cannot get home dir")
}

path, err := GetMemoryBasePathOrGlobal()
if err == nil {
t.Fatal("expected error when global config dir fails")
}
if path != "" {
t.Errorf("expected empty path on error, got: %s", path)
}
}
Loading