Skip to content

Commit 99a8eeb

Browse files
authored
Add support for qwen3_5_text <-> Qwen3_5ForCausalLM (#1602)
1 parent a97b51b commit 99a8eeb

3 files changed

Lines changed: 46 additions & 35 deletions

File tree

packages/transformers/src/models/modeling_utils.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,12 @@ function resolveTypeConfig(modelName, config) {
292292
// Detect cross-architecture loading: e.g., ForCausalLM class loading a ForConditionalGeneration model.
293293
// In this case, use the native architecture's type config (for forward/sessions) in text-only mode.
294294
const nativeArch = config?.architectures?.[0];
295-
if (nativeArch && nativeArch !== modelName
296-
&& modelName?.endsWith('ForCausalLM') && nativeArch.endsWith('ForConditionalGeneration')) {
295+
if (
296+
nativeArch &&
297+
nativeArch !== modelName &&
298+
modelName?.endsWith('ForCausalLM') &&
299+
nativeArch.endsWith('ForConditionalGeneration')
300+
) {
297301
const nativeType = MODEL_TYPE_MAPPING.get(nativeArch);
298302
if (nativeType !== undefined) {
299303
modelType = nativeType;

packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -286,50 +286,57 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
286286

287287
prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
288288
// Overwritten -- in specific circumstances we don't want to forward image inputs to the model
289-
if (model_inputs.attention_mask && !model_inputs.position_ids) {
290-
// Calculate position_ids and rope_deltas
291-
if (!model_inputs.past_key_values) {
292-
[model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
289+
if (!model_inputs.attention_mask || model_inputs.position_ids) {
290+
return model_inputs;
291+
}
292+
293+
const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
294+
if (!session.inputNames.includes('position_ids')) {
295+
return model_inputs;
296+
}
297+
298+
// Calculate position_ids and rope_deltas
299+
if (!model_inputs.past_key_values) {
300+
[model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
301+
model_inputs.input_ids,
302+
model_inputs.image_grid_thw,
303+
model_inputs.video_grid_thw,
304+
model_inputs.attention_mask,
305+
);
306+
} else {
307+
model_inputs.pixel_values = null;
308+
// model_inputs.pixel_values_videos = null;
309+
310+
const past_length = model_inputs.past_key_values.get_seq_length();
311+
312+
if (past_length < model_inputs.input_ids.dims[1]) {
313+
// Externally provided `past_key_values` with full input_ids:
314+
// Compute full position_ids, then slice to only the new (unprocessed) tokens.
315+
const [full_position_ids, rope_deltas] = this.get_rope_index(
293316
model_inputs.input_ids,
294317
model_inputs.image_grid_thw,
295318
model_inputs.video_grid_thw,
296319
model_inputs.attention_mask,
297320
);
321+
model_inputs.rope_deltas = rope_deltas;
322+
model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]);
323+
model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]);
298324
} else {
299-
model_inputs.pixel_values = null;
300-
// model_inputs.pixel_values_videos = null;
301-
302-
const past_length = model_inputs.past_key_values.get_seq_length();
303-
304-
if (past_length < model_inputs.input_ids.dims[1]) {
305-
// Externally provided `past_key_values` with full input_ids:
306-
// Compute full position_ids, then slice to only the new (unprocessed) tokens.
307-
const [full_position_ids, rope_deltas] = this.get_rope_index(
325+
// Auto-regressive case: single new token.
326+
// `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
327+
// In that case, recompute from current inputs instead of relying on persisted model state.
328+
if (!model_inputs.rope_deltas) {
329+
[, model_inputs.rope_deltas] = this.get_rope_index(
308330
model_inputs.input_ids,
309331
model_inputs.image_grid_thw,
310332
model_inputs.video_grid_thw,
311333
model_inputs.attention_mask,
312334
);
313-
model_inputs.rope_deltas = rope_deltas;
314-
model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]);
315-
model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]);
316-
} else {
317-
// Auto-regressive case: single new token.
318-
// `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
319-
// In that case, recompute from current inputs instead of relying on persisted model state.
320-
if (!model_inputs.rope_deltas) {
321-
[, model_inputs.rope_deltas] = this.get_rope_index(
322-
model_inputs.input_ids,
323-
model_inputs.image_grid_thw,
324-
model_inputs.video_grid_thw,
325-
model_inputs.attention_mask,
326-
);
327-
}
328-
329-
const delta = BigInt(past_length);
330-
const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x);
331-
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0);
332335
}
336+
337+
const delta = BigInt(past_length);
338+
const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x);
339+
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0);
333340
}
334341
}
335342

packages/transformers/src/models/registry.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ export const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
298298
['qwen3_vl', 'Qwen3VLForCausalLM'],
299299
['qwen3_vl_moe', 'Qwen3VLMoeForCausalLM'],
300300
['qwen3_5', 'Qwen3_5ForCausalLM'],
301+
['qwen3_5_text', 'Qwen3_5ForCausalLM'],
301302
['qwen3_5_moe', 'Qwen3_5MoeForCausalLM'],
302303
['gemma3n', 'Gemma3nForCausalLM'],
303304
['phi', 'PhiForCausalLM'],
@@ -619,7 +620,6 @@ const CUSTOM_MAPPING = [
619620
['SupertonicForConditionalGeneration', ALL_MODEL_FILES.SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic],
620621
['ChatterboxModel', ALL_MODEL_FILES.ChatterboxModel, MODEL_TYPES.Chatterbox],
621622

622-
623623
[
624624
'VoxtralRealtimeForConditionalGeneration',
625625
ALL_MODEL_FILES.VoxtralRealtimeForConditionalGeneration,

0 commit comments

Comments
 (0)