Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions packages/transformers/src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,37 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
}
}

/**
* A LogitsProcessor that suppresses a list of tokens throughout generation.
* Sets their log probs to `-inf` so that they are not generated.
*/
export class SuppressTokensLogitsProcessor extends LogitsProcessor {
/**
* Create a SuppressTokensLogitsProcessor.
* @param {number[]} suppress_tokens The IDs of the tokens to suppress.
*/
constructor(suppress_tokens) {
super();
this.suppress_tokens = suppress_tokens;
}

/**
* Suppress the specified tokens by setting their logits to -Infinity.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Tensor} The modified logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */ (logits[i].data);
for (const token_id of this.suppress_tokens) {
batch_logits_data[token_id] = -Infinity;
}
}
return logits;
}
}

/**
* A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
* generating using `begin_index` tokens. This should ensure that the tokens defined by
Expand Down Expand Up @@ -271,9 +302,8 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
// suppress <|notimestamps|> which is handled by without_timestamps
batch_logits_data[this.no_timestamps_token_id] = -Infinity;

if (input_ids[i].length === this.begin_index - 1) {
batch_logits_data.fill(-Infinity);
batch_logits_data[this.timestamp_begin] = 0;
if (input_ids[i].length === this.begin_index) {
batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
continue;
}

Expand Down
7 changes: 4 additions & 3 deletions packages/transformers/src/models/modeling_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
LogitsProcessorList,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
Expand Down Expand Up @@ -547,9 +548,9 @@ export class PreTrainedModel extends Callable {
// ));
// }

// if (generation_config.suppress_tokens !== null) {
// processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
// }
if (generation_config.suppress_tokens !== null) {
processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
}

if (generation_config.begin_suppress_tokens !== null) {
const begin_index =
Expand Down
222 changes: 219 additions & 3 deletions packages/transformers/src/models/whisper/modeling_whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
generation_config.return_dict_in_generate = true;
}

// For timestamp mode, use seek-based sequential generation.
// This matches Python's WhisperForConditionalGeneration.generate() which uses a seek loop
// to handle audio that the model doesn't fully transcribe in a single pass.
// Skip the seek loop when max_new_tokens is explicitly set (e.g., for prefix token tests).
if (generation_config.return_timestamps && !kwargs.max_new_tokens) {
return this._generate_with_seek({
inputs,
generation_config,
logits_processor,
init_tokens,
kwargs,
});
}

const outputs = await super.generate({
inputs,
generation_config,
Expand All @@ -160,12 +174,193 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
outputs,
generation_config.alignment_heads,
generation_config.num_frames,
0.02,
init_tokens.length,
);
}

return outputs;
}

/**
* Generates with a seek loop for timestamp mode, re-encoding and generating
* for each segment until all audio frames are consumed.
* This matches Python's WhisperForConditionalGeneration.generate() behavior.
* @private
*/
async _generate_with_seek({ inputs, generation_config, logits_processor, init_tokens, kwargs }) {
const timestamp_begin = generation_config.no_timestamps_token_id + 1;
const eos_token_id = Array.isArray(generation_config.eos_token_id)
? generation_config.eos_token_id[0]
: generation_config.eos_token_id;
const return_token_timestamps = generation_config.return_token_timestamps;

// input_features shape: [batch=1, n_mels, total_frames]
const input_features = inputs;
const total_frames = input_features.dims[2];

// The encoder downsamples by input_stride (=2 for whisper), so:
// num_segment_frames = input_stride * max_source_positions = 3000 mel frames per segment
// Timestamp token T maps to mel frame position T * input_stride
const input_stride = 2;
// @ts-expect-error ts(2339)
const max_source_positions = /** @type {number} */ (this.config.max_source_positions);
const num_segment_frames = input_stride * max_source_positions;

let seek = 0;
const allTokens = [];
const allTokenTimestamps = [];

while (seek < total_frames) {
// Slice input features for this segment
const seek_end = Math.min(seek + num_segment_frames, total_frames);
const segment_input = input_features.slice(null, null, [seek, seek_end]);

// Pad to full segment size if needed (whisper expects fixed-length input)
let segment_features;
const segment_frames = segment_input.dims[2];
if (segment_frames < num_segment_frames) {
const n_mels = input_features.dims[1];
const padded_data = new Float32Array(n_mels * num_segment_frames);
const src = /** @type {Float32Array} */ (segment_input.data);
// Copy each mel band row separately to handle the stride difference
for (let m = 0; m < n_mels; ++m) {
padded_data.set(src.subarray(m * segment_frames, (m + 1) * segment_frames), m * num_segment_frames);
}
segment_features = new Tensor('float32', padded_data, [1, n_mels, num_segment_frames]);
} else {
segment_features = segment_input;
}

// Reset logits processor begin_index for each segment
if (logits_processor) {
for (const proc of logits_processor) {
if ('begin_index' in proc) {
proc.begin_index = init_tokens.length;
}
}
}

const outputs = /** @type {any} */ (
await super.generate({
inputs: segment_features,
generation_config,
logits_processor,
decoder_input_ids: init_tokens,
...kwargs,
})
);

// Extract tokens (skip init_tokens prefix)
const raw_sequence = return_token_timestamps ? outputs.sequences : /** @type {Tensor} */ (outputs);
const generated_tokens = raw_sequence[0].tolist().map(Number).slice(init_tokens.length);

// Extract token-level timestamps for this seek pass if needed
let seek_token_timestamps;
if (return_token_timestamps) {
outputs['token_timestamps'] = this._extract_token_timestamps(
outputs,
generation_config.alignment_heads,
Math.floor((seek_end - seek) / input_stride),
0.02,
init_tokens.length,
);
const time_offset = (seek / input_stride) * 0.02;
seek_token_timestamps = outputs.token_timestamps[0]
.tolist()
.slice(init_tokens.length)
.map((/** @type {number} */ t) => t + time_offset);
}

// Remove trailing EOS
if (generated_tokens.length > 0 && generated_tokens.at(-1) === eos_token_id) {
generated_tokens.pop();
}

if (generated_tokens.length === 0) {
// No tokens generated — skip the rest of the audio
break;
}

// Determine seek advancement using the same logic as Python's _retrieve_segment:
// 1. Find consecutive timestamp token pairs (segment boundaries)
// 2. If the sequence ends with a single timestamp (no speech after it),
// consume all remaining frames in this segment
// 3. Otherwise, seek to the last complete segment boundary
const is_timestamp = generated_tokens.map((t) => t >= timestamp_begin);

// Check for single_timestamp_ending: last token is timestamp, second-to-last is not
const single_timestamp_ending =
generated_tokens.length >= 2 &&
is_timestamp[generated_tokens.length - 1] &&
!is_timestamp[generated_tokens.length - 2];

// Find consecutive timestamp pairs (segment boundaries)
const segment_boundary_indices = [];
for (let i = 0; i < generated_tokens.length - 1; ++i) {
if (is_timestamp[i] && is_timestamp[i + 1]) {
segment_boundary_indices.push(i + 1); // index of the second token in the pair
}
}

let segment_offset;
let tokens_to_keep = generated_tokens.length;
if (segment_boundary_indices.length > 0) {
if (single_timestamp_ending) {
// Ends with a single timestamp after the last pair — no more speech
segment_offset = seek_end - seek;
} else {
// Ends mid-segment — seek to the last pair's end timestamp
// Discard tokens after the last pair (they're from an incomplete segment)
// Keep up to the first token of the last pair (the end-of-segment timestamp),
// excluding the second token (the start-of-next-segment marker)
const last_boundary = segment_boundary_indices.at(-1);
const last_ts_pos = generated_tokens[last_boundary - 1] - timestamp_begin;
segment_offset = last_ts_pos * input_stride;
tokens_to_keep = last_boundary;
}
} else {
// No consecutive pairs found — consume entire segment
segment_offset = seek_end - seek;
}

// Offset timestamp tokens by the current seek position so they're
// monotonically increasing across segments. Cap at the maximum valid
// timestamp token (30.00s = 1500 positions) to stay within the token vocab.
const timestamp_offset = Math.floor(seek / input_stride);
const max_timestamp_token = timestamp_begin + 1500;
for (let i = 0; i < tokens_to_keep; ++i) {
if (generated_tokens[i] >= timestamp_begin) {
generated_tokens[i] = Math.min(generated_tokens[i] + timestamp_offset, max_timestamp_token);
}
}

allTokens.push(...generated_tokens.slice(0, tokens_to_keep));
if (seek_token_timestamps) {
allTokenTimestamps.push(...seek_token_timestamps.slice(0, tokens_to_keep));
}
seek += segment_offset;
}

// Add EOS back
allTokens.push(eos_token_id);

// Reconstruct output
const full_sequence = [...init_tokens, ...allTokens];
if (return_token_timestamps) {
// Return dict format with sequences and token_timestamps
const sequences = new Tensor('int64', full_sequence.map(BigInt), [1, full_sequence.length]);
// Pad token_timestamps to match full_sequence (init_tokens get 0.0)
const full_timestamps = [...new Array(init_tokens.length).fill(0), ...allTokenTimestamps, 0];
const token_timestamps = new Tensor('float32', new Float32Array(full_timestamps), [
1,
full_timestamps.length,
]);
return { sequences, token_timestamps };
}
return new Tensor('int64', full_sequence.map(BigInt), [1, full_sequence.length]);
}

