From 182cbce27c275f74d20ba3e4633e655f27d88f31 Mon Sep 17 00:00:00 2001 From: jhlee111 Date: Sat, 21 Feb 2026 16:44:25 -0800 Subject: [PATCH] feat: add prompt_ids support for Whisper generation Implement prompt_ids handling in WhisperForConditionalGeneration.generate() to support initial prompt conditioning, matching the Python transformers library behavior. When prompt_ids is provided via generation config, it is prepended to init_tokens following the Whisper training format: [<|startofprev|>, ...prompt_text..., <|startoftranscript|>, <|lang|>, <|task|>, ...] The prompt tokens are stripped from output sequences after generation to prevent them from appearing in transcription results. Closes #923 Closes #1028 --- .../src/models/whisper/modeling_whisper.js | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/packages/transformers/src/models/whisper/modeling_whisper.js b/packages/transformers/src/models/whisper/modeling_whisper.js index dfca02672..75c14bceb 100644 --- a/packages/transformers/src/models/whisper/modeling_whisper.js +++ b/packages/transformers/src/models/whisper/modeling_whisper.js @@ -116,7 +116,13 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { }) { generation_config = this._prepare_generation_config(generation_config, kwargs); - const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + let init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + + // Prepend prompt_ids to init_tokens if provided (matches Whisper training format): + // [<|startofprev|>, ...prompt_text..., <|startoftranscript|>, <|lang|>, <|task|>, ...] + if (generation_config.prompt_ids) { + init_tokens = [...generation_config.prompt_ids, ...init_tokens]; + } if (generation_config.return_timestamps) { logits_processor ??= new LogitsProcessorList(); @@ -146,7 +152,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { generation_config.return_dict_in_generate = true; } - const outputs = await super.generate({ + let outputs = await super.generate({ inputs, generation_config, logits_processor, @@ -154,6 +160,19 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { ...kwargs, }); + // Strip prompt_ids from output sequences so they don't appear in transcription + if (generation_config.prompt_ids) { + const prompt_len = generation_config.prompt_ids.length; + // @ts-expect-error TS2339 + if (outputs.sequences) { + // @ts-expect-error TS2339 + outputs.sequences = outputs.sequences.slice(null, [prompt_len, null]); + } else { + // @ts-expect-error TS2339 + outputs = outputs.slice(null, [prompt_len, null]); + } + } + if (generation_config.return_token_timestamps) { outputs['token_timestamps'] = this._extract_token_timestamps( // @ts-expect-error TS2345