@@ -286,50 +286,57 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
286286
287287 prepare_inputs_for_generation ( input_ids , model_inputs , generation_config ) {
288288 // Overwritten -- in specific circumstances we don't want to forward image inputs to the model
289- if ( model_inputs . attention_mask && ! model_inputs . position_ids ) {
290- // Calculate position_ids and rope_deltas
291- if ( ! model_inputs . past_key_values ) {
292- [ model_inputs . position_ids , model_inputs . rope_deltas ] = this . get_rope_index (
289+ if ( ! model_inputs . attention_mask || model_inputs . position_ids ) {
290+ return model_inputs ;
291+ }
292+
293+ const session = this . sessions [ 'decoder_model_merged' ] ?? this . sessions [ 'model' ] ;
294+ if ( ! session . inputNames . includes ( 'position_ids' ) ) {
295+ return model_inputs ;
296+ }
297+
298+ // Calculate position_ids and rope_deltas
299+ if ( ! model_inputs . past_key_values ) {
300+ [ model_inputs . position_ids , model_inputs . rope_deltas ] = this . get_rope_index (
301+ model_inputs . input_ids ,
302+ model_inputs . image_grid_thw ,
303+ model_inputs . video_grid_thw ,
304+ model_inputs . attention_mask ,
305+ ) ;
306+ } else {
307+ model_inputs . pixel_values = null ;
308+ // model_inputs.pixel_values_videos = null;
309+
310+ const past_length = model_inputs . past_key_values . get_seq_length ( ) ;
311+
312+ if ( past_length < model_inputs . input_ids . dims [ 1 ] ) {
313+ // Externally provided `past_key_values` with full input_ids:
314+ // Compute full position_ids, then slice to only the new (unprocessed) tokens.
315+ const [ full_position_ids , rope_deltas ] = this . get_rope_index (
293316 model_inputs . input_ids ,
294317 model_inputs . image_grid_thw ,
295318 model_inputs . video_grid_thw ,
296319 model_inputs . attention_mask ,
297320 ) ;
321+ model_inputs . rope_deltas = rope_deltas ;
322+ model_inputs . position_ids = full_position_ids . slice ( null , null , [ past_length , null ] ) ;
323+ model_inputs . input_ids = model_inputs . input_ids . slice ( null , [ past_length , null ] ) ;
298324 } else {
299- model_inputs . pixel_values = null ;
300- // model_inputs.pixel_values_videos = null;
301-
302- const past_length = model_inputs . past_key_values . get_seq_length ( ) ;
303-
304- if ( past_length < model_inputs . input_ids . dims [ 1 ] ) {
305- // Externally provided `past_key_values` with full input_ids:
306- // Compute full position_ids, then slice to only the new (unprocessed) tokens.
307- const [ full_position_ids , rope_deltas ] = this . get_rope_index (
325+ // Auto-regressive case: single new token.
326+ // `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
327+ // In that case, recompute from current inputs instead of relying on persisted model state.
328+ if ( ! model_inputs . rope_deltas ) {
329+ [ , model_inputs . rope_deltas ] = this . get_rope_index (
308330 model_inputs . input_ids ,
309331 model_inputs . image_grid_thw ,
310332 model_inputs . video_grid_thw ,
311333 model_inputs . attention_mask ,
312334 ) ;
313- model_inputs . rope_deltas = rope_deltas ;
314- model_inputs . position_ids = full_position_ids . slice ( null , null , [ past_length , null ] ) ;
315- model_inputs . input_ids = model_inputs . input_ids . slice ( null , [ past_length , null ] ) ;
316- } else {
317- // Auto-regressive case: single new token.
318- // `rope_deltas` may be absent when generation starts from externally provided `past_key_values`.
319- // In that case, recompute from current inputs instead of relying on persisted model state.
320- if ( ! model_inputs . rope_deltas ) {
321- [ , model_inputs . rope_deltas ] = this . get_rope_index (
322- model_inputs . input_ids ,
323- model_inputs . image_grid_thw ,
324- model_inputs . video_grid_thw ,
325- model_inputs . attention_mask ,
326- ) ;
327- }
328-
329- const delta = BigInt ( past_length ) ;
330- const rope_deltas_list = model_inputs . rope_deltas . map ( ( x ) => delta + x ) ;
331- model_inputs . position_ids = stack ( [ rope_deltas_list , rope_deltas_list , rope_deltas_list ] , 0 ) ;
332335 }
336+
337+ const delta = BigInt ( past_length ) ;
338+ const rope_deltas_list = model_inputs . rope_deltas . map ( ( x ) => delta + x ) ;
339+ model_inputs . position_ids = stack ( [ rope_deltas_list , rope_deltas_list , rope_deltas_list ] , 0 ) ;
333340 }
334341 }
335342
0 commit comments