Skip to content

Commit 488d9ad

Browse files
authored
Move Q² WASM kernel from main thread to worker (issue #76) (#77)
1 parent e6b5fc1 commit 488d9ad

6 files changed

Lines changed: 274 additions & 85 deletions

File tree

DESIGN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ $$f_{\text{shell}}(n, \varepsilon) = 1 - (1-\varepsilon)^n$$
7474
For any fixed $\varepsilon > 0$, $f_{\text{shell}} \to 1$ as $n \to \infty$. The shell
7575
thickness required to capture fraction $f$ is:
7676

77-
$$\varepsilon^{*}(f, n) = 1 - (1-f)^{1/n} \approx \frac{-\ln(1-f)}{n}$$
77+
$$\varepsilon^{\ast}(f, n) = 1 - (1-f)^{1/n} \approx \frac{-\ln(1-f)}{n}$$
7878

7979
| Fraction captured | Shell thickness |
8080
|:-----------------:|:---------------:|

bun.lock

Lines changed: 5 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/app.ts

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,13 @@ import type {
2525
ChatMessage,
2626
GenerationConfig,
2727
EmbeddingMsg,
28+
ModelOutputsMsg,
29+
Q2Msg,
2830
} from './types.js';
2931
import {
30-
getKernel,
3132
l2Normalise,
3233
q2EncodeDirect,
3334
q2KeyDirect,
34-
DTYPE_TO_Q2,
35-
Q2_DTYPE_FP32,
36-
Q2_INPUT_OFFSET,
37-
Q2_OUTPUT_OFFSET,
3835
} from './q2.js';
3936
import {
4037
deleteStoredFile,
@@ -681,6 +678,12 @@ export function handleWorkerMessage(msg: WorkerOutMsg): void {
681678
case 'embedding':
682679
onEmbedding(msg);
683680
break;
681+
case 'model-outputs':
682+
onModelOutputs(msg);
683+
break;
684+
case 'q2':
685+
onQ2(msg);
686+
break;
684687
case 'done':
685688
onDone();
686689
break;
@@ -828,55 +831,48 @@ export function onEmbedding(msg: EmbeddingMsg): void {
828831
`Shape: [${seqLen} × ${hiddenDim}] dtype=${dtype} stats=unavailable`;
829832
}
830833

831-
// ── Q² kernel ────────────────────────────────────────────────────────────
832-
// Run the quaternary quantisation in the background. The WASM kernel is
833-
// preferred; if instantiation fails (e.g. in test environments that lack
834-
// WebAssembly.instantiate) we fall back to the pure-TS implementation.
835-
const n = hiddenDim;
836-
const dtypeId = DTYPE_TO_Q2[dtype] ?? Q2_DTYPE_FP32;
834+
}
837835

838-
if (seqLen < 1) {
839-
appLog('warn', 'Q² embedding: seqLen < 1; skipping quantisation', { seqLen });
840-
return;
836+
/**
837+
* Handles the compact Q² quantisation result sent by the worker kernel.
838+
*
839+
* The worker runs the Q² WASM kernel before sending, so only packed bytes
840+
* and the 64-bit key cross the thread boundary (see worker.ts quantiseAndSend).
841+
*/
842+
/**
843+
* Shows the user which ONNX output nodes the loaded model exports and whether
844+
* Q² fingerprinting was able to locate a hidden-state tensor among them.
845+
*
846+
* Called once per generation turn, immediately after the embedding forward
847+
* pass in the worker. Surfaced in the embedding panel so the user knows
848+
* exactly why Q² may be unavailable and what the model actually exports.
849+
*/
850+
export function onModelOutputs(msg: ModelOutputsMsg): void {
851+
appLog('info', 'onModelOutputs received', msg);
852+
embeddingPanel.classList.remove('hidden');
853+
854+
// Format each output as name[d0×d1×…] for compact display.
855+
const outputList = Object.entries(msg.outputs)
856+
.map(([name, dims]) => `${name}[${dims.join('×')}]`)
857+
.join(' ');
858+
859+
if (msg.hiddenStateKey !== null) {
860+
embeddingStats.textContent =
861+
`ONNX outputs: ${outputList}\n` +
862+
`Q² using: ${msg.hiddenStateKey}[${(msg.outputs[msg.hiddenStateKey] ?? []).join('×')}]`;
863+
} else {
864+
embeddingStats.textContent =
865+
`ONNX outputs: ${outputList}\n` +
866+
`Q² unavailable — no 3-D hidden-state output found.\n` +
867+
`To enable Q² fingerprinting, re-export the model with a last_hidden_state ` +
868+
`(or equivalent) output node, or use a model that already exports one.`;
841869
}
870+
}
842871

843-
appLog('debug', 'onEmbedding: starting Q² kernel', { hiddenDim: n, dtypeId, seqLen });
844-
void (async () => {
845-
try {
846-
const kernel = await getKernel();
847-
const mem = new Uint8Array(kernel.memory.buffer);
848-
849-
// Copy the raw activation buffer into WASM memory at the input offset.
850-
const inputBytes = new Uint8Array(msg.data);
851-
mem.set(inputBytes, Q2_INPUT_OFFSET);
852-
853-
// Run quantisation: L2-normalise last token, threshold, Gray-encode.
854-
kernel.quantise(Q2_INPUT_OFFSET, seqLen, n, dtypeId, Q2_OUTPUT_OFFSET);
855-
856-
// Derive the 64-bit transition key.
857-
const rawKey = kernel.key(Q2_OUTPUT_OFFSET, n);
858-
const key = BigInt.asUintN(64, rawKey);
859-
860-
appLog('debug', 'Q² WASM kernel produced key', { key: `0x${key.toString(16).padStart(16, '0')}`, hiddenDim: n });
861-
// Read back packed bytes.
862-
const packed = new Uint8Array(kernel.memory.buffer, Q2_OUTPUT_OFFSET, n >> 2);
863-
renderQ2Result(packed, key, n, currentSettings.q2KeyDisplayMode);
864-
} catch {
865-
// WASM unavailable — use the pure-TypeScript fallback (fp32 only).
866-
// This path is taken in test environments and SSR contexts.
867-
// For sub-fp32 dtypes the WASM kernel is required; log a warning and skip.
868-
if (dtype !== 'fp32') {
869-
appLog('warn', 'Q² TS fallback: non-fp32 dtype requires WASM kernel; skipping', { dtype });
870-
return;
871-
}
872-
appLog('debug', 'Q² falling back to TS implementation', { seqLen, hiddenDim: n });
873-
const all = new Float32Array(msg.data);
874-
const vec = l2Normalise(all.subarray((seqLen - 1) * n, seqLen * n), n);
875-
const { packed, key } = q2EncodeDirect(vec, n);
876-
appLog('debug', 'Q² TS fallback produced key', { key: `0x${BigInt.asUintN(64, key).toString(16).padStart(16, '0')}`, hiddenDim: n });
877-
renderQ2Result(packed, BigInt.asUintN(64, key), n, currentSettings.q2KeyDisplayMode);
878-
}
879-
})();
872+
export function onQ2(msg: Q2Msg): void {
873+
const packed = new Uint8Array(msg.packed);
874+
appLog('debug', 'onQ2 received', { n: msg.n, key: `0x${msg.key.toString(16).padStart(16, '0')}` });
875+
renderQ2Result(packed, msg.key, msg.n, currentSettings.q2KeyDisplayMode);
880876
}
881877

882878
export function onDone(): void {

src/types.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,47 @@ export interface EmbeddingMsg {
101101
dtype: 'fp32' | 'fp16' | 'q8' | 'q4' | 'q2';
102102
}
103103

104+
/**
105+
* Sent once per generation turn immediately after the embedding forward pass,
106+
* regardless of whether a usable hidden-state output was found.
107+
*
108+
* Lets the main thread show the user exactly which ONNX output nodes the
109+
* loaded model exposes and explain why Q² fingerprinting may be unavailable.
110+
*/
111+
export interface ModelOutputsMsg {
112+
type: 'model-outputs';
113+
/**
114+
* Every output node the model's ONNX session exposes.
115+
* Key: node name. Value: dimension array, e.g. [1, 42, 4096].
116+
*/
117+
outputs: Record<string, number[]>;
118+
/**
119+
* The output node name that was selected for Q² quantisation,
120+
* or null when no suitable hidden-state tensor was found.
121+
*/
122+
hiddenStateKey: string | null;
123+
}
124+
125+
/**
126+
* Q² quantisation result produced by the worker kernel.
127+
*
128+
* The worker runs the Q² WASM kernel immediately after extracting an embedding,
129+
* so only the compact quantised representation crosses the thread boundary
130+
* instead of the raw activation buffer (~64× smaller for fp32 n=4096).
131+
*/
132+
export interface Q2Msg {
133+
type: 'q2';
134+
/**
135+
* n/4 packed Gray-encoded bytes (transferable ArrayBuffer).
136+
* Transfer via postMessage(msg, [packed]) to avoid structured-clone copy.
137+
*/
138+
packed: ArrayBuffer;
139+
/** 64-bit MSB-aligned transition key (DESIGN.md §2.2). */
140+
key: bigint;
141+
/** Original embedding dimension (n). */
142+
n: number;
143+
}
144+
104145
export interface DoneMsg {
105146
type: 'done';
106147
}
@@ -115,6 +156,8 @@ export type WorkerOutMsg =
115156
| ProgressMsg
116157
| TokenMsg
117158
| EmbeddingMsg
159+
| ModelOutputsMsg
160+
| Q2Msg
118161
| DoneMsg
119162
| ErrorMsg;
120163

0 commit comments

Comments
 (0)