Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/transformers/src/models/modeling_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,12 @@ function resolveTypeConfig(modelName, config) {
// Detect cross-architecture loading: e.g., ForCausalLM class loading a ForConditionalGeneration model.
// In this case, use the native architecture's type config (for forward/sessions) in text-only mode.
const nativeArch = config?.architectures?.[0];
if (nativeArch && nativeArch !== modelName
&& modelName?.endsWith('ForCausalLM') && nativeArch.endsWith('ForConditionalGeneration')) {
if (
nativeArch &&
nativeArch !== modelName &&
modelName?.endsWith('ForCausalLM') &&
nativeArch.endsWith('ForConditionalGeneration')
) {
const nativeType = MODEL_TYPE_MAPPING.get(nativeArch);
if (nativeType !== undefined) {
modelType = nativeType;
Expand Down
71 changes: 39 additions & 32 deletions packages/transformers/src/models/qwen2_vl/modeling_qwen2_vl.js
Original file line number Diff line number Diff line change
Expand Up @@ -286,50 +286,57 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {

prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
// Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if (model_inputs.attention_mask && !model_inputs.position_ids) {
// Calculate position_ids and rope_deltas
if (!model_inputs.past_key_values) {
[model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
if (!model_inputs.attention_mask || model_inputs.position_ids) {
return model_inputs;
}

const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
if (!session.inputNames.includes('position_ids')) {
return model_inputs;
}

// Calculate position_ids and rope_deltas
if (!model_inputs.past_key_values) {
[model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
model_inputs.input_ids,
model_inputs.image_grid_thw,
model_inputs.video_grid_thw,
model_inputs.attention_mask,
);
} else {
model_inputs.pixel_values = null;
// model_inputs.pixel_values_videos = null;

const past_length = model_inputs.past_key_values.get_seq_length();

if (past_length < model_inputs.input_ids.dims[1]) {
// Externally provided `past_key_values` with full input_ids:
// Compute full position_ids, then slice to only the new (unprocessed) tokens.
const [full_position_ids, rope_deltas] = this.get_rope_index(
model_inputs.input_ids,
model_inputs.image_grid_thw,
model_inputs.video_grid_thw,
model_inputs.attention_mask,
);
model_inputs.rope_deltas = rope_deltas;
model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]);
model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]);
} else {
model_inputs.pixel_values = null;
// model_inputs.pixel_values_videos = null;

const past_length = model_inputs.past_key_values.get_seq_length();

if (past_length < model_inputs.input_ids.dims[1]) {
// Externally provided `past_key_values` with full input_ids:
// Compute full position_ids, then slice to only the new (unprocessed) tokens.
const [full_position_ids, rope_deltas] = this.get_rope_index(
// Auto-regressive case: single new token.
// `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
// In that case, recompute from current inputs instead of relying on persisted model state.
if (!model_inputs.rope_deltas) {
[, model_inputs.rope_deltas] = this.get_rope_index(
model_inputs.input_ids,
model_inputs.image_grid_thw,
model_inputs.video_grid_thw,
model_inputs.attention_mask,
);
model_inputs.rope_deltas = rope_deltas;
model_inputs.position_ids = full_position_ids.slice(null, null, [past_length, null]);
model_inputs.input_ids = model_inputs.input_ids.slice(null, [past_length, null]);
} else {
// Auto-regressive case: single new token.
// `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
// In that case, recompute from current inputs instead of relying on persisted model state.
if (!model_inputs.rope_deltas) {
[, model_inputs.rope_deltas] = this.get_rope_index(
model_inputs.input_ids,
model_inputs.image_grid_thw,
model_inputs.video_grid_thw,
model_inputs.attention_mask,
);
}

const delta = BigInt(past_length);
const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x);
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0);
}

const delta = BigInt(past_length);
const rope_deltas_list = model_inputs.rope_deltas.map((x) => delta + x);
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0);
}
}

Expand Down
2 changes: 1 addition & 1 deletion packages/transformers/src/models/registry.js
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ export const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['qwen3_vl', 'Qwen3VLForCausalLM'],
['qwen3_vl_moe', 'Qwen3VLMoeForCausalLM'],
['qwen3_5', 'Qwen3_5ForCausalLM'],
['qwen3_5_text', 'Qwen3_5ForCausalLM'],
['qwen3_5_moe', 'Qwen3_5MoeForCausalLM'],
['gemma3n', 'Gemma3nForCausalLM'],
['phi', 'PhiForCausalLM'],
Expand Down Expand Up @@ -619,7 +620,6 @@ const CUSTOM_MAPPING = [
['SupertonicForConditionalGeneration', ALL_MODEL_FILES.SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic],
['ChatterboxModel', ALL_MODEL_FILES.ChatterboxModel, MODEL_TYPES.Chatterbox],


[
'VoxtralRealtimeForConditionalGeneration',
ALL_MODEL_FILES.VoxtralRealtimeForConditionalGeneration,
Expand Down
Loading