Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import { ImageProcessor } from '../../image_processors_utils.js';

export class Gemma3ImageProcessor extends ImageProcessor {}
5 changes: 4 additions & 1 deletion packages/transformers/src/models/gemma3/modeling_gemma3.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { PreTrainedModel } from '../modeling_utils.js';
import { LlavaForConditionalGeneration } from '../llava/modeling_llava.js';

/**
* The bare Gemma3 Model outputting raw hidden-states without any specific head on top.
Expand All @@ -10,4 +11,6 @@ export class Gemma3PreTrainedModel extends PreTrainedModel {}
*/
export class Gemma3Model extends Gemma3PreTrainedModel {}

export class Gemma3ForCausalLM extends Gemma3PreTrainedModel {}
export class Gemma3ForConditionalGeneration extends LlavaForConditionalGeneration {}

export class Gemma3ForCausalLM extends Gemma3ForConditionalGeneration {}
45 changes: 45 additions & 0 deletions packages/transformers/src/models/gemma3/processing_gemma3.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Processor } from '../../processing_utils.js';
import { AutoImageProcessor } from '../auto/image_processing_auto.js';
import { AutoTokenizer } from '../auto/tokenization_auto.js';

export class Gemma3Processor extends Processor {
static tokenizer_class = AutoTokenizer;
static image_processor_class = AutoImageProcessor;
static uses_processor_config = true;
static uses_chat_template_file = true;

constructor(config, components, chat_template) {
super(config, components, chat_template);
this.image_seq_length = this.config.image_seq_length;

const { boi_token, image_token, eoi_token } = this.tokenizer.config;
this.boi_token = boi_token;
this.image_token = image_token;
this.eoi_token = eoi_token;
const image_tokens_expanded = image_token.repeat(this.image_seq_length);
this.full_image_sequence = `\n\n${boi_token}${image_tokens_expanded}${eoi_token}\n\n`;
}

/**
* @param {string|string[]} text
* @param {import('../../utils/image.js').RawImage|import('../../utils/image.js').RawImage[]} [images]
* @param {Object} [options]
*/
async _call(text, images = null, options = {}) {
if (typeof text === 'string') {
text = [text];
}

let image_inputs;
if (images) {
image_inputs = await this.image_processor(images, options);
text = text.map((prompt) => prompt.replaceAll(this.boi_token, this.full_image_sequence));
}

const text_inputs = this.tokenizer(text, options);
return {
...text_inputs,
...image_inputs,
};
}
}
1 change: 1 addition & 0 deletions packages/transformers/src/models/image_processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export * from './dinov3_vit/image_processing_dinov3_vit.js';
export * from './donut/image_processing_donut.js';
export * from './dpt/image_processing_dpt.js';
export * from './efficientnet/image_processing_efficientnet.js';
export * from './gemma3/image_processing_gemma3.js';
export * from './glm46v/image_processing_glm46v.js';
export * from './glpn/image_processing_glpn.js';
export * from './grounding_dino/image_processing_grounding_dino.js';
Expand Down
67 changes: 42 additions & 25 deletions packages/transformers/src/models/modeling_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ export const MODEL_TYPES = {
ImageAudioTextToText: 13,
Supertonic: 14,
Chatterbox: 15,
MultimodalLanguageModelOnly: 16,
VoxtralRealtime: 17,
VoxtralRealtime: 16,
};

