diff --git a/packages/transformers/src/models/whisper/modeling_whisper.js b/packages/transformers/src/models/whisper/modeling_whisper.js index dfca02672..75c14bceb 100644 --- a/packages/transformers/src/models/whisper/modeling_whisper.js +++ b/packages/transformers/src/models/whisper/modeling_whisper.js @@ -116,7 +116,13 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { }) { generation_config = this._prepare_generation_config(generation_config, kwargs); - const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + let init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + + // Prepend prompt_ids to init_tokens if provided (matches Whisper training format): + // [<|startofprev|>, ...prompt_text..., <|startoftranscript|>, <|lang|>, <|task|>, ...] + if (generation_config.prompt_ids) { + init_tokens = [...generation_config.prompt_ids, ...init_tokens]; + } if (generation_config.return_timestamps) { logits_processor ??= new LogitsProcessorList(); @@ -146,7 +152,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { generation_config.return_dict_in_generate = true; } - const outputs = await super.generate({ + let outputs = await super.generate({ inputs, generation_config, logits_processor, @@ -154,6 +160,19 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { ...kwargs, }); + // Strip prompt_ids from output sequences so they don't appear in transcription + if (generation_config.prompt_ids) { + const prompt_len = generation_config.prompt_ids.length; + // @ts-expect-error TS2339 + if (outputs.sequences) { + // @ts-expect-error TS2339 + outputs.sequences = outputs.sequences.slice(null, [prompt_len, null]); + } else { + // @ts-expect-error TS2339 + outputs = outputs.slice(null, [prompt_len, null]); + } + } + if (generation_config.return_token_timestamps) { outputs['token_timestamps'] = this._extract_token_timestamps( // @ts-expect-error TS2345