From fe43e7e5dba7faff87fe43057da79c4caa820d6c Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Fri, 20 Feb 2026 14:21:58 -0700 Subject: [PATCH 1/6] Add beam search support with BeamHypotheses and BeamSearchScorer classes --- .../src/generation/beam_search.js | 231 +++++++++++++++ .../transformers/src/models/modeling_utils.js | 268 +++++++++++++++--- packages/transformers/src/transformers.js | 1 + packages/transformers/src/utils/tensor.js | 67 +++++ .../tests/utils/generation.test.js | 147 ++++++++++ 5 files changed, 674 insertions(+), 40 deletions(-) create mode 100644 packages/transformers/src/generation/beam_search.js diff --git a/packages/transformers/src/generation/beam_search.js b/packages/transformers/src/generation/beam_search.js new file mode 100644 index 000000000..e6475db8f --- /dev/null +++ b/packages/transformers/src/generation/beam_search.js @@ -0,0 +1,231 @@ +/** + * @module generation/beam_search + */ + +/** + * Stores completed beam search hypotheses for a single batch element. + */ +export class BeamHypotheses { + /** + * @param {number} num_beams Number of beams. + * @param {number} length_penalty Exponential penalty to the length. + * @param {boolean|"never"} early_stopping Whether to stop when enough hypotheses are finished. + */ + constructor(num_beams, length_penalty = 1.0, early_stopping = false) { + this.num_beams = num_beams; + this.length_penalty = length_penalty; + this.early_stopping = early_stopping; + + /** @type {{score: number, tokens: bigint[]}[]} */ + this.beams = []; + this.worst_score = 1e9; + } + + get length() { + return this.beams.length; + } + + /** + * Add a new hypothesis to the list. + * @param {number} sum_logprobs Sum of log probabilities of the hypothesis. + * @param {bigint[]} tokens The token ids of the hypothesis. + */ + add(sum_logprobs, tokens) { + const score = sum_logprobs / (tokens.length ** this.length_penalty); + if (this.beams.length < this.num_beams || score > this.worst_score) { + this.beams.push({ score, tokens }); + if (this.beams.length > this.num_beams) { + // Remove worst hypothesis + let worst_idx = 0; + for (let i = 1; i < this.beams.length; ++i) { + if (this.beams[i].score < this.beams[worst_idx].score) { + worst_idx = i; + } + } + this.beams.splice(worst_idx, 1); + } + this.worst_score = this.beams.length === this.num_beams + ? Math.min(...this.beams.map(b => b.score)) + : -1e9; + } + } + + /** + * Check whether adding more beams can possibly improve the hypotheses. + * @param {number} best_sum_logprobs Best sum of log probs among active beams. + * @param {number} cur_len Current length of generated tokens. + * @returns {boolean} + */ + is_done(best_sum_logprobs, cur_len) { + if (this.beams.length < this.num_beams) return false; + + if (this.early_stopping === true) { + return true; + } else if (this.early_stopping === 'never') { + return false; + } else { + // Heuristic: check if the best possible score for the next step + // could beat the worst completed hypothesis + const highest_attainable_score = best_sum_logprobs / (cur_len ** this.length_penalty); + return this.worst_score >= highest_attainable_score; + } + } +} + +/** + * Implements beam search scoring and beam management. + */ +export class BeamSearchScorer { + /** + * @param {number} batch_size + * @param {number} num_beams + * @param {Object} options + * @param {number} [options.length_penalty] + * @param {boolean|"never"} [options.early_stopping] + * @param {number} [options.num_return_sequences] + * @param {number|number[]|null} [options.eos_token_id] + * @param {number|null} [options.pad_token_id] + */ + constructor(batch_size, num_beams, { + length_penalty = 1.0, + early_stopping = false, + num_return_sequences = 1, + eos_token_id = null, + pad_token_id = null, + } = {}) { + this.batch_size = batch_size; + this.num_beams = num_beams; + this.length_penalty = length_penalty; + this.early_stopping = early_stopping; + this.num_return_sequences = num_return_sequences; + this.eos_token_ids = eos_token_id === null ? [] + : (Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]); + this.pad_token_id = pad_token_id ?? 0; + + if (num_return_sequences > num_beams) { + throw new Error( + `num_return_sequences (${num_return_sequences}) must be <= num_beams (${num_beams}).`, + ); + } + + /** @type {BeamHypotheses[]} */ + this._beam_hyps = Array.from( + { length: batch_size }, + () => new BeamHypotheses(num_beams, length_penalty, early_stopping), + ); + /** @type {boolean[]} */ + this._done = new Array(batch_size).fill(false); + } + + get is_done() { + return this._done.every(Boolean); + } + + /** + * Process the current beam candidates and select the next set of beams. + * + * @param {bigint[][]} all_input_ids All sequences so far, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Cumulative scores, shape [batch_size * num_beams]. + * @param {bigint[]} next_tokens Candidate tokens, shape [batch_size * 2 * num_beams]. + * @param {number[]} next_indices Which beam each candidate came from (relative to batch group), shape [batch_size * 2 * num_beams]. + * @param {number[]} next_scores New cumulative scores for candidates, shape [batch_size * 2 * num_beams]. + * @returns {{ next_beam_scores: number[], next_beam_tokens: bigint[], next_beam_indices: number[] }} + */ + process(all_input_ids, beam_scores, next_tokens, next_indices, next_scores) { + const cur_len = all_input_ids[0].length; + const total_beams = this.batch_size * this.num_beams; + + const next_beam_scores = new Array(total_beams).fill(0); + const next_beam_tokens = new Array(total_beams).fill(0n); + const next_beam_indices = new Array(total_beams).fill(0); + + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + if (this._done[batch_idx]) { + // Pad finished batches + for (let beam_idx = 0; beam_idx < this.num_beams; ++beam_idx) { + const flat_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[flat_idx] = 0; + next_beam_tokens[flat_idx] = BigInt(this.pad_token_id); + next_beam_indices[flat_idx] = batch_idx * this.num_beams; + } + continue; + } + + let beam_idx = 0; + const num_candidates = 2 * this.num_beams; + for (let j = 0; j < num_candidates; ++j) { + const cand_idx = batch_idx * num_candidates + j; + const beam_token = next_tokens[cand_idx]; + const beam_score = next_scores[cand_idx]; + const beam_source = next_indices[cand_idx]; // relative to batch + const abs_beam_source = batch_idx * this.num_beams + beam_source; + + const is_eos = this.eos_token_ids.some(id => BigInt(id) === beam_token); + + if (is_eos) { + // Add completed hypothesis + const hypothesis = [...all_input_ids[abs_beam_source], beam_token]; + this._beam_hyps[batch_idx].add(beam_score, hypothesis); + } else { + // Add to next active beams + const out_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[out_idx] = beam_score; + next_beam_tokens[out_idx] = beam_token; + next_beam_indices[out_idx] = abs_beam_source; + beam_idx++; + } + + if (beam_idx === this.num_beams) break; + } + + // If we couldn't fill all beams (too many EOS), pad with last valid + if (beam_idx < this.num_beams) { + const last_valid = batch_idx * this.num_beams + Math.max(0, beam_idx - 1); + for (; beam_idx < this.num_beams; ++beam_idx) { + const out_idx = batch_idx * this.num_beams + beam_idx; + next_beam_scores[out_idx] = next_beam_scores[last_valid] ?? 0; + next_beam_tokens[out_idx] = next_beam_tokens[last_valid] ?? 0n; + next_beam_indices[out_idx] = next_beam_indices[last_valid] ?? (batch_idx * this.num_beams); + } + } + + // Check if done for this batch using next-step scores + const start = batch_idx * this.num_beams; + const end = start + this.num_beams; + const best_sum_logprobs = Math.max(...next_beam_scores.slice(start, end)); + this._done[batch_idx] = this._beam_hyps[batch_idx].is_done(best_sum_logprobs, cur_len + 1); + } + + return { next_beam_scores, next_beam_tokens, next_beam_indices }; + } + + /** + * Finalize: select best hypotheses. + * @param {bigint[][]} all_input_ids Final sequences, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Final cumulative scores. + * @returns {bigint[][]} Best sequences, shape [batch_size * num_return_sequences, seq_len]. + */ + finalize(all_input_ids, beam_scores) { + // For each batch, ensure we have enough hypotheses + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + const hyps = this._beam_hyps[batch_idx]; + if (hyps.length < this.num_beams) { + for (let beam_idx = 0; beam_idx < this.num_beams; ++beam_idx) { + const flat_idx = batch_idx * this.num_beams + beam_idx; + hyps.add(beam_scores[flat_idx], all_input_ids[flat_idx]); + } + } + } + + // Select top num_return_sequences per batch + const results = []; + for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { + const sorted = [...this._beam_hyps[batch_idx].beams] + .sort((a, b) => b.score - a.score); + for (let i = 0; i < this.num_return_sequences; ++i) { + results.push(sorted[i].tokens); + } + } + return results; + } +} diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index d5a85a448..07275fa49 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1,7 +1,7 @@ import { Callable } from '../utils/generic.js'; import { constructSessions, sessionRun } from './session.js'; import { AutoConfig, getCacheShapes } from '../configs.js'; -import { Tensor, full_like, cat, zeros_like, ones_like, ones } from '../utils/tensor.js'; +import { Tensor, full_like, cat, zeros_like, ones_like, ones, index_select, index_select_async } from '../utils/tensor.js'; import { DataTypeMap } from '../utils/dtypes.js'; // These will be populated by registry.js @@ -33,7 +33,9 @@ import { import { GenerationConfig } from '../generation/configuration_utils.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from '../generation/stopping_criteria.js'; import { LogitsSampler } from '../generation/logits_sampler.js'; +import { BeamSearchScorer } from '../generation/beam_search.js'; import { pick } from '../utils/core.js'; +import { log_softmax } from '../utils/maths.js'; import { ModelOutput } from './modeling_outputs.js'; /** @@ -1079,27 +1081,85 @@ export class PreTrainedModel extends Callable { // } const numInputs = model_inputs[model_input_name].dims.at(0); + const num_beams = generation_config.num_beams; + const is_beam_search = num_beams > 1; - // TODO: - // done is a list of booleans to keep track of which inputs are done - // const done = new Array(numInputs).fill(false); - // For efficiency purposes, we remove completed rows from model_inputs - // when the beam is complete, and we keep track of the row index - // const rowIndexToBatchIndex = new Map(); + let beam_scorer = null; + /** @type {number[]} */ + let beam_scores; + + if (is_beam_search) { + // Validate beam search configuration + if (generation_config.num_beam_groups > 1) { + throw new Error('Diverse beam search (num_beam_groups > 1) is not yet supported.'); + } + if (generation_config.do_sample) { + throw new Error('Beam sampling (num_beams > 1 with do_sample = true) is not yet supported.'); + } + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + throw new Error('Classifier-free guidance (guidance_scale > 1) is not supported with beam search.'); + } + if (streamer) { + throw new Error('Streaming is not supported with beam search.'); + } + + beam_scorer = new BeamSearchScorer(numInputs, num_beams, { + length_penalty: generation_config.length_penalty, + early_stopping: generation_config.early_stopping, + num_return_sequences: generation_config.num_return_sequences, + eos_token_id: generation_config.eos_token_id, + pad_token_id: generation_config.pad_token_id, + }); + } const sampler = LogitsSampler.getSampler(generation_config); - // TODO make > numInputs - const scores = new Array(numInputs).fill(0); + const scores = new Array(is_beam_search ? numInputs * num_beams : numInputs).fill(0); /** @type {bigint[][]} */ const all_input_ids = input_ids.tolist(); + + if (is_beam_search) { + // Expand all_input_ids: [A, B] -> [A, A, A, B, B, B] for num_beams=3 + const expanded_ids = []; + for (const ids of all_input_ids) { + for (let b = 0; b < num_beams; ++b) { + expanded_ids.push([...ids]); + } + } + all_input_ids.length = 0; + all_input_ids.push(...expanded_ids); + + // Expand model_inputs tensors by repeating each row num_beams times + for (const key of [model_input_name, 'attention_mask', 'decoder_attention_mask', 'encoder_outputs']) { + const tensor = model_inputs[key]; + if (tensor instanceof Tensor) { + const [batch, ...rest] = tensor.dims; + const rowSize = rest.length > 0 ? rest.reduce((a, b) => a * b, 1) : 1; + // @ts-ignore + const newData = new tensor.data.constructor(batch * num_beams * rowSize); + for (let i = 0; i < batch; ++i) { + const srcOffset = i * rowSize; + for (let b = 0; b < num_beams; ++b) { + newData.set( + tensor.data.subarray(srcOffset, srcOffset + rowSize), + (i * num_beams + b) * rowSize, + ); + } + } + model_inputs[key] = new Tensor(tensor.type, newData, [batch * num_beams, ...rest]); + } + } + + // Initialize beam_scores: first beam of each input = 0, rest = -1e9 + beam_scores = new Array(numInputs * num_beams).fill(-1e9); + for (let i = 0; i < numInputs; ++i) { + beam_scores[i * num_beams] = 0; + } + } + if (streamer) { streamer.put(all_input_ids); } - // const all_generated_input_ids = Array.from({ length: numInputs }, () => []); - - // NOTE: For now, we don't support spawning new beams - // TODO: when we do, we simply copy past key values and accumulate into single large tensor //////////////////////////////////////////////////// // Generic search which handles 4 generation modes: @@ -1144,39 +1204,125 @@ export class PreTrainedModel extends Callable { /** @type {[bigint][]} */ const generated_input_ids = []; - // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv - // Loop over each batch - for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) { - const logs = next_tokens_scores[batch_idx]; - - const sampledTokens = await sampler(logs); - for (const [newTokenId, logProb] of sampledTokens) { - const bigint = BigInt(newTokenId); - // TODO: If branching, use previous beam as a starting point - // update generated ids, model inputs, and length for next step - scores[batch_idx] += logProb; - all_input_ids[batch_idx].push(bigint); - generated_input_ids.push([bigint]); - - // TODO: Support beam search - break; + + if (is_beam_search) { + // Beam search: score all candidates across beams, select top 2*num_beams per batch + const vocab_size = next_tokens_scores.dims.at(-1); + const total_beams = numInputs * num_beams; + + const all_next_tokens = []; + const all_next_indices = []; + const all_next_scores = []; + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + // Collect candidates across all beams for this batch item + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const candidates = []; + + for (let beam_idx = 0; beam_idx < num_beams; ++beam_idx) { + const flat_beam = batch_idx * num_beams + beam_idx; + const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); + const log_probs = log_softmax(logits_data); + + // Find top 2*num_beams candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + beam_candidates.push({ + score: beam_scores[flat_beam] + log_probs[v], + token: BigInt(v), + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * num_beams)); + } + + // Select top 2*num_beams candidates globally for this batch item + candidates.sort((a, b) => b.score - a.score); + const top_candidates = candidates.slice(0, 2 * num_beams); + + for (const c of top_candidates) { + all_next_tokens.push(c.token); + all_next_indices.push(c.beam_idx); + all_next_scores.push(c.score); + } + } + + // Let beam_scorer process: route EOS to hypotheses, select continuing beams + const { next_beam_scores, next_beam_tokens, next_beam_indices } = + beam_scorer.process(all_input_ids, beam_scores, all_next_tokens, all_next_indices, all_next_scores); + + // Reorder all_input_ids based on beam indices and append new tokens + const new_all_input_ids = []; + for (let i = 0; i < total_beams; ++i) { + new_all_input_ids.push([...all_input_ids[next_beam_indices[i]], next_beam_tokens[i]]); + generated_input_ids.push([next_beam_tokens[i]]); + } + all_input_ids.length = 0; + all_input_ids.push(...new_all_input_ids); + + // Update beam scores + for (let i = 0; i < total_beams; ++i) { + beam_scores[i] = next_beam_scores[i]; + } + + // Reorder KV cache to match beam reordering + model_inputs['past_key_values'] = await this._reorder_cache( + this.getPastKeyValues(outputs, model_inputs.past_key_values), + next_beam_indices, + ); + } else { + // Greedy / sample: existing behavior + for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) { + const logs = next_tokens_scores[batch_idx]; + + const sampledTokens = await sampler(logs); + for (const [newTokenId, logProb] of sampledTokens) { + const bigint = BigInt(newTokenId); + scores[batch_idx] += logProb; + all_input_ids[batch_idx].push(bigint); + generated_input_ids.push([bigint]); + break; + } } } + if (streamer) { streamer.put(generated_input_ids); } + // Check stopping conditions + if (is_beam_search && beam_scorer.is_done) { + break; + } const stop = prepared_stopping_criteria(all_input_ids); if (stop.every((x) => x)) { break; } - model_inputs = this._update_model_kwargs_for_generation({ - generated_input_ids, - outputs, - model_inputs, - is_encoder_decoder, - }); + if (is_beam_search) { + // For beam search, we already updated past_key_values during reordering. + // Now update the remaining model inputs. + model_inputs['input_ids'] = new Tensor('int64', generated_input_ids.flat(), [generated_input_ids.length, 1]); + + if (!is_encoder_decoder) { + model_inputs.attention_mask = cat( + [model_inputs.attention_mask, ones([model_inputs.attention_mask.dims[0], 1])], + 1, + ); + } + + // Force recreate position_ids in next iteration + model_inputs['position_ids'] = null; + } else { + model_inputs = this._update_model_kwargs_for_generation({ + generated_input_ids, + outputs, + model_inputs, + is_encoder_decoder, + }); + } } if (streamer) { @@ -1186,8 +1332,23 @@ export class PreTrainedModel extends Callable { // Retrieve and dispose all final past key values (including encoder attentions) const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true); - // TODO: ensure all_input_ids is padded correctly... - const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); + let sequences; + if (is_beam_search) { + // Finalize beam search: select best hypotheses + const best_sequences = beam_scorer.finalize(all_input_ids, beam_scores); + + // Pad sequences to equal length for tensor creation + const max_len = Math.max(...best_sequences.map(s => s.length)); + const pad_id = BigInt(generation_config.pad_token_id ?? 0); + const padded = best_sequences.map(s => { + const p = [...s]; + while (p.length < max_len) p.push(pad_id); + return p; + }); + sequences = new Tensor('int64', padded.flat(), [padded.length, max_len]); + } else { + sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); + } if (generation_config.return_dict_in_generate) { return { @@ -1195,9 +1356,6 @@ export class PreTrainedModel extends Callable { past_key_values, ...attentions, ...return_dict_items, - // TODO: - // scores, - // logits, }; } else { // Dispose all remaining tensors @@ -1253,6 +1411,36 @@ export class PreTrainedModel extends Callable { return pkvs; } + /** + * Reorder the past key values cache to match beam reordering. + * @param {Object} past_key_values The past key values object. + * @param {number[]} beam_indices Indices indicating which beam each new position came from. + * @returns {Promise} Reordered past key values. + */ + async _reorder_cache(past_key_values, beam_indices) { + if (!past_key_values) return past_key_values; + const reordered = Object.create(null); + for (const key in past_key_values) { + const tensor = past_key_values[key]; + const location = tensor.location; + if (key.includes('encoder')) { + // Encoder PKVs are shared across beams; pass through + reordered[key] = tensor; + } else { + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + reordered[key] = index_select(tensor, beam_indices); + } else { + reordered[key] = await index_select_async(tensor, beam_indices); + } + // Dispose old tensor if it owns GPU resources + if (location === 'gpu-buffer' || location === 'texture' || location === 'ml-tensor') { + tensor.dispose(); + } + } + } + return reordered; + } + /** * Returns an object containing attentions from the given model output object. * diff --git a/packages/transformers/src/transformers.js b/packages/transformers/src/transformers.js index b6083583e..a1f1c405d 100644 --- a/packages/transformers/src/transformers.js +++ b/packages/transformers/src/transformers.js @@ -45,6 +45,7 @@ export { PretrainedConfig, AutoConfig } from './configs.js'; export * from './generation/streamers.js'; export * from './generation/stopping_criteria.js'; export * from './generation/logits_process.js'; +export * from './generation/beam_search.js'; export { read_audio, RawAudio } from './utils/audio.js'; export { load_image, RawImage } from './utils/image.js'; diff --git a/packages/transformers/src/utils/tensor.js b/packages/transformers/src/utils/tensor.js index e0f34f305..83f53a27a 100644 --- a/packages/transformers/src/utils/tensor.js +++ b/packages/transformers/src/utils/tensor.js @@ -1344,6 +1344,73 @@ export function cat(tensors, dim = 0) { return new Tensor(resultType, result, resultDims); } +/** + * Select entries from a tensor along dimension 0 using an index array. + * Equivalent to `torch.index_select(tensor, 0, indices)`. + * @param {Tensor} tensor The source tensor. + * @param {number[]} indices Array of row indices to select. + * @returns {Tensor} A new tensor with selected rows. + */ +function _index_select_from_data(tensor, indices, data) { + const [batchSize, ...restDims] = tensor.dims; + const rowSize = restDims.length > 0 ? restDims.reduce((a, b) => a * b, 1) : 1; + const newBatchSize = indices.length; + + let out; + if (ArrayBuffer.isView(data) && typeof data.subarray === 'function' && typeof data.set === 'function') { + // @ts-ignore + out = new data.constructor(newBatchSize * rowSize); + for (let i = 0; i < newBatchSize; ++i) { + const srcOffset = indices[i] * rowSize; + const dstOffset = i * rowSize; + out.set(data.subarray(srcOffset, srcOffset + rowSize), dstOffset); + } + } else { + out = new Array(newBatchSize * rowSize); + for (let i = 0; i < newBatchSize; ++i) { + const srcOffset = indices[i] * rowSize; + const dstOffset = i * rowSize; + for (let j = 0; j < rowSize; ++j) { + out[dstOffset + j] = data[srcOffset + j]; + } + } + } + + return new Tensor(tensor.type, out, [newBatchSize, ...restDims]); +} + +export function index_select(tensor, indices) { + if (tensor.location !== 'cpu' && tensor.location !== 'cpu-pinned') { + throw new Error( + `index_select only supports CPU tensors. Got location: ${tensor.location}. ` + + 'Use index_select_async to handle GPU tensors.', + ); + } + + return _index_select_from_data(tensor, indices, tensor.data); +} + +/** + * Async version of index_select that supports GPU-backed tensors by downloading data to CPU. + * @param {Tensor} tensor The source tensor. + * @param {number[]} indices Array of row indices to select. + * @returns {Promise} A new tensor with selected rows (CPU-backed). + */ +export async function index_select_async(tensor, indices) { + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + return index_select(tensor, indices); + } + + // Download GPU/ML data to CPU. This will materialize CPU data and may release GPU resources. + const ort_tensor = tensor.ort_tensor; + if (!ort_tensor?.getData) { + throw new Error(`Tensor does not support getData() for location: ${tensor.location}.`); + } + + const data = await ort_tensor.getData(true); + return _index_select_from_data(tensor, indices, data); +} + /** * Stack an array of tensors along a specified dimension. * @param {Tensor[]} tensors The array of tensors to stack. diff --git a/packages/transformers/tests/utils/generation.test.js b/packages/transformers/tests/utils/generation.test.js index fc7575292..f10b09bc1 100644 --- a/packages/transformers/tests/utils/generation.test.js +++ b/packages/transformers/tests/utils/generation.test.js @@ -16,6 +16,8 @@ import { // Other TextStreamer, RawImage, + BeamSearchScorer, + BeamHypotheses, } from "../../src/transformers.js"; import { init, MAX_TEST_EXECUTION_TIME, MAX_MODEL_LOAD_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js"; @@ -545,3 +547,148 @@ describe("PKV caching", () => { }, MAX_MODEL_DISPOSE_TIME); }); }); + +describe("Beam search", () => { + describe(`encoder-decoder`, () => { + const model_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"; + const DUMMY_TEXT = "hello"; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForSeq2SeqLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "basic beam search (num_beams=4)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); // at least BOS + 1 token + expect(outputs.dims[1]).toBeLessThanOrEqual(6); // at most BOS + 5 tokens + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "num_return_sequences > 1", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_return_sequences: 3, + max_new_tokens: 5, + }); + // Output should have shape [3, seq_len] + expect(outputs.dims[0]).toEqual(3); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam search produces different output from greedy", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const greedy_outputs = await model.generate({ + ...inputs, + max_new_tokens: 5, + }); + const beam_outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Both should succeed and produce valid tensors + expect(greedy_outputs.dims[0]).toEqual(1); + expect(beam_outputs.dims[0]).toEqual(1); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "beam search rejects classifier-free guidance", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + await expect( + model.generate({ + ...inputs, + num_beams: 4, + guidance_scale: 2, + max_new_tokens: 5, + }), + ).rejects.toThrow("Classifier-free guidance (guidance_scale > 1) is not supported with beam search."); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe(`decoder-only`, () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + const DUMMY_TEXT = "hello"; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "basic beam search (num_beams=4)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + // seq_len = prompt_len + generated (up to max_new_tokens) + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); // at least BOS + prompt + 1 token + expect(outputs.dims[1]).toBeLessThanOrEqual(7); // at most BOS + prompt + 5 tokens + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "num_return_sequences > 1", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_return_sequences: 3, + max_new_tokens: 5, + }); + // Output should have shape [3, seq_len] + expect(outputs.dims[0]).toEqual(3); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("error cases", () => { + it("num_return_sequences > num_beams throws", () => { + expect(() => { + new BeamSearchScorer(1, 4, { num_return_sequences: 5 }); + }).toThrow("num_return_sequences (5) must be <= num_beams (4)"); + }); + }); +}); From cc27cb084d62fa38221dbfcd63919d319199ab5a Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Fri, 20 Feb 2026 15:53:38 -0700 Subject: [PATCH 2/6] Add diverse beam search support --- .../src/generation/beam_search.js | 12 +- .../transformers/src/models/modeling_utils.js | 301 +++++++++++++++--- .../tests/utils/generation.test.js | 38 +++ 3 files changed, 303 insertions(+), 48 deletions(-) diff --git a/packages/transformers/src/generation/beam_search.js b/packages/transformers/src/generation/beam_search.js index e6475db8f..661de7fdb 100644 --- a/packages/transformers/src/generation/beam_search.js +++ b/packages/transformers/src/generation/beam_search.js @@ -206,6 +206,16 @@ export class BeamSearchScorer { * @returns {bigint[][]} Best sequences, shape [batch_size * num_return_sequences, seq_len]. */ finalize(all_input_ids, beam_scores) { + return this.finalize_with_scores(all_input_ids, beam_scores).map((x) => x.tokens); + } + + /** + * Finalize: select best hypotheses and return scores. + * @param {bigint[][]} all_input_ids Final sequences, shape [batch_size * num_beams, seq_len]. + * @param {number[]} beam_scores Final cumulative scores. + * @returns {{tokens: bigint[], score: number}[]} Best sequences with scores. + */ + finalize_with_scores(all_input_ids, beam_scores) { // For each batch, ensure we have enough hypotheses for (let batch_idx = 0; batch_idx < this.batch_size; ++batch_idx) { const hyps = this._beam_hyps[batch_idx]; @@ -223,7 +233,7 @@ export class BeamSearchScorer { const sorted = [...this._beam_hyps[batch_idx].beams] .sort((a, b) => b.score - a.score); for (let i = 0; i < this.num_return_sequences; ++i) { - results.push(sorted[i].tokens); + results.push(sorted[i]); } } return results; diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 07275fa49..b468dfefc 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1082,16 +1082,37 @@ export class PreTrainedModel extends Callable { const numInputs = model_inputs[model_input_name].dims.at(0); const num_beams = generation_config.num_beams; + const num_beam_groups = generation_config.num_beam_groups; const is_beam_search = num_beams > 1; + const is_group_beam_search = is_beam_search && num_beam_groups > 1; let beam_scorer = null; + let beam_scorers = null; /** @type {number[]} */ let beam_scores; if (is_beam_search) { // Validate beam search configuration - if (generation_config.num_beam_groups > 1) { - throw new Error('Diverse beam search (num_beam_groups > 1) is not yet supported.'); + if (generation_config.num_return_sequences > num_beams) { + throw new Error( + `num_return_sequences (${generation_config.num_return_sequences}) must be <= num_beams (${num_beams}).`, + ); + } + if (num_beam_groups < 1) { + throw new Error(`num_beam_groups must be >= 1, but got ${num_beam_groups}.`); + } + if (num_beam_groups > num_beams) { + throw new Error( + `num_beam_groups (${num_beam_groups}) must be <= num_beams (${num_beams}).`, + ); + } + if (num_beams % num_beam_groups !== 0) { + throw new Error( + `num_beams (${num_beams}) must be divisible by num_beam_groups (${num_beam_groups}).`, + ); + } + if (is_group_beam_search && generation_config.do_sample) { + throw new Error('Diverse beam sampling (num_beam_groups > 1 with do_sample = true) is not yet supported.'); } if (generation_config.do_sample) { throw new Error('Beam sampling (num_beams > 1 with do_sample = true) is not yet supported.'); @@ -1103,13 +1124,24 @@ export class PreTrainedModel extends Callable { throw new Error('Streaming is not supported with beam search.'); } - beam_scorer = new BeamSearchScorer(numInputs, num_beams, { - length_penalty: generation_config.length_penalty, - early_stopping: generation_config.early_stopping, - num_return_sequences: generation_config.num_return_sequences, - eos_token_id: generation_config.eos_token_id, - pad_token_id: generation_config.pad_token_id, - }); + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + beam_scorers = Array.from({ length: num_beam_groups }, () => new BeamSearchScorer(numInputs, group_size, { + length_penalty: generation_config.length_penalty, + early_stopping: generation_config.early_stopping, + num_return_sequences: group_size, + eos_token_id: generation_config.eos_token_id, + pad_token_id: generation_config.pad_token_id, + })); + } else { + beam_scorer = new BeamSearchScorer(numInputs, num_beams, { + length_penalty: generation_config.length_penalty, + early_stopping: generation_config.early_stopping, + num_return_sequences: generation_config.num_return_sequences, + eos_token_id: generation_config.eos_token_id, + pad_token_id: generation_config.pad_token_id, + }); + } } const sampler = LogitsSampler.getSampler(generation_config); @@ -1153,7 +1185,10 @@ export class PreTrainedModel extends Callable { // Initialize beam_scores: first beam of each input = 0, rest = -1e9 beam_scores = new Array(numInputs * num_beams).fill(-1e9); for (let i = 0; i < numInputs; ++i) { - beam_scores[i * num_beams] = 0; + const group_size = num_beams / num_beam_groups; + for (let g = 0; g < num_beam_groups; ++g) { + beam_scores[i * num_beams + g * group_size] = 0; + } } } @@ -1210,48 +1245,172 @@ export class PreTrainedModel extends Callable { const vocab_size = next_tokens_scores.dims.at(-1); const total_beams = numInputs * num_beams; - const all_next_tokens = []; - const all_next_indices = []; - const all_next_scores = []; + /** @type {number[]} */ + let next_beam_scores; + /** @type {bigint[]} */ + let next_beam_tokens; + /** @type {number[]} */ + let next_beam_indices; + + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + next_beam_scores = new Array(total_beams).fill(0); + next_beam_tokens = new Array(total_beams).fill(0n); + next_beam_indices = new Array(total_beams).fill(0); + + const diversity_penalty = generation_config.diversity_penalty ?? 0; + /** @type {Map[]} */ + const prev_group_tokens = Array.from({ length: numInputs }, () => new Map()); + + for (let group_idx = 0; group_idx < num_beam_groups; ++group_idx) { + const group_offset = group_idx * group_size; + + const group_all_input_ids = new Array(numInputs * group_size); + const group_beam_scores = new Array(numInputs * group_size); + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + group_all_input_ids[local] = all_input_ids[global]; + group_beam_scores[local] = beam_scores[global]; + } + } - for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { - // Collect candidates across all beams for this batch item - /** @type {{score: number, token: bigint, beam_idx: number}[]} */ - const candidates = []; + const group_next_tokens = []; + const group_next_indices = []; + const group_next_scores = []; + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + // Collect candidates across all beams for this batch item + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const candidates = []; + + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const flat_beam = batch_idx * num_beams + group_offset + beam_idx; + const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); + const log_probs = log_softmax(logits_data); + const beam_score = beam_scores[flat_beam]; + + // Find top 2*group_size candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + const token = BigInt(v); + let score = beam_score + log_probs[v]; + const penalty_count = prev_group_tokens[batch_idx].get(token) ?? 0; + if (penalty_count > 0) { + score -= diversity_penalty * penalty_count; + } + beam_candidates.push({ + score, + token, + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * group_size)); + } + + // Select top 2*group_size candidates globally for this batch item + candidates.sort((a, b) => b.score - a.score); + const top_candidates = candidates.slice(0, 2 * group_size); + + for (const c of top_candidates) { + group_next_tokens.push(c.token); + group_next_indices.push(c.beam_idx); + group_next_scores.push(c.score); + } + } - for (let beam_idx = 0; beam_idx < num_beams; ++beam_idx) { - const flat_beam = batch_idx * num_beams + beam_idx; - const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); - const log_probs = log_softmax(logits_data); + // Let beam_scorer process: route EOS to hypotheses, select continuing beams + const { + next_beam_scores: group_next_beam_scores, + next_beam_tokens: group_next_beam_tokens, + next_beam_indices: group_next_beam_indices, + } = beam_scorers[group_idx].process( + group_all_input_ids, + group_beam_scores, + group_next_tokens, + group_next_indices, + group_next_scores, + ); - // Find top 2*num_beams candidates per beam for efficiency - /** @type {{score: number, token: bigint, beam_idx: number}[]} */ - const beam_candidates = []; - for (let v = 0; v < vocab_size; ++v) { - beam_candidates.push({ - score: beam_scores[flat_beam] + log_probs[v], - token: BigInt(v), - beam_idx, - }); + // Map group outputs into global beam arrays + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + next_beam_scores[global] = group_next_beam_scores[local]; + next_beam_tokens[global] = group_next_beam_tokens[local]; + const local_source = group_next_beam_indices[local]; + const source_in_batch = local_source - batch_idx * group_size; + next_beam_indices[global] = batch_idx * num_beams + group_offset + source_in_batch; + } + } + + // Track tokens selected by this group for diversity penalty + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const token_counts = prev_group_tokens[batch_idx]; + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const token = next_beam_tokens[global]; + token_counts.set(token, (token_counts.get(token) ?? 0) + 1); + } } - beam_candidates.sort((a, b) => b.score - a.score); - candidates.push(...beam_candidates.slice(0, 2 * num_beams)); } + } else { + const all_next_tokens = []; + const all_next_indices = []; + const all_next_scores = []; - // Select top 2*num_beams candidates globally for this batch item - candidates.sort((a, b) => b.score - a.score); - const top_candidates = candidates.slice(0, 2 * num_beams); + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + // Collect candidates across all beams for this batch item + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const candidates = []; + + for (let beam_idx = 0; beam_idx < num_beams; ++beam_idx) { + const flat_beam = batch_idx * num_beams + beam_idx; + const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); + const log_probs = log_softmax(logits_data); + + // Find top 2*num_beams candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + beam_candidates.push({ + score: beam_scores[flat_beam] + log_probs[v], + token: BigInt(v), + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * num_beams)); + } + + // Select top 2*num_beams candidates globally for this batch item + candidates.sort((a, b) => b.score - a.score); + const top_candidates = candidates.slice(0, 2 * num_beams); - for (const c of top_candidates) { - all_next_tokens.push(c.token); - all_next_indices.push(c.beam_idx); - all_next_scores.push(c.score); + for (const c of top_candidates) { + all_next_tokens.push(c.token); + all_next_indices.push(c.beam_idx); + all_next_scores.push(c.score); + } } - } - // Let beam_scorer process: route EOS to hypotheses, select continuing beams - const { next_beam_scores, next_beam_tokens, next_beam_indices } = - beam_scorer.process(all_input_ids, beam_scores, all_next_tokens, all_next_indices, all_next_scores); + // Let beam_scorer process: route EOS to hypotheses, select continuing beams + ({ + next_beam_scores, + next_beam_tokens, + next_beam_indices, + } = beam_scorer.process( + all_input_ids, + beam_scores, + all_next_tokens, + all_next_indices, + all_next_scores, + )); + } // Reorder all_input_ids based on beam indices and append new tokens const new_all_input_ids = []; @@ -1293,8 +1452,13 @@ export class PreTrainedModel extends Callable { } // Check stopping conditions - if (is_beam_search && beam_scorer.is_done) { - break; + if (is_beam_search) { + const done = is_group_beam_search + ? beam_scorers.every((scorer) => scorer.is_done) + : beam_scorer.is_done; + if (done) { + break; + } } const stop = prepared_stopping_criteria(all_input_ids); if (stop.every((x) => x)) { @@ -1334,8 +1498,51 @@ export class PreTrainedModel extends Callable { let sequences; if (is_beam_search) { - // Finalize beam search: select best hypotheses - const best_sequences = beam_scorer.finalize(all_input_ids, beam_scores); + /** @type {bigint[][]} */ + let best_sequences; + if (is_group_beam_search) { + const group_size = num_beams / num_beam_groups; + /** @type {{tokens: bigint[], score: number}[][]} */ + const grouped = Array.from({ length: numInputs }, () => []); + + for (let group_idx = 0; group_idx < num_beam_groups; ++group_idx) { + const group_offset = group_idx * group_size; + const group_all_input_ids = new Array(numInputs * group_size); + const group_beam_scores = new Array(numInputs * group_size); + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + for (let beam_idx = 0; beam_idx < group_size; ++beam_idx) { + const global = batch_idx * num_beams + group_offset + beam_idx; + const local = batch_idx * group_size + beam_idx; + group_all_input_ids[local] = all_input_ids[global]; + group_beam_scores[local] = beam_scores[global]; + } + } + + const group_results = beam_scorers[group_idx].finalize_with_scores( + group_all_input_ids, + group_beam_scores, + ); + + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const offset = batch_idx * group_size; + for (let i = 0; i < group_size; ++i) { + grouped[batch_idx].push(group_results[offset + i]); + } + } + } + + best_sequences = []; + for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { + const sorted = grouped[batch_idx].sort((a, b) => b.score - a.score); + for (let i = 0; i < generation_config.num_return_sequences; ++i) { + best_sequences.push(sorted[i].tokens); + } + } + } else { + // Finalize beam search: select best hypotheses + best_sequences = beam_scorer.finalize(all_input_ids, beam_scores); + } // Pad sequences to equal length for tensor creation const max_len = Math.max(...best_sequences.map(s => s.length)); diff --git a/packages/transformers/tests/utils/generation.test.js b/packages/transformers/tests/utils/generation.test.js index f10b09bc1..5dc83fa20 100644 --- a/packages/transformers/tests/utils/generation.test.js +++ b/packages/transformers/tests/utils/generation.test.js @@ -577,6 +577,25 @@ describe("Beam search", () => { MAX_TEST_EXECUTION_TIME, ); + it( + "diverse beam search (num_beam_groups=2)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_beam_groups: 2, + diversity_penalty: 0.5, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); + expect(outputs.dims[1]).toBeLessThanOrEqual(6); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "num_return_sequences > 1", async () => { @@ -663,6 +682,25 @@ describe("Beam search", () => { MAX_TEST_EXECUTION_TIME, ); + it( + "diverse beam search (num_beam_groups=2)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + num_beam_groups: 2, + diversity_penalty: 0.5, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); + expect(outputs.dims[1]).toBeLessThanOrEqual(7); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "num_return_sequences > 1", async () => { From 2b1f5aeca002c8e9e6a7ac0fabd7b5161d6ce596 Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Fri, 20 Feb 2026 15:55:11 -0700 Subject: [PATCH 3/6] Add beam sampling for beam search --- .../transformers/src/models/modeling_utils.js | 60 ++++++++++++++----- .../tests/utils/generation.test.js | 38 ++++++++++++ 2 files changed, 83 insertions(+), 15 deletions(-) diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index b468dfefc..937ad44ad 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1114,9 +1114,6 @@ export class PreTrainedModel extends Callable { if (is_group_beam_search && generation_config.do_sample) { throw new Error('Diverse beam sampling (num_beam_groups > 1 with do_sample = true) is not yet supported.'); } - if (generation_config.do_sample) { - throw new Error('Beam sampling (num_beams > 1 with do_sample = true) is not yet supported.'); - } if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { throw new Error('Classifier-free guidance (guidance_scale > 1) is not supported with beam search.'); } @@ -1362,6 +1359,8 @@ export class PreTrainedModel extends Callable { const all_next_tokens = []; const all_next_indices = []; const all_next_scores = []; + const do_beam_sample = generation_config.do_sample; + const top_k = generation_config.top_k; for (let batch_idx = 0; batch_idx < numInputs; ++batch_idx) { // Collect candidates across all beams for this batch item @@ -1371,20 +1370,51 @@ export class PreTrainedModel extends Callable { for (let beam_idx = 0; beam_idx < num_beams; ++beam_idx) { const flat_beam = batch_idx * num_beams + beam_idx; const logits_data = /** @type {Float32Array} */ (next_tokens_scores[flat_beam].data); - const log_probs = log_softmax(logits_data); + if (do_beam_sample) { + const beam_score = beam_scores[flat_beam]; + const beam_vocab = logits_data.length; + let k = top_k > 0 ? Math.min(top_k, beam_vocab) : beam_vocab; + const indices = Array.from({ length: beam_vocab }, (_, i) => i); + indices.sort((a, b) => logits_data[b] - logits_data[a]); + if (k < indices.length) { + indices.length = k; + } + const maxLogit = logits_data[indices[0]]; + const probs = new Array(indices.length); + let sum = 0; + for (let i = 0; i < indices.length; ++i) { + const val = Math.exp(logits_data[indices[i]] - maxLogit); + probs[i] = val; + sum += val; + } + for (let i = 0; i < probs.length; ++i) { + probs[i] /= sum; + } + for (let s = 0; s < num_beams; ++s) { + const sampledIndex = sampler.randomSelect(probs); + const tokenId = indices[sampledIndex]; + candidates.push({ + score: beam_score + Math.log(probs[sampledIndex]), + token: BigInt(tokenId), + beam_idx, + }); + } + } else { + const log_probs = log_softmax(logits_data); - // Find top 2*num_beams candidates per beam for efficiency - /** @type {{score: number, token: bigint, beam_idx: number}[]} */ - const beam_candidates = []; - for (let v = 0; v < vocab_size; ++v) { - beam_candidates.push({ - score: beam_scores[flat_beam] + log_probs[v], - token: BigInt(v), - beam_idx, - }); + // Find top 2*num_beams candidates per beam for efficiency + /** @type {{score: number, token: bigint, beam_idx: number}[]} */ + const beam_candidates = []; + for (let v = 0; v < vocab_size; ++v) { + beam_candidates.push({ + score: beam_scores[flat_beam] + log_probs[v], + token: BigInt(v), + beam_idx, + }); + } + beam_candidates.sort((a, b) => b.score - a.score); + candidates.push(...beam_candidates.slice(0, 2 * num_beams)); } - beam_candidates.sort((a, b) => b.score - a.score); - candidates.push(...beam_candidates.slice(0, 2 * num_beams)); } // Select top 2*num_beams candidates globally for this batch item diff --git a/packages/transformers/tests/utils/generation.test.js b/packages/transformers/tests/utils/generation.test.js index 5dc83fa20..3ded59a07 100644 --- a/packages/transformers/tests/utils/generation.test.js +++ b/packages/transformers/tests/utils/generation.test.js @@ -596,6 +596,25 @@ describe("Beam search", () => { MAX_TEST_EXECUTION_TIME, ); + it( + "beam sampling (do_sample=true)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + do_sample: true, + top_k: 10, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] (1 return sequence) + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(2); + expect(outputs.dims[1]).toBeLessThanOrEqual(6); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "num_return_sequences > 1", async () => { @@ -701,6 +720,25 @@ describe("Beam search", () => { MAX_TEST_EXECUTION_TIME, ); + it( + "beam sampling (do_sample=true)", + async () => { + const inputs = tokenizer(DUMMY_TEXT); + const outputs = await model.generate({ + ...inputs, + num_beams: 4, + do_sample: true, + top_k: 10, + max_new_tokens: 5, + }); + // Output should have shape [1, seq_len] + expect(outputs.dims[0]).toEqual(1); + expect(outputs.dims[1]).toBeGreaterThanOrEqual(3); + expect(outputs.dims[1]).toBeLessThanOrEqual(7); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "num_return_sequences > 1", async () => { From 1cad78c0e9e6f088ab0486262182c44260de43ab Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Tue, 24 Feb 2026 20:26:03 -0700 Subject: [PATCH 4/6] Refactor model input expansion for beam search to use index selection --- .../transformers/src/models/modeling_utils.js | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 937ad44ad..55624ac01 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1159,23 +1159,20 @@ export class PreTrainedModel extends Callable { all_input_ids.push(...expanded_ids); // Expand model_inputs tensors by repeating each row num_beams times + const expanded_indices = []; + for (let i = 0; i < numInputs; ++i) { + for (let b = 0; b < num_beams; ++b) { + expanded_indices.push(i); + } + } for (const key of [model_input_name, 'attention_mask', 'decoder_attention_mask', 'encoder_outputs']) { const tensor = model_inputs[key]; if (tensor instanceof Tensor) { - const [batch, ...rest] = tensor.dims; - const rowSize = rest.length > 0 ? rest.reduce((a, b) => a * b, 1) : 1; - // @ts-ignore - const newData = new tensor.data.constructor(batch * num_beams * rowSize); - for (let i = 0; i < batch; ++i) { - const srcOffset = i * rowSize; - for (let b = 0; b < num_beams; ++b) { - newData.set( - tensor.data.subarray(srcOffset, srcOffset + rowSize), - (i * num_beams + b) * rowSize, - ); - } + if (tensor.location === 'cpu' || tensor.location === 'cpu-pinned') { + model_inputs[key] = index_select(tensor, expanded_indices); + } else { + model_inputs[key] = await index_select_async(tensor, expanded_indices); } - model_inputs[key] = new Tensor(tensor.type, newData, [batch * num_beams, ...rest]); } } From b9d908e4823262a0b1228df5d4bd18d778f14615 Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Wed, 25 Feb 2026 11:26:04 -0700 Subject: [PATCH 5/6] Handle past_key_values for beam search in PreTrainedModel to prevent caching issues --- packages/transformers/src/models/modeling_utils.js | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 55624ac01..98ff442e4 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1585,9 +1585,17 @@ export class PreTrainedModel extends Callable { } if (generation_config.return_dict_in_generate) { + let past_key_values_for_return = past_key_values; + if (is_beam_search) { + console.warn( + 'Beam search does not return aligned past_key_values for finalized sequences. ' + + 'past_key_values will be null; re-generate without beam search if you need caching.', + ); + past_key_values_for_return = null; + } return { sequences, - past_key_values, + past_key_values: past_key_values_for_return, ...attentions, ...return_dict_items, }; From ccf1a43fcac524247c480994e2abd4f88b69a100 Mon Sep 17 00:00:00 2001 From: Dan Levy Date: Wed, 25 Feb 2026 12:18:49 -0700 Subject: [PATCH 6/6] Simplify past_key_values handling in PreTrainedModel for beam search --- packages/transformers/src/models/modeling_utils.js | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/packages/transformers/src/models/modeling_utils.js b/packages/transformers/src/models/modeling_utils.js index 98ff442e4..6bc72695d 100644 --- a/packages/transformers/src/models/modeling_utils.js +++ b/packages/transformers/src/models/modeling_utils.js @@ -1585,14 +1585,7 @@ export class PreTrainedModel extends Callable { } if (generation_config.return_dict_in_generate) { - let past_key_values_for_return = past_key_values; - if (is_beam_search) { - console.warn( - 'Beam search does not return aligned past_key_values for finalized sequences. ' + - 'past_key_values will be null; re-generate without beam search if you need caching.', - ); - past_key_values_for_return = null; - } + const past_key_values_for_return = is_beam_search ? null : past_key_values; return { sequences, past_key_values: past_key_values_for_return,