From 0ba527c52bcb4e2dc953017abc9b21f3830efdf1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:02:07 -0400 Subject: [PATCH] Add support for qwen3_5_text <-> Qwen3_5ForCausalLM --- .../transformers/src/models/modeling_utils.js | 8 ++- .../src/models/qwen2_vl/modeling_qwen2_vl.js | 71 ++++++++++--------- packages/transformers/src/models/registry.js | 2 +- 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 41020e2e3..88521c671 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -292,8 +292,12 @@ function resolveTypeConfig(modelName, config) { // Detect cross-architecture loading: e.g., ForCausalLM class loading a ForConditionalGeneration model. // In this case, use the native architecture's type config (for forward/sessions) in text-only mode. const nativeArch = config?.architectures?.[0]; - if (nativeArch && nativeArch !== modelName - && modelName?.endsWith('ForCausalLM') && nativeArch.endsWith('ForConditionalGeneration')) { + if ( + nativeArch && + nativeArch !== modelName && + modelName?.endsWith('ForCausalLM') && + nativeArch.endsWith('ForConditionalGeneration') + ) { const nativeType = MODEL_TYPE_MAPPING.get(nativeArch); if (nativeType !== undefined) { modelType = nativeType; diff --git a/packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js b/packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js index 93e540c8c..3848bcaf7 100644 --- a/packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js +++ b/packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js @@ -286,50 +286,57 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { prepare_inputs_for_generation(input_ids, model_inputs, generation_config) { // Overwritten -- in specific circumstances we don't want to forward image inputs to the model - if (model_inputs.attention_mask && !model_inputs.position_ids) { - // Calculate position_ids and rope_deltas - if (!model_inputs.past_key_values) { - [model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index( + if (!model_inputs.attention_mask || model_inputs.position_ids) { + return model_inputs; + } + + const session = this.sessions['decoder_model_merged'] ?? this.sessions['model']; + if (!session.inputNames.includes('position_ids')) { + return model_inputs; + } + + // Calculate position_ids and rope_deltas + if (!model_inputs.past_key_values) { + [model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index( + model_inputs.input_ids, + model_inputs.image_grid_thw, + model_inputs.video_grid_thw, + model_inputs.attention_mask, + ); + } else { + model_inputs.pixel_values = null; + // model_inputs.pixel_values_videos = null; + + const past_length = model_inputs.past_key_values.get_seq_length(); + + if (past_length < model_inputs.input_ids.dims[1]) { + // Externally provided `past_key_values` with full input_ids: + // Compute full position_ids, then slice to only the new (unprocessed) tokens. + const [full_position_ids, rope_deltas] = this.get_rope_index( model_inputs.input_ids, model_inputs.image_grid_thw, model_inputs.video_grid_thw, model_inputs.attention_mask, ); + model_inputs.rope_deltas = rope_deltas; + model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]); + model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]); } else { - model_inputs.pixel_values = null; - // model_inputs.pixel_values_videos = null; - - const past_length = model_inputs.past_key_values.get_seq_length(); - - if (past_length < model_inputs.input_ids.dims[1]) { - // Externally provided `past_key_values` with full input_ids: - // Compute full position_ids, then slice to only the new (unprocessed) tokens. - const [full_position_ids, rope_deltas] = this.get_rope_index( + // Auto-regressive case: single new token. + // `rope_deltas` may be absent when generation starts from externally provided `past_key_values`. + // In that case, recompute from current inputs instead of relying on persisted model state. + if (!model_inputs.rope_deltas) { + [, model_inputs.rope_deltas] = this.get_rope_index( model_inputs.input_ids, model_inputs.image_grid_thw, model_inputs.video_grid_thw, model_inputs.attention_mask, ); - model_inputs.rope_deltas = rope_deltas; - model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]); - model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]); - } else { - // Auto-regressive case: single new token. - // `rope_deltas` may be absent when generation starts from externally provided `past_key_values`. - // In that case, recompute from current inputs instead of relying on persisted model state. - if (!model_inputs.rope_deltas) { - [, model_inputs.rope_deltas] = this.get_rope_index( - model_inputs.input_ids, - model_inputs.image_grid_thw, - model_inputs.video_grid_thw, - model_inputs.attention_mask, - ); - } - - const delta = BigInt(past_length); - const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x); - model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0); } + + const delta = BigInt(past_length); + const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x); + model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0); } } diff --git a/packages/transformers/src/models/registry.js b/packages/transformers/src/models/registry.js index ca4a64061..ee9a7a726 100644 --- a/packages/transformers/src/models/registry.js +++ b/packages/transformers/src/models/registry.js @@ -298,6 +298,7 @@ export const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ ['qwen3_vl', 'Qwen3VLForCausalLM'], ['qwen3_vl_moe', 'Qwen3VLMoeForCausalLM'], ['qwen3_5', 'Qwen3_5ForCausalLM'], + ['qwen3_5_text', 'Qwen3_5ForCausalLM'], ['qwen3_5_moe', 'Qwen3_5MoeForCausalLM'], ['gemma3n', 'Gemma3nForCausalLM'], ['phi', 'PhiForCausalLM'], @@ -619,7 +620,6 @@ const CUSTOM_MAPPING = [ ['SupertonicForConditionalGeneration', ALL_MODEL_FILES.SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic], ['ChatterboxModel', ALL_MODEL_FILES.ChatterboxModel, MODEL_TYPES.Chatterbox], - [ 'VoxtralRealtimeForConditionalGeneration', ALL_MODEL_FILES.VoxtralRealtimeForConditionalGeneration,