From 6e9df68c66f3c2b0a27b9d82a91359e88db4a3ea Mon Sep 17 00:00:00 2001 From: Adam Bowker Date: Thu, 9 Apr 2026 11:02:05 -0700 Subject: [PATCH] feat(code): warn on local task branch mismatch --- .../components/MessageEditor.tsx | 3 + .../message-editor/tiptap/useTiptapEditor.ts | 35 ++- .../sessions/components/SessionView.tsx | 3 + .../components/BranchMismatchDialog.tsx | 148 +++++++++++ .../task-detail/components/TaskLogsPanel.tsx | 11 + .../workspace/hooks/useBranchMismatch.test.ts | 138 ++++++++++ .../workspace/hooks/useBranchMismatch.ts | 62 +++++ .../hooks/useBranchMismatchDialog.test.ts | 240 ++++++++++++++++++ .../hooks/useBranchMismatchDialog.ts | 116 +++++++++ 9 files changed, 747 insertions(+), 9 deletions(-) create mode 100644 apps/code/src/renderer/features/task-detail/components/BranchMismatchDialog.tsx create mode 100644 apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.test.ts create mode 100644 apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.ts create mode 100644 apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.test.ts create mode 100644 apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.ts diff --git a/apps/code/src/renderer/features/message-editor/components/MessageEditor.tsx b/apps/code/src/renderer/features/message-editor/components/MessageEditor.tsx index 19973c11a..d13d95745 100644 --- a/apps/code/src/renderer/features/message-editor/components/MessageEditor.tsx +++ b/apps/code/src/renderer/features/message-editor/components/MessageEditor.tsx @@ -136,6 +136,7 @@ function ModeAndBranchRow({ interface MessageEditorProps { sessionId: string; placeholder?: string; + onBeforeSubmit?: (text: string, clearEditor: () => void) => boolean; onSubmit?: (text: string) => void; onBashCommand?: (command: string) => void; onBashModeChange?: (isBashMode: boolean) => void; @@ -154,6 +155,7 @@ export const MessageEditor = forwardRef( { sessionId, placeholder = "Type a message... @ to mention files, ! for bash mode, / for skills", + onBeforeSubmit, onSubmit, onBashCommand, onBashModeChange, @@ -213,6 +215,7 @@ export const MessageEditor = forwardRef( context: { taskId, repoPath }, getPromptHistory, capabilities: { bashMode: !isCloud }, + onBeforeSubmit, onSubmit, onBashCommand, onBashModeChange, diff --git a/apps/code/src/renderer/features/message-editor/tiptap/useTiptapEditor.ts b/apps/code/src/renderer/features/message-editor/tiptap/useTiptapEditor.ts index a5752157a..5de5e2801 100644 --- a/apps/code/src/renderer/features/message-editor/tiptap/useTiptapEditor.ts +++ b/apps/code/src/renderer/features/message-editor/tiptap/useTiptapEditor.ts @@ -29,6 +29,7 @@ export interface UseTiptapEditorOptions { }; clearOnSubmit?: boolean; getPromptHistory?: () => string[]; + onBeforeSubmit?: (text: string, clearEditor: () => void) => boolean; onSubmit?: (text: string) => void; onBashCommand?: (command: string) => void; onBashModeChange?: (isBashMode: boolean) => void; @@ -84,6 +85,7 @@ export function useTiptapEditor(options: UseTiptapEditorOptions) { capabilities = {}, clearOnSubmit = true, getPromptHistory, + onBeforeSubmit, onSubmit, onBashCommand, onBashModeChange, @@ -99,6 +101,7 @@ export function useTiptapEditor(options: UseTiptapEditorOptions) { } = capabilities; const callbackRefs = useRef({ + onBeforeSubmit, onSubmit, onBashCommand, onBashModeChange, @@ -107,6 +110,7 @@ export function useTiptapEditor(options: UseTiptapEditorOptions) { onBlur, }); callbackRefs.current = { + onBeforeSubmit, onSubmit, onBashCommand, onBashModeChange, @@ -450,8 +454,19 @@ export function useTiptapEditor(options: UseTiptapEditorOptions) { const text = editor.getText().trim(); + const doClear = () => { + if (!clearOnSubmit) return; + editor.commands.clearContent(); + prevBashModeRef.current = false; + pasteCountRef.current = 0; + setAttachments([]); + draft.clearDraft(); + }; + if (enableBashMode && text.startsWith("!")) { - // Bash mode requires immediate execution, can't be queued + // Bash mode requires immediate execution, can't be queued. + // Intentionally bypasses onBeforeSubmit — bash commands run inline and + // cannot be deferred the way normal prompts can. if (isLoading) { toast.error("Cannot run shell commands while agent is generating"); return; @@ -459,17 +474,19 @@ export function useTiptapEditor(options: UseTiptapEditorOptions) { const command = text.slice(1).trim(); if (command) callbackRefs.current.onBashCommand?.(command); } else { + const serialized = contentToXml(content); + + if (callbackRefs.current.onBeforeSubmit) { + if (!callbackRefs.current.onBeforeSubmit(serialized, doClear)) { + return; + } + } + // Normal prompts can be queued when loading - callbackRefs.current.onSubmit?.(contentToXml(content)); + callbackRefs.current.onSubmit?.(serialized); } - if (clearOnSubmit) { - editor.commands.clearContent(); - prevBashModeRef.current = false; - pasteCountRef.current = 0; - setAttachments([]); - draft.clearDraft(); - } + doClear(); }, [ editor, disabled, diff --git a/apps/code/src/renderer/features/sessions/components/SessionView.tsx b/apps/code/src/renderer/features/sessions/components/SessionView.tsx index 5ce4716fc..6907b2159 100644 --- a/apps/code/src/renderer/features/sessions/components/SessionView.tsx +++ b/apps/code/src/renderer/features/sessions/components/SessionView.tsx @@ -38,6 +38,7 @@ interface SessionViewProps { isRunning: boolean; isPromptPending?: boolean | null; promptStartedAt?: number | null; + onBeforeSubmit?: (text: string, clearEditor: () => void) => boolean; onSendPrompt: (text: string) => void; onBashCommand?: (command: string) => void; onCancelPrompt: () => void; @@ -73,6 +74,7 @@ export function SessionView({ isRunning, isPromptPending = false, promptStartedAt, + onBeforeSubmit, onSendPrompt, onBashCommand, onCancelPrompt, @@ -538,6 +540,7 @@ export function SessionView({ ref={editorRef} sessionId={sessionId} placeholder="Type a message... @ to mention files, ! for bash mode, / for skills" + onBeforeSubmit={onBeforeSubmit} onSubmit={handleSubmit} onBashCommand={onBashCommand} onCancel={onCancelPrompt} diff --git a/apps/code/src/renderer/features/task-detail/components/BranchMismatchDialog.tsx b/apps/code/src/renderer/features/task-detail/components/BranchMismatchDialog.tsx new file mode 100644 index 000000000..2553a0a56 --- /dev/null +++ b/apps/code/src/renderer/features/task-detail/components/BranchMismatchDialog.tsx @@ -0,0 +1,148 @@ +import { GitBranch, Warning } from "@phosphor-icons/react"; +import { + AlertDialog, + Button, + Callout, + Code, + Flex, + Text, +} from "@radix-ui/themes"; + +interface BranchMismatchDialogProps { + open: boolean; + linkedBranch: string; + currentBranch: string; + hasUncommittedChanges: boolean; + switchError: string | null; + onSwitch: () => void; + onContinue: () => void; + onCancel: () => void; + isSwitching?: boolean; +} + +function BranchLabel({ name }: { name: string }) { + return ( + + + + {name} + + + ); +} + +export function BranchMismatchDialog({ + open, + linkedBranch, + currentBranch, + hasUncommittedChanges, + switchError, + onSwitch, + onContinue, + onCancel, + isSwitching, +}: BranchMismatchDialogProps) { + return ( + { + if (!isOpen) onCancel(); + }} + > + + + + + Wrong branch + + + + This task is linked to a different branch than the one you're + currently on. The agent will make changes on the current branch. + + + + + Linked + + + + + + Current + + + + + + {hasUncommittedChanges && !switchError && ( + + + You have uncommitted changes on your current branch. If needed, + commit or stash them first. + + + )} + + {switchError && ( + + {switchError} + + )} + + + + + + + + + + + + + + + ); +} diff --git a/apps/code/src/renderer/features/task-detail/components/TaskLogsPanel.tsx b/apps/code/src/renderer/features/task-detail/components/TaskLogsPanel.tsx index 1347c7347..1f370f8fb 100644 --- a/apps/code/src/renderer/features/task-detail/components/TaskLogsPanel.tsx +++ b/apps/code/src/renderer/features/task-detail/components/TaskLogsPanel.tsx @@ -15,7 +15,9 @@ import { useSessionConnection } from "@features/sessions/hooks/useSessionConnect import { useSessionViewState } from "@features/sessions/hooks/useSessionViewState"; import { useRestoreTask } from "@features/suspension/hooks/useRestoreTask"; import { useSuspendedTaskIds } from "@features/suspension/hooks/useSuspendedTaskIds"; +import { BranchMismatchDialog } from "@features/task-detail/components/BranchMismatchDialog"; import { WorkspaceSetupPrompt } from "@features/task-detail/components/WorkspaceSetupPrompt"; +import { useBranchMismatchDialog } from "@features/workspace/hooks/useBranchMismatchDialog"; import { useCreateWorkspace, useWorkspaceLoaded, @@ -81,6 +83,12 @@ export function TaskLogsPanel({ taskId, task, hideInput }: TaskLogsPanelProps) { handleBashCommand, } = useSessionCallbacks({ taskId, task, session, repoPath }); + const { handleBeforeSubmit, dialogProps } = useBranchMismatchDialog({ + taskId, + repoPath, + onSendPrompt: handleSendPrompt, + }); + const cloudOutput = session?.cloudOutput ?? null; const prUrl = isCloud && cloudOutput?.pr_url ? (cloudOutput.pr_url as string) : null; @@ -147,6 +155,7 @@ export function TaskLogsPanel({ taskId, task, hideInput }: TaskLogsPanelProps) { isRestoring={isRestoring} isPromptPending={isPromptPending} promptStartedAt={promptStartedAt} + onBeforeSubmit={handleBeforeSubmit} onSendPrompt={handleSendPrompt} onBashCommand={isCloud ? undefined : handleBashCommand} onCancelPrompt={handleCancelPrompt} @@ -165,6 +174,8 @@ export function TaskLogsPanel({ taskId, task, hideInput }: TaskLogsPanelProps) { + + {dialogProps && } ); } diff --git a/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.test.ts b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.test.ts new file mode 100644 index 000000000..3aef24359 --- /dev/null +++ b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.test.ts @@ -0,0 +1,138 @@ +import { act, renderHook } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const mockUseWorkspace = vi.hoisted(() => + vi.fn((): { branchName: string; linkedBranch: string | null } | null => null), +); +vi.mock("./useWorkspace", () => ({ useWorkspace: mockUseWorkspace })); + +import { + useBranchMismatchGuard, + useBranchWarningStore, +} from "./useBranchMismatch"; + +describe("useBranchWarningStore", () => { + beforeEach(() => { + useBranchWarningStore.setState({ dismissed: {} }); + }); + + it("starts with no dismissed tasks", () => { + expect(useBranchWarningStore.getState().dismissed).toEqual({}); + }); + + it("dismiss marks task as dismissed", () => { + useBranchWarningStore.getState().dismiss("task-1"); + expect(useBranchWarningStore.getState().dismissed["task-1"]).toBe(true); + }); + + it("reset clears dismissed for a task", () => { + useBranchWarningStore.getState().dismiss("task-1"); + useBranchWarningStore.getState().reset("task-1"); + expect(useBranchWarningStore.getState().dismissed["task-1"]).toBe(false); + }); + + it("dismiss/reset are independent per task", () => { + useBranchWarningStore.getState().dismiss("task-1"); + useBranchWarningStore.getState().dismiss("task-2"); + useBranchWarningStore.getState().reset("task-1"); + + expect(useBranchWarningStore.getState().dismissed["task-1"]).toBe(false); + expect(useBranchWarningStore.getState().dismissed["task-2"]).toBe(true); + }); +}); + +describe("useBranchMismatchGuard", () => { + beforeEach(() => { + useBranchWarningStore.setState({ dismissed: {} }); + mockUseWorkspace.mockReturnValue(null); + }); + + it("shouldWarn is false when no workspace", () => { + const { result } = renderHook(() => useBranchMismatchGuard("task-1")); + expect(result.current.shouldWarn).toBe(false); + }); + + it("shouldWarn is false when no linked branch", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "main", + linkedBranch: null, + }); + const { result } = renderHook(() => useBranchMismatchGuard("task-1")); + expect(result.current.shouldWarn).toBe(false); + }); + + it("shouldWarn is false when branches match", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "feat/foo", + linkedBranch: "feat/foo", + }); + const { result } = renderHook(() => useBranchMismatchGuard("task-1")); + expect(result.current.shouldWarn).toBe(false); + }); + + it("shouldWarn is true when branches mismatch", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "main", + linkedBranch: "feat/foo", + }); + const { result } = renderHook(() => useBranchMismatchGuard("task-1")); + + expect(result.current.shouldWarn).toBe(true); + expect(result.current.linkedBranch).toBe("feat/foo"); + expect(result.current.currentBranch).toBe("main"); + }); + + it("dismissWarning stops shouldWarn", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "main", + linkedBranch: "feat/foo", + }); + const { result } = renderHook(() => useBranchMismatchGuard("task-1")); + expect(result.current.shouldWarn).toBe(true); + + act(() => result.current.dismissWarning()); + + expect(result.current.shouldWarn).toBe(false); + }); + + it("shouldWarn resets when currentBranch changes", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "main", + linkedBranch: "feat/foo", + }); + const { result, rerender } = renderHook(() => + useBranchMismatchGuard("task-1"), + ); + + act(() => result.current.dismissWarning()); + expect(result.current.shouldWarn).toBe(false); + + // Simulate switching to a different (still mismatched) branch + mockUseWorkspace.mockReturnValue({ + branchName: "develop", + linkedBranch: "feat/foo", + }); + rerender(); + + expect(result.current.shouldWarn).toBe(true); + }); + + it("shouldWarn is false after switching to the linked branch", () => { + mockUseWorkspace.mockReturnValue({ + branchName: "main", + linkedBranch: "feat/foo", + }); + const { result, rerender } = renderHook(() => + useBranchMismatchGuard("task-1"), + ); + expect(result.current.shouldWarn).toBe(true); + + mockUseWorkspace.mockReturnValue({ + branchName: "feat/foo", + linkedBranch: "feat/foo", + }); + rerender(); + + expect(result.current.shouldWarn).toBe(false); + }); +}); diff --git a/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.ts b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.ts new file mode 100644 index 000000000..00950414e --- /dev/null +++ b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatch.ts @@ -0,0 +1,62 @@ +import { useCallback, useEffect, useRef } from "react"; +import { create } from "zustand"; +import { useWorkspace } from "./useWorkspace"; + +interface BranchWarningState { + dismissed: Record; + dismiss: (taskId: string) => void; + reset: (taskId: string) => void; +} + +export const useBranchWarningStore = create()((set) => ({ + dismissed: {}, + dismiss: (taskId) => + set((state) => ({ + dismissed: { ...state.dismissed, [taskId]: true }, + })), + reset: (taskId) => + set((state) => ({ + dismissed: { ...state.dismissed, [taskId]: false }, + })), +})); + +function useBranchMismatch(taskId: string) { + const workspace = useWorkspace(taskId); + const linkedBranch = workspace?.linkedBranch ?? null; + const currentBranch = workspace?.branchName ?? null; + const isMismatch = + !!linkedBranch && !!currentBranch && linkedBranch !== currentBranch; + + const branchWarningDismissed = useBranchWarningStore( + (s) => s.dismissed[taskId] ?? false, + ); + const reset = useBranchWarningStore((s) => s.reset); + + const prevBranchRef = useRef(currentBranch); + useEffect(() => { + if (prevBranchRef.current !== currentBranch) { + prevBranchRef.current = currentBranch; + reset(taskId); + } + }, [currentBranch, taskId, reset]); + + const shouldWarn = isMismatch && !branchWarningDismissed; + + return { + linkedBranch, + currentBranch, + isMismatch, + shouldWarn, + }; +} + +export function useBranchMismatchGuard(taskId: string) { + const { shouldWarn, linkedBranch, currentBranch } = useBranchMismatch(taskId); + const dismiss = useBranchWarningStore((s) => s.dismiss); + + const dismissWarning = useCallback(() => { + dismiss(taskId); + }, [dismiss, taskId]); + + return { shouldWarn, linkedBranch, currentBranch, dismissWarning }; +} diff --git a/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.test.ts b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.test.ts new file mode 100644 index 000000000..c212d53c7 --- /dev/null +++ b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.test.ts @@ -0,0 +1,240 @@ +import { act, renderHook } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +let mockShouldWarn = false; +const mockDismissWarning = vi.fn(); + +const mockGuard = vi.hoisted(() => ({ + useBranchMismatchGuard: vi.fn( + (): { + shouldWarn: boolean; + linkedBranch: string | null; + currentBranch: string | null; + dismissWarning: () => void; + } => ({ + shouldWarn: mockShouldWarn, + linkedBranch: "feat/foo", + currentBranch: "main", + dismissWarning: mockDismissWarning, + }), + ), +})); +vi.mock("@features/workspace/hooks/useBranchMismatch", () => mockGuard); + +vi.mock("@features/git-interaction/hooks/useGitQueries", () => ({ + useGitQueries: () => ({ hasChanges: false }), +})); + +vi.mock("@features/git-interaction/utils/gitCacheKeys", () => ({ + invalidateGitBranchQueries: vi.fn(), +})); + +let capturedMutationOptions: { + onSuccess?: () => void; + onError?: (e: Error) => void; +} = {}; +const mockMutate = vi.fn(); + +vi.mock("@renderer/trpc/client", () => ({ + useTRPC: () => ({ + git: { + checkoutBranch: { + mutationOptions: (opts: Record) => { + capturedMutationOptions = opts as typeof capturedMutationOptions; + return opts; + }, + }, + }, + }), +})); + +vi.mock("@tanstack/react-query", () => ({ + useMutation: () => ({ + mutate: mockMutate, + isPending: false, + }), +})); + +vi.mock("@utils/logger", () => ({ + logger: { scope: () => ({ error: vi.fn() }) }, +})); + +import { useBranchMismatchDialog } from "./useBranchMismatchDialog"; + +function renderDialog(overrides?: { shouldWarn?: boolean }) { + mockShouldWarn = overrides?.shouldWarn ?? false; + mockGuard.useBranchMismatchGuard.mockReturnValue({ + shouldWarn: mockShouldWarn, + linkedBranch: "feat/foo", + currentBranch: "main", + dismissWarning: mockDismissWarning, + }); + + const onSendPrompt = vi.fn(); + const hook = renderHook(() => + useBranchMismatchDialog({ + taskId: "task-1", + repoPath: "/repo", + onSendPrompt, + }), + ); + return { ...hook, onSendPrompt }; +} + +describe("useBranchMismatchDialog", () => { + beforeEach(() => { + vi.clearAllMocks(); + capturedMutationOptions = {}; + mockShouldWarn = false; + }); + + describe("handleBeforeSubmit", () => { + it("returns true when shouldWarn is false", () => { + const { result } = renderDialog({ shouldWarn: false }); + const clearEditor = vi.fn(); + + const allowed = result.current.handleBeforeSubmit("hello", clearEditor); + + expect(allowed).toBe(true); + expect(result.current.dialogProps?.open).toBeFalsy(); + }); + + it("returns false and opens dialog when shouldWarn is true", () => { + const { result } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + let allowed: boolean; + act(() => { + allowed = result.current.handleBeforeSubmit("hello", clearEditor); + }); + + expect(allowed!).toBe(false); + expect(result.current.dialogProps?.open).toBe(true); + expect(clearEditor).not.toHaveBeenCalled(); + }); + }); + + describe("handleContinue", () => { + it("sends the pending message and clears editor", () => { + const { result, onSendPrompt } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + act(() => { + result.current.handleBeforeSubmit("hello", clearEditor); + }); + + act(() => { + result.current.dialogProps?.onContinue(); + }); + + expect(onSendPrompt).toHaveBeenCalledWith("hello"); + expect(clearEditor).toHaveBeenCalled(); + expect(mockDismissWarning).toHaveBeenCalled(); + expect(result.current.dialogProps?.open).toBe(false); + }); + }); + + describe("handleCancel", () => { + it("clears state without sending or clearing editor", () => { + const { result, onSendPrompt } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + act(() => { + result.current.handleBeforeSubmit("hello", clearEditor); + }); + expect(result.current.dialogProps?.open).toBe(true); + + act(() => { + result.current.dialogProps?.onCancel(); + }); + + expect(onSendPrompt).not.toHaveBeenCalled(); + expect(clearEditor).not.toHaveBeenCalled(); + expect(result.current.dialogProps?.open).toBe(false); + }); + }); + + describe("handleSwitch", () => { + it("calls checkoutBranch mutation", () => { + const { result } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + act(() => { + result.current.handleBeforeSubmit("hello", clearEditor); + }); + + act(() => { + result.current.dialogProps?.onSwitch(); + }); + + expect(mockMutate).toHaveBeenCalledWith({ + directoryPath: "/repo", + branchName: "feat/foo", + }); + }); + + it("on success: sends pending message and clears editor", () => { + const { result, onSendPrompt } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + act(() => { + result.current.handleBeforeSubmit("hello", clearEditor); + }); + + act(() => { + result.current.dialogProps?.onSwitch(); + }); + + // Simulate mutation success + act(() => { + capturedMutationOptions.onSuccess?.(); + }); + + expect(onSendPrompt).toHaveBeenCalledWith("hello"); + expect(clearEditor).toHaveBeenCalled(); + expect(mockDismissWarning).toHaveBeenCalled(); + }); + + it("on error: shows error without sending message", () => { + const { result, onSendPrompt } = renderDialog({ shouldWarn: true }); + const clearEditor = vi.fn(); + + act(() => { + result.current.handleBeforeSubmit("hello", clearEditor); + }); + + act(() => { + result.current.dialogProps?.onSwitch(); + }); + + act(() => { + capturedMutationOptions.onError?.(new Error("dirty worktree")); + }); + + expect(onSendPrompt).not.toHaveBeenCalled(); + expect(clearEditor).not.toHaveBeenCalled(); + expect(result.current.dialogProps?.switchError).toBe("dirty worktree"); + }); + }); + + describe("dialogProps", () => { + it("is null when no linked branch", () => { + mockGuard.useBranchMismatchGuard.mockReturnValue({ + shouldWarn: false, + linkedBranch: null, + currentBranch: "main", + dismissWarning: mockDismissWarning, + }); + + const { result } = renderHook(() => + useBranchMismatchDialog({ + taskId: "task-1", + repoPath: "/repo", + onSendPrompt: vi.fn(), + }), + ); + + expect(result.current.dialogProps).toBeNull(); + }); + }); +}); diff --git a/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.ts b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.ts new file mode 100644 index 000000000..bd6e7d9c8 --- /dev/null +++ b/apps/code/src/renderer/features/workspace/hooks/useBranchMismatchDialog.ts @@ -0,0 +1,116 @@ +import { useGitQueries } from "@features/git-interaction/hooks/useGitQueries"; +import { invalidateGitBranchQueries } from "@features/git-interaction/utils/gitCacheKeys"; +import { useBranchMismatchGuard } from "@features/workspace/hooks/useBranchMismatch"; +import { useTRPC } from "@renderer/trpc/client"; +import { useMutation } from "@tanstack/react-query"; +import { logger } from "@utils/logger"; +import { useCallback, useRef, useState } from "react"; + +const log = logger.scope("branch-mismatch"); + +interface UseBranchMismatchDialogOptions { + taskId: string; + repoPath: string | null; + onSendPrompt: (text: string) => void; +} + +export function useBranchMismatchDialog({ + taskId, + repoPath, + onSendPrompt, +}: UseBranchMismatchDialogOptions) { + const { shouldWarn, linkedBranch, currentBranch, dismissWarning } = + useBranchMismatchGuard(taskId); + + // State drives dialog visibility (`open`), refs avoid stale closures in + // mutation callbacks (onSuccess / handleContinue) that capture at mount time. + const [pendingMessage, setPendingMessage] = useState(null); + const pendingMessageRef = useRef(null); + const pendingClearRef = useRef<(() => void) | null>(null); + const onSendPromptRef = useRef(onSendPrompt); + onSendPromptRef.current = onSendPrompt; + const [switchError, setSwitchError] = useState(null); + + const { hasChanges: hasUncommittedChanges } = useGitQueries( + repoPath ?? undefined, + ); + + const trpc = useTRPC(); + const { mutate: checkoutBranch, isPending: isSwitching } = useMutation( + trpc.git.checkoutBranch.mutationOptions({ + onSuccess: () => { + if (repoPath) invalidateGitBranchQueries(repoPath); + dismissWarning(); + pendingClearRef.current?.(); + pendingClearRef.current = null; + const message = pendingMessageRef.current; + if (message) onSendPromptRef.current(message); + setPendingMessage(null); + pendingMessageRef.current = null; + }, + onError: (error) => { + log.error("Failed to switch branch", error); + setSwitchError( + error instanceof Error ? error.message : "Failed to switch branch", + ); + }, + }), + ); + + const handleBeforeSubmit = useCallback( + (text: string, clearEditor: () => void): boolean => { + if (shouldWarn) { + setPendingMessage(text); + pendingMessageRef.current = text; + pendingClearRef.current = clearEditor; + return false; + } + return true; + }, + [shouldWarn], + ); + + const handleSwitch = useCallback(() => { + if (!linkedBranch || !repoPath) return; + setSwitchError(null); + checkoutBranch({ + directoryPath: repoPath, + branchName: linkedBranch, + }); + }, [linkedBranch, repoPath, checkoutBranch]); + + const handleContinue = useCallback(() => { + dismissWarning(); + pendingClearRef.current?.(); + pendingClearRef.current = null; + const message = pendingMessageRef.current; + if (message) onSendPromptRef.current(message); + setPendingMessage(null); + pendingMessageRef.current = null; + setSwitchError(null); + }, [dismissWarning]); + + const handleCancel = useCallback(() => { + setPendingMessage(null); + pendingMessageRef.current = null; + pendingClearRef.current = null; + setSwitchError(null); + }, []); + + const dialogProps = + linkedBranch && currentBranch + ? { + open: pendingMessage !== null, + linkedBranch, + currentBranch, + hasUncommittedChanges, + switchError, + onSwitch: handleSwitch, + onContinue: handleContinue, + onCancel: handleCancel, + isSwitching, + } + : null; + + return { handleBeforeSubmit, dialogProps }; +}