-
Notifications
You must be signed in to change notification settings - Fork 2
Fix the code generation and custom decoders running pipelines #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9a9fd6b
4518c34
c2fe37a
0ce48d7
f04ea61
761c6db
b7b815c
a013668
130faf3
c8ab382
fba1a10
738b452
ba1ca5a
2ec04fe
b6a1903
6b3d52a
c7dc1e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,81 +1,151 @@ | ||
| // Decoder execution wrapper | ||
| /** | ||
| * Unified Decoder Execution | ||
| * | ||
| * Execution paths: | ||
| * 1. Built-in TFJS (tfjsModelType) → Web Worker inference | ||
| * 2. Custom TFJS code → Web Worker (same as built-in) | ||
| * 3. Simple JS code → Direct execution (baselines) | ||
| * | ||
| * Custom TensorFlow.js code is executed in the Web Worker | ||
| * where the full tf namespace (including tf.train) is available. | ||
| */ | ||
|
|
||
| import type { DecoderInput, DecoderOutput, Decoder } from '../types/decoders'; | ||
| import { PERFORMANCE_THRESHOLDS } from '../utils/constants'; | ||
| import { tfWorker, getWorkerModelType } from './tfWorkerManager'; | ||
|
|
||
| // Spike history for temporal models | ||
| // Spike history for temporal models (LSTM, Attention) | ||
| const spikeHistory: number[][] = []; | ||
| const MAX_HISTORY = 10; | ||
|
|
||
| // Cache compiled decoder functions to avoid recompiling on every call | ||
| const compiledDecoders = new Map<string, (input: DecoderInput) => { x: number; y: number; vx?: number; vy?: number; confidence?: number }>(); | ||
| // Cache for compiled JS functions (simple decoders, not TFJS) | ||
| type JSDecoderFn = (input: DecoderInput) => { x: number; y: number; vx?: number; vy?: number; confidence?: number }; | ||
| const jsFunctions = new Map<string, JSDecoderFn>(); | ||
|
|
||
| function getCompiledDecoder(decoder: Decoder): (input: DecoderInput) => { x: number; y: number; vx?: number; vy?: number; confidence?: number } { | ||
| const cacheKey = `${decoder.id}:${decoder.code}`; | ||
|
|
||
| if (!compiledDecoders.has(cacheKey)) { | ||
| console.log(`[Decoder] Compiling decoder: ${decoder.name}`); | ||
| const fn = new Function('input', decoder.code!) as (input: DecoderInput) => { x: number; y: number; vx?: number; vy?: number; confidence?: number }; | ||
| compiledDecoders.set(cacheKey, fn); | ||
| // Track which custom TFJS models have been loaded into the worker | ||
| const customModelsLoaded = new Set<string>(); | ||
|
|
||
| /** | ||
| * Check if code is TensorFlow.js model creation code | ||
| */ | ||
| function isTFJSModelCode(code: string): boolean { | ||
| return code.includes('tf.sequential') || | ||
| code.includes('tf.model') || | ||
| code.includes('tf.layers'); | ||
| } | ||
|
|
||
| /** | ||
| * Get or compile a JS decoder function (cached) | ||
| */ | ||
| function getOrCompileJSDecoder(decoder: Decoder): JSDecoderFn { | ||
| if (!jsFunctions.has(decoder.id)) { | ||
| console.log(`[Decoder] Compiling JS: ${decoder.name}`); | ||
| const fn = new Function('input', decoder.code!) as JSDecoderFn; | ||
| jsFunctions.set(decoder.id, fn); | ||
| } | ||
|
|
||
| return compiledDecoders.get(cacheKey)!; | ||
| return jsFunctions.get(decoder.id)!; | ||
| } | ||
|
|
||
| // Clear cache when decoder changes | ||
| /** | ||
| * Clear decoder cache - call when decoder is updated or removed | ||
| */ | ||
| export function clearDecoderCache(decoderId?: string) { | ||
| if (decoderId) { | ||
| for (const key of compiledDecoders.keys()) { | ||
| if (key.startsWith(decoderId + ':')) { | ||
| compiledDecoders.delete(key); | ||
| } | ||
| jsFunctions.delete(decoderId); | ||
| // Also dispose from worker if it was a custom TFJS model | ||
| if (customModelsLoaded.has(decoderId)) { | ||
| tfWorker.disposeModel(decoderId); | ||
| customModelsLoaded.delete(decoderId); | ||
| } | ||
| } else { | ||
| compiledDecoders.clear(); | ||
| jsFunctions.clear(); | ||
| // Dispose all custom models from worker | ||
| for (const id of customModelsLoaded) { | ||
| tfWorker.disposeModel(id); | ||
| } | ||
| customModelsLoaded.clear(); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Execute a JavaScript decoder (synchronous) | ||
| * Execute a custom TensorFlow.js model decoder via Web Worker | ||
| * Same execution path as built-in TFJS models - code runs in worker | ||
| */ | ||
| async function executeCustomTFJSDecoder( | ||
| decoder: Decoder, | ||
| input: DecoderInput | ||
| ): Promise<DecoderOutput> { | ||
| const startTime = performance.now(); | ||
|
|
||
| // Create model in worker if not already loaded | ||
| if (!customModelsLoaded.has(decoder.id)) { | ||
| await tfWorker.createModelFromCode(decoder.id, decoder.code!); | ||
| customModelsLoaded.add(decoder.id); | ||
| } | ||
|
Comment on lines
+80
to
+84
|
||
|
|
||
| // Run inference via worker (same as built-in models) | ||
| const output = await tfWorker.infer(decoder.id, [...input.spikes]); | ||
| const latency = performance.now() - startTime; | ||
|
|
||
| // Scale output to velocity (same as built-in) | ||
| const VELOCITY_SCALE = 50; | ||
| const vx = output[0] * VELOCITY_SCALE; | ||
| const vy = output[1] * VELOCITY_SCALE; | ||
|
|
||
| // Calculate new position | ||
| const DT = 0.025; | ||
| const x = input.kinematics.x + vx * DT; | ||
| const y = input.kinematics.y + vy * DT; | ||
|
|
||
| return { | ||
| x: Math.max(-100, Math.min(100, x)), | ||
| y: Math.max(-100, Math.min(100, y)), | ||
| vx, | ||
| vy, | ||
| latency, | ||
| }; | ||
| } | ||
|
|
||
| /** | ||
| * Execute a simple JavaScript decoder (baselines + custom JS) | ||
| */ | ||
| export function executeJSDecoder( | ||
| function executeSimpleJSDecoder( | ||
| decoder: Decoder, | ||
| input: DecoderInput | ||
| ): DecoderOutput { | ||
| const startTime = performance.now(); | ||
|
|
||
| try { | ||
| // Use cached compiled function | ||
| const decoderFunction = getCompiledDecoder(decoder); | ||
| const result = decoderFunction(input); | ||
|
|
||
| const latency = performance.now() - startTime; | ||
| const decoderFn = getOrCompileJSDecoder(decoder); | ||
| const result = decoderFn(input); | ||
|
|
||
| const latency = performance.now() - startTime; | ||
|
|
||
| // Enforce timeout | ||
| if (latency > PERFORMANCE_THRESHOLDS.DECODER_TIMEOUT_MS) { | ||
| console.warn(`[Decoder] ${decoder.name} exceeded timeout: ${latency.toFixed(2)}ms`); | ||
| } | ||
| if (latency > PERFORMANCE_THRESHOLDS.DECODER_TIMEOUT_MS) { | ||
| console.warn(`[Decoder] ${decoder.name} exceeded timeout: ${latency.toFixed(2)}ms`); | ||
| } | ||
|
|
||
| return { | ||
| x: result.x, | ||
| y: result.y, | ||
| vx: result.vx, | ||
| vy: result.vy, | ||
| confidence: result.confidence, | ||
| latency, | ||
| }; | ||
| } catch (error) { | ||
| console.error(`[Decoder] JS execution error in ${decoder.name}:`, error); | ||
|
|
||
| // Return current position as fallback | ||
| return { | ||
| x: input.kinematics.x, | ||
| y: input.kinematics.y, | ||
| vx: input.kinematics.vx, | ||
| vy: input.kinematics.vy, | ||
| latency: performance.now() - startTime, | ||
| }; | ||
| return { | ||
| x: result.x, | ||
| y: result.y, | ||
| vx: result.vx, | ||
| vy: result.vy, | ||
| confidence: result.confidence, | ||
| latency, | ||
| }; | ||
| } | ||
|
Comment on lines
+112
to
+135
|
||
|
|
||
| /** | ||
| * Execute a code-based decoder (auto-detects type) | ||
| */ | ||
| async function executeCodeDecoder( | ||
| decoder: Decoder, | ||
| input: DecoderInput | ||
| ): Promise<DecoderOutput> { | ||
| // Detect if this is TensorFlow.js model code or simple JS | ||
| if (isTFJSModelCode(decoder.code!)) { | ||
| return executeCustomTFJSDecoder(decoder, input); | ||
| } else { | ||
| return executeSimpleJSDecoder(decoder, input); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -93,14 +163,9 @@ export async function executeTFJSDecoder( | |
| const workerType = getWorkerModelType(modelType); | ||
|
|
||
| if (!workerType) { | ||
| console.warn(`[Decoder] Unknown TFJS model type: ${decoder.tfjsModelType}`); | ||
| return { | ||
| x: input.kinematics.x, | ||
| y: input.kinematics.y, | ||
| vx: input.kinematics.vx, | ||
| vy: input.kinematics.vy, | ||
| latency: 0, | ||
| }; | ||
| throw new Error( | ||
| `[Decoder] Unknown TFJS model type "${decoder.tfjsModelType}" for decoder "${decoder.name}"` | ||
| ); | ||
| } | ||
|
|
||
| // Prepare input based on model type | ||
|
|
@@ -124,7 +189,7 @@ export async function executeTFJSDecoder( | |
| workerInput = [...input.spikes]; | ||
| } | ||
|
|
||
| // Run inference in worker | ||
| // Run inference in worker (off main thread) | ||
| const output = await tfWorker.infer(workerType, workerInput); | ||
| const latency = performance.now() - startTime; | ||
|
|
||
|
|
@@ -146,37 +211,40 @@ export async function executeTFJSDecoder( | |
| latency, | ||
| }; | ||
| } catch (error) { | ||
| console.error(`[Decoder] TFJS Worker execution error in ${decoder.name}:`, error); | ||
|
|
||
| return { | ||
| x: input.kinematics.x, | ||
| y: input.kinematics.y, | ||
| vx: input.kinematics.vx, | ||
| vy: input.kinematics.vy, | ||
| latency: performance.now() - startTime, | ||
| }; | ||
| // Never silently corrupt data - rethrow with context | ||
| throw new Error( | ||
| `[Decoder] TFJS Worker failed for "${decoder.name}": ${error instanceof Error ? error.message : String(error)}` | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Execute any decoder (routes to appropriate executor) | ||
| * Execute any decoder - unified routing | ||
| * | ||
| * Execution paths: | ||
| * 1. Has code → Auto-detect: TFJS model code or simple JS | ||
| * 2. Has tfjsModelType → Web Worker inference (built-in TFJS models) | ||
| * | ||
| * THROWS on invalid decoder configuration - never silently corrupts data | ||
| */ | ||
| export async function executeDecoder( | ||
| decoder: Decoder, | ||
| input: DecoderInput | ||
| ): Promise<DecoderOutput> { | ||
| if (decoder.type === 'tfjs') { | ||
| // Code-based decoders (custom + baselines) | ||
| // Auto-detects if code creates a TF model or is simple JS | ||
| if (decoder.code) { | ||
| return executeCodeDecoder(decoder, input); | ||
| } | ||
|
|
||
| // Built-in TFJS decoders (Web Worker, non-blocking) | ||
| if (decoder.tfjsModelType) { | ||
| return executeTFJSDecoder(decoder, input); | ||
| } else if (decoder.type === 'javascript' && decoder.code) { | ||
| return executeJSDecoder(decoder, input); | ||
| } | ||
|
|
||
| // Fallback - passthrough | ||
| return { | ||
| x: input.kinematics.x, | ||
| y: input.kinematics.y, | ||
| vx: input.kinematics.vx, | ||
| vy: input.kinematics.vy, | ||
| latency: 0, | ||
| }; | ||
| // Invalid decoder configuration - fail hard, never silently corrupt data | ||
| throw new Error( | ||
| `[Decoder] Invalid decoder configuration for "${decoder.name}" (id: ${decoder.id}). ` + | ||
| `Decoder must have either 'code' (JavaScript/TFJS) or 'tfjsModelType' (built-in TFJS).` | ||
| ); | ||
| } | ||
|
Comment on lines
230
to
250
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TFJS model detection logic is fragile and may produce false positives. Code that contains 'tf.sequential' or similar strings in comments, string literals, or template strings will be incorrectly classified as TFJS code. Consider using a more robust detection method, such as checking for function signatures or attempting to parse the code structure.