diff --git a/packages/transformers/docs/source/guides/dtypes.md b/packages/transformers/docs/source/guides/dtypes.md index 156df0fe1..b3e24727c 100644 --- a/packages/transformers/docs/source/guides/dtypes.md +++ b/packages/transformers/docs/source/guides/dtypes.md @@ -38,6 +38,31 @@ const output = await generator(messages, { max_new_tokens: 128 }); console.log(output[0].generated_text.at(-1).content); ``` +## Detecting available dtypes + +Not sure which quantizations a model offers? Use `ModelRegistry.get_available_dtypes()` to probe the repository and find out: + +```js +import { ModelRegistry } from "@huggingface/transformers"; + +const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/all-MiniLM-L6-v2-ONNX"); +console.log(dtypes); // e.g., [ 'fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4' ] +``` + +This checks which ONNX files exist on the Hugging Face Hub for each dtype. For multi-session models (e.g., encoder-decoder), a dtype is only listed if **all** required session files are present. + +You can use this to build UIs that let users pick a quantization level, or to automatically select the smallest available dtype: + +```js +const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/Qwen3-0.6B-ONNX"); + +// Pick the smallest available quantization, falling back to fp32 +const preferred = ["q4", "q8", "fp16", "fp32"]; +const dtype = preferred.find((d) => dtypes.includes(d)) ?? "fp32"; + +const generator = await pipeline("text-generation", "onnx-community/Qwen3-0.6B-ONNX", { dtype }); +``` + ## Per-module dtypes Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype. diff --git a/packages/transformers/src/utils/model_registry/ModelRegistry.js b/packages/transformers/src/utils/model_registry/ModelRegistry.js index bf392f1df..1179a92f6 100644 --- a/packages/transformers/src/utils/model_registry/ModelRegistry.js +++ b/packages/transformers/src/utils/model_registry/ModelRegistry.js @@ -3,6 +3,7 @@ * * Provides static methods for: * - Discovering which files a model needs + * - Detecting available quantization levels (dtypes) * - Getting file metadata * - Checking cache status * @@ -35,6 +36,16 @@ * console.log(processorFiles); // [ ] * ``` * + * **Example:** Detect available quantization levels for a model + * ```javascript + * const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/all-MiniLM-L6-v2-ONNX"); + * console.log(dtypes); // [ 'fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4' ] + * + * // Use the result to pick the best available dtype + * const preferredDtype = dtypes.includes("q4") ? "q4" : "fp32"; + * const files = await ModelRegistry.get_files("onnx-community/all-MiniLM-L6-v2-ONNX", { dtype: preferredDtype }); + * ``` + * * **Example:** Check file metadata without downloading * ```javascript * const metadata = await ModelRegistry.get_file_metadata( @@ -98,6 +109,7 @@ import { get_processor_files } from './get_processor_files.js'; import { is_cached, is_cached_files, is_pipeline_cached, is_pipeline_cached_files } from './is_cached.js'; import { get_file_metadata } from './get_file_metadata.js'; import { clear_cache, clear_pipeline_cache } from './clear_cache.js'; +import { get_available_dtypes } from './get_available_dtypes.js'; /** * Static class for cache and file management operations. @@ -193,6 +205,30 @@ export class ModelRegistry { return get_processor_files(modelId); } + /** + * Detects which quantization levels (dtypes) are available for a model + * by checking which ONNX files exist on the hub or locally. + * + * A dtype is considered available if all required model session files + * exist for that dtype. + * + * @param {string} modelId - The model id (e.g., "onnx-community/all-MiniLM-L6-v2-ONNX") + * @param {Object} [options] - Optional parameters + * @param {import('../../configs.js').PretrainedConfig} [options.config=null] - Pre-loaded config + * @param {string} [options.model_file_name=null] - Override the model file name (excluding .onnx suffix) + * @param {string} [options.revision='main'] - Model revision + * @param {string} [options.cache_dir=null] - Custom cache directory + * @param {boolean} [options.local_files_only=false] - Only check local files + * @returns {Promise} Array of available dtype strings (e.g., ['fp32', 'fp16', 'q4', 'q8']) + * + * @example + * const dtypes = await ModelRegistry.get_available_dtypes('onnx-community/all-MiniLM-L6-v2-ONNX'); + * console.log(dtypes); // ['fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4'] + */ + static async get_available_dtypes(modelId, options = {}) { + return get_available_dtypes(modelId, options); + } + /** * Quickly checks if a model is fully cached by verifying `config.json` is present, * then confirming all required files are cached. diff --git a/packages/transformers/src/utils/model_registry/get_available_dtypes.js b/packages/transformers/src/utils/model_registry/get_available_dtypes.js new file mode 100644 index 000000000..c0b77bdf9 --- /dev/null +++ b/packages/transformers/src/utils/model_registry/get_available_dtypes.js @@ -0,0 +1,68 @@ +import { getSessionsConfig } from '../../models/modeling_utils.js'; +import { DEFAULT_DTYPE_SUFFIX_MAPPING } from '../dtypes.js'; +import { get_file_metadata } from './get_file_metadata.js'; +import { get_config } from './get_model_files.js'; +import { resolve_model_type } from './resolve_model_type.js'; + +/** + * @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig + */ + +/** + * The dtypes to probe for availability (excludes 'auto' which is not a concrete dtype). + * @type {string[]} + */ +const CONCRETE_DTYPES = Object.keys(DEFAULT_DTYPE_SUFFIX_MAPPING); + +/** + * Detects which quantization levels (dtypes) are available for a model + * by checking which ONNX files exist on the hub or locally. + * + * A dtype is considered available if *all* required model session files + * exist for that dtype. For example, a Seq2Seq model needs both an encoder + * and decoder file — the dtype is only listed if both are present. + * + * @param {string} modelId The model id (e.g., "onnx-community/all-MiniLM-L6-v2-ONNX") + * @param {Object} [options] Optional parameters + * @param {PretrainedConfig} [options.config=null] Pre-loaded model config (optional, will be fetched if not provided) + * @param {string} [options.model_file_name=null] Override the model file name (excluding .onnx suffix) + * @param {string} [options.revision='main'] Model revision + * @param {string} [options.cache_dir=null] Custom cache directory + * @param {boolean} [options.local_files_only=false] Only check local files + * @returns {Promise} Array of available dtype strings (e.g., ['fp32', 'fp16', 'q4', 'q8']) + */ +export async function get_available_dtypes( + modelId, + { config = null, model_file_name = null, revision = 'main', cache_dir = null, local_files_only = false } = {}, +) { + config = await get_config(modelId, { config, cache_dir, local_files_only, revision }); + + const subfolder = 'onnx'; + + const modelType = resolve_model_type(config); + + const { sessions } = getSessionsConfig(modelType, config, { model_file_name }); + + // Get all base names for model session files + const baseNames = Object.values(sessions); + + // For each dtype, check if all session files exist + const metadataOptions = { revision, cache_dir, local_files_only }; + + // Probe all (dtype, baseName) combinations in parallel + const probeResults = await Promise.all( + CONCRETE_DTYPES.map(async (dtype) => { + const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[dtype] ?? ''; + const allExist = await Promise.all( + baseNames.map(async (baseName) => { + const filename = `${subfolder}/${baseName}${suffix}.onnx`; + const metadata = await get_file_metadata(modelId, filename, metadataOptions); + return metadata.exists; + }), + ); + return { dtype, available: allExist.every(Boolean) }; + }), + ); + + return probeResults.filter((r) => r.available).map((r) => r.dtype); +} diff --git a/packages/transformers/src/utils/model_registry/get_model_files.js b/packages/transformers/src/utils/model_registry/get_model_files.js index afa81254b..72c4731bb 100644 --- a/packages/transformers/src/utils/model_registry/get_model_files.js +++ b/packages/transformers/src/utils/model_registry/get_model_files.js @@ -1,16 +1,10 @@ import { DEFAULT_DTYPE_SUFFIX_MAPPING, selectDtype } from '../dtypes.js'; import { selectDevice } from '../devices.js'; import { resolveExternalDataFormat, getExternalDataChunkNames } from '../model-loader.js'; -import { - MODEL_TYPES, - MODEL_TYPE_MAPPING, - MODEL_MAPPING_NAMES, - getSessionsConfig, -} from '../../models/modeling_utils.js'; +import { getSessionsConfig } from '../../models/modeling_utils.js'; import { AutoConfig } from '../../configs.js'; -import { GITHUB_ISSUE_URL } from '../constants.js'; -import { logger } from '../logger.js'; import { memoizePromise } from '../memoize_promise.js'; +import { resolve_model_type } from './resolve_model_type.js'; /** * @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig @@ -32,7 +26,10 @@ import { memoizePromise } from '../memoize_promise.js'; * @param {string} [options.revision='main'] Git branch, tag, or commit SHA. * @returns {Promise} */ -function get_config(modelId, { config = null, cache_dir = null, local_files_only = false, revision = 'main' } = {}) { +export function get_config( + modelId, + { config = null, cache_dir = null, local_files_only = false, revision = 'main' } = {}, +) { // When a pre-loaded config is provided, skip memoization — no fetch occurs // and there is no meaningful key to deduplicate on. if (config !== null) { @@ -78,57 +75,7 @@ export async function get_model_files( let dtype = overrideDtype ?? custom_config.dtype; // Infer model type from config - let modelType; - - // @ts-ignore - architectures is set via Object.assign in PretrainedConfig constructor - const architectures = /** @type {string[]} */ (config.architectures || []); - - // Try to find a known architecture in MODEL_TYPE_MAPPING - // This ensures we use the same logic as from_pretrained() - let foundInMapping = false; - for (const arch of architectures) { - const mappedType = MODEL_TYPE_MAPPING.get(arch); - if (mappedType !== undefined) { - modelType = mappedType; - foundInMapping = true; - break; - } - } - - // If not found by architecture, try model_type (handles custom models with no architectures) - if (!foundInMapping && config.model_type) { - const mappedType = MODEL_TYPE_MAPPING.get(config.model_type); - if (mappedType !== undefined) { - modelType = mappedType; - foundInMapping = true; - } - - if (!foundInMapping) { - // As a last resort, map model_type based on MODEL_MAPPING_NAMES - for (const mapping of Object.values(MODEL_MAPPING_NAMES)) { - if (mapping.has(config.model_type)) { - modelType = MODEL_TYPE_MAPPING.get(mapping.get(config.model_type)); - foundInMapping = true; - break; - } - } - } - } - - // Fall back to EncoderOnly if not found in mapping - if (!foundInMapping) { - const archList = architectures.length > 0 ? architectures.join(', ') : '(none)'; - logger.warn( - `[get_model_files] Architecture(s) not found in MODEL_TYPE_MAPPING: [${archList}] ` + - `for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` + - `If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`, - ); - - // Always fallback to EncoderOnly (single model.onnx file) - // Other model types (Vision2Seq, Musicgen, etc.) require specific file structures - // and should be properly registered in MODEL_TYPE_MAPPING if they are valid. - modelType = MODEL_TYPES.EncoderOnly; - } + const modelType = resolve_model_type(config); const add_model_file = (fileName, baseName = null) => { baseName = baseName ?? fileName; diff --git a/packages/transformers/src/utils/model_registry/resolve_model_type.js b/packages/transformers/src/utils/model_registry/resolve_model_type.js new file mode 100644 index 000000000..5dcee62d7 --- /dev/null +++ b/packages/transformers/src/utils/model_registry/resolve_model_type.js @@ -0,0 +1,66 @@ +import { MODEL_MAPPING_NAMES, MODEL_TYPES, MODEL_TYPE_MAPPING } from '../../models/modeling_utils.js'; +import { GITHUB_ISSUE_URL } from '../constants.js'; +import { logger } from '../logger.js'; + +/** + * @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig + */ + +/** + * Resolves the model type (e.g., EncoderOnly, DecoderOnly, Seq2Seq, …) from a + * model config by checking architectures and model_type against the known + * MODEL_TYPE_MAPPING. + * + * Resolution order: + * 1. `config.architectures` entries looked up in MODEL_TYPE_MAPPING + * 2. `config.model_type` looked up directly in MODEL_TYPE_MAPPING + * 3. `config.model_type` looked up via MODEL_MAPPING_NAMES → architecture → MODEL_TYPE_MAPPING + * 4. Fallback to `MODEL_TYPES.EncoderOnly` + * + * @param {PretrainedConfig} config The model config object. + * @param {{ warn?: boolean }} [options] Set `warn` to false to suppress the + * fallback warning (defaults to true). + * @returns {number} One of the MODEL_TYPES enum values. + */ +export function resolve_model_type(config, { warn = true } = {}) { + // @ts-ignore - architectures is set via Object.assign in PretrainedConfig constructor + const architectures = /** @type {string[]} */ (config.architectures || []); + + // 1. Try architectures against MODEL_TYPE_MAPPING + for (const arch of architectures) { + const mappedType = MODEL_TYPE_MAPPING.get(arch); + if (mappedType !== undefined) { + return mappedType; + } + } + + // 2. Try config.model_type directly + if (config.model_type) { + const mappedType = MODEL_TYPE_MAPPING.get(config.model_type); + if (mappedType !== undefined) { + return mappedType; + } + + // 3. Try MODEL_MAPPING_NAMES as a last resort + for (const mapping of Object.values(MODEL_MAPPING_NAMES)) { + if (mapping.has(config.model_type)) { + const resolved = MODEL_TYPE_MAPPING.get(mapping.get(config.model_type)); + if (resolved !== undefined) { + return resolved; + } + } + } + } + + // 4. Fallback + if (warn) { + const archList = architectures.length > 0 ? architectures.join(', ') : '(none)'; + logger.warn( + `[resolve_model_type] Architecture(s) not found in MODEL_TYPE_MAPPING: [${archList}] ` + + `for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` + + `If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`, + ); + } + + return MODEL_TYPES.EncoderOnly; +} diff --git a/packages/transformers/tests/utils/model_registry.test.js b/packages/transformers/tests/utils/model_registry.test.js new file mode 100644 index 000000000..cb9e6408a --- /dev/null +++ b/packages/transformers/tests/utils/model_registry.test.js @@ -0,0 +1,183 @@ +import { jest } from "@jest/globals"; + +// Mock get_file_metadata before importing the module under test +const mockGetFileMetadata = jest.fn(); +jest.unstable_mockModule("../../src/utils/model_registry/get_file_metadata.js", () => ({ + get_file_metadata: mockGetFileMetadata, +})); + +// Import registry to populate MODEL_TYPE_MAPPING (side-effect import) +await import("../../src/models/registry.js"); + +// Dynamic import after mock setup (required for ESM) +const { get_available_dtypes } = await import("../../src/utils/model_registry/get_available_dtypes.js"); + +// A minimal config that mimics a BERT-like encoder-only model +const ENCODER_ONLY_CONFIG = { + architectures: ["BertModel"], + model_type: "bert", +}; + +// A minimal config for a decoder-only (causal LM) model +const DECODER_ONLY_CONFIG = { + architectures: ["LlamaForCausalLM"], + model_type: "llama", +}; + +// A minimal config for a Seq2Seq model (encoder + decoder) +const SEQ2SEQ_CONFIG = { + architectures: ["T5ForConditionalGeneration"], + model_type: "t5", +}; + +// A config with an unknown architecture (falls back to EncoderOnly) +const UNKNOWN_ARCH_CONFIG = { + architectures: ["SomeUnknownArchitecture"], + model_type: "unknown_type", +}; + +/** + * Helper: given a set of files that "exist", returns a mock implementation + * for get_file_metadata that resolves { exists: true } for those files. + * @param {string[]} existingFiles + */ +function setupExistingFiles(...existingFiles) { + mockGetFileMetadata.mockImplementation((_modelId, filename, _options) => { + return Promise.resolve({ + exists: existingFiles.includes(filename), + fromCache: false, + }); + }); +} + +describe("get_available_dtypes", () => { + beforeEach(() => { + mockGetFileMetadata.mockReset(); + }); + + it("should detect fp32 and q4 for an encoder-only model", async () => { + setupExistingFiles( + "onnx/model.onnx", // fp32 + "onnx/model_q4.onnx", // q4 + ); + + const dtypes = await get_available_dtypes("test/model", { config: ENCODER_ONLY_CONFIG }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).toContain("q4"); + expect(dtypes).not.toContain("fp16"); + expect(dtypes).not.toContain("q8"); + expect(dtypes).not.toContain("int8"); + }); + + it("should detect all dtypes when all files exist", async () => { + setupExistingFiles( + "onnx/model.onnx", // fp32 + "onnx/model_fp16.onnx", // fp16 + "onnx/model_int8.onnx", // int8 + "onnx/model_uint8.onnx", // uint8 + "onnx/model_quantized.onnx", // q8 + "onnx/model_q4.onnx", // q4 + "onnx/model_q4f16.onnx", // q4f16 + "onnx/model_bnb4.onnx", // bnb4 + ); + + const dtypes = await get_available_dtypes("test/model", { config: ENCODER_ONLY_CONFIG }); + + expect(dtypes).toEqual(["fp32", "fp16", "int8", "uint8", "q8", "q4", "q4f16", "bnb4"]); + }); + + it("should return empty array when no ONNX files exist", async () => { + setupExistingFiles(); + const dtypes = await get_available_dtypes("test/model", { config: ENCODER_ONLY_CONFIG }); + + expect(dtypes).toEqual([]); + }); + + it("should require all session files for seq2seq models", async () => { + // Only encoder has q4, decoder does not — q4 should NOT be available + setupExistingFiles( + "onnx/encoder_model.onnx", // fp32 encoder + "onnx/decoder_model_merged.onnx", // fp32 decoder + "onnx/encoder_model_q4.onnx", // q4 encoder (but no q4 decoder) + ); + + const dtypes = await get_available_dtypes("test/model", { config: SEQ2SEQ_CONFIG }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).not.toContain("q4"); + }); + + it("should list dtype only when all session files exist for seq2seq", async () => { + // Both encoder and decoder have fp32 and q8 + setupExistingFiles("onnx/encoder_model.onnx", "onnx/decoder_model_merged.onnx", "onnx/encoder_model_quantized.onnx", "onnx/decoder_model_merged_quantized.onnx"); + + const dtypes = await get_available_dtypes("test/model", { config: SEQ2SEQ_CONFIG }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).toContain("q8"); + expect(dtypes).not.toContain("fp16"); + expect(dtypes).not.toContain("q4"); + }); + + it("should handle decoder-only models", async () => { + setupExistingFiles("onnx/model.onnx", "onnx/model_q4.onnx", "onnx/model_q4f16.onnx"); + + const dtypes = await get_available_dtypes("test/model", { config: DECODER_ONLY_CONFIG }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).toContain("q4"); + expect(dtypes).toContain("q4f16"); + expect(dtypes).toHaveLength(3); + }); + + it("should fall back to EncoderOnly for unknown architectures", async () => { + setupExistingFiles("onnx/model.onnx", "onnx/model_fp16.onnx"); + + const dtypes = await get_available_dtypes("test/model", { config: UNKNOWN_ARCH_CONFIG }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).toContain("fp16"); + expect(dtypes).toHaveLength(2); + }); + + it("should support custom model_file_name", async () => { + setupExistingFiles("onnx/custom_model.onnx", "onnx/custom_model_q4.onnx"); + + const dtypes = await get_available_dtypes("test/model", { + config: ENCODER_ONLY_CONFIG, + model_file_name: "custom_model", + }); + + expect(dtypes).toContain("fp32"); + expect(dtypes).toContain("q4"); + expect(dtypes).not.toContain("fp16"); + }); + + it("should pass revision and cache_dir to get_file_metadata", async () => { + setupExistingFiles("onnx/model.onnx"); + + await get_available_dtypes("test/model", { + config: ENCODER_ONLY_CONFIG, + revision: "v2", + cache_dir: "/tmp/cache", + }); + + // Verify that metadata calls received the correct options + for (const call of mockGetFileMetadata.mock.calls) { + expect(call[0]).toBe("test/model"); + expect(call[2]).toMatchObject({ revision: "v2", cache_dir: "/tmp/cache" }); + } + }); + + it("should only return valid dtype strings", async () => { + setupExistingFiles("onnx/model.onnx", "onnx/model_fp16.onnx"); + + const dtypes = await get_available_dtypes("test/model", { config: ENCODER_ONLY_CONFIG }); + + const validDtypes = ["fp32", "fp16", "int8", "uint8", "q8", "q4", "q4f16", "bnb4"]; + for (const dtype of dtypes) { + expect(validDtypes).toContain(dtype); + } + }); +});