diff --git a/src/providers/embedding/local.ts b/src/providers/embedding/local.ts index ad6c2d214..9f74f33c7 100644 --- a/src/providers/embedding/local.ts +++ b/src/providers/embedding/local.ts @@ -1,4 +1,5 @@ import type { EmbeddingProvider } from "../../types.js"; +import { getEnvVar } from "../../config.js"; type Pipeline = ( task: string, @@ -10,11 +11,57 @@ type Pipeline = ( ) => Promise<{ tolist: () => number[][] }> >; +/** 已知模型的嵌入维度映射表(未列出的模型需通过 OPENAI_EMBEDDING_DIMENSIONS 指定) */ +const KNOWN_DIMS: Record = { + // MiniLM 系列(英文) + "Xenova/all-MiniLM-L6-v2": 384, + // BGE 中文系列 + "Xenova/bge-large-zh-v1.5": 1024, + "Xenova/bge-base-zh-v1.5": 768, + "Xenova/bge-small-zh-v1.5": 512, + // BGE 多语言系列 + "Xenova/bge-m3": 1024, + // 多语言 MiniLM + "Xenova/paraphrase-multilingual-MiniLM-L12-v2": 384, + // E5 多语言系列 + "Xenova/multilingual-e5-large": 1024, + "Xenova/multilingual-e5-base": 768, + "Xenova/multilingual-e5-small": 384, +}; + +const DEFAULT_MODEL = "Xenova/all-MiniLM-L6-v2"; +const DEFAULT_DIMS = 384; + +function resolveDimensions( + modelName: string, + override: string | undefined, +): number { + if (override !== undefined && override.trim().length > 0) { + const parsed = parseInt(override, 10); + if (!Number.isFinite(parsed) || parsed <= 0) { + throw new Error( + `OPENAI_EMBEDDING_DIMENSIONS must be a positive integer, got: ${override}`, + ); + } + return parsed; + } + return KNOWN_DIMS[modelName] ?? DEFAULT_DIMS; +} + export class LocalEmbeddingProvider implements EmbeddingProvider { readonly name = "local"; - readonly dimensions = 384; + readonly dimensions: number; + private modelName: string; private extractor: Awaited> | null = null; + constructor() { + this.modelName = getEnvVar("EMBEDDING_MODEL") || DEFAULT_MODEL; + this.dimensions = resolveDimensions( + this.modelName, + getEnvVar("OPENAI_EMBEDDING_DIMENSIONS"), + ); + } + async embed(text: string): Promise { const [result] = await this.embedBatch([text]); return result; @@ -45,7 +92,8 @@ export class LocalEmbeddingProvider implements EmbeddingProvider { this.extractor = await transformers.pipeline( "feature-extraction", - "Xenova/all-MiniLM-L6-v2", + this.modelName, + { local_files_only: true, quantized: false }, ); return this.extractor; }