/**
* Calculates token-level timestamps using the encoder-decoder cross-attentions and
* dynamic time-warping (DTW) to map each output token to a position in the input audio.
Expand All @@ -176,9 +371,16 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
* @param {number[][]} alignment_heads Alignment heads of the model
* @param {number} [num_frames=null] Number of frames in the input audio.
* @param {number} [time_precision=0.02] Precision of the timestamps in seconds
* @param {number} [num_input_ids=0] Number of decoder input ids (prefix tokens) to skip in DTW
* @returns {Tensor} tensor containing the timestamps in seconds for each predicted token
*/
_extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
_extract_token_timestamps(
generate_outputs,
alignment_heads,
num_frames = null,
time_precision = 0.02,
num_input_ids = 0,
) {
if (!generate_outputs.cross_attentions) {
throw new Error(
'Model outputs must contain cross attentions to extract timestamps. ' +
Expand Down Expand Up @@ -253,8 +455,14 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
}
}

// Skip decoder_input_ids in the cross-attention weights
const croppedWeights =
num_input_ids > 0
? smoothedWeights.slice(null, null, [num_input_ids, smoothedWeights.dims[2]], null)
: smoothedWeights;

// Average the different cross-attention heads.
const batchedMatrices = [mean(smoothedWeights, 1)];
const batchedMatrices = [mean(croppedWeights, 1)];

const timestampsShape = generate_outputs.sequences.dims;

Expand Down Expand Up @@ -284,7 +492,15 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
jump_times.push(time_indices[i] * time_precision);
}
}
timestamps[batch_idx].data.set(jump_times, 1);

// Pad with num_input_ids zeros at the start (for prefix tokens),
// then DTW jump_times, then duplicate last value (for eos token)
const padded = new Array(num_input_ids).fill(0);
padded.push(...jump_times);
if (jump_times.length > 0) {
padded.push(jump_times.at(-1));
}
timestamps[batch_idx].data.set(padded);
}

return timestamps;
Expand Down
Loading
Loading