Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions packages/transformers/docs/source/guides/dtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ const output = await generator(messages, { max_new_tokens: 128 });
console.log(output[0].generated_text.at(-1).content);
```

## Detecting available dtypes

Not sure which quantizations a model offers? Use `ModelRegistry.get_available_dtypes()` to probe the repository and find out:

```js
import { ModelRegistry } from "@huggingface/transformers";

const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/all-MiniLM-L6-v2-ONNX");
console.log(dtypes); // e.g., [ 'fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4' ]
```

This checks which ONNX files exist on the Hugging Face Hub for each dtype. For multi-session models (e.g., encoder-decoder), a dtype is only listed if **all** required session files are present.

You can use this to build UIs that let users pick a quantization level, or to automatically select the smallest available dtype:

```js
const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/Qwen3-0.6B-ONNX");

// Pick the smallest available quantization, falling back to fp32
const preferred = ["q4", "q8", "fp16", "fp32"];
const dtype = preferred.find((d) => dtypes.includes(d)) ?? "fp32";

const generator = await pipeline("text-generation", "onnx-community/Qwen3-0.6B-ONNX", { dtype });
```

## Per-module dtypes

Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype.
Expand Down
36 changes: 36 additions & 0 deletions packages/transformers/src/utils/model_registry/ModelRegistry.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*
* Provides static methods for:
* - Discovering which files a model needs
* - Detecting available quantization levels (dtypes)
* - Getting file metadata
* - Checking cache status
*
Expand Down Expand Up @@ -35,6 +36,16 @@
* console.log(processorFiles); // [ ]
* ```
*
* **Example:** Detect available quantization levels for a model
* ```javascript
* const dtypes = await ModelRegistry.get_available_dtypes("onnx-community/all-MiniLM-L6-v2-ONNX");
* console.log(dtypes); // [ 'fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4' ]
*
* // Use the result to pick the best available dtype
* const preferredDtype = dtypes.includes("q4") ? "q4" : "fp32";
* const files = await ModelRegistry.get_files("onnx-community/all-MiniLM-L6-v2-ONNX", { dtype: preferredDtype });
* ```
*
* **Example:** Check file metadata without downloading
* ```javascript
* const metadata = await ModelRegistry.get_file_metadata(
Expand Down Expand Up @@ -98,6 +109,7 @@ import { get_processor_files } from './get_processor_files.js';
import { is_cached, is_cached_files, is_pipeline_cached, is_pipeline_cached_files } from './is_cached.js';
import { get_file_metadata } from './get_file_metadata.js';
import { clear_cache, clear_pipeline_cache } from './clear_cache.js';
import { get_available_dtypes } from './get_available_dtypes.js';

/**
* Static class for cache and file management operations.
Expand Down Expand Up @@ -193,6 +205,30 @@ export class ModelRegistry {
return get_processor_files(modelId);
}

/**
* Detects which quantization levels (dtypes) are available for a model
* by checking which ONNX files exist on the hub or locally.
*
* A dtype is considered available if all required model session files
* exist for that dtype.
*
* @param {string} modelId - The model id (e.g., "onnx-community/all-MiniLM-L6-v2-ONNX")
* @param {Object} [options] - Optional parameters
* @param {import('../../configs.js').PretrainedConfig} [options.config=null] - Pre-loaded config
* @param {string} [options.model_file_name=null] - Override the model file name (excluding .onnx suffix)
* @param {string} [options.revision='main'] - Model revision
* @param {string} [options.cache_dir=null] - Custom cache directory
* @param {boolean} [options.local_files_only=false] - Only check local files
* @returns {Promise<string[]>} Array of available dtype strings (e.g., ['fp32', 'fp16', 'q4', 'q8'])
*
* @example
* const dtypes = await ModelRegistry.get_available_dtypes('onnx-community/all-MiniLM-L6-v2-ONNX');
* console.log(dtypes); // ['fp32', 'fp16', 'int8', 'uint8', 'q8', 'q4']
*/
static async get_available_dtypes(modelId, options = {}) {
return get_available_dtypes(modelId, options);
}

/**
* Quickly checks if a model is fully cached by verifying `config.json` is present,
* then confirming all required files are cached.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import { getSessionsConfig } from '../../models/modeling_utils.js';
import { DEFAULT_DTYPE_SUFFIX_MAPPING } from '../dtypes.js';
import { get_file_metadata } from './get_file_metadata.js';
import { get_config } from './get_model_files.js';
import { resolve_model_type } from './resolve_model_type.js';

/**
* @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig
*/

/**
* The dtypes to probe for availability (excludes 'auto' which is not a concrete dtype).
* @type {string[]}
*/
const CONCRETE_DTYPES = Object.keys(DEFAULT_DTYPE_SUFFIX_MAPPING);

/**
* Detects which quantization levels (dtypes) are available for a model
* by checking which ONNX files exist on the hub or locally.
*
* A dtype is considered available if *all* required model session files
* exist for that dtype. For example, a Seq2Seq model needs both an encoder
* and decoder file — the dtype is only listed if both are present.
*
* @param {string} modelId The model id (e.g., "onnx-community/all-MiniLM-L6-v2-ONNX")
* @param {Object} [options] Optional parameters
* @param {PretrainedConfig} [options.config=null] Pre-loaded model config (optional, will be fetched if not provided)
* @param {string} [options.model_file_name=null] Override the model file name (excluding .onnx suffix)
* @param {string} [options.revision='main'] Model revision
* @param {string} [options.cache_dir=null] Custom cache directory
* @param {boolean} [options.local_files_only=false] Only check local files
* @returns {Promise<string[]>} Array of available dtype strings (e.g., ['fp32', 'fp16', 'q4', 'q8'])
*/
export async function get_available_dtypes(
modelId,
{ config = null, model_file_name = null, revision = 'main', cache_dir = null, local_files_only = false } = {},
) {
config = await get_config(modelId, { config, cache_dir, local_files_only, revision });

const subfolder = 'onnx';

const modelType = resolve_model_type(config);

const { sessions } = getSessionsConfig(modelType, config, { model_file_name });

// Get all base names for model session files
const baseNames = Object.values(sessions);

// For each dtype, check if all session files exist
const metadataOptions = { revision, cache_dir, local_files_only };

// Probe all (dtype, baseName) combinations in parallel
const probeResults = await Promise.all(
CONCRETE_DTYPES.map(async (dtype) => {
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[dtype] ?? '';
const allExist = await Promise.all(
baseNames.map(async (baseName) => {
const filename = `${subfolder}/${baseName}${suffix}.onnx`;
const metadata = await get_file_metadata(modelId, filename, metadataOptions);
return metadata.exists;
}),
);
return { dtype, available: allExist.every(Boolean) };
}),
);

return probeResults.filter((r) => r.available).map((r) => r.dtype);
}
67 changes: 7 additions & 60 deletions packages/transformers/src/utils/model_registry/get_model_files.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import { DEFAULT_DTYPE_SUFFIX_MAPPING, selectDtype } from '../dtypes.js';
import { selectDevice } from '../devices.js';
import { resolveExternalDataFormat, getExternalDataChunkNames } from '../model-loader.js';
import {
MODEL_TYPES,
MODEL_TYPE_MAPPING,
MODEL_MAPPING_NAMES,
getSessionsConfig,
} from '../../models/modeling_utils.js';
import { getSessionsConfig } from '../../models/modeling_utils.js';
import { AutoConfig } from '../../configs.js';
import { GITHUB_ISSUE_URL } from '../constants.js';
import { logger } from '../logger.js';
import { memoizePromise } from '../memoize_promise.js';
import { resolve_model_type } from './resolve_model_type.js';

/**
* @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig
Expand All @@ -32,7 +26,10 @@ import { memoizePromise } from '../memoize_promise.js';
* @param {string} [options.revision='main'] Git branch, tag, or commit SHA.
* @returns {Promise<PretrainedConfig>}
*/
function get_config(modelId, { config = null, cache_dir = null, local_files_only = false, revision = 'main' } = {}) {
export function get_config(
modelId,
{ config = null, cache_dir = null, local_files_only = false, revision = 'main' } = {},
) {
// When a pre-loaded config is provided, skip memoization — no fetch occurs
// and there is no meaningful key to deduplicate on.
if (config !== null) {
Expand Down Expand Up @@ -78,57 +75,7 @@ export async function get_model_files(
let dtype = overrideDtype ?? custom_config.dtype;

// Infer model type from config
let modelType;

// @ts-ignore - architectures is set via Object.assign in PretrainedConfig constructor
const architectures = /** @type {string[]} */ (config.architectures || []);

// Try to find a known architecture in MODEL_TYPE_MAPPING
// This ensures we use the same logic as from_pretrained()
let foundInMapping = false;
for (const arch of architectures) {
const mappedType = MODEL_TYPE_MAPPING.get(arch);
if (mappedType !== undefined) {
modelType = mappedType;
foundInMapping = true;
break;
}
}

// If not found by architecture, try model_type (handles custom models with no architectures)
if (!foundInMapping && config.model_type) {
const mappedType = MODEL_TYPE_MAPPING.get(config.model_type);
if (mappedType !== undefined) {
modelType = mappedType;
foundInMapping = true;
}

if (!foundInMapping) {
// As a last resort, map model_type based on MODEL_MAPPING_NAMES
for (const mapping of Object.values(MODEL_MAPPING_NAMES)) {
if (mapping.has(config.model_type)) {
modelType = MODEL_TYPE_MAPPING.get(mapping.get(config.model_type));
foundInMapping = true;
break;
}
}
}
}

// Fall back to EncoderOnly if not found in mapping
if (!foundInMapping) {
const archList = architectures.length > 0 ? architectures.join(', ') : '(none)';
logger.warn(
`[get_model_files] Architecture(s) not found in MODEL_TYPE_MAPPING: [${archList}] ` +
`for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` +
`If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`,
);

// Always fallback to EncoderOnly (single model.onnx file)
// Other model types (Vision2Seq, Musicgen, etc.) require specific file structures
// and should be properly registered in MODEL_TYPE_MAPPING if they are valid.
modelType = MODEL_TYPES.EncoderOnly;
}
const modelType = resolve_model_type(config);

const add_model_file = (fileName, baseName = null) => {
baseName = baseName ?? fileName;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import { MODEL_MAPPING_NAMES, MODEL_TYPES, MODEL_TYPE_MAPPING } from '../../models/modeling_utils.js';
import { GITHUB_ISSUE_URL } from '../constants.js';
import { logger } from '../logger.js';

/**
* @typedef {import('../../configs.js').PretrainedConfig} PretrainedConfig
*/

/**
* Resolves the model type (e.g., EncoderOnly, DecoderOnly, Seq2Seq, …) from a
* model config by checking architectures and model_type against the known
* MODEL_TYPE_MAPPING.
*
* Resolution order:
* 1. `config.architectures` entries looked up in MODEL_TYPE_MAPPING
* 2. `config.model_type` looked up directly in MODEL_TYPE_MAPPING
* 3. `config.model_type` looked up via MODEL_MAPPING_NAMES → architecture → MODEL_TYPE_MAPPING
* 4. Fallback to `MODEL_TYPES.EncoderOnly`
*
* @param {PretrainedConfig} config The model config object.
* @param {{ warn?: boolean }} [options] Set `warn` to false to suppress the
* fallback warning (defaults to true).
* @returns {number} One of the MODEL_TYPES enum values.
*/
export function resolve_model_type(config, { warn = true } = {}) {
// @ts-ignore - architectures is set via Object.assign in PretrainedConfig constructor
const architectures = /** @type {string[]} */ (config.architectures || []);

// 1. Try architectures against MODEL_TYPE_MAPPING
for (const arch of architectures) {
const mappedType = MODEL_TYPE_MAPPING.get(arch);
if (mappedType !== undefined) {
return mappedType;
}
}

// 2. Try config.model_type directly
if (config.model_type) {
const mappedType = MODEL_TYPE_MAPPING.get(config.model_type);
if (mappedType !== undefined) {
return mappedType;
}

// 3. Try MODEL_MAPPING_NAMES as a last resort
for (const mapping of Object.values(MODEL_MAPPING_NAMES)) {
if (mapping.has(config.model_type)) {
const resolved = MODEL_TYPE_MAPPING.get(mapping.get(config.model_type));
if (resolved !== undefined) {
return resolved;
}
}
}
}

// 4. Fallback
if (warn) {
const archList = architectures.length > 0 ? architectures.join(', ') : '(none)';
logger.warn(
`[resolve_model_type] Architecture(s) not found in MODEL_TYPE_MAPPING: [${archList}] ` +
`for model type '${config.model_type}'. Falling back to EncoderOnly (single model.onnx file). ` +
`If you encounter issues, please report at: ${GITHUB_ISSUE_URL}`,
);
}

return MODEL_TYPES.EncoderOnly;
}
Loading
Loading