diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 815a3bcd..fdeef963 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ on: jobs: quality: name: Format, Lint, Typecheck, Test, Browser Test, Build - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-24.04 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/README.md b/README.md index 4a711cf6..42114e57 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # T3 Code + Copilot -This repo is a T3 Code fork that stays up to date with upstream and adds GitHub Copilot support. +This repo is a T3 Code fork that stays up to date with upstream and adds GitHub Copilot plus browser-local WebGPU support. -T3 Code is a minimal web GUI for coding agents. This fork supports both Codex and GitHub Copilot. +T3 Code is a minimal web GUI for coding agents. This fork supports Codex, GitHub Copilot, and a browser-side Local WebGPU adapter powered by Hugging Face Transformers.js. ## Preview @@ -13,17 +13,18 @@ T3 Code is a minimal web GUI for coding agents. This fork supports both Codex an - tracks upstream `pingdotgg/t3code` - adds GitHub Copilot provider support +- adds a browser-side Local WebGPU provider for curated Hugging Face ONNX models - keeps Codex support working too ## How to use > [!WARNING] -> You need to have either [Codex CLI](https://github.com/openai/codex) or GitHub Copilot available and authorized for T3 Code to work. +> You need either [Codex CLI](https://github.com/openai/codex), GitHub Copilot, or a WebGPU-capable browser for the Local WebGPU adapter. The easiest way to use this fork is the desktop app. - Download it from the [releases page](https://github.com/zortos293/t3code-copilot/releases) -- Launch the app and choose either `Codex` or `GitHub Copilot` +- Launch the app and choose `Codex`, `GitHub Copilot`, or `Local WebGPU` You can also run it from source: @@ -34,6 +35,13 @@ bun run dev Open the app, connect your provider, and start chatting. +### Local WebGPU notes + +- Local WebGPU runs entirely in the browser and does not use the server provider runtime. +- The first run downloads model files from Hugging Face/CDN endpoints and may take a while. +- Use the Settings browser to search compatible Hugging Face models, then start with the smaller instruct variants for the best chance of fitting browser memory limits. +- WebGPU availability depends on browser, OS, and GPU support. Recent Chromium-based browsers work best today. + ## Some notes We are very very early in this project. Expect bugs. diff --git a/apps/server/src/huggingFaceModelSearch.test.ts b/apps/server/src/huggingFaceModelSearch.test.ts new file mode 100644 index 00000000..1c4ca689 --- /dev/null +++ b/apps/server/src/huggingFaceModelSearch.test.ts @@ -0,0 +1,110 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +import { + clearHuggingFaceModelSearchCache, + searchHuggingFaceModels, +} from "./huggingFaceModelSearch"; + +describe("huggingFaceModelSearch", () => { + afterEach(() => { + clearHuggingFaceModelSearchCache(); + vi.restoreAllMocks(); + }); + + it("returns featured recommended models and caches repeated requests", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify([ + { + id: "onnx-community/Qwen2.5-1.5B-Instruct", + likes: 12, + downloads: 845, + private: false, + tags: ["transformers.js", "onnx", "text-generation", "license:apache-2.0"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + { + id: "community/SmallLM-Instruct", + likes: 4, + downloads: 120, + private: false, + tags: ["transformers.js", "text-generation"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + { + id: "onnx-community/Skip-Me", + likes: 1, + downloads: 2, + private: false, + tags: ["transformers.js", "custom_code"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + ]), + { status: 200, headers: { "content-type": "application/json" } }, + ), + ); + + const first = await searchHuggingFaceModels({ limit: 5 }); + const second = await searchHuggingFaceModels({ limit: 5 }); + + expect(fetchSpy).toHaveBeenCalledTimes(1); + expect(first).toEqual({ + mode: "featured", + models: [ + { + id: "onnx-community/Qwen2.5-1.5B-Instruct", + author: "onnx-community", + name: "Qwen2.5-1.5B-Instruct", + downloads: 845, + likes: 12, + pipelineTag: "text-generation", + libraryName: "transformers.js", + license: "apache-2.0", + compatibility: "recommended", + }, + ], + truncated: false, + }); + expect(second).toEqual(first); + }); + + it("returns compatible search results for explicit queries", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify([ + { + id: "community/Phi-lite-Instruct", + likes: 14, + downloads: 912, + private: false, + tags: ["transformers.js", "text-generation", "license:mit"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + { + id: "onnx-community/Phi-3.5-mini-instruct-onnx-web", + likes: 36, + downloads: 1500, + private: false, + tags: ["transformers.js", "onnx", "text-generation", "license:mit"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + ]), + { status: 200, headers: { "content-type": "application/json" } }, + ), + ); + + const result = await searchHuggingFaceModels({ query: "phi", limit: 10 }); + + expect(result.mode).toBe("search"); + expect(result.query).toBe("phi"); + expect(result.models.map((model) => model.id)).toEqual([ + "onnx-community/Phi-3.5-mini-instruct-onnx-web", + "community/Phi-lite-Instruct", + ]); + }); +}); diff --git a/apps/server/src/huggingFaceModelSearch.ts b/apps/server/src/huggingFaceModelSearch.ts new file mode 100644 index 00000000..455dacb6 --- /dev/null +++ b/apps/server/src/huggingFaceModelSearch.ts @@ -0,0 +1,264 @@ +import type { + ServerHuggingFaceModel, + ServerHuggingFaceModelSearchInput, + ServerHuggingFaceModelSearchResult, +} from "@t3tools/contracts"; + +const HUGGING_FACE_MODELS_API_URL = "https://huggingface.co/api/models"; +const HUGGING_FACE_SEARCH_CACHE_TTL_MS = 60_000; +const HUGGING_FACE_SEARCH_TIMEOUT_MS = 8_000; +const HUGGING_FACE_FETCH_LIMIT_MULTIPLIER = 4; +const HUGGING_FACE_FETCH_LIMIT_MIN = 24; +const HUGGING_FACE_FETCH_LIMIT_MAX = 96; + +type HuggingFaceApiModel = { + id?: unknown; + modelId?: unknown; + likes?: unknown; + downloads?: unknown; + private?: unknown; + tags?: unknown; + pipeline_tag?: unknown; + library_name?: unknown; +}; + +interface SearchCacheEntry { + expiresAt: number; + result: ServerHuggingFaceModelSearchResult; +} + +const searchCache = new Map(); +const inFlightSearches = new Map>(); + +function normalizeSearchQuery(input: string | undefined): string | undefined { + const normalized = input?.trim(); + return normalized && normalized.length > 0 ? normalized : undefined; +} + +function searchCacheKey(input: ServerHuggingFaceModelSearchInput): string { + return `${normalizeSearchQuery(input.query)?.toLowerCase() ?? ""}::${input.limit ?? ""}`; +} + +function clampFetchLimit(limit: number): number { + return Math.max( + HUGGING_FACE_FETCH_LIMIT_MIN, + Math.min(HUGGING_FACE_FETCH_LIMIT_MAX, limit * HUGGING_FACE_FETCH_LIMIT_MULTIPLIER), + ); +} + +function coerceNonNegativeInt(value: unknown): number { + const parsed = Number(value); + if (!Number.isFinite(parsed) || parsed <= 0) { + return 0; + } + return Math.trunc(parsed); +} + +function toStringArray(value: unknown): string[] { + if (!Array.isArray(value)) { + return []; + } + return value.flatMap((entry) => { + if (typeof entry !== "string") { + return []; + } + const normalized = entry.trim(); + return normalized.length > 0 ? [normalized] : []; + }); +} + +function hasTransformersJsSupport(model: HuggingFaceApiModel, tags: readonly string[]): boolean { + return model.library_name === "transformers.js" || tags.includes("transformers.js"); +} + +function readLicense(tags: readonly string[]): string | undefined { + const licenseTag = tags.find((tag) => tag.startsWith("license:")); + const license = licenseTag?.slice("license:".length).trim(); + return license && license.length > 0 ? license : undefined; +} + +function normalizeHuggingFaceModel(raw: unknown): ServerHuggingFaceModel | null { + if (typeof raw !== "object" || raw === null) { + return null; + } + const model = raw as HuggingFaceApiModel; + const id = typeof model.id === "string" ? model.id.trim() : ""; + const fallbackId = typeof model.modelId === "string" ? model.modelId.trim() : ""; + const resolvedId = id || fallbackId; + if (!resolvedId || model.private === true) { + return null; + } + + const tags = toStringArray(model.tags); + if ( + model.pipeline_tag !== "text-generation" || + !hasTransformersJsSupport(model, tags) || + tags.includes("custom_code") + ) { + return null; + } + + const slashIndex = resolvedId.indexOf("/"); + const author = + slashIndex > 0 ? resolvedId.slice(0, slashIndex).trim() : "huggingface"; + const name = + slashIndex > 0 ? resolvedId.slice(slashIndex + 1).trim() : resolvedId; + if (!author || !name) { + return null; + } + + return { + id: resolvedId, + author, + name, + downloads: coerceNonNegativeInt(model.downloads), + likes: coerceNonNegativeInt(model.likes), + pipelineTag: "text-generation", + ...(typeof model.library_name === "string" && model.library_name.trim().length > 0 + ? { libraryName: model.library_name.trim() } + : {}), + ...(readLicense(tags) ? { license: readLicense(tags) } : {}), + compatibility: resolvedId.startsWith("onnx-community/") ? "recommended" : "community", + }; +} + +function scoreModelMatch(model: ServerHuggingFaceModel, query: string | undefined): number { + if (!query) { + return 6; + } + const normalizedQuery = query.toLowerCase(); + const id = model.id.toLowerCase(); + const name = model.name.toLowerCase(); + if (id === normalizedQuery) return 0; + if (name === normalizedQuery) return 1; + if (id.startsWith(normalizedQuery)) return 2; + if (name.startsWith(normalizedQuery)) return 3; + if (id.includes(normalizedQuery)) return 4; + if (name.includes(normalizedQuery)) return 5; + return 6; +} + +function compareModels( + left: ServerHuggingFaceModel, + right: ServerHuggingFaceModel, + query: string | undefined, +): number { + const matchDifference = scoreModelMatch(left, query) - scoreModelMatch(right, query); + if (matchDifference !== 0) { + return matchDifference; + } + if (left.compatibility !== right.compatibility) { + return left.compatibility === "recommended" ? -1 : 1; + } + const leftInstruct = left.name.toLowerCase().includes("instruct"); + const rightInstruct = right.name.toLowerCase().includes("instruct"); + if (leftInstruct !== rightInstruct) { + return leftInstruct ? -1 : 1; + } + const downloadDifference = right.downloads - left.downloads; + if (downloadDifference !== 0) { + return downloadDifference; + } + const likesDifference = right.likes - left.likes; + if (likesDifference !== 0) { + return likesDifference; + } + return left.id.localeCompare(right.id); +} + +function dedupeModels(models: readonly ServerHuggingFaceModel[]): ServerHuggingFaceModel[] { + const byId = new Map(); + for (const model of models) { + byId.set(model.id, model); + } + return Array.from(byId.values()); +} + +async function requestHuggingFaceModels( + input: ServerHuggingFaceModelSearchInput, +): Promise { + const limit = input.limit ?? 12; + const query = normalizeSearchQuery(input.query); + const mode = query ? "search" : "featured"; + const params = new URLSearchParams({ + limit: String(clampFetchLimit(limit)), + }); + + if (query) { + params.set("search", query); + } else { + params.set("author", "onnx-community"); + params.set("search", "Instruct"); + } + + const response = await fetch(`${HUGGING_FACE_MODELS_API_URL}?${params.toString()}`, { + headers: { + accept: "application/json", + }, + signal: AbortSignal.timeout(HUGGING_FACE_SEARCH_TIMEOUT_MS), + }); + if (!response.ok) { + throw new Error(`Hugging Face search failed (${response.status} ${response.statusText}).`); + } + + const payload = await response.json(); + if (!Array.isArray(payload)) { + throw new Error("Hugging Face search returned an unexpected response."); + } + + const normalizedModels = dedupeModels( + payload.flatMap((entry) => { + const normalized = normalizeHuggingFaceModel(entry); + return normalized ? [normalized] : []; + }), + ).toSorted((left, right) => compareModels(left, right, query)); + + const featuredModels = + mode === "featured" + ? normalizedModels.filter((model) => model.compatibility === "recommended") + : normalizedModels; + const models = (featuredModels.length > 0 ? featuredModels : normalizedModels).slice(0, limit); + + return { + mode, + ...(query ? { query } : {}), + models, + truncated: (featuredModels.length > 0 ? featuredModels : normalizedModels).length > limit, + }; +} + +export function clearHuggingFaceModelSearchCache(): void { + searchCache.clear(); + inFlightSearches.clear(); +} + +export async function searchHuggingFaceModels( + input: ServerHuggingFaceModelSearchInput, +): Promise { + const key = searchCacheKey(input); + const now = Date.now(); + const cached = searchCache.get(key); + if (cached && cached.expiresAt > now) { + return cached.result; + } + + const inFlight = inFlightSearches.get(key); + if (inFlight) { + return inFlight; + } + + const request = requestHuggingFaceModels(input) + .then((result) => { + searchCache.set(key, { + expiresAt: Date.now() + HUGGING_FACE_SEARCH_CACHE_TTL_MS, + result, + }); + return result; + }) + .finally(() => { + inFlightSearches.delete(key); + }); + + inFlightSearches.set(key, request); + return request; +} diff --git a/apps/server/src/orchestration/Layers/CheckpointReactor.test.ts b/apps/server/src/orchestration/Layers/CheckpointReactor.test.ts index eecfc069..5bebd707 100644 --- a/apps/server/src/orchestration/Layers/CheckpointReactor.test.ts +++ b/apps/server/src/orchestration/Layers/CheckpointReactor.test.ts @@ -43,7 +43,7 @@ const asTurnId = (value: string): TurnId => TurnId.makeUnsafe(value); type LegacyProviderRuntimeEvent = { readonly type: string; readonly eventId: EventId; - readonly provider: "codex"; + readonly provider: ProviderRuntimeEvent["provider"]; readonly createdAt: string; readonly threadId: ThreadId; readonly turnId?: string | undefined; diff --git a/apps/server/src/orchestration/Layers/ProviderRuntimeIngestion.test.ts b/apps/server/src/orchestration/Layers/ProviderRuntimeIngestion.test.ts index ae671b09..7473ee88 100644 --- a/apps/server/src/orchestration/Layers/ProviderRuntimeIngestion.test.ts +++ b/apps/server/src/orchestration/Layers/ProviderRuntimeIngestion.test.ts @@ -45,7 +45,7 @@ const asTurnId = (value: string): TurnId => TurnId.makeUnsafe(value); type LegacyProviderRuntimeEvent = { readonly type: string; readonly eventId: EventId; - readonly provider: "codex" | "copilot"; + readonly provider: ProviderRuntimeEvent["provider"]; readonly createdAt: string; readonly threadId: ThreadId; readonly turnId?: string | undefined; diff --git a/apps/server/src/provider/Layers/ProviderService.test.ts b/apps/server/src/provider/Layers/ProviderService.test.ts index 178f8691..799cee23 100644 --- a/apps/server/src/provider/Layers/ProviderService.test.ts +++ b/apps/server/src/provider/Layers/ProviderService.test.ts @@ -52,7 +52,7 @@ const asTurnId = (value: string): TurnId => TurnId.makeUnsafe(value); type LegacyProviderRuntimeEvent = { readonly type: string; readonly eventId: EventId; - readonly provider: "codex"; + readonly provider: ProviderRuntimeEvent["provider"]; readonly createdAt: string; readonly threadId: ThreadId; readonly turnId?: string | undefined; diff --git a/apps/server/src/wsServer.test.ts b/apps/server/src/wsServer.test.ts index ecb4c09e..d5a2fafb 100644 --- a/apps/server/src/wsServer.test.ts +++ b/apps/server/src/wsServer.test.ts @@ -761,6 +761,76 @@ describe("WebSocket Server", () => { expectAvailableEditors((response.result as { availableEditors: unknown }).availableEditors); }); + it("responds to server.searchHuggingFaceModels", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify([ + { + id: "onnx-community/Qwen2.5-Coder-1.5B-Instruct", + likes: 7, + downloads: 104, + private: false, + tags: ["transformers.js", "onnx", "text-generation", "license:apache-2.0"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + { + id: "community/Qwen-helper", + likes: 1, + downloads: 12, + private: false, + tags: ["transformers.js", "text-generation"], + pipeline_tag: "text-generation", + library_name: "transformers.js", + }, + ]), + { status: 200, headers: { "content-type": "application/json" } }, + ), + ); + + server = await createTestServer({ cwd: "/my/workspace" }); + const addr = server.address(); + const port = typeof addr === "object" && addr !== null ? addr.port : 0; + + const ws = await connectWs(port); + connections.push(ws); + await waitForMessage(ws); + + const response = await sendRequest(ws, WS_METHODS.serverSearchHuggingFaceModels, { + query: "qwen coder", + limit: 5, + }); + expect(response.error).toBeUndefined(); + expect(response.result).toEqual({ + mode: "search", + query: "qwen coder", + models: [ + { + id: "onnx-community/Qwen2.5-Coder-1.5B-Instruct", + author: "onnx-community", + name: "Qwen2.5-Coder-1.5B-Instruct", + downloads: 104, + likes: 7, + pipelineTag: "text-generation", + libraryName: "transformers.js", + license: "apache-2.0", + compatibility: "recommended", + }, + { + id: "community/Qwen-helper", + author: "community", + name: "Qwen-helper", + downloads: 12, + likes: 1, + pipelineTag: "text-generation", + libraryName: "transformers.js", + compatibility: "community", + }, + ], + truncated: false, + }); + }); + it("bootstraps default keybindings file when missing", async () => { const stateDir = makeTempDir("t3code-state-bootstrap-keybindings-"); const keybindingsPath = path.join(stateDir, "keybindings.json"); diff --git a/apps/server/src/wsServer.ts b/apps/server/src/wsServer.ts index 1bd1c02c..ac83f54b 100644 --- a/apps/server/src/wsServer.ts +++ b/apps/server/src/wsServer.ts @@ -74,6 +74,7 @@ import { import { parseBase64DataUrl } from "./imageMime.ts"; import { AnalyticsService } from "./telemetry/Services/AnalyticsService.ts"; import { expandHomePath } from "./os-jank.ts"; +import { searchHuggingFaceModels } from "./huggingFaceModelSearch"; /** * ServerShape - Service API for server lifecycle control. @@ -917,6 +918,20 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return< return { keybindings: keybindingsConfig, issues: [] }; } + case WS_METHODS.serverSearchHuggingFaceModels: { + const body = stripRequestTag(request.body); + return yield* Effect.tryPromise({ + try: () => searchHuggingFaceModels(body), + catch: (cause) => + new RouteRequestError({ + message: + cause instanceof Error + ? cause.message + : "Failed to search Hugging Face models.", + }), + }); + } + default: { const _exhaustiveCheck: never = request.body; return yield* new RouteRequestError({ diff --git a/apps/web/src/appSettings.ts b/apps/web/src/appSettings.ts index 70f2a5b5..6708aca2 100644 --- a/apps/web/src/appSettings.ts +++ b/apps/web/src/appSettings.ts @@ -1,6 +1,6 @@ import { useCallback, useSyncExternalStore } from "react"; import { Option, Schema } from "effect"; -import { type ProviderKind } from "@t3tools/contracts"; +import { WEBGPU_DTYPE_OPTIONS, type ProviderKind, type WebGpuModelDtype } from "@t3tools/contracts"; import { getDefaultModel, getModelOptions, normalizeModelSlug } from "@t3tools/shared/model"; const APP_SETTINGS_STORAGE_KEY = "t3code:app-settings:v1"; @@ -9,6 +9,7 @@ export const MAX_CUSTOM_MODEL_LENGTH = 256; const BUILT_IN_MODEL_SLUGS_BY_PROVIDER: Record> = { codex: new Set(getModelOptions("codex").map((option) => option.slug)), copilot: new Set(getModelOptions("copilot").map((option) => option.slug)), + webgpu: new Set(getModelOptions("webgpu").map((option) => option.slug)), }; const AppSettingsSchema = Schema.Struct({ @@ -34,6 +35,16 @@ const AppSettingsSchema = Schema.Struct({ customCopilotModels: Schema.Array(Schema.String).pipe( Schema.withConstructorDefault(() => Option.some([])), ), + webGpuEnabled: Schema.Boolean.pipe(Schema.withConstructorDefault(() => Option.some(true))), + webGpuDefaultModel: Schema.String.pipe( + Schema.withConstructorDefault(() => Option.some(getDefaultModel("webgpu"))), + ), + webGpuPreferredDtype: Schema.Literals(WEBGPU_DTYPE_OPTIONS).pipe( + Schema.withConstructorDefault(() => Option.some("q4" satisfies WebGpuModelDtype)), + ), + customWebGpuModels: Schema.Array(Schema.String).pipe( + Schema.withConstructorDefault(() => Option.some([])), + ), }); export type AppSettings = typeof AppSettingsSchema.Type; export interface AppModelOption { @@ -86,10 +97,17 @@ export function normalizeCustomModelSlugs( } function normalizeAppSettings(settings: AppSettings): AppSettings { + const customWebGpuModels = normalizeCustomModelSlugs(settings.customWebGpuModels, "webgpu"); return { ...settings, customCodexModels: normalizeCustomModelSlugs(settings.customCodexModels, "codex"), customCopilotModels: normalizeCustomModelSlugs(settings.customCopilotModels, "copilot"), + customWebGpuModels, + webGpuDefaultModel: resolveAppModelSelection( + "webgpu", + customWebGpuModels, + settings.webGpuDefaultModel, + ), }; } diff --git a/apps/web/src/components/ChatView.tsx b/apps/web/src/components/ChatView.tsx index 33e1e75d..2eb2706c 100644 --- a/apps/web/src/components/ChatView.tsx +++ b/apps/web/src/components/ChatView.tsx @@ -41,6 +41,7 @@ import { useRef, useState, useId, + useSyncExternalStore, } from "react"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useDebouncedValue } from "@tanstack/react-pacer"; @@ -186,6 +187,7 @@ import { CursorIcon, Gemini, GitHubIcon, + HuggingFaceIcon, Icon, OpenAI, OpenCodeIcon, @@ -219,6 +221,10 @@ import { Toggle } from "./ui/toggle"; import { SidebarTrigger } from "./ui/sidebar"; import { newCommandId, newMessageId, newThreadId } from "~/lib/utils"; import { readNativeApi } from "~/nativeApi"; +import { + getLocalWebGpuStatusSnapshot, + subscribeLocalWebGpuStatus, +} from "../localWebGpuOrchestration"; import { getAppModelOptions, resolveAppModelSelection, @@ -1075,6 +1081,9 @@ export default function ChatView({ threadId }: ChatViewProps) { const clearProjectDraftThreadId = useComposerDraftStore( (store) => store.clearProjectDraftThreadId, ); + const clearProjectDraftThreadById = useComposerDraftStore( + (store) => store.clearProjectDraftThreadById, + ); const draftThread = useComposerDraftStore( (store) => store.draftThreadsByThreadId[threadId] ?? null, ); @@ -1329,6 +1338,11 @@ export default function ChatView({ threadId }: ChatViewProps) { ? (sessionProvider ?? selectedProviderByThreadId ?? null) : null; const selectedProvider: ProviderKind = lockedProvider ?? selectedProviderByThreadId ?? "codex"; + const localWebGpuStatus = useSyncExternalStore( + subscribeLocalWebGpuStatus, + getLocalWebGpuStatusSnapshot, + getLocalWebGpuStatusSnapshot, + ); const providerStatuses = serverConfigQuery.data?.providers ?? EMPTY_PROVIDER_STATUSES; const copilotProviderStatus = providerStatuses.find((status) => status.provider === "copilot") ?? null; @@ -1346,30 +1360,41 @@ export default function ChatView({ threadId }: ChatViewProps) { copilotProviderModels.length > 0 ? copilotProviderModels.map((model) => ({ slug: model.id, name: model.name })) : getModelOptions("copilot"), + webgpu: getModelOptions("webgpu"), }), [copilotProviderModels], ); + const customModelsByProvider = useMemo>( + () => ({ + codex: settings.customCodexModels, + copilot: settings.customCopilotModels, + webgpu: settings.customWebGpuModels, + }), + [settings.customCodexModels, settings.customCopilotModels, settings.customWebGpuModels], + ); const defaultModelByProvider = useMemo>( () => ({ codex: getDefaultModel("codex"), copilot: builtInModelOptionsByProvider.copilot[0]?.slug ?? getDefaultModel("copilot"), + webgpu: settings.webGpuDefaultModel, }), - [builtInModelOptionsByProvider], + [builtInModelOptionsByProvider, settings.webGpuDefaultModel], ); const baseThreadModel = - selectedProvider === "copilot" + selectedProvider === "copilot" || selectedProvider === "webgpu" ? resolveAppModelSelection( - "copilot", - settings.customCopilotModels, - activeThread?.model ?? activeProject?.model ?? defaultModelByProvider.copilot, - builtInModelOptionsByProvider.copilot, + selectedProvider, + customModelsByProvider[selectedProvider], + activeThread?.model ?? + (selectedProvider === "webgpu" ? defaultModelByProvider.webgpu : activeProject?.model) ?? + defaultModelByProvider[selectedProvider], + builtInModelOptionsByProvider[selectedProvider], ) : resolveModelSlugForProvider( selectedProvider, activeThread?.model ?? activeProject?.model ?? defaultModelByProvider[selectedProvider], ); - const customModelsForSelectedProvider = - selectedProvider === "copilot" ? settings.customCopilotModels : settings.customCodexModels; + const customModelsForSelectedProvider = customModelsByProvider[selectedProvider]; const selectedModel = useMemo(() => { const draftModel = composerDraft.model; if (!draftModel) { @@ -1418,9 +1443,28 @@ export default function ChatView({ threadId }: ChatViewProps) { if (selectedProvider === "copilot" && supportsReasoningEffort && selectedEffort) { return { copilot: { reasoningEffort: selectedEffort } }; } + if (selectedProvider === "webgpu") { + return { + webgpu: { + dtype: settings.webGpuPreferredDtype, + maxTokens: 384, + temperature: 0.7, + topP: 0.95, + }, + }; + } return undefined; - }, [selectedCodexFastModeEnabled, selectedEffort, selectedProvider, supportsReasoningEffort]); + }, [ + selectedCodexFastModeEnabled, + selectedEffort, + selectedProvider, + settings.webGpuPreferredDtype, + supportsReasoningEffort, + ]); const providerOptionsForDispatch = useMemo(() => { + if (selectedProvider !== "codex") { + return undefined; + } if (!settings.codexBinaryPath && !settings.codexHomePath) { return undefined; } @@ -1430,12 +1474,43 @@ export default function ChatView({ threadId }: ChatViewProps) { ...(settings.codexHomePath ? { homePath: settings.codexHomePath } : {}), }, }; - }, [settings.codexBinaryPath, settings.codexHomePath]); + }, [selectedProvider, settings.codexBinaryPath, settings.codexHomePath]); const selectedModelForPicker = selectedModel; const modelOptionsByProvider = useMemo( () => getCustomModelOptionsByProvider(settings, builtInModelOptionsByProvider.copilot), [builtInModelOptionsByProvider.copilot, settings], ); + const localWebGpuUnavailableMessage = + selectedProvider !== "webgpu" + ? null + : !settings.webGpuEnabled + ? "Enable Local WebGPU in Settings to use the browser-side adapter." + : localWebGpuStatus.supportMessage; + const localWebGpuStatusMessage = + selectedProvider !== "webgpu" + ? null + : localWebGpuUnavailableMessage ?? + (composerImages.length > 0 + ? "Image attachments are not supported for local WebGPU turns yet." + : localWebGpuStatus.phase === "loading-model" + ? localWebGpuStatus.progress?.total && localWebGpuStatus.progress.total > 0 + ? `Loading local model… ${Math.max( + 0, + Math.min( + 100, + Math.round( + (localWebGpuStatus.progress.loaded / localWebGpuStatus.progress.total) * 100, + ), + ), + )}%` + : "Loading local model…" + : localWebGpuStatus.phase === "generating" + ? "Running locally on this device." + : localWebGpuStatus.lastError); + const isLocalWebGpuAttachmentUnsupported = + selectedProvider === "webgpu" && composerImages.length > 0; + const disableLocalWebGpuSend = + Boolean(localWebGpuUnavailableMessage) || isLocalWebGpuAttachmentUnsupported; const selectedModelForPickerWithCustomFallback = useMemo(() => { const currentOptions = modelOptionsByProvider[selectedProvider]; return currentOptions.some((option) => option.slug === selectedModelForPicker) @@ -3099,6 +3174,10 @@ export default function ChatView({ threadId }: ChatViewProps) { } if (!trimmed && composerImages.length === 0) return; if (!activeProject) return; + if (localWebGpuUnavailableMessage) { + setStoreThreadError(activeThread.id, localWebGpuUnavailableMessage); + return; + } const threadIdForSend = activeThread.id; const isFirstMessage = !isServerThread || activeThread.messages.length === 0; const baseBranchForWorktree = @@ -3212,7 +3291,8 @@ export default function ChatView({ threadId }: ChatViewProps) { let threadCreateModel: ModelSlug = selectedModel || (activeProject.model as ModelSlug) || DEFAULT_MODEL_BY_PROVIDER.codex; - if (isLocalDraftThread) { + const shouldCreateServerThreadForDraft = isLocalDraftThread && selectedProvider !== "webgpu"; + if (shouldCreateServerThreadForDraft) { await api.orchestration.dispatchCommand({ type: "thread.create", commandId: newCommandId(), @@ -3256,7 +3336,7 @@ export default function ChatView({ threadId }: ChatViewProps) { } // Auto-title from first message - if (isFirstMessage && isServerThread) { + if (isFirstMessage && isServerThread && selectedProvider !== "webgpu") { await api.orchestration.dispatchCommand({ type: "thread.meta.update", commandId: newCommandId(), @@ -3265,7 +3345,7 @@ export default function ChatView({ threadId }: ChatViewProps) { }); } - if (isServerThread) { + if (isServerThread && selectedProvider !== "webgpu") { await persistThreadSettingsForNextTurn({ threadId: threadIdForSend, createdAt: messageCreatedAt, @@ -3519,6 +3599,10 @@ export default function ChatView({ threadId }: ChatViewProps) { ) { return; } + if (localWebGpuUnavailableMessage) { + setThreadError(activeThread.id, localWebGpuUnavailableMessage); + return; + } const trimmed = text.trim(); if (!trimmed) { @@ -3616,6 +3700,7 @@ export default function ChatView({ threadId }: ChatViewProps) { setComposerDraftInteractionMode, setThreadError, settings.enableAssistantStreaming, + localWebGpuUnavailableMessage, ], ); @@ -3633,6 +3718,10 @@ export default function ChatView({ threadId }: ChatViewProps) { ) { return; } + if (localWebGpuUnavailableMessage) { + setStoreThreadError(activeThread.id, localWebGpuUnavailableMessage); + return; + } const createdAt = new Date().toISOString(); const nextThreadId = newThreadId(); @@ -3652,6 +3741,60 @@ export default function ChatView({ threadId }: ChatViewProps) { resetSendPhase(); }; + if (selectedProvider === "webgpu") { + setProjectDraftThreadId(activeProject.id, nextThreadId, { + branch: activeThread.branch, + worktreePath: activeThread.worktreePath, + createdAt, + envMode: activeThread.worktreePath ? "worktree" : "local", + runtimeMode, + interactionMode: "default", + }); + setComposerDraftProvider(nextThreadId, "webgpu"); + setComposerDraftModel(nextThreadId, nextThreadModel, "webgpu"); + await api.orchestration + .dispatchCommand({ + type: "thread.turn.start", + commandId: newCommandId(), + threadId: nextThreadId, + message: { + messageId: newMessageId(), + role: "user", + text: implementationPrompt, + attachments: [], + }, + provider: "webgpu", + model: nextThreadModel, + ...(selectedModelOptionsForDispatch + ? { modelOptions: selectedModelOptionsForDispatch } + : {}), + assistantDeliveryMode: settings.enableAssistantStreaming ? "streaming" : "buffered", + runtimeMode, + interactionMode: "default", + createdAt, + }) + .then(() => { + clearDraftThread(nextThreadId); + planSidebarOpenOnNextThreadRef.current = true; + return navigate({ + to: "/$threadId", + params: { threadId: nextThreadId }, + }); + }) + .catch((err) => { + clearProjectDraftThreadById(activeProject.id, nextThreadId); + clearDraftThread(nextThreadId); + toastManager.add({ + type: "error", + title: "Could not start implementation thread", + description: + err instanceof Error ? err.message : "An error occurred while creating the new thread.", + }); + }) + .then(finish, finish); + return; + } + await api.orchestration .dispatchCommand({ type: "thread.create", @@ -3734,11 +3877,18 @@ export default function ChatView({ threadId }: ChatViewProps) { runtimeMode, selectedModel, selectedModelOptionsForDispatch, - providerOptionsForDispatch, - selectedProvider, - settings.enableAssistantStreaming, - syncServerReadModel, - ]); + providerOptionsForDispatch, + selectedProvider, + setComposerDraftModel, + setComposerDraftProvider, + setProjectDraftThreadId, + settings.enableAssistantStreaming, + syncServerReadModel, + clearDraftThread, + clearProjectDraftThreadById, + localWebGpuUnavailableMessage, + setStoreThreadError, + ]); const onProviderModelSelect = useCallback( (provider: ProviderKind, model: ModelSlug) => { @@ -3752,7 +3902,11 @@ export default function ChatView({ threadId }: ChatViewProps) { activeThread.id, resolveAppModelSelection( provider, - provider === "copilot" ? settings.customCopilotModels : settings.customCodexModels, + provider === "copilot" + ? settings.customCopilotModels + : provider === "webgpu" + ? settings.customWebGpuModels + : settings.customCodexModels, model, builtInModelOptionsByProvider[provider], ), @@ -3768,6 +3922,7 @@ export default function ChatView({ threadId }: ChatViewProps) { setComposerDraftProvider, settings.customCopilotModels, settings.customCodexModels, + settings.customWebGpuModels, ], ); const onEffortSelect = useCallback( @@ -4314,6 +4469,15 @@ export default function ChatView({ threadId }: ChatViewProps) { /> + {localWebGpuStatusMessage ? ( +
+
+ + {localWebGpuStatusMessage} +
+
+ ) : null} + {/* Bottom toolbar */} {activePendingApproval ? (
@@ -4536,21 +4700,21 @@ export default function ChatView({ threadId }: ChatViewProps) { ) : pendingUserInputs.length === 0 ? ( showPlanFollowUpPrompt ? ( prompt.trim().length > 0 ? ( - + ) : (
@@ -4562,7 +4726,7 @@ export default function ChatView({ threadId }: ChatViewProps) { variant="default" className="h-9 rounded-l-none rounded-r-full border-l-white/12 px-2 sm:h-8" aria-label="Implementation actions" - disabled={isSendBusy || isConnecting} + disabled={isSendBusy || isConnecting || disableLocalWebGpuSend} /> } > @@ -4586,6 +4750,7 @@ export default function ChatView({ threadId }: ChatViewProps) { disabled={ isSendBusy || isConnecting || + disableLocalWebGpuSend || (!prompt.trim() && composerImages.length === 0) } aria-label={ @@ -6170,6 +6335,7 @@ function getCustomModelOptionsByProvider( settings: { customCodexModels: readonly string[]; customCopilotModels: readonly string[]; + customWebGpuModels: readonly string[]; }, builtInCopilotOptions: ReadonlyArray, ): Record> { @@ -6181,12 +6347,14 @@ function getCustomModelOptionsByProvider( undefined, builtInCopilotOptions, ), + webgpu: getAppModelOptions("webgpu", settings.customWebGpuModels), }; } -const PROVIDER_ICON_BY_PROVIDER: Record = { +const PROVIDER_ICON_BY_PROVIDER: Record = { codex: OpenAI, copilot: GitHubIcon, + webgpu: HuggingFaceIcon, claudeCode: ClaudeAI, cursor: CursorIcon, }; diff --git a/apps/web/src/components/Icons.tsx b/apps/web/src/components/Icons.tsx index 4e1a586d..5adfa02f 100644 --- a/apps/web/src/components/Icons.tsx +++ b/apps/web/src/components/Icons.tsx @@ -287,6 +287,20 @@ export const Gemini: Icon = (props) => ( ); +export const HUGGING_FACE_BRAND_ASSET_URL = + "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg"; + +export const HuggingFaceIcon: Icon = (props) => ( + +); + export const OpenCodeIcon: Icon = (props) => ( diff --git a/apps/web/src/composerDraftStore.ts b/apps/web/src/composerDraftStore.ts index d3dd40c1..6ad9e280 100644 --- a/apps/web/src/composerDraftStore.ts +++ b/apps/web/src/composerDraftStore.ts @@ -249,7 +249,7 @@ function shouldRemoveDraft(draft: ComposerThreadDraftState): boolean { } function normalizeProviderKind(value: unknown): ProviderKind | null { - return value === "codex" || value === "copilot" ? value : null; + return value === "codex" || value === "copilot" || value === "webgpu" ? value : null; } function normalizeDraftModel( diff --git a/apps/web/src/lib/serverReactQuery.ts b/apps/web/src/lib/serverReactQuery.ts index 85853e2e..a61f2c9c 100644 --- a/apps/web/src/lib/serverReactQuery.ts +++ b/apps/web/src/lib/serverReactQuery.ts @@ -1,9 +1,20 @@ +import type { ServerHuggingFaceModelSearchResult } from "@t3tools/contracts"; import { queryOptions } from "@tanstack/react-query"; import { ensureNativeApi } from "~/nativeApi"; export const serverQueryKeys = { all: ["server"] as const, config: () => ["server", "config"] as const, + huggingFaceModels: (query: string | null, limit: number) => + ["server", "hugging-face-models", query, limit] as const, +}; + +const DEFAULT_HUGGING_FACE_MODEL_SEARCH_LIMIT = 12; +const DEFAULT_HUGGING_FACE_MODEL_SEARCH_STALE_TIME = 30_000; +const EMPTY_HUGGING_FACE_MODEL_SEARCH_RESULT: ServerHuggingFaceModelSearchResult = { + mode: "featured", + models: [], + truncated: false, }; export function serverConfigQueryOptions() { @@ -16,3 +27,26 @@ export function serverConfigQueryOptions() { staleTime: Infinity, }); } + +export function huggingFaceModelSearchQueryOptions(input: { + query: string | null; + enabled?: boolean; + limit?: number; + staleTime?: number; +}) { + const normalizedQuery = input.query?.trim() || null; + const limit = input.limit ?? DEFAULT_HUGGING_FACE_MODEL_SEARCH_LIMIT; + return queryOptions({ + queryKey: serverQueryKeys.huggingFaceModels(normalizedQuery, limit), + queryFn: async () => { + const api = ensureNativeApi(); + return api.server.searchHuggingFaceModels({ + ...(normalizedQuery ? { query: normalizedQuery } : {}), + limit, + }); + }, + enabled: input.enabled ?? true, + staleTime: input.staleTime ?? DEFAULT_HUGGING_FACE_MODEL_SEARCH_STALE_TIME, + placeholderData: (previous) => previous ?? EMPTY_HUGGING_FACE_MODEL_SEARCH_RESULT, + }); +} diff --git a/apps/web/src/localWebGpuOrchestration.test.ts b/apps/web/src/localWebGpuOrchestration.test.ts new file mode 100644 index 00000000..e65a0950 --- /dev/null +++ b/apps/web/src/localWebGpuOrchestration.test.ts @@ -0,0 +1,239 @@ +import { + MessageId, + ProjectId, + ThreadId, + type NativeApi, + type OrchestrationEvent, +} from "@t3tools/contracts"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +type MockWorkerMessage = + | { + type: "generate"; + requestId: string; + model: string; + dtype: string; + messages: Array<{ role: "user" | "assistant" | "system"; content: string }>; + } + | { type: "dispose" }; + +class MockWorker { + private readonly listeners = new Set<(event: MessageEvent) => void>(); + + addEventListener(_type: "message", listener: (event: MessageEvent) => void): void { + this.listeners.add(listener); + } + + postMessage(message: MockWorkerMessage): void { + if (message.type !== "generate") { + return; + } + queueMicrotask(() => { + this.emit({ + type: "status", + status: "ready", + model: message.model, + dtype: message.dtype, + }); + this.emit({ + type: "text-delta", + requestId: message.requestId, + delta: "Hello from WebGPU", + }); + this.emit({ + type: "complete", + requestId: message.requestId, + text: "Hello from WebGPU", + }); + }); + } + + terminate(): void {} + + private emit(data: unknown): void { + for (const listener of this.listeners) { + listener({ data } as MessageEvent); + } + } +} + +function createLocalStorageMock() { + const storage = new Map(); + return { + getItem: (key: string) => storage.get(key) ?? null, + setItem: (key: string, value: string) => { + storage.set(key, value); + }, + removeItem: (key: string) => { + storage.delete(key); + }, + clear: () => { + storage.clear(); + }, + }; +} + +function createBaseApi(): NativeApi { + return { + dialogs: { pickFolder: vi.fn(), confirm: vi.fn() }, + terminal: { + open: vi.fn(), + write: vi.fn(), + resize: vi.fn(), + clear: vi.fn(), + restart: vi.fn(), + close: vi.fn(), + onEvent: vi.fn(() => () => {}), + }, + projects: { + searchEntries: vi.fn(), + writeFile: vi.fn(), + }, + shell: { + openInEditor: vi.fn(), + openExternal: vi.fn(), + }, + git: { + pull: vi.fn(), + status: vi.fn(), + runStackedAction: vi.fn(), + listBranches: vi.fn(), + createWorktree: vi.fn(), + removeWorktree: vi.fn(), + createBranch: vi.fn(), + checkout: vi.fn(), + init: vi.fn(), + resolvePullRequest: vi.fn(), + preparePullRequestThread: vi.fn(), + }, + contextMenu: { + show: vi.fn(), + }, + server: { + getConfig: vi.fn(), + upsertKeybinding: vi.fn(), + searchHuggingFaceModels: vi.fn(), + }, + orchestration: { + getSnapshot: vi.fn(async () => ({ + snapshotSequence: 7, + updatedAt: "2026-03-10T00:00:00.000Z", + projects: [ + { + id: ProjectId.makeUnsafe("project-1"), + title: "Project", + workspaceRoot: "/tmp/project", + defaultModel: null, + scripts: [], + createdAt: "2026-03-10T00:00:00.000Z", + updatedAt: "2026-03-10T00:00:00.000Z", + deletedAt: null, + }, + ], + threads: [], + })), + dispatchCommand: vi.fn(async () => ({ sequence: 1 })), + getTurnDiff: vi.fn(), + getFullThreadDiff: vi.fn(), + replayEvents: vi.fn(), + onDomainEvent: vi.fn(() => () => {}), + }, + }; +} + +beforeEach(() => { + vi.resetModules(); + const localStorageMock = createLocalStorageMock(); + Object.defineProperty(globalThis, "localStorage", { + configurable: true, + value: localStorageMock, + }); + Object.defineProperty(globalThis, "window", { + configurable: true, + value: { + localStorage: localStorageMock, + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + Worker: MockWorker, + }, + }); + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: { gpu: {} }, + }); + Object.defineProperty(globalThis, "Worker", { + configurable: true, + value: MockWorker, + }); +}); + +afterEach(async () => { + const { clearLocalWebGpuState } = await import("./localWebGpuOrchestration"); + await clearLocalWebGpuState(); + vi.restoreAllMocks(); +}); + +describe("localWebGpuOrchestration", () => { + it("returns the same status snapshot reference when nothing changed", async () => { + const { getLocalWebGpuStatusSnapshot } = await import("./localWebGpuOrchestration"); + + const first = getLocalWebGpuStatusSnapshot(); + const second = getLocalWebGpuStatusSnapshot(); + + expect(second).toBe(first); + }); + + it("routes webgpu turns locally and merges them into snapshots", async () => { + const { createHybridNativeApi, getLocalWebGpuStatusSnapshot } = await import( + "./localWebGpuOrchestration" + ); + const { useComposerDraftStore } = await import("./composerDraftStore"); + const api = createHybridNativeApi(createBaseApi()); + const threadId = ThreadId.makeUnsafe("local-thread-1"); + const projectId = ProjectId.makeUnsafe("project-1"); + const initialStatus = getLocalWebGpuStatusSnapshot(); + useComposerDraftStore.getState().setProjectDraftThreadId(projectId, threadId, { + createdAt: "2026-03-10T00:00:00.000Z", + envMode: "local", + }); + + const receivedEvents: OrchestrationEvent[] = []; + api.orchestration.onDomainEvent((event) => { + receivedEvents.push(event); + }); + + await api.orchestration.dispatchCommand({ + type: "thread.turn.start", + commandId: "command-1" as never, + threadId, + message: { + messageId: MessageId.makeUnsafe("message-1"), + role: "user", + text: "Say hi", + attachments: [], + }, + provider: "webgpu", + model: "onnx-community/Qwen2.5-0.5B-Instruct", + runtimeMode: "full-access", + interactionMode: "default", + createdAt: "2026-03-10T00:00:00.000Z", + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + await new Promise((resolve) => setTimeout(resolve, 0)); + + const snapshot = await api.orchestration.getSnapshot(); + const thread = snapshot.threads.find((entry) => entry.id === threadId); + const finalStatus = getLocalWebGpuStatusSnapshot(); + + expect(thread).toBeDefined(); + expect(thread?.session?.providerName).toBe("webgpu"); + expect(thread?.messages.map((message) => message.text)).toEqual(["Say hi", "Hello from WebGPU"]); + expect(thread?.messages.at(-1)?.streaming).toBe(false); + expect(receivedEvents.some((event) => event.type === "thread.turn-start-requested")).toBe(true); + expect(receivedEvents.some((event) => event.type === "thread.session-set")).toBe(true); + expect(finalStatus).not.toBe(initialStatus); + expect(finalStatus.phase).toBe("ready"); + expect(finalStatus.model).toBe("onnx-community/Qwen2.5-0.5B-Instruct"); + }); +}); diff --git a/apps/web/src/localWebGpuOrchestration.ts b/apps/web/src/localWebGpuOrchestration.ts new file mode 100644 index 00000000..ce0c21e0 --- /dev/null +++ b/apps/web/src/localWebGpuOrchestration.ts @@ -0,0 +1,1290 @@ +import { + DEFAULT_MODEL_BY_PROVIDER, + EventId, + MessageId, + ThreadId, + TurnId, + type ClientOrchestrationCommand, + type NativeApi, + type OrchestrationEvent, + type OrchestrationReadModel, + type OrchestrationSession, + type OrchestrationSessionStatus, + type OrchestrationThread, + type WebGpuModelDtype, +} from "@t3tools/contracts"; +import { getAppSettingsSnapshot } from "./appSettings"; +import { useComposerDraftStore } from "./composerDraftStore"; +import { randomUUID } from "./lib/utils"; +import { truncateTitle } from "./truncateTitle"; + +const LOCAL_WEBGPU_STORAGE_KEY = "t3code:webgpu-local-state:v1"; + +type LocalWebGpuRuntimePhase = + | "idle" + | "loading-model" + | "ready" + | "generating" + | "error" + | "unsupported"; + +export interface LocalWebGpuStatusSnapshot { + enabled: boolean; + supported: boolean; + supportMessage: string | null; + phase: LocalWebGpuRuntimePhase; + model: string | null; + dtype: WebGpuModelDtype; + progress: + | { + file: string | null; + loaded: number; + total: number | null; + } + | null; + lastError: string | null; +} + +function localWebGpuProgressEquals( + left: LocalWebGpuStatusSnapshot["progress"], + right: LocalWebGpuStatusSnapshot["progress"], +): boolean { + if (left === right) { + return true; + } + if (left === null || right === null) { + return left === right; + } + return ( + left.file === right.file && + left.loaded === right.loaded && + left.total === right.total + ); +} + +function localWebGpuStatusSnapshotEquals( + left: LocalWebGpuStatusSnapshot | null, + right: LocalWebGpuStatusSnapshot, +): boolean { + if (left === null) { + return false; + } + return ( + left.enabled === right.enabled && + left.supported === right.supported && + left.supportMessage === right.supportMessage && + left.phase === right.phase && + left.model === right.model && + left.dtype === right.dtype && + localWebGpuProgressEquals(left.progress, right.progress) && + left.lastError === right.lastError + ); +} + +interface PersistedLocalWebGpuState { + threads: OrchestrationThread[]; + updatedAt: string; +} + +interface LocalWebGpuChatMessage { + role: "user" | "assistant" | "system"; + content: string; +} + +type WorkerGenerateMessage = { + type: "generate"; + requestId: string; + model: string; + dtype: WebGpuModelDtype; + messages: LocalWebGpuChatMessage[]; + maxNewTokens: number; + temperature: number; + topP: number; +}; + +type WorkerDisposeMessage = { + type: "dispose"; +}; + +type WorkerInboundMessage = WorkerGenerateMessage | WorkerDisposeMessage; + +type WorkerStatusMessage = { + type: "status"; + status: Exclude; + model: string | null; + dtype: WebGpuModelDtype; + message?: string; +}; + +type WorkerProgressMessage = { + type: "download-progress"; + file: string | null; + loaded: number; + total: number | null; +}; + +type WorkerTextDeltaMessage = { + type: "text-delta"; + requestId: string; + delta: string; +}; + +type WorkerCompleteMessage = { + type: "complete"; + requestId: string; + text: string; +}; + +type WorkerErrorMessage = { + type: "error"; + requestId?: string; + message: string; +}; + +type WorkerOutboundMessage = + | WorkerStatusMessage + | WorkerProgressMessage + | WorkerTextDeltaMessage + | WorkerCompleteMessage + | WorkerErrorMessage; + +type WorkerGenerationRequest = { + requestId: string; + threadId: ThreadId; + turnId: TurnId; + assistantMessageId: MessageId; + resolve: (text: string) => void; + reject: (error: Error) => void; + onDelta: (delta: string) => void; +}; + +function nowIso(): string { + return new Date().toISOString(); +} + +function getWebGpuSupportMessage(): string | null { + if (typeof window === "undefined") { + return "Local WebGPU is only available in the browser."; + } + if (!("Worker" in window)) { + return "This browser does not support Web Workers."; + } + if (!("gpu" in navigator)) { + return "WebGPU is unavailable in this browser. Try a recent Chromium build with WebGPU enabled."; + } + return null; +} + +function newEventId() { + return EventId.makeUnsafe(randomUUID()); +} + +function newMessageId() { + return MessageId.makeUnsafe(randomUUID()); +} + +function newTurnId() { + return TurnId.makeUnsafe(randomUUID()); +} + +function emptyMetadata() { + return { adapterKey: "webgpu.local" }; +} + +function createThreadEvent(input: { + type: TType; + threadId: ThreadId; + payload: Extract["payload"]; + commandId: string | null; + occurredAt?: string; +}): OrchestrationEvent { + return { + sequence: 0, + eventId: newEventId(), + aggregateKind: "thread", + aggregateId: input.threadId, + occurredAt: input.occurredAt ?? nowIso(), + commandId: input.commandId, + causationEventId: null, + correlationId: input.commandId, + metadata: emptyMetadata(), + type: input.type, + payload: input.payload, + } as OrchestrationEvent; +} + +function persistLocalState(threadsById: ReadonlyMap): void { + if (typeof window === "undefined") { + return; + } + const value: PersistedLocalWebGpuState = { + threads: Array.from(threadsById.values()), + updatedAt: nowIso(), + }; + try { + window.localStorage.setItem(LOCAL_WEBGPU_STORAGE_KEY, JSON.stringify(value)); + } catch { + // Best-effort persistence only. + } +} + +function loadPersistedLocalState(): Map { + if (typeof window === "undefined") { + return new Map(); + } + try { + const raw = window.localStorage.getItem(LOCAL_WEBGPU_STORAGE_KEY); + if (!raw) { + return new Map(); + } + const parsed = JSON.parse(raw) as Partial; + const threads = Array.isArray(parsed.threads) ? parsed.threads : []; + return new Map( + threads.flatMap((thread) => + thread && typeof thread.id === "string" + ? [[ThreadId.makeUnsafe(thread.id), thread as OrchestrationThread] as const] + : [], + ), + ); + } catch { + return new Map(); + } +} + +class WebGpuWorkerClient { + private worker: Worker | null = null; + private activeRequest: WorkerGenerationRequest | null = null; + + constructor( + private readonly onStatus: (message: WorkerStatusMessage) => void, + private readonly onProgress: (message: WorkerProgressMessage) => void, + ) {} + + private ensureWorker(): Worker { + if (this.worker) { + return this.worker; + } + const worker = new Worker(new URL("./workers/webgpuInference.worker.ts", import.meta.url), { + type: "module", + }); + worker.addEventListener("message", (event: MessageEvent) => { + const message = event.data; + switch (message.type) { + case "status": + this.onStatus(message); + break; + case "download-progress": + this.onProgress(message); + break; + case "text-delta": + if (this.activeRequest?.requestId === message.requestId) { + this.activeRequest.onDelta(message.delta); + } + break; + case "complete": + if (this.activeRequest?.requestId === message.requestId) { + this.activeRequest.resolve(message.text); + this.activeRequest = null; + } + break; + case "error": + if (this.activeRequest && (!message.requestId || this.activeRequest.requestId === message.requestId)) { + this.activeRequest.reject(new Error(message.message)); + this.activeRequest = null; + } else { + this.onStatus({ + type: "status", + status: "error", + model: null, + dtype: getAppSettingsSnapshot().webGpuPreferredDtype, + message: message.message, + }); + } + break; + } + }); + this.worker = worker; + return worker; + } + + async generate(input: { + requestId: string; + threadId: ThreadId; + turnId: TurnId; + assistantMessageId: MessageId; + model: string; + dtype: WebGpuModelDtype; + messages: LocalWebGpuChatMessage[]; + maxNewTokens: number; + temperature: number; + topP: number; + onDelta: (delta: string) => void; + }): Promise { + if (this.activeRequest !== null) { + throw new Error("Only one local WebGPU generation can run at a time in v1."); + } + const worker = this.ensureWorker(); + const result = new Promise((resolve, reject) => { + this.activeRequest = { + requestId: input.requestId, + threadId: input.threadId, + turnId: input.turnId, + assistantMessageId: input.assistantMessageId, + resolve, + reject, + onDelta: input.onDelta, + }; + }); + const message: WorkerGenerateMessage = { + type: "generate", + requestId: input.requestId, + model: input.model, + dtype: input.dtype, + messages: input.messages, + maxNewTokens: input.maxNewTokens, + temperature: input.temperature, + topP: input.topP, + }; + // eslint-disable-next-line unicorn/require-post-message-target-origin -- Dedicated worker messaging does not use targetOrigin. + worker.postMessage(message satisfies WorkerInboundMessage); + return result; + } + + interrupt(): void { + if (this.activeRequest) { + this.activeRequest.reject(new Error("Local WebGPU generation was interrupted.")); + this.activeRequest = null; + } + if (this.worker) { + this.worker.terminate(); + this.worker = null; + } + } + + dispose(): void { + if (this.worker) { + // eslint-disable-next-line unicorn/require-post-message-target-origin -- Dedicated worker messaging does not use targetOrigin. + this.worker.postMessage({ type: "dispose" } satisfies WorkerInboundMessage); + this.worker.terminate(); + this.worker = null; + } + this.activeRequest = null; + } +} + +class LocalWebGpuOrchestrationController { + private readonly threadsById = loadPersistedLocalState(); + private readonly statusListeners = new Set<() => void>(); + private readonly domainEventListeners = new Set<(event: OrchestrationEvent) => void>(); + private cachedStatusSnapshot: LocalWebGpuStatusSnapshot | null = null; + private status: Omit = { + phase: "idle", + model: null, + dtype: getAppSettingsSnapshot().webGpuPreferredDtype, + progress: null, + lastError: null, + }; + private activeGeneration: + | { + threadId: ThreadId; + turnId: TurnId; + assistantMessageId: MessageId; + } + | null = null; + + private readonly workerClient = new WebGpuWorkerClient( + (message) => { + this.status = { + ...this.status, + phase: message.status, + model: message.model, + dtype: message.dtype, + lastError: message.message ?? null, + ...(message.status === "ready" || message.status === "idle" ? { progress: null } : {}), + }; + this.emitStatus(); + }, + (message) => { + this.status = { + ...this.status, + phase: "loading-model", + progress: message, + }; + this.emitStatus(); + }, + ); + + subscribeStatus(listener: () => void): () => void { + this.statusListeners.add(listener); + return () => { + this.statusListeners.delete(listener); + }; + } + + subscribeDomainEvents(listener: (event: OrchestrationEvent) => void): () => void { + this.domainEventListeners.add(listener); + return () => { + this.domainEventListeners.delete(listener); + }; + } + + getStatusSnapshot(): LocalWebGpuStatusSnapshot { + const settings = getAppSettingsSnapshot(); + const supportMessage = getWebGpuSupportMessage(); + const nextSnapshot: LocalWebGpuStatusSnapshot = { + enabled: settings.webGpuEnabled, + supported: supportMessage === null, + supportMessage, + ...this.status, + dtype: this.status.dtype ?? settings.webGpuPreferredDtype, + }; + const cachedSnapshot = this.cachedStatusSnapshot; + if (cachedSnapshot && localWebGpuStatusSnapshotEquals(cachedSnapshot, nextSnapshot)) { + return cachedSnapshot; + } + this.cachedStatusSnapshot = nextSnapshot; + return nextSnapshot; + } + + private emitStatus(): void { + for (const listener of this.statusListeners) { + listener(); + } + } + + private emitDomainEvent(event: OrchestrationEvent): void { + for (const listener of this.domainEventListeners) { + listener(event); + } + } + + private persist(): void { + persistLocalState(this.threadsById); + } + + private getThread(threadId: ThreadId): OrchestrationThread { + const thread = this.threadsById.get(threadId); + if (!thread) { + throw new Error("Local WebGPU thread not found."); + } + return thread; + } + + private setThread(thread: OrchestrationThread): void { + this.threadsById.set(thread.id, thread); + this.persist(); + } + + isLocalThread(threadId: ThreadId): boolean { + return this.threadsById.has(threadId); + } + + mergeSnapshot(baseSnapshot: OrchestrationReadModel, snapshotSequence: number): OrchestrationReadModel { + const projectIds = new Set(baseSnapshot.projects.filter((project) => project.deletedAt === null).map((project) => project.id)); + let changed = false; + for (const [threadId, thread] of this.threadsById.entries()) { + if (!projectIds.has(thread.projectId)) { + this.threadsById.delete(threadId); + changed = true; + } + } + if (changed) { + this.persist(); + } + const localThreads = Array.from(this.threadsById.values()).toSorted((left, right) => + left.createdAt.localeCompare(right.createdAt), + ); + return { + ...baseSnapshot, + snapshotSequence, + threads: [...baseSnapshot.threads.filter((thread) => !this.threadsById.has(thread.id)), ...localThreads], + updatedAt: + localThreads.length > 0 + ? [baseSnapshot.updatedAt, localThreads.at(-1)?.updatedAt ?? baseSnapshot.updatedAt] + .toSorted() + .at(-1) ?? + baseSnapshot.updatedAt + : baseSnapshot.updatedAt, + }; + } + + private buildSession( + threadId: ThreadId, + status: OrchestrationSessionStatus, + runtimeMode: OrchestrationThread["runtimeMode"], + activeTurnId: TurnId | null, + updatedAt: string, + lastError?: string | null, + ): OrchestrationSession { + return { + threadId, + status, + providerName: "webgpu", + runtimeMode, + activeTurnId, + lastError: lastError ?? null, + updatedAt, + }; + } + + private ensureThreadForTurnStart( + command: Extract, + selectedModel: string, + ): { thread: OrchestrationThread; created: boolean } { + const existing = this.threadsById.get(command.threadId); + if (existing) { + return { thread: existing, created: false }; + } + const draftThread = useComposerDraftStore.getState().getDraftThread(command.threadId); + if (!draftThread) { + throw new Error("Create a local draft thread before starting a WebGPU turn."); + } + const titleSeed = command.message.text.trim() || "New local thread"; + const thread: OrchestrationThread = { + id: command.threadId, + projectId: draftThread.projectId, + title: truncateTitle(titleSeed), + model: selectedModel, + runtimeMode: draftThread.runtimeMode, + interactionMode: draftThread.interactionMode, + branch: draftThread.branch, + worktreePath: draftThread.worktreePath, + latestTurn: null, + createdAt: draftThread.createdAt, + updatedAt: command.createdAt, + deletedAt: null, + messages: [], + proposedPlans: [], + activities: [], + checkpoints: [], + session: null, + }; + this.setThread(thread); + return { thread, created: true }; + } + + private buildGenerationMessages( + thread: OrchestrationThread, + userText: string, + ): LocalWebGpuChatMessage[] { + return [ + ...thread.messages + .filter((message) => !message.streaming && message.text.trim().length > 0) + .map((message) => ({ + role: message.role, + content: message.text, + })), + { role: "user", content: userText }, + ]; + } + + private appendAssistantDelta( + threadId: ThreadId, + assistantMessageId: MessageId, + turnId: TurnId, + delta: string, + commandId: string, + ): void { + const thread = this.getThread(threadId); + const messages = thread.messages.map((message) => + message.id === assistantMessageId + ? { + ...message, + text: `${message.text}${delta}`, + updatedAt: nowIso(), + } + : message, + ); + const updatedThread = { + ...thread, + messages, + updatedAt: nowIso(), + }; + this.setThread(updatedThread); + const assistantMessage = messages.find((message) => message.id === assistantMessageId); + if (!assistantMessage) { + return; + } + this.emitDomainEvent( + createThreadEvent({ + type: "thread.message-sent", + threadId, + commandId, + payload: { + threadId, + messageId: assistantMessage.id, + role: assistantMessage.role, + text: assistantMessage.text, + turnId, + streaming: true, + createdAt: assistantMessage.createdAt, + updatedAt: assistantMessage.updatedAt, + }, + }), + ); + } + + private completeAssistantMessage( + threadId: ThreadId, + turnId: TurnId, + assistantMessageId: MessageId, + finalText: string, + commandId: string, + ): void { + const completedAt = nowIso(); + const thread = this.getThread(threadId); + const messages = thread.messages.map((message) => + message.id === assistantMessageId + ? { + ...message, + text: finalText.length > 0 ? finalText : message.text, + streaming: false, + updatedAt: completedAt, + } + : message, + ); + const assistantMessage = messages.find((message) => message.id === assistantMessageId); + const updatedThread: OrchestrationThread = { + ...thread, + messages, + latestTurn: { + turnId, + state: "completed", + requestedAt: thread.latestTurn?.requestedAt ?? completedAt, + startedAt: thread.latestTurn?.startedAt ?? completedAt, + completedAt, + assistantMessageId, + }, + session: this.buildSession(threadId, "ready", thread.runtimeMode, null, completedAt), + updatedAt: completedAt, + }; + this.setThread(updatedThread); + if (assistantMessage) { + this.emitDomainEvent( + createThreadEvent({ + type: "thread.message-sent", + threadId, + commandId, + payload: { + threadId, + messageId: assistantMessage.id, + role: assistantMessage.role, + text: assistantMessage.text, + turnId, + streaming: false, + createdAt: assistantMessage.createdAt, + updatedAt: assistantMessage.updatedAt, + }, + }), + ); + } + this.emitDomainEvent( + createThreadEvent({ + type: "thread.session-set", + threadId, + commandId, + payload: { + threadId, + session: updatedThread.session!, + }, + }), + ); + } + + private failGeneration( + threadId: ThreadId, + turnId: TurnId, + assistantMessageId: MessageId, + error: Error, + commandId: string, + ): void { + const completedAt = nowIso(); + const thread = this.getThread(threadId); + const messages = thread.messages.map((message) => + message.id === assistantMessageId + ? { + ...message, + streaming: false, + updatedAt: completedAt, + } + : message, + ); + const updatedThread: OrchestrationThread = { + ...thread, + messages, + latestTurn: { + turnId, + state: "error", + requestedAt: thread.latestTurn?.requestedAt ?? completedAt, + startedAt: thread.latestTurn?.startedAt ?? completedAt, + completedAt, + assistantMessageId, + }, + session: this.buildSession(threadId, "error", thread.runtimeMode, null, completedAt, error.message), + updatedAt: completedAt, + }; + this.setThread(updatedThread); + this.status = { + ...this.status, + phase: "error", + lastError: error.message, + }; + this.emitStatus(); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.session-set", + threadId, + commandId, + payload: { + threadId, + session: updatedThread.session!, + }, + }), + ); + } + + async clearPersistedState(): Promise { + const threadIds = Array.from(this.threadsById.keys()); + this.workerClient.dispose(); + this.activeGeneration = null; + this.status = { + phase: "idle", + model: null, + dtype: getAppSettingsSnapshot().webGpuPreferredDtype, + progress: null, + lastError: null, + }; + this.threadsById.clear(); + if (typeof window !== "undefined") { + try { + window.localStorage.removeItem(LOCAL_WEBGPU_STORAGE_KEY); + } catch { + // Ignore storage failures. + } + } + this.emitStatus(); + for (const threadId of threadIds) { + this.emitDomainEvent( + createThreadEvent({ + type: "thread.deleted", + threadId, + commandId: null, + payload: { + threadId, + deletedAt: nowIso(), + }, + }), + ); + } + } + + async dispatchCommand( + baseApi: NativeApi, + command: ClientOrchestrationCommand, + ): Promise<{ sequence: number }> { + if (!this.shouldHandleLocally(command)) { + return baseApi.orchestration.dispatchCommand(command); + } + switch (command.type) { + case "thread.turn.start": + this.startLocalTurn(command); + return { sequence: 0 }; + case "thread.turn.interrupt": + this.interruptLocalTurn(command); + return { sequence: 0 }; + case "thread.session.stop": + this.stopLocalSession(command); + return { sequence: 0 }; + case "thread.delete": + this.deleteLocalThread(command); + return { sequence: 0 }; + case "thread.meta.update": + this.updateLocalThreadMeta(command); + return { sequence: 0 }; + case "thread.runtime-mode.set": + this.setLocalRuntimeMode(command); + return { sequence: 0 }; + case "thread.interaction-mode.set": + this.setLocalInteractionMode(command); + return { sequence: 0 }; + case "thread.checkpoint.revert": + throw new Error("Checkpoint revert is not supported for local WebGPU threads yet."); + case "thread.approval.respond": + case "thread.user-input.respond": + throw new Error("Local WebGPU threads do not support interactive approval requests."); + default: + return baseApi.orchestration.dispatchCommand(command); + } + } + + private shouldHandleLocally(command: ClientOrchestrationCommand): boolean { + if (command.type === "thread.turn.start" && command.provider === "webgpu") { + return true; + } + if (!("threadId" in command)) { + return false; + } + return this.threadsById.has(command.threadId); + } + + private startLocalTurn(command: Extract): void { + const settings = getAppSettingsSnapshot(); + const supportMessage = getWebGpuSupportMessage(); + if (!settings.webGpuEnabled) { + throw new Error("Enable Local WebGPU in Settings before using the browser adapter."); + } + if (supportMessage) { + throw new Error(supportMessage); + } + if (command.message.attachments.length > 0) { + throw new Error("Local WebGPU does not support image attachments yet."); + } + if (this.activeGeneration !== null) { + throw new Error("A local WebGPU turn is already running."); + } + + const selectedModel = command.model?.trim() || settings.webGpuDefaultModel || DEFAULT_MODEL_BY_PROVIDER.webgpu; + const { thread: existingThread, created } = this.ensureThreadForTurnStart(command, selectedModel); + const requestId = randomUUID(); + const turnId = newTurnId(); + const assistantMessageId = newMessageId(); + const startedAt = command.createdAt; + const generationMessages = this.buildGenerationMessages(existingThread, command.message.text); + const userMessage = { + id: command.message.messageId, + role: "user" as const, + text: command.message.text, + attachments: [], + turnId, + streaming: false, + createdAt: command.createdAt, + updatedAt: command.createdAt, + }; + const assistantMessage = { + id: assistantMessageId, + role: "assistant" as const, + text: "", + turnId, + streaming: true, + createdAt: startedAt, + updatedAt: startedAt, + }; + const updatedThread: OrchestrationThread = { + ...existingThread, + model: selectedModel, + runtimeMode: command.runtimeMode, + interactionMode: command.interactionMode, + title: + existingThread.messages.length === 0 && existingThread.title === "New thread" + ? truncateTitle(command.message.text.trim() || existingThread.title) + : existingThread.title, + messages: [...existingThread.messages, userMessage, assistantMessage], + latestTurn: { + turnId, + state: "running", + requestedAt: startedAt, + startedAt, + completedAt: null, + assistantMessageId, + }, + session: this.buildSession(command.threadId, "running", command.runtimeMode, turnId, startedAt), + updatedAt: startedAt, + }; + this.setThread(updatedThread); + if (created) { + this.emitDomainEvent( + createThreadEvent({ + type: "thread.created", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: updatedThread.id, + projectId: updatedThread.projectId, + title: updatedThread.title, + model: updatedThread.model, + runtimeMode: updatedThread.runtimeMode, + interactionMode: updatedThread.interactionMode, + branch: updatedThread.branch, + worktreePath: updatedThread.worktreePath, + createdAt: updatedThread.createdAt, + updatedAt: updatedThread.updatedAt, + }, + }), + ); + } + this.emitDomainEvent( + createThreadEvent({ + type: "thread.message-sent", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: command.threadId, + messageId: userMessage.id, + role: "user", + text: userMessage.text, + attachments: [], + turnId, + streaming: false, + createdAt: userMessage.createdAt, + updatedAt: userMessage.updatedAt, + }, + }), + ); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.turn-start-requested", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: command.threadId, + messageId: userMessage.id, + provider: "webgpu", + model: selectedModel, + ...(command.modelOptions ? { modelOptions: command.modelOptions } : {}), + assistantDeliveryMode: command.assistantDeliveryMode, + runtimeMode: command.runtimeMode, + interactionMode: command.interactionMode, + createdAt: command.createdAt, + }, + }), + ); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.message-sent", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: command.threadId, + messageId: assistantMessage.id, + role: "assistant", + text: "", + turnId, + streaming: true, + createdAt: assistantMessage.createdAt, + updatedAt: assistantMessage.updatedAt, + }, + }), + ); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.session-set", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: command.threadId, + session: updatedThread.session!, + }, + }), + ); + + const webgpuOptions = command.modelOptions?.webgpu; + const dtype = webgpuOptions?.dtype ?? settings.webGpuPreferredDtype; + this.activeGeneration = { + threadId: command.threadId, + turnId, + assistantMessageId, + }; + void this.workerClient + .generate({ + requestId, + threadId: command.threadId, + turnId, + assistantMessageId, + model: selectedModel, + dtype, + messages: generationMessages, + maxNewTokens: webgpuOptions?.maxTokens ?? 384, + temperature: webgpuOptions?.temperature ?? 0.7, + topP: webgpuOptions?.topP ?? 0.95, + onDelta: (delta) => + this.appendAssistantDelta(command.threadId, assistantMessageId, turnId, delta, command.commandId), + }) + .then((finalText) => { + if ( + !this.activeGeneration || + this.activeGeneration.threadId !== command.threadId || + this.activeGeneration.turnId !== turnId + ) { + return; + } + this.completeAssistantMessage( + command.threadId, + turnId, + assistantMessageId, + finalText, + command.commandId, + ); + this.activeGeneration = null; + }) + .catch((error: unknown) => { + if ( + !this.activeGeneration || + this.activeGeneration.threadId !== command.threadId || + this.activeGeneration.turnId !== turnId + ) { + return; + } + this.failGeneration( + command.threadId, + turnId, + assistantMessageId, + error instanceof Error ? error : new Error("Local WebGPU generation failed."), + command.commandId, + ); + this.activeGeneration = null; + }); + } + + private interruptLocalTurn( + command: Extract, + ): void { + const thread = this.getThread(command.threadId); + const activeGeneration = + this.activeGeneration && + this.activeGeneration.threadId === command.threadId && + (command.turnId === undefined || this.activeGeneration.turnId === command.turnId) + ? this.activeGeneration + : null; + if (!activeGeneration) { + return; + } + this.workerClient.interrupt(); + const completedAt = command.createdAt; + const messages = thread.messages.map((message) => + message.id === activeGeneration.assistantMessageId + ? { + ...message, + streaming: false, + updatedAt: completedAt, + } + : message, + ); + const updatedThread: OrchestrationThread = { + ...thread, + messages, + latestTurn: { + turnId: activeGeneration.turnId, + state: "interrupted", + requestedAt: thread.latestTurn?.requestedAt ?? completedAt, + startedAt: thread.latestTurn?.startedAt ?? completedAt, + completedAt, + assistantMessageId: activeGeneration.assistantMessageId, + }, + session: this.buildSession(command.threadId, "interrupted", thread.runtimeMode, null, completedAt), + updatedAt: completedAt, + }; + this.setThread(updatedThread); + this.activeGeneration = null; + this.status = { + ...this.status, + phase: "idle", + progress: null, + lastError: null, + }; + this.emitStatus(); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.session-set", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: completedAt, + payload: { + threadId: command.threadId, + session: updatedThread.session!, + }, + }), + ); + } + + private stopLocalSession( + command: Extract, + ): void { + if (this.activeGeneration?.threadId === command.threadId) { + this.interruptLocalTurn({ + type: "thread.turn.interrupt", + commandId: command.commandId, + threadId: command.threadId, + turnId: this.activeGeneration.turnId, + createdAt: command.createdAt, + }); + } + const thread = this.getThread(command.threadId); + const updatedThread: OrchestrationThread = { + ...thread, + session: this.buildSession(command.threadId, "stopped", thread.runtimeMode, null, command.createdAt), + updatedAt: command.createdAt, + }; + this.setThread(updatedThread); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.session-set", + threadId: command.threadId, + commandId: command.commandId, + occurredAt: command.createdAt, + payload: { + threadId: command.threadId, + session: updatedThread.session!, + }, + }), + ); + } + + private deleteLocalThread(command: Extract): void { + if (this.activeGeneration?.threadId === command.threadId) { + this.workerClient.interrupt(); + this.activeGeneration = null; + } + this.threadsById.delete(command.threadId); + this.persist(); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.deleted", + threadId: command.threadId, + commandId: command.commandId, + payload: { + threadId: command.threadId, + deletedAt: nowIso(), + }, + }), + ); + } + + private updateLocalThreadMeta( + command: Extract, + ): void { + const thread = this.getThread(command.threadId); + const updatedThread: OrchestrationThread = { + ...thread, + ...(command.title ? { title: command.title } : {}), + ...(command.model ? { model: command.model } : {}), + ...(command.branch !== undefined ? { branch: command.branch } : {}), + ...(command.worktreePath !== undefined ? { worktreePath: command.worktreePath } : {}), + updatedAt: nowIso(), + }; + this.setThread(updatedThread); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.meta-updated", + threadId: command.threadId, + commandId: command.commandId, + payload: { + threadId: command.threadId, + ...(command.title ? { title: command.title } : {}), + ...(command.model ? { model: command.model } : {}), + ...(command.branch !== undefined ? { branch: command.branch } : {}), + ...(command.worktreePath !== undefined ? { worktreePath: command.worktreePath } : {}), + updatedAt: updatedThread.updatedAt, + }, + }), + ); + } + + private setLocalRuntimeMode( + command: Extract, + ): void { + const thread = this.getThread(command.threadId); + const updatedThread = { + ...thread, + runtimeMode: command.runtimeMode, + updatedAt: command.createdAt, + }; + this.setThread(updatedThread); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.runtime-mode-set", + threadId: command.threadId, + commandId: command.commandId, + payload: { + threadId: command.threadId, + runtimeMode: command.runtimeMode, + updatedAt: command.createdAt, + }, + }), + ); + } + + private setLocalInteractionMode( + command: Extract, + ): void { + const thread = this.getThread(command.threadId); + const updatedThread = { + ...thread, + interactionMode: command.interactionMode, + updatedAt: command.createdAt, + }; + this.setThread(updatedThread); + this.emitDomainEvent( + createThreadEvent({ + type: "thread.interaction-mode-set", + threadId: command.threadId, + commandId: command.commandId, + payload: { + threadId: command.threadId, + interactionMode: command.interactionMode, + updatedAt: command.createdAt, + }, + }), + ); + } +} + +const localWebGpuController = new LocalWebGpuOrchestrationController(); + +export function getLocalWebGpuStatusSnapshot(): LocalWebGpuStatusSnapshot { + return localWebGpuController.getStatusSnapshot(); +} + +export function subscribeLocalWebGpuStatus(listener: () => void): () => void { + return localWebGpuController.subscribeStatus(listener); +} + +export function clearLocalWebGpuState(): Promise { + return localWebGpuController.clearPersistedState(); +} + +export function createHybridNativeApi(baseApi: NativeApi): NativeApi { + const domainEventListeners = new Set<(event: OrchestrationEvent) => void>(); + let snapshotSequence = 0; + + const emitMappedDomainEvent = (event: OrchestrationEvent) => { + snapshotSequence += 1; + const mappedEvent = { + ...event, + sequence: snapshotSequence, + } satisfies OrchestrationEvent; + for (const listener of domainEventListeners) { + listener(mappedEvent); + } + }; + + baseApi.orchestration.onDomainEvent((event) => { + emitMappedDomainEvent(event); + }); + localWebGpuController.subscribeDomainEvents((event) => { + emitMappedDomainEvent(event); + }); + + return { + ...baseApi, + orchestration: { + ...baseApi.orchestration, + getSnapshot: async () => { + const baseSnapshot = await baseApi.orchestration.getSnapshot(); + snapshotSequence = Math.max(snapshotSequence, baseSnapshot.snapshotSequence); + return localWebGpuController.mergeSnapshot(baseSnapshot, snapshotSequence); + }, + dispatchCommand: (command) => localWebGpuController.dispatchCommand(baseApi, command), + onDomainEvent: (callback) => { + domainEventListeners.add(callback); + return () => { + domainEventListeners.delete(callback); + }; + }, + }, + }; +} diff --git a/apps/web/src/nativeApi.ts b/apps/web/src/nativeApi.ts index 40443f67..7e53df1b 100644 --- a/apps/web/src/nativeApi.ts +++ b/apps/web/src/nativeApi.ts @@ -1,5 +1,6 @@ import type { NativeApi } from "@t3tools/contracts"; +import { createHybridNativeApi } from "./localWebGpuOrchestration"; import { createWsNativeApi } from "./wsNativeApi"; let cachedApi: NativeApi | undefined; @@ -8,12 +9,7 @@ export function readNativeApi(): NativeApi | undefined { if (typeof window === "undefined") return undefined; if (cachedApi) return cachedApi; - if (window.nativeApi) { - cachedApi = window.nativeApi; - return cachedApi; - } - - cachedApi = createWsNativeApi(); + cachedApi = createHybridNativeApi(window.nativeApi ?? createWsNativeApi()); return cachedApi; } diff --git a/apps/web/src/routes/_chat.settings.tsx b/apps/web/src/routes/_chat.settings.tsx index b4922c72..4c7e449a 100644 --- a/apps/web/src/routes/_chat.settings.tsx +++ b/apps/web/src/routes/_chat.settings.tsx @@ -1,17 +1,32 @@ import { createFileRoute } from "@tanstack/react-router"; import { useQuery } from "@tanstack/react-query"; -import { useCallback, useState } from "react"; -import { type ProviderKind } from "@t3tools/contracts"; +import { useDebouncedValue } from "@tanstack/react-pacer"; +import { ArrowUpRightIcon, DownloadIcon, HeartIcon, SearchIcon } from "lucide-react"; +import { useCallback, useMemo, useState, useSyncExternalStore } from "react"; +import { + WEBGPU_DTYPE_OPTIONS, + type ProviderKind, + type ServerHuggingFaceModel, + type WebGpuModelDtype, +} from "@t3tools/contracts"; import { getModelOptions, normalizeModelSlug } from "@t3tools/shared/model"; import { MAX_CUSTOM_MODEL_LENGTH, useAppSettings } from "../appSettings"; import { isElectron } from "../env"; import { useTheme } from "../hooks/useTheme"; -import { serverConfigQueryOptions } from "../lib/serverReactQuery"; +import { huggingFaceModelSearchQueryOptions, serverConfigQueryOptions } from "../lib/serverReactQuery"; +import { + clearLocalWebGpuState, + getLocalWebGpuStatusSnapshot, + subscribeLocalWebGpuStatus, +} from "../localWebGpuOrchestration"; import { ensureNativeApi } from "../nativeApi"; import { preferredTerminalEditor } from "../terminal-links"; +import { HUGGING_FACE_BRAND_ASSET_URL } from "../components/Icons"; +import { Badge } from "../components/ui/badge"; import { Button } from "../components/ui/button"; import { Input } from "../components/ui/input"; +import { Spinner } from "../components/ui/spinner"; import { Switch } from "../components/ui/switch"; import { SidebarInset } from "~/components/ui/sidebar"; @@ -33,6 +48,20 @@ const THEME_OPTIONS = [ }, ] as const; +const HUGGING_FACE_QUICK_FILTERS = [ + { label: "Featured", query: "" }, + { label: "Qwen", query: "Qwen instruct" }, + { label: "Coder", query: "coder instruct" }, + { label: "Llama", query: "Llama instruct" }, + { label: "Phi", query: "Phi instruct" }, + { label: "SmolLM", query: "SmolLM instruct" }, +] as const; + +const compactNumberFormatter = new Intl.NumberFormat(undefined, { + notation: "compact", + maximumFractionDigits: 1, +}); + const MODEL_PROVIDER_SETTINGS: Array<{ provider: ProviderKind; title: string; @@ -47,14 +76,55 @@ const MODEL_PROVIDER_SETTINGS: Array<{ placeholder: "your-codex-model-slug", example: "gpt-6.7-codex-ultra-preview", }, + { + provider: "webgpu", + title: "Local WebGPU", + description: "Save additional Hugging Face / ONNX model ids for the local browser adapter.", + placeholder: "onnx-community/your-model-id", + example: "onnx-community/Qwen2.5-0.5B-Instruct", + }, ] as const; +type SaveCustomModelResult = + | { + ok: true; + slug: string; + builtIn: boolean; + alreadySaved: boolean; + added: boolean; + } + | { + ok: false; + error: string; + }; + +function formatCompactMetric(value: number): string { + return compactNumberFormatter.format(Math.max(0, value)); +} + +function huggingFaceModelUrl(modelId: string): string { + return `https://huggingface.co/${modelId + .split("/") + .map((segment) => encodeURIComponent(segment)) + .join("/")}`; +} + +function huggingFaceSearchUrl(query: string): string { + const normalizedQuery = query.trim(); + return normalizedQuery.length > 0 + ? `https://huggingface.co/models?search=${encodeURIComponent(normalizedQuery)}` + : "https://huggingface.co/models?author=onnx-community&pipeline_tag=text-generation"; +} + function getCustomModelsForProvider( settings: ReturnType["settings"], provider: ProviderKind, ) { switch (provider) { case "codex": + return settings.customCodexModels; + case "webgpu": + return settings.customWebGpuModels; default: return settings.customCodexModels; } @@ -66,6 +136,9 @@ function getDefaultCustomModelsForProvider( ) { switch (provider) { case "codex": + return defaults.customCodexModels; + case "webgpu": + return defaults.customWebGpuModels; default: return defaults.customCodexModels; } @@ -74,23 +147,53 @@ function getDefaultCustomModelsForProvider( function patchCustomModels(provider: ProviderKind, models: string[]) { switch (provider) { case "codex": + return { customCodexModels: models }; + case "webgpu": + return { customWebGpuModels: models }; default: return { customCodexModels: models }; } } +function formatLocalWebGpuProgress(progress: { + loaded: number; + total: number | null; + file: string | null; +} | null): string | null { + if (!progress) return null; + if (progress.total && progress.total > 0) { + const percent = Math.max(0, Math.min(100, Math.round((progress.loaded / progress.total) * 100))); + return `${percent}%${progress.file ? ` · ${progress.file}` : ""}`; + } + return progress.file ? `Downloading ${progress.file}` : "Downloading model files"; +} + function SettingsRouteView() { const { theme, setTheme, resolvedTheme } = useTheme(); const { settings, defaults, updateSettings } = useAppSettings(); const serverConfigQuery = useQuery(serverConfigQueryOptions()); + const localWebGpuStatus = useSyncExternalStore( + subscribeLocalWebGpuStatus, + getLocalWebGpuStatusSnapshot, + getLocalWebGpuStatusSnapshot, + ); const [isOpeningKeybindings, setIsOpeningKeybindings] = useState(false); const [openKeybindingsError, setOpenKeybindingsError] = useState(null); + const [isClearingLocalWebGpuState, setIsClearingLocalWebGpuState] = useState(false); + const [localWebGpuActionMessage, setLocalWebGpuActionMessage] = useState(null); + const [huggingFaceModelQuery, setHuggingFaceModelQuery] = useState(""); const [customModelInputByProvider, setCustomModelInputByProvider] = useState< Record >({ codex: "", copilot: "", + webgpu: "", }); + const [debouncedHuggingFaceModelQuery, huggingFaceModelQueryDebouncer] = useDebouncedValue( + huggingFaceModelQuery, + { wait: 350 }, + (debouncerState) => ({ isPending: debouncerState.isPending }), + ); const [customModelErrorByProvider, setCustomModelErrorByProvider] = useState< Partial> >({}); @@ -98,6 +201,36 @@ function SettingsRouteView() { const codexBinaryPath = settings.codexBinaryPath; const codexHomePath = settings.codexHomePath; const keybindingsConfigPath = serverConfigQuery.data?.keybindingsConfigPath ?? null; + const webGpuBuiltInOptions = getModelOptions("webgpu"); + const webGpuModelOptions = useMemo( + () => [ + ...webGpuBuiltInOptions.map((option) => ({ slug: option.slug, name: option.name })), + ...settings.customWebGpuModels + .filter((slug) => !webGpuBuiltInOptions.some((option) => option.slug === slug)) + .map((slug) => ({ slug, name: slug })), + ], + [settings.customWebGpuModels, webGpuBuiltInOptions], + ); + const webGpuBuiltInModelSlugs = useMemo( + () => new Set(webGpuBuiltInOptions.map((option) => option.slug)), + [webGpuBuiltInOptions], + ); + const localWebGpuProgressLabel = formatLocalWebGpuProgress(localWebGpuStatus.progress); + const normalizedHuggingFaceModelQuery = debouncedHuggingFaceModelQuery.trim(); + const huggingFaceModelsQuery = useQuery( + huggingFaceModelSearchQueryOptions({ + query: normalizedHuggingFaceModelQuery.length > 0 ? normalizedHuggingFaceModelQuery : null, + limit: normalizedHuggingFaceModelQuery.length > 0 ? 10 : 8, + }), + ); + const huggingFaceModels = huggingFaceModelsQuery.data?.models ?? []; + const isRefreshingHuggingFaceModels = + huggingFaceModelQueryDebouncer.state.isPending || huggingFaceModelsQuery.isFetching; + const huggingFaceBrowseError = huggingFaceModelsQuery.isError + ? huggingFaceModelsQuery.error instanceof Error + ? huggingFaceModelsQuery.error.message + : "Unable to load Hugging Face models right now." + : null; const openKeybindingsFile = useCallback(() => { if (!keybindingsConfigPath) return; @@ -116,33 +249,60 @@ function SettingsRouteView() { }); }, [keybindingsConfigPath]); - const addCustomModel = useCallback( - (provider: ProviderKind) => { - const customModelInput = customModelInputByProvider[provider]; + const saveCustomModel = useCallback( + (provider: ProviderKind, rawModelSlug: string): SaveCustomModelResult => { const customModels = getCustomModelsForProvider(settings, provider); - const normalized = normalizeModelSlug(customModelInput, provider); + const normalized = normalizeModelSlug(rawModelSlug, provider); if (!normalized) { - setCustomModelErrorByProvider((existing) => ({ - ...existing, - [provider]: "Enter a model slug.", - })); - return; + return { ok: false, error: "Enter a model slug." }; + } + if (normalized.length > MAX_CUSTOM_MODEL_LENGTH) { + return { + ok: false, + error: `Model slugs must be ${MAX_CUSTOM_MODEL_LENGTH} characters or less.`, + }; + } + + const builtIn = getModelOptions(provider).some((option) => option.slug === normalized); + const alreadySaved = customModels.includes(normalized); + if (!builtIn && !alreadySaved) { + updateSettings(patchCustomModels(provider, [...customModels, normalized])); } - if (getModelOptions(provider).some((option) => option.slug === normalized)) { + + setCustomModelErrorByProvider((existing) => ({ + ...existing, + [provider]: null, + })); + + return { + ok: true, + slug: normalized, + builtIn, + alreadySaved, + added: !builtIn && !alreadySaved, + }; + }, + [settings, updateSettings], + ); + + const addCustomModel = useCallback( + (provider: ProviderKind) => { + const result = saveCustomModel(provider, customModelInputByProvider[provider]); + if (!result.ok) { setCustomModelErrorByProvider((existing) => ({ ...existing, - [provider]: "That model is already built in.", + [provider]: result.error, })); return; } - if (normalized.length > MAX_CUSTOM_MODEL_LENGTH) { + if (result.builtIn) { setCustomModelErrorByProvider((existing) => ({ ...existing, - [provider]: `Model slugs must be ${MAX_CUSTOM_MODEL_LENGTH} characters or less.`, + [provider]: "That model is already built in.", })); return; } - if (customModels.includes(normalized)) { + if (result.alreadySaved) { setCustomModelErrorByProvider((existing) => ({ ...existing, [provider]: "That custom model is already saved.", @@ -150,17 +310,12 @@ function SettingsRouteView() { return; } - updateSettings(patchCustomModels(provider, [...customModels, normalized])); setCustomModelInputByProvider((existing) => ({ ...existing, [provider]: "", })); - setCustomModelErrorByProvider((existing) => ({ - ...existing, - [provider]: null, - })); }, - [customModelInputByProvider, settings, updateSettings], + [customModelInputByProvider, saveCustomModel], ); const removeCustomModel = useCallback( @@ -180,6 +335,66 @@ function SettingsRouteView() { [settings, updateSettings], ); + const resetLocalWebGpuState = useCallback(async () => { + setIsClearingLocalWebGpuState(true); + setLocalWebGpuActionMessage(null); + try { + await clearLocalWebGpuState(); + setLocalWebGpuActionMessage("Cleared local WebGPU threads and unloaded the active model."); + } catch (error) { + setLocalWebGpuActionMessage( + error instanceof Error ? error.message : "Unable to clear local WebGPU state.", + ); + } finally { + setIsClearingLocalWebGpuState(false); + } + }, []); + + const openHuggingFaceModelPage = useCallback((modelId: string) => { + const api = ensureNativeApi(); + void api.shell.openExternal(huggingFaceModelUrl(modelId)).catch((error) => { + setLocalWebGpuActionMessage( + error instanceof Error ? error.message : "Unable to open the Hugging Face model page.", + ); + }); + }, []); + + const saveHuggingFaceModel = useCallback( + (model: ServerHuggingFaceModel, options?: { setDefault?: boolean }) => { + const result = saveCustomModel("webgpu", model.id); + if (!result.ok) { + setLocalWebGpuActionMessage(result.error); + return; + } + + const setDefault = options?.setDefault ?? false; + const isAlreadyDefault = settings.webGpuDefaultModel === result.slug; + if (setDefault && !isAlreadyDefault) { + updateSettings({ webGpuDefaultModel: result.slug }); + } + + if (setDefault) { + setLocalWebGpuActionMessage( + isAlreadyDefault + ? `${result.slug} is already the default local WebGPU model.` + : `Set ${result.slug} as the default local WebGPU model.`, + ); + return; + } + + if (result.added) { + setLocalWebGpuActionMessage(`Added ${result.slug} to your local WebGPU models.`); + return; + } + if (result.builtIn) { + setLocalWebGpuActionMessage(`${result.slug} is already available as a built-in local model.`); + return; + } + setLocalWebGpuActionMessage(`${result.slug} is already saved in your custom local models.`); + }, + [saveCustomModel, settings.webGpuDefaultModel, updateSettings], + ); + return (
@@ -434,6 +649,341 @@ function SettingsRouteView() {
+
+
+

Local WebGPU

+

+ Run curated Hugging Face ONNX models in the browser. The first run downloads + model files and may take a while. +

+
+ +
+
+
+

Enable local WebGPU provider

+

+ Disable this to hide the browser-side local model adapter from the picker. +

+
+ updateSettings({ webGpuEnabled: Boolean(checked) })} + /> +
+ +
+ + + +
+ +
+
+ + Status:{" "} + + {localWebGpuStatus.supported ? localWebGpuStatus.phase : "unsupported"} + + + {localWebGpuProgressLabel ? {localWebGpuProgressLabel} : null} +
+

+ {localWebGpuStatus.supportMessage ?? + localWebGpuStatus.lastError ?? + "Use a recent Chromium-based browser with WebGPU enabled for the best results."} +

+ {localWebGpuActionMessage ? ( +

{localWebGpuActionMessage}

+ ) : null} +
+ +
+
+
+
+ + + +
+
+

+ Browse Hugging Face models +

+ + Local WebGPU + +
+

+ Search public text-generation repos filtered for Transformers.js + compatibility, then save them straight into your local model picker. +

+
+
+ + +
+ +
+
+ + +
+ {HUGGING_FACE_QUICK_FILTERS.map((filter) => { + const active = + (filter.query.length === 0 && huggingFaceModelQuery.length === 0) || + huggingFaceModelQuery === filter.query; + return ( + + ); + })} +
+
+
+
+ +
+
+
+ + {huggingFaceModelsQuery.data?.mode === "search" + ? "Search results" + : "Featured picks"} + + + {huggingFaceModels.length} model{huggingFaceModels.length === 1 ? "" : "s"} + {huggingFaceModelsQuery.data?.truncated ? " shown" : ""} + +
+ + {normalizedHuggingFaceModelQuery.length > 0 + ? `Query: ${normalizedHuggingFaceModelQuery}` + : "Showing onnx-community instruct models first"} + +
+ + {huggingFaceBrowseError ? ( +
+ {huggingFaceBrowseError} +
+ ) : null} + + {huggingFaceModels.length > 0 ? ( +
+ {huggingFaceModels.map((model) => { + const isBuiltIn = webGpuBuiltInModelSlugs.has(model.id); + const isSaved = settings.customWebGpuModels.includes(model.id); + const isDefault = settings.webGpuDefaultModel === model.id; + return ( +
+
+
+
+

+ {model.name} +

+ + {model.compatibility === "recommended" + ? "Recommended" + : "Community"} + + {isBuiltIn ? ( + + Built in + + ) : null} + {isSaved ? ( + + Saved + + ) : null} + {isDefault ? ( + + Default + + ) : null} +
+ + {model.id} + +
+ + +
+ +
+ + + {formatCompactMetric(model.downloads)} + + + + {formatCompactMetric(model.likes)} + + + {model.pipelineTag} + + {model.libraryName ? ( + + {model.libraryName} + + ) : null} + {model.license ? ( + + {model.license} + + ) : null} +
+ +
+ + +
+
+ ); + })} +
+ ) : isRefreshingHuggingFaceModels ? ( +
+ + Loading compatible Hugging Face models... +
+ ) : ( +
+ No compatible public text-generation models matched that search. Try a + broader family name like Qwen or Llama. +
+ )} +
+
+ +
+ +

+ This resets locally persisted threads and unloads the current browser model. +

+
+
+
+

Responses

diff --git a/apps/web/src/session-logic.test.ts b/apps/web/src/session-logic.test.ts index 53fa003f..f6402835 100644 --- a/apps/web/src/session-logic.test.ts +++ b/apps/web/src/session-logic.test.ts @@ -770,13 +770,15 @@ describe("deriveActiveWorkStartedAt", () => { }); describe("PROVIDER_OPTIONS", () => { - it("keeps Claude Code and Cursor visible as unavailable placeholders in the stack base", () => { + it("keeps available providers and unavailable placeholders visible in the stack base", () => { const copilot = PROVIDER_OPTIONS.find((option) => option.value === "copilot"); + const webgpu = PROVIDER_OPTIONS.find((option) => option.value === "webgpu"); const claude = PROVIDER_OPTIONS.find((option) => option.value === "claudeCode"); const cursor = PROVIDER_OPTIONS.find((option) => option.value === "cursor"); expect(PROVIDER_OPTIONS).toEqual([ { value: "codex", label: "Codex", available: true }, { value: "copilot", label: "GitHub Copilot", available: true }, + { value: "webgpu", label: "Local WebGPU", available: true }, { value: "claudeCode", label: "Claude Code", available: false }, { value: "cursor", label: "Cursor", available: false }, ]); @@ -785,6 +787,11 @@ describe("PROVIDER_OPTIONS", () => { label: "GitHub Copilot", available: true, }); + expect(webgpu).toEqual({ + value: "webgpu", + label: "Local WebGPU", + available: true, + }); expect(claude).toEqual({ value: "claudeCode", label: "Claude Code", diff --git a/apps/web/src/session-logic.ts b/apps/web/src/session-logic.ts index e938deca..9556936e 100644 --- a/apps/web/src/session-logic.ts +++ b/apps/web/src/session-logic.ts @@ -25,6 +25,7 @@ export const PROVIDER_OPTIONS: Array<{ }> = [ { value: "codex", label: "Codex", available: true }, { value: "copilot", label: "GitHub Copilot", available: true }, + { value: "webgpu", label: "Local WebGPU", available: true }, { value: "claudeCode", label: "Claude Code", available: false }, { value: "cursor", label: "Cursor", available: false }, ]; diff --git a/apps/web/src/store.ts b/apps/web/src/store.ts index 08925af1..28ccd8e4 100644 --- a/apps/web/src/store.ts +++ b/apps/web/src/store.ts @@ -189,7 +189,7 @@ function toLegacySessionStatus( } function toLegacyProvider(providerName: string | null): ProviderKind { - if (providerName === "codex" || providerName === "copilot") { + if (providerName === "codex" || providerName === "copilot" || providerName === "webgpu") { return providerName; } return "codex"; @@ -197,14 +197,23 @@ function toLegacyProvider(providerName: string | null): ProviderKind { const CODEX_MODEL_SLUGS = new Set(getModelOptions("codex").map((option) => option.slug)); const COPILOT_MODEL_SLUGS = new Set(getModelOptions("copilot").map((option) => option.slug)); +const WEBGPU_MODEL_SLUGS = new Set(getModelOptions("webgpu").map((option) => option.slug)); function inferProviderForThreadModel(input: { readonly model: string; readonly sessionProviderName: string | null; }): ProviderKind { - if (input.sessionProviderName === "codex" || input.sessionProviderName === "copilot") { + if ( + input.sessionProviderName === "codex" || + input.sessionProviderName === "copilot" || + input.sessionProviderName === "webgpu" + ) { return input.sessionProviderName; } + const normalizedWebGpu = normalizeModelSlug(input.model, "webgpu"); + if (normalizedWebGpu && WEBGPU_MODEL_SLUGS.has(normalizedWebGpu)) { + return "webgpu"; + } const normalizedCopilot = normalizeModelSlug(input.model, "copilot"); if (normalizedCopilot && COPILOT_MODEL_SLUGS.has(normalizedCopilot)) { return "copilot"; diff --git a/apps/web/src/workers/webgpuInference.worker.ts b/apps/web/src/workers/webgpuInference.worker.ts new file mode 100644 index 00000000..0ca0da82 --- /dev/null +++ b/apps/web/src/workers/webgpuInference.worker.ts @@ -0,0 +1,239 @@ +import type { WebGpuModelDtype } from "@t3tools/contracts"; + +const TRANSFORMERS_JS_CDN_URL = "https://cdn.jsdelivr.net/npm/@huggingface/transformers/+esm"; + +type LocalWebGpuChatMessage = { + role: "user" | "assistant" | "system"; + content: string; +}; + +type WorkerGenerateMessage = { + type: "generate"; + requestId: string; + model: string; + dtype: WebGpuModelDtype; + messages: LocalWebGpuChatMessage[]; + maxNewTokens: number; + temperature: number; + topP: number; +}; + +type WorkerDisposeMessage = { + type: "dispose"; +}; + +type WorkerInboundMessage = WorkerGenerateMessage | WorkerDisposeMessage; + +type WorkerStatusMessage = { + type: "status"; + status: "idle" | "loading-model" | "ready" | "generating" | "error"; + model: string | null; + dtype: WebGpuModelDtype; + message?: string; +}; + +type WorkerProgressMessage = { + type: "download-progress"; + file: string | null; + loaded: number; + total: number | null; +}; + +type WorkerTextDeltaMessage = { + type: "text-delta"; + requestId: string; + delta: string; +}; + +type WorkerCompleteMessage = { + type: "complete"; + requestId: string; + text: string; +}; + +type WorkerErrorMessage = { + type: "error"; + requestId?: string; + message: string; +}; + +type WorkerOutboundMessage = + | WorkerStatusMessage + | WorkerProgressMessage + | WorkerTextDeltaMessage + | WorkerCompleteMessage + | WorkerErrorMessage; + +type TextGenerationResult = + | Array<{ + generated_text?: string | Array<{ role?: string; content?: string }>; + }> + | { + generated_text?: string | Array<{ role?: string; content?: string }>; + }; + +type TextGenerationPipeline = ((messages: LocalWebGpuChatMessage[], options: Record) => Promise) & { + tokenizer: unknown; +}; + +type TransformersModule = { + pipeline: ( + task: "text-generation", + model: string, + options: Record, + ) => Promise; + TextStreamer: new ( + tokenizer: unknown, + options: { + skip_prompt?: boolean; + callback_function?: (text: string) => void; + }, + ) => unknown; +}; + +const workerScope = self as typeof globalThis & { + postMessage: (message: WorkerOutboundMessage) => void; +}; + +let transformersPromise: Promise | null = null; +let activePipeline: TextGenerationPipeline | null = null; +let activeModel: string | null = null; +let activeDtype: WebGpuModelDtype | null = null; + +function postMessageToMain(message: WorkerOutboundMessage): void { + // eslint-disable-next-line unicorn/require-post-message-target-origin -- Dedicated worker messaging does not use targetOrigin. + workerScope.postMessage(message); +} + +async function loadTransformersModule(): Promise { + if (!transformersPromise) { + transformersPromise = import(/* @vite-ignore */ TRANSFORMERS_JS_CDN_URL) as Promise; + } + return transformersPromise; +} + +function extractGeneratedText(result: TextGenerationResult, fallback: string): string { + const first = Array.isArray(result) ? result[0] : result; + const generated = first?.generated_text; + if (typeof generated === "string") { + return generated; + } + if (Array.isArray(generated)) { + const lastMessage = generated.at(-1); + if (lastMessage?.content) { + return lastMessage.content; + } + } + return fallback; +} + +async function ensurePipeline( + model: string, + dtype: WebGpuModelDtype, +): Promise<{ module: TransformersModule; pipeline: TextGenerationPipeline }> { + const module = await loadTransformersModule(); + if (activePipeline && activeModel === model && activeDtype === dtype) { + return { module, pipeline: activePipeline }; + } + postMessageToMain({ + type: "status", + status: "loading-model", + model, + dtype, + }); + activePipeline = await module.pipeline("text-generation", model, { + device: "webgpu", + dtype, + progress_callback: (progress: { file?: string; loaded?: number; total?: number }) => { + postMessageToMain({ + type: "download-progress", + file: progress.file ?? null, + loaded: progress.loaded ?? 0, + total: progress.total ?? null, + }); + }, + }); + activeModel = model; + activeDtype = dtype; + postMessageToMain({ + type: "status", + status: "ready", + model, + dtype, + }); + return { module, pipeline: activePipeline }; +} + +async function handleGenerate(message: WorkerGenerateMessage): Promise { + const { module, pipeline } = await ensurePipeline(message.model, message.dtype); + postMessageToMain({ + type: "status", + status: "generating", + model: message.model, + dtype: message.dtype, + }); + let streamedText = ""; + const streamer = new module.TextStreamer(pipeline.tokenizer, { + skip_prompt: true, + callback_function: (text) => { + if (!text) { + return; + } + streamedText += text; + postMessageToMain({ + type: "text-delta", + requestId: message.requestId, + delta: text, + }); + }, + }); + const result = await pipeline(message.messages, { + max_new_tokens: message.maxNewTokens, + temperature: message.temperature, + top_p: message.topP, + do_sample: message.temperature > 0, + return_full_text: false, + streamer, + }); + postMessageToMain({ + type: "complete", + requestId: message.requestId, + text: extractGeneratedText(result, streamedText), + }); + postMessageToMain({ + type: "status", + status: "ready", + model: message.model, + dtype: message.dtype, + }); +} + +workerScope.addEventListener("message", (event: MessageEvent) => { + const message = event.data; + if (message.type === "dispose") { + activePipeline = null; + activeModel = null; + activeDtype = null; + postMessageToMain({ + type: "status", + status: "idle", + model: null, + dtype: "q4", + }); + return; + } + void handleGenerate(message).catch((error: unknown) => { + postMessageToMain({ + type: "error", + requestId: message.requestId, + message: error instanceof Error ? error.message : "Local WebGPU generation failed.", + }); + postMessageToMain({ + type: "status", + status: "error", + model: message.model, + dtype: message.dtype, + message: error instanceof Error ? error.message : "Local WebGPU generation failed.", + }); + }); +}); diff --git a/apps/web/src/wsNativeApi.ts b/apps/web/src/wsNativeApi.ts index 5f2bfcc8..e1306cee 100644 --- a/apps/web/src/wsNativeApi.ts +++ b/apps/web/src/wsNativeApi.ts @@ -188,6 +188,8 @@ export function createWsNativeApi(): NativeApi { server: { getConfig: () => transport.request(WS_METHODS.serverGetConfig), upsertKeybinding: (input) => transport.request(WS_METHODS.serverUpsertKeybinding, input), + searchHuggingFaceModels: (input) => + transport.request(WS_METHODS.serverSearchHuggingFaceModels, input), }, orchestration: { getSnapshot: () => transport.request(ORCHESTRATION_WS_METHODS.getSnapshot), diff --git a/packages/contracts/src/ipc.ts b/packages/contracts/src/ipc.ts index 4c93970f..84357ec7 100644 --- a/packages/contracts/src/ipc.ts +++ b/packages/contracts/src/ipc.ts @@ -24,7 +24,11 @@ import type { ProjectWriteFileInput, ProjectWriteFileResult, } from "./project"; -import type { ServerConfig } from "./server"; +import type { + ServerConfig, + ServerHuggingFaceModelSearchInput, + ServerHuggingFaceModelSearchResult, +} from "./server"; import type { TerminalClearInput, TerminalCloseInput, @@ -156,6 +160,9 @@ export interface NativeApi { server: { getConfig: () => Promise; upsertKeybinding: (input: ServerUpsertKeybindingInput) => Promise; + searchHuggingFaceModels: ( + input: ServerHuggingFaceModelSearchInput, + ) => Promise; }; orchestration: { getSnapshot: () => Promise; diff --git a/packages/contracts/src/model.ts b/packages/contracts/src/model.ts index af7fb956..46dbd9be 100644 --- a/packages/contracts/src/model.ts +++ b/packages/contracts/src/model.ts @@ -1,4 +1,5 @@ import { Schema } from "effect"; +import { NonNegativeInt } from "./baseSchemas"; import { ProviderKind } from "./orchestration"; export const CODEX_REASONING_EFFORT_OPTIONS = ["xhigh", "high", "medium", "low"] as const; @@ -13,10 +14,20 @@ export const CopilotModelOptions = Schema.Struct({ reasoningEffort: Schema.optional(Schema.Literals(CODEX_REASONING_EFFORT_OPTIONS)), }); export type CopilotModelOptions = typeof CopilotModelOptions.Type; +export const WEBGPU_DTYPE_OPTIONS = ["q4", "q8", "fp16", "fp32"] as const; +export type WebGpuModelDtype = (typeof WEBGPU_DTYPE_OPTIONS)[number]; +export const WebGpuModelOptions = Schema.Struct({ + temperature: Schema.optional(Schema.Number), + topP: Schema.optional(Schema.Number), + maxTokens: Schema.optional(NonNegativeInt), + dtype: Schema.optional(Schema.Literals(WEBGPU_DTYPE_OPTIONS)), +}); +export type WebGpuModelOptions = typeof WebGpuModelOptions.Type; export const ProviderModelOptions = Schema.Struct({ codex: Schema.optional(CodexModelOptions), copilot: Schema.optional(CopilotModelOptions), + webgpu: Schema.optional(WebGpuModelOptions), }); export type ProviderModelOptions = typeof ProviderModelOptions.Type; @@ -53,6 +64,10 @@ export const MODEL_OPTIONS_BY_PROVIDER = { { slug: "gpt-5-mini", name: "GPT-5 mini" }, { slug: "gpt-4.1", name: "GPT-4.1" }, ], + webgpu: [ + { slug: "onnx-community/Qwen2.5-0.5B-Instruct", name: "Qwen 2.5 0.5B Instruct" }, + { slug: "onnx-community/SmolLM2-360M-Instruct", name: "SmolLM2 360M Instruct" }, + ], } as const satisfies Record; export type ModelOptionsByProvider = typeof MODEL_OPTIONS_BY_PROVIDER; @@ -62,6 +77,7 @@ export type ModelSlug = BuiltInModelSlug | (string & {}); export const DEFAULT_MODEL_BY_PROVIDER = { codex: "gpt-5.4", copilot: "claude-sonnet-4.6", + webgpu: "onnx-community/Qwen2.5-0.5B-Instruct", } as const satisfies Record; export const MODEL_SLUG_ALIASES_BY_PROVIDER = { @@ -88,14 +104,22 @@ export const MODEL_SLUG_ALIASES_BY_PROVIDER = { opus: "claude-opus-4.6", gemini: "gemini-3-pro-preview", }, + webgpu: { + qwen: "onnx-community/Qwen2.5-0.5B-Instruct", + "qwen-0.5b": "onnx-community/Qwen2.5-0.5B-Instruct", + smollm: "onnx-community/SmolLM2-360M-Instruct", + "smollm-360m": "onnx-community/SmolLM2-360M-Instruct", + }, } as const satisfies Record>; export const REASONING_EFFORT_OPTIONS_BY_PROVIDER = { codex: CODEX_REASONING_EFFORT_OPTIONS, copilot: [], + webgpu: [], } as const satisfies Record; export const DEFAULT_REASONING_EFFORT_BY_PROVIDER = { codex: "high", copilot: null, + webgpu: null, } as const satisfies Record; diff --git a/packages/contracts/src/orchestration.ts b/packages/contracts/src/orchestration.ts index bd35dcb9..30ec31d5 100644 --- a/packages/contracts/src/orchestration.ts +++ b/packages/contracts/src/orchestration.ts @@ -27,7 +27,11 @@ export const ORCHESTRATION_WS_CHANNELS = { domainEvent: "orchestration.domainEvent", } as const; -export const ProviderKind = Schema.Union([Schema.Literal("codex"), Schema.Literal("copilot")]); +export const ProviderKind = Schema.Union([ + Schema.Literal("codex"), + Schema.Literal("copilot"), + Schema.Literal("webgpu"), +]); export type ProviderKind = typeof ProviderKind.Type; export const ProviderApprovalPolicy = Schema.Literals([ "untrusted", diff --git a/packages/contracts/src/server.ts b/packages/contracts/src/server.ts index ec2e9be8..f492eb2e 100644 --- a/packages/contracts/src/server.ts +++ b/packages/contracts/src/server.ts @@ -1,5 +1,10 @@ import { Schema } from "effect"; -import { IsoDateTime, NonNegativeInt, TrimmedNonEmptyString } from "./baseSchemas"; +import { + IsoDateTime, + NonNegativeInt, + PositiveInt, + TrimmedNonEmptyString, +} from "./baseSchemas"; import { KeybindingRule, ResolvedKeybindingsConfig } from "./keybindings"; import { EditorId } from "./editor"; import { ProviderKind } from "./orchestration"; @@ -97,3 +102,43 @@ export const ServerConfigUpdatedPayload = Schema.Struct({ providers: ServerProviderStatuses, }); export type ServerConfigUpdatedPayload = typeof ServerConfigUpdatedPayload.Type; + +const SERVER_HUGGING_FACE_MODEL_SEARCH_QUERY_MAX_LENGTH = 120; +const SERVER_HUGGING_FACE_MODEL_SEARCH_MAX_LIMIT = 24; + +export const ServerHuggingFaceModelSearchMode = Schema.Literals(["featured", "search"]); +export type ServerHuggingFaceModelSearchMode = typeof ServerHuggingFaceModelSearchMode.Type; + +export const ServerHuggingFaceModelCompatibility = Schema.Literals(["recommended", "community"]); +export type ServerHuggingFaceModelCompatibility = typeof ServerHuggingFaceModelCompatibility.Type; + +export const ServerHuggingFaceModelSearchInput = Schema.Struct({ + query: Schema.optional( + TrimmedNonEmptyString.check(Schema.isMaxLength(SERVER_HUGGING_FACE_MODEL_SEARCH_QUERY_MAX_LENGTH)), + ), + limit: Schema.optional( + PositiveInt.check(Schema.isLessThanOrEqualTo(SERVER_HUGGING_FACE_MODEL_SEARCH_MAX_LIMIT)), + ), +}); +export type ServerHuggingFaceModelSearchInput = typeof ServerHuggingFaceModelSearchInput.Type; + +export const ServerHuggingFaceModel = Schema.Struct({ + id: TrimmedNonEmptyString, + author: TrimmedNonEmptyString, + name: TrimmedNonEmptyString, + downloads: NonNegativeInt, + likes: NonNegativeInt, + pipelineTag: TrimmedNonEmptyString, + libraryName: Schema.optional(TrimmedNonEmptyString), + license: Schema.optional(TrimmedNonEmptyString), + compatibility: ServerHuggingFaceModelCompatibility, +}); +export type ServerHuggingFaceModel = typeof ServerHuggingFaceModel.Type; + +export const ServerHuggingFaceModelSearchResult = Schema.Struct({ + mode: ServerHuggingFaceModelSearchMode, + query: Schema.optional(TrimmedNonEmptyString), + models: Schema.Array(ServerHuggingFaceModel), + truncated: Schema.Boolean, +}); +export type ServerHuggingFaceModelSearchResult = typeof ServerHuggingFaceModelSearchResult.Type; diff --git a/packages/contracts/src/ws.ts b/packages/contracts/src/ws.ts index 9814cae8..894bcf9d 100644 --- a/packages/contracts/src/ws.ts +++ b/packages/contracts/src/ws.ts @@ -33,6 +33,7 @@ import { import { KeybindingRule } from "./keybindings"; import { ProjectSearchEntriesInput, ProjectWriteFileInput } from "./project"; import { OpenInEditorInput } from "./editor"; +import { ServerHuggingFaceModelSearchInput } from "./server"; // ── WebSocket RPC Method Names ─────────────────────────────────────── @@ -71,6 +72,7 @@ export const WS_METHODS = { // Server meta serverGetConfig: "server.getConfig", serverUpsertKeybinding: "server.upsertKeybinding", + serverSearchHuggingFaceModels: "server.searchHuggingFaceModels", } as const; // ── Push Event Channels ────────────────────────────────────────────── @@ -135,6 +137,7 @@ const WebSocketRequestBody = Schema.Union([ // Server meta tagRequestBody(WS_METHODS.serverGetConfig, Schema.Struct({})), tagRequestBody(WS_METHODS.serverUpsertKeybinding, KeybindingRule), + tagRequestBody(WS_METHODS.serverSearchHuggingFaceModels, ServerHuggingFaceModelSearchInput), ]); export const WebSocketRequest = Schema.Struct({ diff --git a/packages/shared/src/model.test.ts b/packages/shared/src/model.test.ts index 8771a24c..636bf423 100644 --- a/packages/shared/src/model.test.ts +++ b/packages/shared/src/model.test.ts @@ -54,6 +54,12 @@ describe("resolveModelSlug", () => { expect(getDefaultModel()).toBe(DEFAULT_MODEL_BY_PROVIDER.codex); expect(getModelOptions()).toEqual(MODEL_OPTIONS_BY_PROVIDER.codex); }); + + it("resolves built-in webgpu models", () => { + expect(resolveModelSlug("onnx-community/Qwen2.5-0.5B-Instruct", "webgpu")).toBe( + "onnx-community/Qwen2.5-0.5B-Instruct", + ); + }); }); describe("getReasoningEffortOptions", () => { diff --git a/packages/shared/src/model.ts b/packages/shared/src/model.ts index 1c058d76..0333d4f6 100644 --- a/packages/shared/src/model.ts +++ b/packages/shared/src/model.ts @@ -13,6 +13,7 @@ type CatalogProvider = keyof typeof MODEL_OPTIONS_BY_PROVIDER; const MODEL_SLUG_SET_BY_PROVIDER: Record> = { codex: new Set(MODEL_OPTIONS_BY_PROVIDER.codex.map((option) => option.slug)), copilot: new Set(MODEL_OPTIONS_BY_PROVIDER.copilot.map((option) => option.slug)), + webgpu: new Set(MODEL_OPTIONS_BY_PROVIDER.webgpu.map((option) => option.slug)), }; export function getModelOptions(provider: ProviderKind = "codex") {