const MODEL_TYPE_CONFIG = {
Expand Down Expand Up @@ -158,12 +157,12 @@ const MODEL_TYPE_CONFIG = {
can_generate: true,
forward: image_text_to_text_forward,
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
sessions: (config) => {
sessions: (config, options, textOnly) => {
const s = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
};
if (!textOnly) s['vision_encoder'] = 'vision_encoder';
if (config.is_encoder_decoder) s['model'] = 'encoder_model';
return s;
},
Expand All @@ -185,12 +184,17 @@ const MODEL_TYPE_CONFIG = {
[MODEL_TYPES.ImageAudioTextToText]: {
can_generate: true,
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
sessions: () => ({
embed_tokens: 'embed_tokens',
audio_encoder: 'audio_encoder',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}),
sessions: (config, options, textOnly) => {
const s = {
embed_tokens: 'embed_tokens',
decoder_model_merged: 'decoder_model_merged',
};
if (!textOnly) {
s['audio_encoder'] = 'audio_encoder';
s['vision_encoder'] = 'vision_encoder';
}
return s;
},
optional_configs: { generation_config: 'generation_config.json' },
},
[MODEL_TYPES.Phi3V]: {
Expand Down Expand Up @@ -241,14 +245,6 @@ const MODEL_TYPE_CONFIG = {
cache_sessions: { model: true },
optional_configs: { generation_config: 'generation_config.json' },
},
[MODEL_TYPES.MultimodalLanguageModelOnly]: {
can_generate: true,
forward: image_text_to_text_forward,
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
sessions: () => ({ embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' }),
cache_sessions: { decoder_model_merged: true },
optional_configs: { generation_config: 'generation_config.json' },
},
[MODEL_TYPES.VoxtralRealtime]: {
can_generate: true,
prepare_inputs: decoder_prepare_inputs_for_generation,
Expand Down Expand Up @@ -283,6 +279,31 @@ export function getSessionsConfig(modelType, config, options = {}) {
};
}

/**
* Resolves the model type config for a given class name and config.
* @param {string} modelName The name of the class being used to load.
* @param {Object} config The model config.
* @returns {{ typeConfig: Object, textOnly: boolean, modelType: number|undefined }}
*/
function resolveTypeConfig(modelName, config) {
let modelType = MODEL_TYPE_MAPPING.get(modelName);
let textOnly = false;

// Detect cross-architecture loading: e.g., ForCausalLM class loading a ForConditionalGeneration model.
// In this case, use the native architecture's type config (for forward/sessions) in text-only mode.
const nativeArch = config?.architectures?.[0];
if (nativeArch && nativeArch !== modelName
&& modelName?.endsWith('ForCausalLM') && nativeArch.endsWith('ForConditionalGeneration')) {
const nativeType = MODEL_TYPE_MAPPING.get(nativeArch);
if (nativeType !== undefined) {
modelType = nativeType;
textOnly = true;
}
}

return { typeConfig: MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default, textOnly, modelType };
}

export const MODEL_TYPE_MAPPING = new Map();
export const MODEL_NAME_TO_CLASS_MAPPING = new Map();
export const MODEL_CLASS_TO_NAME_MAPPING = new Map();
Expand All @@ -309,10 +330,7 @@ export class PreTrainedModel extends Callable {
this.configs = configs;

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);

// Get configuration for this model type
const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
const { typeConfig } = resolveTypeConfig(modelName, config);

this.can_generate = typeConfig.can_generate;
this._forward = typeConfig.forward;
Expand Down Expand Up @@ -385,11 +403,10 @@ export class PreTrainedModel extends Callable {
};

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);

config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
const { typeConfig, textOnly, modelType } = resolveTypeConfig(modelName, config);

if (modelType === undefined) {
const type = modelName ?? config?.model_type;
Expand All @@ -400,7 +417,7 @@ export class PreTrainedModel extends Callable {
}
}

const sessions = typeConfig.sessions(config, options);
const sessions = typeConfig.sessions(config, options, textOnly);
const promises = [
constructSessions(pretrained_model_name_or_path, sessions, options, typeConfig.cache_sessions),
];
Expand Down
1 change: 1 addition & 0 deletions packages/transformers/src/models/processors.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from './chatterbox/processing_chatterbox.js';
export * from './florence2/processing_florence2.js';
export * from './gemma3/processing_gemma3.js';
export * from './gemma3n/processing_gemma3n.js';
export * from './glm46v/processing_glm46v.js';
export * from './granite_speech/processing_granite_speech.js';
Expand Down
8 changes: 1 addition & 7 deletions packages/transformers/src/models/registry.js
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
['smolvlm', 'SmolVLMForConditionalGeneration'],
['paligemma', 'PaliGemmaForConditionalGeneration'],
['llava_qwen2', 'LlavaQwen2ForCausalLM'],
['gemma3', 'Gemma3ForConditionalGeneration'],
['gemma3n', 'Gemma3nForConditionalGeneration'],
['mistral3', 'Mistral3ForConditionalGeneration'],
['lighton_ocr', 'LightOnOcrForConditionalGeneration'],
Expand Down Expand Up @@ -618,13 +619,6 @@ const CUSTOM_MAPPING = [
['SupertonicForConditionalGeneration', ALL_MODEL_FILES.SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic],
['ChatterboxModel', ALL_MODEL_FILES.ChatterboxModel, MODEL_TYPES.Chatterbox],

['Qwen2VLForCausalLM', ALL_MODEL_FILES.Qwen2VLForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Qwen2_5_VLForCausalLM', ALL_MODEL_FILES.Qwen2_5_VLForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Qwen3VLForCausalLM', ALL_MODEL_FILES.Qwen3VLForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Qwen3VLMoeForCausalLM', ALL_MODEL_FILES.Qwen3VLMoeForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Qwen3_5ForCausalLM', ALL_MODEL_FILES.Qwen3_5ForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Qwen3_5MoeForCausalLM', ALL_MODEL_FILES.Qwen3_5MoeForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],
['Gemma3nForCausalLM', ALL_MODEL_FILES.Gemma3nForCausalLM, MODEL_TYPES.MultimodalLanguageModelOnly],

[
'VoxtralRealtimeForConditionalGeneration',
Expand Down
128 changes: 128 additions & 0 deletions packages/transformers/tests/models/gemma3/test_modeling_gemma3.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { Gemma3ForConditionalGeneration, Gemma3ForCausalLM, AutoProcessor, AutoTokenizer, RawImage } from "../../../src/transformers.js";

import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../../init.js";

export default () => {
const CONVERSATION = [
{
role: "user",
content: [{ type: "text", text: "Hello" }],
},
];

const CONVERSATION_WITH_IMAGE = [
{
role: "user",
content: [{ type: "image" }, { type: "text", text: "Describe this image." }],
},
];

// Empty white image
const dims = [224, 224, 3];
const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims);

describe("Gemma3ForConditionalGeneration", () => {
const model_id = "onnx-internal-testing/tiny-random-Gemma3ForConditionalGeneration";

/** @type {Gemma3ForConditionalGeneration} */
let model;
/** @type {AutoProcessor} */
let processor;
beforeAll(async () => {
model = await Gemma3ForConditionalGeneration.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
processor = await AutoProcessor.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"text-only forward",
async () => {
const text = processor.apply_chat_template(CONVERSATION, { add_generation_prompt: true });
const inputs = await processor(text);
const { logits } = await model(inputs);
expect(logits.dims).toEqual([1, 11, 262208]);
expect(logits.mean().item()).toBeCloseTo(-0.004435515962541103, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"text + image forward",
async () => {
const text = processor.apply_chat_template(CONVERSATION_WITH_IMAGE, { add_generation_prompt: true });
const inputs = await processor(text, image);
const { logits } = await model(inputs);
expect(logits.dims).toEqual([1, 21, 262208]);
expect(logits.mean().item()).toBeCloseTo(-0.0029795959126204252, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"text-only (batch_size=1)",
async () => {
const text = processor.apply_chat_template(CONVERSATION, { add_generation_prompt: true });
const inputs = await processor(text);
const generate_ids = await model.generate({
...inputs,
max_new_tokens: 10,
do_sample: false,
});
const new_tokens = generate_ids.slice(null, [inputs.input_ids.dims.at(-1), null]);
expect(new_tokens.tolist()).toEqual([[107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n]]);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"text + image (batch_size=1)",
async () => {
const text = processor.apply_chat_template(CONVERSATION_WITH_IMAGE, { add_generation_prompt: true });
const inputs = await processor(text, image);
const generate_ids = await model.generate({
...inputs,
max_new_tokens: 10,
do_sample: false,
});
const new_tokens = generate_ids.slice(null, [inputs.input_ids.dims.at(-1), null]);
expect(new_tokens.tolist()).toEqual([[107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n, 107n]]);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("Gemma3ForCausalLM", () => {
const model_id = "onnx-internal-testing/tiny-random-Gemma3ForCausalLM";

/** @type {Gemma3ForCausalLM} */
let model;
/** @type {AutoTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await Gemma3ForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await AutoTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const outputs = await model.generate({
...inputs,
max_new_tokens: 5,
do_sample: false,
});
const new_tokens = outputs.slice(null, [inputs.input_ids.dims.at(-1), null]);
expect(new_tokens.tolist()).toEqual([[23391n, 23391n, 23391n, 23391n, 23391n]]);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
};
Loading
Loading