diff --git a/frontend/src/hooks/useTaskSubscription.ts b/frontend/src/hooks/useTaskSubscription.ts new file mode 100644 index 00000000..708ff0b2 --- /dev/null +++ b/frontend/src/hooks/useTaskSubscription.ts @@ -0,0 +1,181 @@ +import { useEffect, useRef, useState, useCallback } from 'react' +import { API_BASE, getTaskStatus } from '../api' + +export interface UseTaskSubscriptionOptions { + taskId: string + onStatus?: (status: string) => void + onPhase?: (phase: string) => void + onOutput?: (chunk: string) => void + pollingInterval?: number + maxReconnectAttempts?: number + reconnectBaseDelay?: number +} + +export interface UseTaskSubscriptionResult { + isConnected: boolean + isPolling: boolean + error: string | null +} + +export function useTaskSubscription({ + taskId, + onStatus, + onPhase, + onOutput, + pollingInterval = 5000, + maxReconnectAttempts = 5, + reconnectBaseDelay = 1000, +}: UseTaskSubscriptionOptions): UseTaskSubscriptionResult { + const [isConnected, setIsConnected] = useState(false) + const [isPolling, setIsPolling] = useState(false) + const [error, setError] = useState(null) + + const onStatusRef = useRef(onStatus) + const onPhaseRef = useRef(onPhase) + const onOutputRef = useRef(onOutput) + const esRef = useRef(null) + const pollIntervalRef = useRef | null>(null) + const reconnectAttemptRef = useRef(0) + const reconnectTimerRef = useRef | null>(null) + const lastStatusRef = useRef(null) + const seenOutputsRef = useRef>(new Set()) + const cleanupRef = useRef(false) + + onStatusRef.current = onStatus + onPhaseRef.current = onPhase + onOutputRef.current = onOutput + + const cleanupAll = useCallback(() => { + cleanupRef.current = true + if (esRef.current) { + esRef.current.close() + esRef.current = null + } + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current) + pollIntervalRef.current = null + } + if (reconnectTimerRef.current) { + clearTimeout(reconnectTimerRef.current) + reconnectTimerRef.current = null + } + }, []) + + const startPolling = useCallback(() => { + if (cleanupRef.current) return + setIsPolling(true) + setIsConnected(false) + pollIntervalRef.current = setInterval(async () => { + if (cleanupRef.current) return + try { + const data = await getTaskStatus(taskId) as { status?: string } + if (cleanupRef.current) return + if (data.status && data.status !== lastStatusRef.current) { + lastStatusRef.current = data.status + onStatusRef.current?.(data.status) + } + if (data.status && ['completed', 'failed', 'cancelled'].includes(data.status)) { + cleanupAll() + setIsPolling(false) + } + } catch { + } + }, pollingInterval) + }, [taskId, pollingInterval, cleanupAll]) + + const connectSSE = useCallback(() => { + if (cleanupRef.current) return + if (esRef.current) { + esRef.current.close() + esRef.current = null + } + + const url = `${API_BASE}/task/${taskId}/stream` + const es = new EventSource(url) + esRef.current = es + + es.addEventListener('status', (e: MessageEvent) => { + if (cleanupRef.current) return + try { + const data = JSON.parse(e.data) as { status: string; scan_phase?: string } + if (data.scan_phase) { + onPhaseRef.current?.(data.scan_phase) + } + if (data.status && data.status !== lastStatusRef.current) { + lastStatusRef.current = data.status + onStatusRef.current?.(data.status) + } + if (['completed', 'failed', 'cancelled'].includes(data.status)) { + cleanupAll() + setIsConnected(false) + setIsPolling(false) + } + } catch { + } + }) + + es.addEventListener('phase', (e: MessageEvent) => { + if (cleanupRef.current) return + try { + const data = JSON.parse(e.data) as { scan_phase: string } + if (data.scan_phase) { + onPhaseRef.current?.(data.scan_phase) + } + } catch { + } + }) + + es.addEventListener('output', (e: MessageEvent) => { + if (cleanupRef.current) return + try { + const data = JSON.parse(e.data) as { chunk: string } + if (data.chunk && !seenOutputsRef.current.has(data.chunk)) { + seenOutputsRef.current.add(data.chunk) + onOutputRef.current?.(data.chunk) + } + } catch { + } + }) + + es.onerror = () => { + if (cleanupRef.current) return + es.close() + esRef.current = null + setIsConnected(false) + setError('SSE connection lost') + + if (reconnectAttemptRef.current < maxReconnectAttempts) { + const delay = reconnectBaseDelay * Math.pow(2, reconnectAttemptRef.current) + reconnectAttemptRef.current++ + reconnectTimerRef.current = setTimeout(() => { + if (!cleanupRef.current) connectSSE() + }, delay) + } else { + startPolling() + } + } + + es.onopen = () => { + if (cleanupRef.current) return + reconnectAttemptRef.current = 0 + setIsConnected(true) + setIsPolling(false) + setError(null) + } + }, [taskId, maxReconnectAttempts, reconnectBaseDelay, cleanupAll, startPolling]) + + useEffect(() => { + cleanupRef.current = false + lastStatusRef.current = null + seenOutputsRef.current = new Set() + reconnectAttemptRef.current = 0 + + connectSSE() + + return () => { + cleanupAll() + } + }, [taskId, connectSSE, cleanupAll]) + + return { isConnected, isPolling, error } +} diff --git a/frontend/src/pages/TaskDetails.tsx b/frontend/src/pages/TaskDetails.tsx index 2389983a..9b2c881a 100644 --- a/frontend/src/pages/TaskDetails.tsx +++ b/frontend/src/pages/TaskDetails.tsx @@ -12,6 +12,7 @@ import { Refresh01Icon, } from '@hugeicons/core-free-icons' import { API_BASE, getPluginSchema, getTaskResult, getTaskStatus, PluginFieldSchema, PluginSchemaResponse, startTask } from '../api' +import { useTaskSubscription } from '../hooks/useTaskSubscription' import { routes, routePath } from '../routes' import { parseDateSafe, formatDateLong, formatLocaleTime } from '../utils/date' import { @@ -361,50 +362,23 @@ export default function TaskDetails() { useEffect(() => { loadTask() + }, [taskId]) - const es = new EventSource(`${API_BASE}/task/${taskId}/stream`) - - es.addEventListener('status', (e) => { - try { - const data = JSON.parse(e.data) - setTask((prev: Task | null) => prev ? { ...prev, status: data.status } : null) - if (data.scan_phase) { - setScanPhase(data.scan_phase) - } - if (['completed', 'failed', 'cancelled'].includes(data.status)) { - es.close() - loadTask() - } - } catch (err) { - console.error("Status stream error", err) - } - }) - - es.addEventListener('phase', (e) => { - try { - const data = JSON.parse(e.data) - setScanPhase(data.scan_phase) - } catch (err) { - console.error("Phase stream error", err) + useTaskSubscription({ + taskId: taskId!, + onStatus: (status) => { + setTask((prev: Task | null) => prev ? { ...prev, status } : null) + if (['completed', 'failed', 'cancelled'].includes(status)) { + loadTask() } - }) - - es.addEventListener('output', (e) => { - try { - const data = JSON.parse(e.data) - setRawOutput(prev => prev + data.chunk) - } catch (err) { - console.error("Output stream error", err) - } - }) - - es.onerror = (err) => { - console.error("EventSource error:", err) - es.close() - } - - return () => es.close() - }, [taskId]) + }, + onPhase: (phase) => { + setScanPhase(phase) + }, + onOutput: (chunk) => { + setRawOutput((prev) => prev + chunk) + }, + }) async function loadTask() { try { diff --git a/frontend/testing/unit/hooks/useTaskSubscription.test.ts b/frontend/testing/unit/hooks/useTaskSubscription.test.ts new file mode 100644 index 00000000..80349b36 --- /dev/null +++ b/frontend/testing/unit/hooks/useTaskSubscription.test.ts @@ -0,0 +1,197 @@ +import { render, act } from '@testing-library/react' +import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest' +import React from 'react' +import { useTaskSubscription } from '../../../src/hooks/useTaskSubscription' + +vi.mock('../../../src/api', () => ({ + API_BASE: 'http://localhost', + getTaskStatus: vi.fn().mockResolvedValue({ status: 'running' }), +})) + +import { getTaskStatus } from '../../../src/api' + +class MockEventSource { + static instances: MockEventSource[] = [] + onopen: (() => void) | null = null + onerror: ((err: Event) => void) | null = null + listeners: Map void> = new Map() + url: string + readyState = 0 + closeCount = 0 + constructor(url: string) { this.url = url; MockEventSource.instances.push(this) } + addEventListener(event: string, handler: (e: MessageEvent) => void) { this.listeners.set(event, handler) } + close() { this.closeCount++; const idx = MockEventSource.instances.indexOf(this); if (idx !== -1) MockEventSource.instances.splice(idx, 1) } + dispatchEvent(event: string, data: string) { const h = this.listeners.get(event); if (h) h(new MessageEvent(event, { data })) } + triggerOpen() { this.readyState = 1; this.onopen?.() } + triggerError() { this.onerror?.(new Event('error')) } + static reset() { MockEventSource.instances = [] } +} + +function renderHook(props: { taskId: string; onStatus?: (s: string) => void; onOutput?: (c: string) => void; pollingInterval?: number; maxReconnectAttempts?: number }) { + const Comp = () => { useTaskSubscription(props); return null } + return render(React.createElement(Comp)) +} + +function getES() { return MockEventSource.instances[0] } + +beforeEach(() => { + MockEventSource.reset() + vi.stubGlobal('EventSource', MockEventSource as any) + vi.useFakeTimers() + vi.mocked(getTaskStatus).mockReset() + vi.mocked(getTaskStatus).mockResolvedValue({ status: 'running' }) +}) + +afterEach(() => { + vi.useRealTimers() + vi.unstubAllGlobals() +}) + +async function flush() { await act(async () => { await Promise.resolve(); await Promise.resolve(); await Promise.resolve() }) } +async function tickTime(ms: number) { await act(async () => { vi.advanceTimersByTime(ms); await Promise.resolve(); await Promise.resolve(); await Promise.resolve() }) } + +describe('useTaskSubscription', () => { + it('connects to SSE on mount', async () => { + renderHook({ taskId: 'task-1' }) + await flush() + const es = getES() + expect(es).toBeTruthy() + expect(es!.url).toContain('/task/task-1/stream') + }) + + it('calls onStatus on status event', async () => { + const onStatus = vi.fn() + renderHook({ taskId: 'task-1', onStatus }) + await flush() + getES()!.triggerOpen() + getES()!.dispatchEvent('status', JSON.stringify({ status: 'running' })) + expect(onStatus).toHaveBeenCalledWith('running') + }) + + it('deduplicates same status value', async () => { + const onStatus = vi.fn() + renderHook({ taskId: 'task-1', onStatus }) + await flush() + const es = getES()! + es.triggerOpen() + es.dispatchEvent('status', JSON.stringify({ status: 'running' })) + es.dispatchEvent('status', JSON.stringify({ status: 'running' })) + es.dispatchEvent('status', JSON.stringify({ status: 'running' })) + expect(onStatus).toHaveBeenCalledTimes(1) + }) + + it('does not deduplicate different status values', async () => { + const onStatus = vi.fn() + renderHook({ taskId: 'task-1', onStatus }) + await flush() + const es = getES()! + es.triggerOpen() + es.dispatchEvent('status', JSON.stringify({ status: 'queued' })) + es.dispatchEvent('status', JSON.stringify({ status: 'running' })) + es.dispatchEvent('status', JSON.stringify({ status: 'completed' })) + expect(onStatus).toHaveBeenCalledTimes(3) + expect(onStatus).toHaveBeenNthCalledWith(1, 'queued') + expect(onStatus).toHaveBeenNthCalledWith(2, 'running') + expect(onStatus).toHaveBeenNthCalledWith(3, 'completed') + }) + + it('calls onOutput on output event', async () => { + const onOutput = vi.fn() + renderHook({ taskId: 'task-1', onOutput }) + await flush() + getES()!.triggerOpen() + getES()!.dispatchEvent('output', JSON.stringify({ chunk: 'line1\n' })) + expect(onOutput).toHaveBeenCalledWith('line1\n') + }) + + it('deduplicates identical output chunks', async () => { + const onOutput = vi.fn() + renderHook({ taskId: 'task-1', onOutput }) + await flush() + const es = getES()! + es.triggerOpen() + es.dispatchEvent('output', JSON.stringify({ chunk: 'line1\n' })) + es.dispatchEvent('output', JSON.stringify({ chunk: 'line1\n' })) + expect(onOutput).toHaveBeenCalledTimes(1) + }) + + it('falls back to polling after SSE max reconnect attempts', async () => { + renderHook({ taskId: 'task-1', pollingInterval: 50, maxReconnectAttempts: 3 }) + await flush() + + for (let i = 0; i < 4; i++) { + await act(() => { getES()!.triggerError() }) + await tickTime(1000 * Math.pow(2, Math.min(i, 3))) + } + + await tickTime(50) + expect(getTaskStatus).toHaveBeenCalled() + }) + + it('polls getTaskStatus at the configured interval', async () => { + renderHook({ taskId: 'task-1', pollingInterval: 50, maxReconnectAttempts: 0 }) + await flush() + + const es = getES()! + await act(() => { es.triggerError() }) + + await tickTime(50) + expect(getTaskStatus).toHaveBeenCalledTimes(1) + + await tickTime(50) + expect(getTaskStatus).toHaveBeenCalledTimes(2) + }) + + it('stops polling on terminal status', async () => { + let resolveGetTaskStatus: (value: unknown) => void + vi.mocked(getTaskStatus).mockImplementation(() => new Promise(resolve => { + resolveGetTaskStatus = resolve + })) + + renderHook({ taskId: 'task-1', pollingInterval: 50, maxReconnectAttempts: 0 }) + await flush() + + await act(() => { getES()!.triggerError() }) + + await tickTime(50) + // Resolve with terminal status so cleanupAll() clears the interval + await act(async () => { resolveGetTaskStatus({ status: 'completed' }); await Promise.resolve(); await Promise.resolve() }) + + const callsAfterCleanup = vi.mocked(getTaskStatus).mock.calls.length + + await tickTime(200) + expect(vi.mocked(getTaskStatus).mock.calls.length).toBe(callsAfterCleanup) + }) + + it('stops SSE on terminal status event', async () => { + const onStatus = vi.fn() + renderHook({ taskId: 'task-1', onStatus }) + await flush() + const es = getES()! + es.triggerOpen() + es.dispatchEvent('status', JSON.stringify({ status: 'completed' })) + expect(onStatus).toHaveBeenCalledWith('completed') + }) + + it('cleans up EventSource on unmount', async () => { + const { unmount } = renderHook({ taskId: 'task-1', pollingInterval: 50, maxReconnectAttempts: 0 }) + await flush() + + const es = getES()! + const closeSpy = vi.spyOn(es, 'close') + unmount() + expect(closeSpy).toHaveBeenCalled() + }) + + it('polls and calls onStatus with new statuses', async () => { + const onStatus = vi.fn() + renderHook({ taskId: 'task-1', onStatus, pollingInterval: 50, maxReconnectAttempts: 0 }) + await flush() + + const es = getES()! + await act(() => { es.triggerError() }) + + await tickTime(50) + expect(onStatus).toHaveBeenCalledWith('running') + }) +})